├── LFCA ├── LFData │ ├── list │ │ ├── test.txt │ │ └── train.txt │ ├── read_eslf.m │ ├── PrepareDate_test.m │ └── PrepareDate_train.m ├── DeviceParameters.py ├── LFDataset.py ├── LFDatatest.py ├── MainNet_pfe_test.py ├── MainNet_pfe_ver0.py ├── MainNet_pfe_pretrain.py ├── Functions.py ├── LFCA-Test.ipynb ├── RefNet_pfe_pretrain.py ├── RefNet_pfe_test.py ├── LFCA-PFE-preTrain.ipynb ├── LFCA-PFE-Train-Orignal.ipynb └── RefNet_pfe_ver0.py ├── LFDN ├── LFDataset.py ├── LFDatatest.py ├── MainNet_pfe_pretrain.py ├── MainNet_pfe_test.py ├── MainNet_pfe_ver0.py ├── Functions.py ├── LFDN-PFE-preTrain.ipynb ├── LFDN-Test.ipynb ├── LFDN-PFE-Train-Orignal.ipynb ├── RefNet_pfe_pretrain.py ├── RefNet_pfe_test.py └── RefNet_pfe_ver0.py └── README.md /LFCA/LFData/list/test.txt: -------------------------------------------------------------------------------- 1 | Cars.png 2 | Flower1.png 3 | Flower2.png 4 | IMG_1085_eslf.png 5 | IMG_1086_eslf.png 6 | IMG_1184_eslf.png 7 | IMG_1187_eslf.png 8 | IMG_1306_eslf.png 9 | IMG_1312_eslf.png 10 | IMG_1316_eslf.png 11 | IMG_1317_eslf.png 12 | IMG_1320_eslf.png 13 | IMG_1321_eslf.png 14 | IMG_1324_eslf.png 15 | IMG_1325_eslf.png 16 | IMG_1327_eslf.png 17 | IMG_1328_eslf.png 18 | IMG_1340_eslf.png 19 | IMG_1389_eslf.png 20 | IMG_1390_eslf.png 21 | IMG_1411_eslf.png 22 | IMG_1419_eslf.png 23 | IMG_1528_eslf.png 24 | IMG_1541_eslf.png 25 | IMG_1554_eslf.png 26 | IMG_1555_eslf.png 27 | IMG_1586_eslf.png 28 | IMG_1743_eslf.png 29 | Rock.png 30 | Seahorse.png -------------------------------------------------------------------------------- /LFCA/LFData/read_eslf.m: -------------------------------------------------------------------------------- 1 | function lf = read_eslf(read_path, an_org, an_new) 2 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 3 | % read [h,w,3,ah,aw] data from eslf data 4 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 5 | 6 | eslf = im2uint8(imread(read_path)); 7 | 8 | H = size(eslf,1) / an_org; 9 | H = floor(H/4)*4; 10 | W = size(eslf,2) / an_org; 11 | W = floor(W/4)*4; 12 | 13 | lf = zeros(H,W,3,an_org,an_org,'uint8'); 14 | 15 | for v = 1:an_org 16 | for u = 1:an_org 17 | sub = eslf(v:an_org:end, u:an_org:end, :); 18 | lf(:,:,:,v,u) = sub(1:H,1:W,:); 19 | end 20 | end 21 | an_crop = ceil((an_org - an_new) / 2 ); 22 | lf = lf(:,:,:,1+an_crop:an_new+an_crop,1+an_crop:an_new+an_crop); 23 | 24 | end 25 | -------------------------------------------------------------------------------- /LFCA/DeviceParameters.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import matplotlib.pyplot as plt 3 | import warnings 4 | warnings.filterwarnings("ignore") 5 | plt.ion() 6 | 7 | #Wrap a dataloader to move data to a device 8 | class DeviceDataLoader(): 9 | def __init__(self,dl,device): 10 | self.dl=dl 11 | self.device=device 12 | def __iter__(self): 13 | for b in self.dl: 14 | yield to_device(b,self.device) 15 | def __len__(self): 16 | return len(self.dl) 17 | 18 | 19 | #Move tensor(s) to chosen device 20 | def to_device(data, device): 21 | if isinstance(data,(list,tuple)): 22 | return [to_device(x,device) for x in data] 23 | return data.to(device,non_blocking=True) -------------------------------------------------------------------------------- /LFCA/LFDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import h5py 4 | import scipy.io as scio 5 | import numpy as np 6 | from Functions import ExtractPatch 7 | 8 | # Loading data 9 | class LFDataset(Dataset): 10 | """Light Field dataset.""" 11 | 12 | def __init__(self, opt): 13 | super(LFDataset, self).__init__() 14 | dataSet = h5py.File(opt.dataPath) #[c,y,x,v,u,ind] 15 | self.lfSet = dataSet.get('lf')[:].transpose(5,4,3,0,2,1) #[ind, u, v, c, x, y] 16 | self.lfSize = dataSet.get('lfSize')[:].transpose(1,0) #[ind, H,W] The spatial resolution of LF 17 | self.patchSize=opt.patchSize 18 | 19 | def __getitem__(self, idx): 20 | lf=self.lfSet[idx] #[u, v, c, x, y] 21 | H,W=self.lfSize[idx] #[H,W] 22 | lfPatch=ExtractPatch(lf, H, W, self.patchSize) #[u v c x y] 23 | lfPatch= torch.from_numpy(lfPatch.astype(np.float32)/255) 24 | sample = {'lf':lfPatch} 25 | return sample 26 | 27 | def __len__(self): 28 | return self.lfSet.shape[0] 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /LFCA/LFDatatest.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from torch.utils.data import Dataset 5 | import warnings 6 | import scipy.io as scio 7 | import numpy as np 8 | from Functions import ExtractPatch 9 | warnings.filterwarnings("ignore") 10 | plt.ion() 11 | 12 | # Loading data 13 | class LFDatatest(Dataset): 14 | """Light Field dataset.""" 15 | 16 | def __init__(self, opt): 17 | super(LFDatatest, self).__init__() 18 | dataSet = scio.loadmat(opt.testPath) 19 | self.LFSet = dataSet['lf'] #[ind, u, v, x, y, c] 20 | self.lfNameSet = dataSet['LF_name'] #[ind, 1] LF name represented by ASCII 21 | 22 | def __getitem__(self, idx): 23 | LF=self.LFSet[idx] #[u, v, x, y, c] 24 | lfName=''.join([chr(self.lfNameSet[idx][0][0][i]) for i in range(self.lfNameSet[idx][0][0].shape[0])]) 25 | 26 | LF= torch.from_numpy(LF.astype(np.float32)/255) 27 | sample = {'LF':LF,'lfName':lfName} 28 | return sample 29 | 30 | def __len__(self): 31 | return self.LFSet.shape[0] 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /LFDN/LFDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import scipy.io as scio 4 | import numpy as np 5 | from Functions import ExtractPatch 6 | 7 | # Loading data 8 | class LFDataset(Dataset): 9 | """Light Field dataset.""" 10 | def __init__(self, opt): 11 | super(LFDataset, self).__init__() 12 | dataSet = scio.loadmat(opt.dataPath) 13 | self.lfSet = dataSet['lf'].transpose(4,0,1,2,3) 14 | self.noiselfSet = dataSet['noilf_{}'.format(opt.noiselevel)].transpose(4,0,1,2,3) 15 | self.patchSize=opt.patchSize 16 | 17 | def __getitem__(self, idx): 18 | lf=self.lfSet[idx] 19 | noiself = self.noiselfSet[idx] 20 | H = self.lfSet.shape[3] 21 | W = self.lfSet.shape[4] 22 | lfPatch, noiselfPatch=ExtractPatch(lf, noiself, H, W, self.patchSize) 23 | lfPatch= torch.from_numpy(lfPatch.astype(np.float32)/255) 24 | noiselfPatch= torch.from_numpy(noiselfPatch.astype(np.float32)/255) 25 | sample = {'lf':lfPatch,'noiself':noiselfPatch} 26 | return sample 27 | 28 | def __len__(self): 29 | return self.lfSet.shape[0] 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /LFDN/LFDatatest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import scipy.io as scio 4 | import numpy as np 5 | from Functions import ExtractPatch 6 | 7 | # Loading data 8 | class LFDataset(Dataset): 9 | """Light Field dataset.""" 10 | def __init__(self, opt): 11 | super(LFDataset, self).__init__() 12 | dataSet = scio.loadmat(opt.dataPath) 13 | self.lfSet = dataSet['lf'].transpose(4,0,1,2,3) 14 | self.noiselfSet = dataSet['noilf_{}'.format(opt.noiselevel)].transpose(4,0,1,2,3) 15 | self.lfNameSet = dataSet['LF_name'] 16 | self.patchSize=opt.patchSize 17 | 18 | def __getitem__(self, idx): 19 | lf=self.lfSet[idx] 20 | noiself = self.noiselfSet[idx] 21 | lfPatch= torch.from_numpy(lf.astype(np.float32)/255) 22 | noiselfPatch= torch.from_numpy(noiself.astype(np.float32)/255) 23 | LF_name = ''.join([chr(self.lfNameSet[idx][0][0][i]) for i in range(self.lfNameSet[idx][0][0].shape[0])]) 24 | sample = {'lf':lfPatch,'noiself':noiselfPatch,'lfname':LF_name} 25 | return sample 26 | 27 | def __len__(self): 28 | return self.lfSet.shape[0] 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /LFCA/LFData/PrepareDate_test.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | % generate test data for LFCA 3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4 | 5 | clear;close all; 6 | 7 | %% params 8 | 9 | % Kalantari 10 | data_folder = '\path to the eslf files'; 11 | savepath = 'test_LFCA_Kalantari.mat'; 12 | an = 7; 13 | h = 372; 14 | w = 540; 15 | 16 | %% initilization 17 | lf = []; 18 | LF_name = {}; 19 | count = 0; 20 | data_list = dir(data_folder); 21 | data_list = data_list(3:end); 22 | 23 | %% generate data 24 | for k = 1:length(data_list) 25 | lfname = data_list(k).name; 26 | read_path = fullfile(data_folder,lfname); 27 | lf_gt_rgb = read_eslf(read_path, 14, an); %[h,w,3,ah,aw] 28 | lf_gt_rgb = lf_gt_rgb(1:h,1:w,:,:,:); 29 | lf = cat(6,lf,lf_gt_rgb); %[h,w,3,ah,aw,N] 30 | LF_name = cat(1,LF_name,abs(lfname(1:end-4))); %[N,1] 31 | end 32 | 33 | lf = permute(lf,[6,4,5,1,2,3]); %[h,w,3,ah,aw,N]==>[N,u,v,h,w,3] 34 | 35 | %% save data 36 | if exist(savepath,'file') 37 | fprintf('Warning: replacing existing file %s \n', savepath); 38 | delete(savepath); 39 | end 40 | 41 | save(savepath, 'lf', 'LF_name','-v6'); 42 | -------------------------------------------------------------------------------- /LFCA/LFData/PrepareDate_train.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | % generate training data for LFCA 3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4 | 5 | clear; close all; 6 | 7 | %% path 8 | data_folder = 'Y:\LF_Dataset\Dataset_kalantari_SIG2016\SIGGRAPHAsia16_ViewSynthesis_Trainingset'; 9 | savepath = 'train_LFCA_Kalantari.mat'; 10 | an = 7; 11 | 12 | %%initilization 13 | lf = zeros(an,an,600,600,3,'uint8'); 14 | lfSize = zeros(2,1,'uint16'); 15 | count = 0; 16 | 17 | %% read datasets 18 | data_list = dir(data_folder); 19 | data_list = data_list(3:end); 20 | 21 | 22 | %% read lfs 23 | for i_lf = 1:length(data_list) 24 | lfname = data_list(i_lf).name; 25 | read_path = fullfile(data_folder,lfname); 26 | lf_rgb = read_eslf(read_path,14,an); 27 | 28 | H = size(lf_rgb,1); 29 | W = size(lf_rgb,2); 30 | 31 | count = count +1; 32 | lf(:,:,1:H,1:W,:,count) = permute(lf_rgb,[4,5,1,2,3]); 33 | lfSize(:,count)=[H,W]; 34 | end 35 | 36 | %% generate data 37 | order = randperm(count); 38 | lf = permute(lf(:, :, :, :, :, order),[6,1,2,3,4,5]); %[u,v,x,y,c,N] -> [N,u,v,x,y,c] 39 | lfSize = permute(lfSize(:,order),[2,1]); %[N,2] 40 | 41 | %% writing to mat 42 | if exist(savepath,'file') 43 | fprintf('Warning: replacing existing file %s \n', savepath); 44 | delete(savepath); 45 | end 46 | 47 | save(savepath,'lf','lfSize','-v7.3'); 48 | 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LFCA-CR-NET 2 | Repository for International Journal of Computer Vision paper "Probabilistic-based Feature Learning of Light Fields for Compressive Imaging and Denoising" 3 | 4 | https://trebuchet.public.springernature.app/get_content/6401293d-1745-4b94-9f65-b206fb1b5f3e?utm_source=rct_congratemailt&utm_medium=email&utm_campaign=nonoa_20240112&utm_content=10.1007/s11263-023-01974-9 5 | 6 | # Dataset 7 | You can download the dataset for LF denosing from 8 | 9 | https://drive.google.com/drive/folders/1emg1Ll2KPmqkMGuEvLOp7fA6i_kEBYtM?usp=sharing 10 | 11 | For the compressive LF imaging, we provide MATLAB code for preparing the training and test data. Please first download light field datasets, and put them into corresponding folders in LFData. 12 | 13 | 14 | # Requirements 15 | - Python 3.8.3 16 | - PyTorch 1.13.1 17 | 18 | 19 | # Training 20 | 21 | For the tasks of compressive imaging and denoising, it is necessary to pretrain a model initially and then utilize this pretrained model to obtain the final Probabilistic-based Feature Embedding (PFE) model. Let's consider the LFCA task as an example. 22 | 23 | First, pretrain the model by running 'LFCA-PFE-preTrain.ipynb'. 24 | 25 | Next, train the PFE model by running 'LFCA-PFE-Train-Original.ipynb'. 26 | 27 | You will need to configure the training data path and set the learning rate according to the type of data you are working with. The batch size can also be adjusted as needed. 28 | 29 | # Testing 30 | 31 | Set the testing configureation. 32 | 33 | And run LFCA-Test.ipynb or LFDN-Test.ipynb 34 | -------------------------------------------------------------------------------- /LFCA/LFData/list/train.txt: -------------------------------------------------------------------------------- 1 | IMG_0288_eslf 2 | IMG_0289_eslf 3 | IMG_0359_eslf 4 | IMG_0360_eslf 5 | IMG_0466_eslf 6 | IMG_0518_eslf 7 | IMG_0575_eslf 8 | IMG_0596_eslf 9 | IMG_0681_eslf 10 | IMG_0780_eslf 11 | IMG_0820_eslf 12 | IMG_1016_eslf 13 | IMG_1410_eslf 14 | IMG_1413_eslf 15 | IMG_1414_eslf 16 | IMG_1415_eslf 17 | IMG_1416_eslf 18 | IMG_1419_eslf 19 | IMG_1469_eslf 20 | IMG_1470_eslf 21 | IMG_1471_eslf 22 | IMG_1473_eslf 23 | IMG_1474_eslf 24 | IMG_1475_eslf 25 | IMG_1476_eslf 26 | IMG_1477_eslf 27 | IMG_1478_eslf 28 | IMG_1479_eslf 29 | IMG_1480_eslf 30 | IMG_1481_eslf 31 | IMG_1482_eslf 32 | IMG_1483_eslf 33 | IMG_1484_eslf 34 | IMG_1486_eslf 35 | IMG_1487_eslf 36 | IMG_1490_eslf 37 | IMG_1499_eslf 38 | IMG_1500_eslf 39 | IMG_1501_eslf 40 | IMG_1504_eslf 41 | IMG_1505_eslf 42 | IMG_1508_eslf 43 | IMG_1509_eslf 44 | IMG_1510_eslf 45 | IMG_1511_eslf 46 | IMG_1513_eslf 47 | IMG_1514_eslf 48 | IMG_1516_eslf 49 | IMG_1522_eslf 50 | IMG_1523_eslf 51 | IMG_1527_eslf 52 | IMG_1529_eslf 53 | IMG_1530_eslf 54 | IMG_1534_eslf 55 | IMG_1538_eslf 56 | IMG_1544_eslf 57 | IMG_1546_eslf 58 | IMG_1547_eslf 59 | IMG_1560_eslf 60 | IMG_1565_eslf 61 | IMG_1566_eslf 62 | IMG_1567_eslf 63 | IMG_1568_eslf 64 | IMG_1580_eslf 65 | IMG_1582_eslf 66 | IMG_1583_eslf 67 | IMG_1594_eslf 68 | IMG_1595_eslf 69 | IMG_1598_eslf 70 | IMG_1599_eslf 71 | IMG_1600_eslf 72 | IMG_1601_eslf 73 | bikes_11_eslf 74 | bikes_12_eslf 75 | bikes_13_eslf 76 | bikes_20_eslf 77 | bikes_4_eslf 78 | bikes_9_eslf 79 | buildings_10_eslf 80 | buildings_3_eslf 81 | buildings_6_eslf 82 | cars_21_eslf 83 | cars_36_eslf 84 | cars_37_eslf 85 | cars_38_eslf 86 | cars_39_eslf 87 | cars_44_eslf 88 | cars_50_eslf 89 | flowers_plants_17_eslf 90 | flowers_plants_23_eslf 91 | flowers_plants_24_eslf 92 | flowers_plants_28_eslf 93 | flowers_plants_42_eslf 94 | flowers_plants_62_eslf 95 | general_15_eslf 96 | general_19_eslf 97 | general_31_eslf 98 | general_4_eslf 99 | general_9_eslf 100 | occlusions_24_eslf 101 | -------------------------------------------------------------------------------- /LFDN/MainNet_pfe_pretrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import warnings 6 | import scipy.io as scio 7 | from RefNet_pfe_pretrain import RefNet 8 | from Functions import weights_init 9 | 10 | warnings.filterwarnings("ignore") 11 | plt.ion() 12 | 13 | class StageBlock(torch.nn.Module): 14 | def __init__(self, opt, bs): 15 | super(StageBlock,self).__init__() 16 | # Regularization sub-network 17 | self.refnet=RefNet(opt,bs) 18 | self.refnet.apply(weights_init) 19 | def forward(self, mResidual,sampleLF,epoch): 20 | lfRedisual = self.refnet(mResidual,sampleLF,epoch) #[b,uv,c,x,y] 21 | return lfRedisual 22 | 23 | 24 | def CascadeStages(block, opt, bs): 25 | blocks = torch.nn.ModuleList([]) 26 | for _ in range(opt.stageNum): 27 | blocks.append(block(opt, bs)) 28 | return blocks 29 | 30 | class MainNet(torch.nn.Module): 31 | def __init__(self,opt): 32 | super(MainNet,self).__init__() 33 | self.kernelSize=[opt.angResolution,opt.angResolution] 34 | self.angularnum = opt.angResolution 35 | # global average 36 | self.avglf = torch.nn.AvgPool2d(kernel_size=self.kernelSize,stride = None, padding = 0) 37 | self.proj_init = torch.nn.Conv2d(in_channels=1,out_channels=7,kernel_size=self.kernelSize,bias=False) 38 | torch.nn.init.xavier_uniform_(self.proj_init.weight.data) 39 | self.initialRefnet=RefNet(opt, True) 40 | self.initialRefnet.apply(weights_init) 41 | # Iterative stages 42 | self.iterativeRecon = CascadeStages(StageBlock, opt, False) 43 | 44 | 45 | def forward(self, noiself, epoch): 46 | b,u,v,x,y=noiself.shape 47 | avgLF = self.avglf(noiself.permute(0,3,4,1,2).reshape(b,x*y,u,v)) 48 | avgLF = avgLF.reshape(b,x,y,1).permute(0,3,1,2) 49 | projLF = self.proj_init(noiself.permute(0,3,4,1,2).reshape(b*x*y,1,u,v)) 50 | projLF = projLF.reshape(b,x,y,7).permute(0,3,1,2) 51 | sampleLF = torch.cat([avgLF,projLF],1) 52 | out = self.initialRefnet(noiself,sampleLF,epoch) 53 | # Reconstructing iteratively 54 | for stage in self.iterativeRecon: 55 | avgLF = self.avglf(out.permute(0,3,4,1,2).reshape(b,x*y,u,v)) 56 | avgLF = avgLF.reshape(b,x,y,1).permute(0,3,1,2) 57 | projLF = self.proj_init(out.permute(0,3,4,1,2).reshape(b*x*y,1,u,v)) 58 | projLF = projLF.reshape(b,x,y,7).permute(0,3,1,2) 59 | sampleLF = torch.cat([avgLF,projLF],1) 60 | out = out + stage(out,sampleLF,epoch) 61 | return out 62 | -------------------------------------------------------------------------------- /LFDN/MainNet_pfe_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import warnings 6 | import scipy.io as scio 7 | from RefNet_pfe_test import RefNet 8 | from Functions import weights_init 9 | 10 | warnings.filterwarnings("ignore") 11 | plt.ion() 12 | 13 | class StageBlock(torch.nn.Module): 14 | def __init__(self, opt, bs): 15 | super(StageBlock,self).__init__() 16 | # Regularization sub-network 17 | self.refnet=RefNet(opt,bs) 18 | self.refnet.apply(weights_init) 19 | def forward(self, mResidual,sampleLF,epoch): 20 | lfRedisual = self.refnet(mResidual,sampleLF,epoch) 21 | return lfRedisual 22 | 23 | 24 | def CascadeStages(block, opt, bs): 25 | blocks = torch.nn.ModuleList([]) 26 | for _ in range(opt.stageNum): 27 | blocks.append(block(opt, bs)) 28 | return blocks 29 | 30 | class MainNet(torch.nn.Module): 31 | def __init__(self,opt): 32 | super(MainNet,self).__init__() 33 | self.kernelSize=[opt.angResolution,opt.angResolution] 34 | self.angularnum = opt.angResolution 35 | # global average 36 | self.avglf = torch.nn.AvgPool2d(kernel_size=self.kernelSize,stride = None, padding = 0) 37 | self.proj_init = torch.nn.Conv2d(in_channels=1,out_channels=7,kernel_size=self.kernelSize,bias=False) 38 | torch.nn.init.xavier_uniform_(self.proj_init.weight.data) 39 | self.initialRefnet=RefNet(opt, True) 40 | self.initialRefnet.apply(weights_init) 41 | # Iterative stages 42 | self.iterativeRecon = CascadeStages(StageBlock, opt, False) 43 | 44 | def forward(self, noiself, epoch): 45 | b,u,v,x,y=noiself.shape 46 | avgLF = self.avglf(noiself.permute(0,3,4,1,2).reshape(b,x*y,u,v)) 47 | avgLF = avgLF.reshape(b,x,y,1).permute(0,3,1,2) 48 | projLF = self.proj_init(noiself.permute(0,3,4,1,2).reshape(b*x*y,1,u,v)) 49 | projLF = projLF.reshape(b,x,y,7).permute(0,3,1,2) 50 | sampleLF = torch.cat([avgLF,projLF],1) 51 | # Initialize LF 52 | out = self.initialRefnet(noiself,sampleLF,epoch) 53 | 54 | # Reconstructing iteratively 55 | for stage in self.iterativeRecon: 56 | avgLF = self.avglf(out.permute(0,3,4,1,2).reshape(b,x*y,u,v)) 57 | avgLF = avgLF.reshape(b,x,y,1).permute(0,3,1,2) 58 | projLF = self.proj_init(out.permute(0,3,4,1,2).reshape(b*x*y,1,u,v)) 59 | projLF = projLF.reshape(b,x,y,7).permute(0,3,1,2) 60 | sampleLF = torch.cat([avgLF,projLF],1) 61 | out = out + stage(out,sampleLF,epoch) 62 | return out 63 | -------------------------------------------------------------------------------- /LFDN/MainNet_pfe_ver0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import warnings 6 | import scipy.io as scio 7 | from RefNet_pfe_ver0 import RefNet 8 | from Functions import weights_init 9 | 10 | warnings.filterwarnings("ignore") 11 | plt.ion() 12 | 13 | class StageBlock(torch.nn.Module): 14 | def __init__(self, opt, bs): 15 | super(StageBlock,self).__init__() 16 | # Regularization sub-network 17 | self.refnet=RefNet(opt,bs) 18 | self.refnet.apply(weights_init) 19 | 20 | def forward(self, mResidual,sampleLF,epoch): 21 | lfRedisual = self.refnet(mResidual,sampleLF,epoch) #[b,uv,c,x,y] 22 | return lfRedisual 23 | 24 | def CascadeStages(block, opt, bs): 25 | blocks = torch.nn.ModuleList([]) 26 | for _ in range(opt.stageNum): 27 | blocks.append(block(opt, bs)) 28 | return blocks 29 | 30 | class MainNet(torch.nn.Module): 31 | def __init__(self,opt): 32 | super(MainNet,self).__init__() 33 | self.kernelSize=[opt.angResolution,opt.angResolution] 34 | self.angularnum = opt.angResolution 35 | # global average 36 | self.avglf = torch.nn.AvgPool2d(kernel_size=self.kernelSize,stride = None, padding = 0) 37 | self.proj_init = torch.nn.Conv2d(in_channels=1,out_channels=7,kernel_size=self.kernelSize,bias=False) 38 | torch.nn.init.xavier_uniform_(self.proj_init.weight.data) 39 | 40 | self.initialRefnet=RefNet(opt, True) 41 | self.initialRefnet.apply(weights_init) 42 | # Iterative stages 43 | self.iterativeRecon = CascadeStages(StageBlock, opt, False) 44 | 45 | def forward(self, noiself, epoch): 46 | b,u,v,x,y=noiself.shape 47 | avgLF = self.avglf(noiself.permute(0,3,4,1,2).reshape(b,x*y,u,v)) 48 | avgLF = avgLF.reshape(b,x,y,1).permute(0,3,1,2) 49 | projLF = self.proj_init(noiself.permute(0,3,4,1,2).reshape(b*x*y,1,u,v)) 50 | projLF = projLF.reshape(b,x,y,7).permute(0,3,1,2) 51 | sampleLF = torch.cat([avgLF,projLF],1) 52 | # Initialize LF 53 | out = self.initialRefnet(noiself,sampleLF,epoch) 54 | 55 | # Reconstructing iteratively 56 | for stage in self.iterativeRecon: 57 | avgLF = self.avglf(out.permute(0,3,4,1,2).reshape(b,x*y,u,v)) 58 | avgLF = avgLF.reshape(b,x,y,1).permute(0,3,1,2) 59 | projLF = self.proj_init(out.permute(0,3,4,1,2).reshape(b*x*y,1,u,v)) 60 | projLF = projLF.reshape(b,x,y,7).permute(0,3,1,2) 61 | sampleLF = torch.cat([avgLF,projLF],1) 62 | out = out + stage(out,sampleLF,epoch) 63 | return out 64 | -------------------------------------------------------------------------------- /LFCA/MainNet_pfe_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import matplotlib.pyplot as plt 4 | import warnings 5 | import scipy.io as scio 6 | from RefNet_pfe_test import RefNet 7 | from Functions import weights_init 8 | 9 | warnings.filterwarnings("ignore") 10 | plt.ion() 11 | 12 | class StageBlock(torch.nn.Module): 13 | def __init__(self, opt, bs): 14 | super(StageBlock,self).__init__() 15 | # Regularization sub-network 16 | self.refnet=RefNet(opt,bs) 17 | self.refnet.apply(weights_init) 18 | 19 | def forward(self, mResidual, epoch): 20 | lfRedisual = self.refnet(mResidual, epoch) 21 | return lfRedisual 22 | 23 | 24 | def CascadeStages(block, opt, bs): 25 | blocks = torch.nn.ModuleList([]) 26 | for _ in range(opt.stageNum): 27 | blocks.append(block(opt, bs)) 28 | return blocks 29 | 30 | 31 | # Main Network construction 32 | class MainNet(torch.nn.Module): 33 | def __init__(self,opt): 34 | super(MainNet,self).__init__() 35 | # self.channelNum = opt.channelNum 36 | if opt.measurementNum == 1: 37 | self.kernelSize=[opt.angResolution,opt.angResolution] 38 | if opt.measurementNum == 2: 39 | self.kernelSize=[opt.angResolution,opt.angResolution] 40 | if opt.measurementNum == 4: 41 | self.kernelSize=[opt.angResolution,opt.angResolution] 42 | 43 | # Shot layer 44 | self.proj_init=torch.nn.Conv2d(in_channels=opt.channelNum,out_channels=opt.measurementNum,kernel_size=self.kernelSize,bias=False) 45 | torch.nn.init.xavier_uniform_(self.proj_init.weight.data) 46 | # Initialize LF from measurements 47 | self.recon=torch.nn.ConvTranspose2d(in_channels=opt.channelNum,out_channels=opt.channelNum,kernel_size=self.kernelSize,bias=False) 48 | torch.nn.init.xavier_uniform_(self.recon.weight.data) 49 | self.initialRefnet=RefNet(opt, True) 50 | self.initialRefnet.apply(weights_init) 51 | # Iterative stages 52 | self.iterativeRecon = CascadeStages(StageBlock, opt, False) 53 | 54 | 55 | def forward(self, lf, epoch): 56 | b,u,v,c,x,y=lf.shape 57 | # Shot 58 | degLF=self.proj_init(lf.permute(0,4,5,3,1,2).reshape(b*x*y,c,u,v)) 59 | _,m,_,_ = degLF.shape 60 | degLF = degLF.reshape(b,x,y,m,c).permute(0,3,4,1,2) 61 | 62 | # Initialize LF from measurements 63 | initLF = self.initialRefnet(degLF, epoch) 64 | out=initLF 65 | # Reconstructing iteratively 66 | for stage in self.iterativeRecon: 67 | mResidual = degLF -self.proj_init(out.reshape(b,u,v,c,x,y).permute(0,4,5,3,1,2).reshape(b*x*y,c,u,v)).reshape(b,x,y,m,c).permute(0,3,4,1,2) 68 | out = out + stage(mResidual,epoch) 69 | return out.reshape(b,u,v,c,x,y) -------------------------------------------------------------------------------- /LFCA/MainNet_pfe_ver0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import matplotlib.pyplot as plt 4 | import warnings 5 | import scipy.io as scio 6 | from RefNet_pfe_ver0 import RefNet 7 | from Functions import weights_init 8 | 9 | warnings.filterwarnings("ignore") 10 | plt.ion() 11 | 12 | class StageBlock(torch.nn.Module): 13 | def __init__(self, opt, bs): 14 | super(StageBlock,self).__init__() 15 | # Regularization sub-network 16 | self.refnet=RefNet(opt,bs) 17 | self.refnet.apply(weights_init) 18 | 19 | def forward(self, mResidual, epoch): 20 | lfRedisual = self.refnet(mResidual, epoch) 21 | return lfRedisual 22 | 23 | 24 | def CascadeStages(block, opt, bs): 25 | blocks = torch.nn.ModuleList([]) 26 | for _ in range(opt.stageNum): 27 | blocks.append(block(opt, bs)) 28 | return blocks 29 | 30 | 31 | # Main Network construction 32 | class MainNet(torch.nn.Module): 33 | def __init__(self,opt): 34 | super(MainNet,self).__init__() 35 | # self.channelNum = opt.channelNum 36 | if opt.measurementNum == 1: 37 | self.kernelSize=[opt.angResolution,opt.angResolution] 38 | if opt.measurementNum == 2: 39 | self.kernelSize=[opt.angResolution,opt.angResolution] 40 | if opt.measurementNum == 4: 41 | self.kernelSize=[opt.angResolution,opt.angResolution] 42 | 43 | # Shot layer 44 | self.proj_init=torch.nn.Conv2d(in_channels=opt.channelNum,out_channels=opt.measurementNum,kernel_size=self.kernelSize,bias=False) 45 | torch.nn.init.xavier_uniform_(self.proj_init.weight.data) 46 | # Initialize LF from measurements 47 | self.recon=torch.nn.ConvTranspose2d(in_channels=opt.channelNum,out_channels=opt.channelNum,kernel_size=self.kernelSize,bias=False) 48 | torch.nn.init.xavier_uniform_(self.recon.weight.data) 49 | self.initialRefnet=RefNet(opt, True) 50 | self.initialRefnet.apply(weights_init) 51 | # Iterative stages 52 | self.iterativeRecon = CascadeStages(StageBlock, opt, False) 53 | 54 | 55 | def forward(self, lf, epoch): 56 | b,u,v,c,x,y=lf.shape 57 | # Shot 58 | degLF=self.proj_init(lf.permute(0,4,5,3,1,2).reshape(b*x*y,c,u,v)) 59 | _,m,_,_ = degLF.shape 60 | degLF = degLF.reshape(b,x,y,m,c).permute(0,3,4,1,2) 61 | 62 | # Initialize LF from measurements 63 | initLF = self.initialRefnet(degLF, epoch) 64 | out=initLF 65 | # Reconstructing iteratively 66 | for stage in self.iterativeRecon: 67 | mResidual = degLF -self.proj_init(out.reshape(b,u,v,c,x,y).permute(0,4,5,3,1,2).reshape(b*x*y,c,u,v)).reshape(b,x,y,m,c).permute(0,3,4,1,2) 68 | out = out + stage(mResidual,epoch) 69 | return out.reshape(b,u,v,c,x,y) 70 | 71 | -------------------------------------------------------------------------------- /LFCA/MainNet_pfe_pretrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import matplotlib.pyplot as plt 4 | import warnings 5 | import scipy.io as scio 6 | from RefNet_pfe_pretrain import RefNet 7 | from Functions import weights_init 8 | 9 | warnings.filterwarnings("ignore") 10 | plt.ion() 11 | 12 | class StageBlock(torch.nn.Module): 13 | def __init__(self, opt, bs): 14 | super(StageBlock,self).__init__() 15 | # Regularization sub-network 16 | self.refnet=RefNet(opt,bs) 17 | self.refnet.apply(weights_init) 18 | 19 | def forward(self, mResidual, epoch): 20 | lfRedisual = self.refnet(mResidual, epoch) 21 | return lfRedisual 22 | 23 | 24 | def CascadeStages(block, opt, bs): 25 | blocks = torch.nn.ModuleList([]) 26 | for _ in range(opt.stageNum): 27 | blocks.append(block(opt, bs)) 28 | return blocks 29 | 30 | # Main Network construction 31 | class MainNet(torch.nn.Module): 32 | def __init__(self,opt): 33 | super(MainNet,self).__init__() 34 | if opt.measurementNum == 1: 35 | self.kernelSize=[opt.angResolution,opt.angResolution] 36 | if opt.measurementNum == 2: 37 | self.kernelSize=[opt.angResolution,opt.angResolution] 38 | if opt.measurementNum == 4: 39 | self.kernelSize=[opt.angResolution,opt.angResolution] 40 | 41 | # Shot layer 42 | ##self.proj_init=torch.nn.Conv2d(in_channels=opt.channelNum,out_channels=opt.channelNum,kernel_size=self.kernelSize,bias=False) 43 | self.proj_init=torch.nn.Conv2d(in_channels=opt.channelNum,out_channels=opt.measurementNum,kernel_size=self.kernelSize,bias=False) 44 | torch.nn.init.xavier_uniform_(self.proj_init.weight.data) 45 | 46 | # Initialize LF from measurements 47 | self.recon=torch.nn.ConvTranspose2d(in_channels=opt.channelNum,out_channels=opt.channelNum,kernel_size=self.kernelSize,bias=False) 48 | torch.nn.init.xavier_uniform_(self.recon.weight.data) 49 | self.initialRefnet=RefNet(opt, True) 50 | self.initialRefnet.apply(weights_init) 51 | 52 | # Iterative stages 53 | self.iterativeRecon = CascadeStages(StageBlock, opt, False) 54 | 55 | 56 | def forward(self, lf, epoch): 57 | b,u,v,c,x,y=lf.shape 58 | # Shot 59 | degLF=self.proj_init(lf.permute(0,4,5,3,1,2).reshape(b*x*y,c,u,v)) 60 | _,m,_,_ = degLF.shape 61 | degLF = degLF.reshape(b,x,y,m,c).permute(0,3,4,1,2) #[b,m,c,x,y] m = udeg*vdeg 62 | # Initialize LF from measurements 63 | initLF = self.initialRefnet(degLF, epoch) #[buv,c,x,y] 64 | out=initLF 65 | # Reconstructing iteratively 66 | for stage in self.iterativeRecon: 67 | mResidual = degLF -self.proj_init(out.reshape(b,u,v,c,x,y).permute(0,4,5,3,1,2).reshape(b*x*y,c,u,v)).reshape(b,x,y,m,c).permute(0,3,4,1,2) 68 | out = out + stage(mResidual,epoch) 69 | return out.reshape(b,u,v,c,x,y) 70 | -------------------------------------------------------------------------------- /LFCA/Functions.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from torch.autograd import Variable 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import torch.nn.functional as F 6 | import math 7 | import warnings 8 | from scipy import sparse 9 | import random 10 | import numpy as np 11 | 12 | warnings.filterwarnings("ignore") 13 | plt.ion() 14 | 15 | 16 | 17 | 18 | #Initiate parameters in model 19 | def weights_init(m): 20 | classname = m.__class__.__name__ 21 | if classname.find('Conv2d') != -1: 22 | torch.nn.init.xavier_uniform_(m.weight.data) 23 | #torch.nn.init.constant_(m.bias.data, 0.0) 24 | elif classname.find('ConvTranspose2d') != -1: 25 | torch.nn.init.xavier_uniform_(m.weight.data) 26 | #torch.nn.init.constant_(m.bias.data, 0.0) 27 | if classname.find('Conv3d') != -1: 28 | torch.nn.init.xavier_uniform_(m.weight.data) 29 | #torch.nn.init.constant_(m.bias.data, 0.0) 30 | elif classname.find('ConvTranspose3d') != -1: 31 | torch.nn.init.xavier_uniform_(m.weight.data) 32 | #torch.nn.init.constant_(m.bias.data, 0.0) 33 | elif classname.find('Linear') != -1: 34 | torch.nn.init.xavier_uniform_(m.weight.data) 35 | #torch.nn.init.constant_(m.bias.data, 0.0) 36 | 37 | def SetupSeed(seed): 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | np.random.seed(seed) 41 | random.seed(seed) 42 | torch.backends.cudnn.deterministic = True 43 | 44 | def ExtractPatch(lf, H, W, patchSize): 45 | indx=random.randrange(0,H-patchSize,8) 46 | indy=random.randrange(0,W-patchSize,8) 47 | indc=random.randint(0,2) 48 | 49 | lfPatch=lf[:,:,indc:indc+1, 50 | indx:indx+patchSize, 51 | indy:indy+patchSize] 52 | return lfPatch #[u v c x y] 53 | 54 | def ResizeLF(lf,scale_factor): 55 | u,v,x,y,c=lf.shape 56 | resizedLF=np.zeros((u,v,int(scale_factor*x),int(scale_factor*y),c),dtype=np.int) 57 | for ind_u in range(u): 58 | for ind_v in range(v): 59 | view=lf[ind_u,ind_v,:,:,:] 60 | resizedView=cv2.resize(view, (int(scale_factor*x),int(scale_factor*y)), interpolation=cv2.INTER_CUBIC) 61 | resizedLF[ind_u,ind_v,:,:,:]=resizedView.reshape(int(scale_factor*x),int(scale_factor*y),-1) 62 | return resizedLF 63 | 64 | 65 | def CropLF(lf,patchSize, overlap): #lf [b,u,v,c,x,y] 66 | b,u,v,c,x,y=lf.shape 67 | numX=0 68 | numY=0 69 | while (patchSize-overlap)*numX < x: 70 | numX = numX + 1 71 | while (patchSize-overlap)*numY < y: 72 | numY = numY + 1 73 | lfStack=torch.zeros(b,numX*numY,u,v,c,patchSize,patchSize) 74 | indCurrent=0 75 | for i in range(numX): 76 | for j in range(numY): 77 | if (i != numX-1)and(j != numY-1): 78 | lfPatch=lf[:,:,:,:,i*(patchSize-overlap):(i+1)*patchSize-i*overlap,j*(patchSize-overlap):(j+1)*patchSize-j*overlap] 79 | elif (i != numX-1)and(j == numY-1): 80 | lfPatch=lf[:,:,:,:,i*(patchSize-overlap):(i+1)*patchSize-i*overlap,-patchSize:] 81 | elif (i == numX-1)and(j != numY-1): 82 | lfPatch=lf[:,:,:,:,-patchSize:,j*(patchSize-overlap):(j+1)*patchSize-j*overlap] 83 | else : 84 | lfPatch=lf[:,:,:,:,-patchSize:,-patchSize:] 85 | # print(numX,numY,i,j,lfPatch.shape) 86 | lfStack[:,indCurrent,:,:,:,:,:]=lfPatch 87 | indCurrent=indCurrent+1 88 | return lfStack, [numX,numY] #lfStack [b,n,u,v,c,x,y] 89 | 90 | 91 | def MergeLF(lfStack, coordinate, overlap, x, y): 92 | b,n,u,v,c,patchSize,_=lfStack.shape 93 | lfMerged=torch.zeros(b,u,v,c,x-overlap,y-overlap) 94 | for i in range(coordinate[0]): 95 | for j in range(coordinate[1]): 96 | if (i != coordinate[0]-1)and(j != coordinate[1]-1): 97 | lfMerged[:,:,:,:, 98 | i*(patchSize-overlap):(i+1)*(patchSize-overlap), 99 | j*(patchSize-overlap):(j+1)*(patchSize-overlap)]=lfStack[:,i*coordinate[1]+j,:,:,:, 100 | overlap//2:-overlap//2, 101 | overlap//2:-overlap//2] 102 | elif (i == coordinate[0]-1)and(j != coordinate[1]-1): 103 | lfMerged[:,:,:,:,i*(patchSize-overlap):, 104 | j*(patchSize-overlap):(j+1)*(patchSize-overlap)]=lfStack[:,i*coordinate[1]+j,:,:,:, 105 | -((x-overlap)-i*(patchSize-overlap))-overlap//2:-overlap//2, 106 | overlap//2:-overlap//2] 107 | elif (i != coordinate[0]-1)and(j == coordinate[1]-1): 108 | lfMerged[:,:,:,:,i*(patchSize-overlap):(i+1)*(patchSize-overlap), 109 | j*(patchSize-overlap):]=lfStack[:,i*coordinate[1]+j,:,:,:, 110 | overlap//2:-overlap//2, 111 | -((y-overlap)-j*(patchSize-overlap))-overlap//2:-overlap//2] 112 | else: 113 | lfMerged[:,:,:,:,i*(patchSize-overlap):, 114 | j*(patchSize-overlap):]=lfStack[:,i*coordinate[1]+j,:,:,:, 115 | -((x-overlap)-i*(patchSize-overlap))-overlap//2:-overlap//2, 116 | -((y-overlap)-j*(patchSize-overlap))-overlap//2:-overlap//2] 117 | return lfMerged # [b,u,v,c,x,y] 118 | 119 | def ComptPSNR(img1, img2): 120 | mse = np.mean( (img1 - img2) ** 2 ) 121 | if mse == 0: 122 | return 100 123 | PIXEL_MAX = 1.0 124 | 125 | if mse > 1000: 126 | return -100 127 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 128 | 129 | def rgb2ycbcr(rgb): 130 | m = np.array([[ 65.481, 128.553, 24.966], 131 | [-37.797, -74.203, 112], 132 | [ 112, -93.786, -18.214]]) 133 | shape = rgb.shape 134 | if len(shape) == 3: 135 | rgb = rgb.reshape((shape[0] * shape[1], 3)) 136 | ycbcr = np.dot(rgb, m.transpose() / 255.) 137 | ycbcr[:,0] += 16. 138 | ycbcr[:,1:] += 128. 139 | return ycbcr.reshape(shape) 140 | -------------------------------------------------------------------------------- /LFDN/Functions.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from torch.autograd import Variable 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import torch.nn.functional as F 6 | import math 7 | # Ignore warnings 8 | #import cv2 9 | import warnings 10 | from scipy import sparse 11 | import random 12 | import numpy as np 13 | 14 | warnings.filterwarnings("ignore") 15 | plt.ion() 16 | 17 | 18 | 19 | 20 | #Initiate parameters in model 21 | def weights_init(m): 22 | classname = m.__class__.__name__ 23 | if classname.find('Conv2d') != -1: 24 | torch.nn.init.xavier_uniform_(m.weight.data) 25 | #torch.nn.init.constant_(m.bias.data, 0.0) 26 | elif classname.find('ConvTranspose2d') != -1: 27 | torch.nn.init.xavier_uniform_(m.weight.data) 28 | #torch.nn.init.constant_(m.bias.data, 0.0) 29 | if classname.find('Conv3d') != -1: 30 | torch.nn.init.xavier_uniform_(m.weight.data) 31 | #torch.nn.init.constant_(m.bias.data, 0.0) 32 | elif classname.find('ConvTranspose3d') != -1: 33 | torch.nn.init.xavier_uniform_(m.weight.data) 34 | #torch.nn.init.constant_(m.bias.data, 0.0) 35 | elif classname.find('Linear') != -1: 36 | torch.nn.init.xavier_uniform_(m.weight.data) 37 | #torch.nn.init.constant_(m.bias.data, 0.0) 38 | 39 | def SetupSeed(seed): 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed_all(seed) 42 | np.random.seed(seed) 43 | random.seed(seed) 44 | torch.backends.cudnn.deterministic = True 45 | 46 | def ExtractPatch(lf,noiself, H, W, patchSize): 47 | indx=random.randrange(0,H-patchSize,8) 48 | indy=random.randrange(0,W-patchSize,8) 49 | # indc=random.randint(0,2) 50 | 51 | lfPatch=lf[:,:,indx:indx+patchSize, indy:indy+patchSize] 52 | noiselfPatch=noiself[:,:,indx:indx+patchSize,indy:indy+patchSize] 53 | return lfPatch,noiselfPatch #[u v x y] 54 | 55 | def ResizeLF(lf,scale_factor): 56 | u,v,x,y,c=lf.shape 57 | resizedLF=np.zeros((u,v,int(scale_factor*x),int(scale_factor*y),c),dtype=np.int) 58 | for ind_u in range(u): 59 | for ind_v in range(v): 60 | view=lf[ind_u,ind_v,:,:,:] 61 | resizedView=cv2.resize(view, (int(scale_factor*x),int(scale_factor*y)), interpolation=cv2.INTER_CUBIC) 62 | resizedLF[ind_u,ind_v,:,:,:]=resizedView.reshape(int(scale_factor*x),int(scale_factor*y),-1) 63 | return resizedLF 64 | 65 | 66 | def CropLF(lf,patchSize, overlap): #lf [b,u,v,c,x,y] 67 | b,u,v,x,y=lf.shape 68 | numX=0 69 | numY=0 70 | while (patchSize-overlap)*numX < x: 71 | numX = numX + 1 72 | while (patchSize-overlap)*numY < y: 73 | numY = numY + 1 74 | lfStack=torch.zeros(b,numX*numY,u,v,patchSize,patchSize) 75 | indCurrent=0 76 | for i in range(numX): 77 | for j in range(numY): 78 | if (i != numX-1)and(j != numY-1): 79 | lfPatch=lf[:,:,:,i*(patchSize-overlap):(i+1)*patchSize-i*overlap,j*(patchSize-overlap):(j+1)*patchSize-j*overlap] 80 | elif (i != numX-1)and(j == numY-1): 81 | lfPatch=lf[:,:,:,i*(patchSize-overlap):(i+1)*patchSize-i*overlap,-patchSize:] 82 | elif (i == numX-1)and(j != numY-1): 83 | lfPatch=lf[:,:,:,-patchSize:,j*(patchSize-overlap):(j+1)*patchSize-j*overlap] 84 | else : 85 | lfPatch=lf[:,:,:,-patchSize:,-patchSize:] 86 | # print(numX,numY,i,j,lfPatch.shape) 87 | lfStack[:,indCurrent,:,:,:,:]=lfPatch 88 | indCurrent=indCurrent+1 89 | return lfStack, [numX,numY] #lfStack [b,n,u,v,c,x,y] 90 | 91 | 92 | def MergeLF(lfStack, coordinate, overlap, x, y): 93 | b,n,u,v,patchSize,_=lfStack.shape 94 | lfMerged=torch.zeros(b,u,v,x-overlap,y-overlap) 95 | for i in range(coordinate[0]): 96 | for j in range(coordinate[1]): 97 | if (i != coordinate[0]-1)and(j != coordinate[1]-1): 98 | lfMerged[:,:,:, 99 | i*(patchSize-overlap):(i+1)*(patchSize-overlap), 100 | j*(patchSize-overlap):(j+1)*(patchSize-overlap)]=lfStack[:,i*coordinate[1]+j,:,:, 101 | overlap//2:-overlap//2, 102 | overlap//2:-overlap//2] 103 | elif (i == coordinate[0]-1)and(j != coordinate[1]-1): 104 | lfMerged[:,:,:,i*(patchSize-overlap):, 105 | j*(patchSize-overlap):(j+1)*(patchSize-overlap)]=lfStack[:,i*coordinate[1]+j,:,:, 106 | -((x-overlap)-i*(patchSize-overlap))-overlap//2:-overlap//2, 107 | overlap//2:-overlap//2] 108 | elif (i != coordinate[0]-1)and(j == coordinate[1]-1): 109 | lfMerged[:,:,:,i*(patchSize-overlap):(i+1)*(patchSize-overlap), 110 | j*(patchSize-overlap):]=lfStack[:,i*coordinate[1]+j,:,:, 111 | overlap//2:-overlap//2, 112 | -((y-overlap)-j*(patchSize-overlap))-overlap//2:-overlap//2] 113 | else: 114 | lfMerged[:,:,:,i*(patchSize-overlap):, 115 | j*(patchSize-overlap):]=lfStack[:,i*coordinate[1]+j,:,:, 116 | -((x-overlap)-i*(patchSize-overlap))-overlap//2:-overlap//2, 117 | -((y-overlap)-j*(patchSize-overlap))-overlap//2:-overlap//2] 118 | return lfMerged # [b,u,v,c,x,y] 119 | 120 | def ComptPSNR(img1, img2): 121 | mse = np.mean( (img1 - img2) ** 2 ) 122 | if mse == 0: 123 | return 100 124 | PIXEL_MAX = 1.0 125 | 126 | if mse > 1000: 127 | return -100 128 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 129 | 130 | def rgb2ycbcr(rgb): 131 | m = np.array([[ 65.481, 128.553, 24.966], 132 | [-37.797, -74.203, 112], 133 | [ 112, -93.786, -18.214]]) 134 | shape = rgb.shape 135 | if len(shape) == 3: 136 | rgb = rgb.reshape((shape[0] * shape[1], 3)) 137 | ycbcr = np.dot(rgb, m.transpose() / 255.) 138 | ycbcr[:,0] += 16. 139 | ycbcr[:,1:] += 128. 140 | return ycbcr.reshape(shape) 141 | -------------------------------------------------------------------------------- /LFCA/LFCA-Test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import print_function, division\n", 10 | "import torch\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "from torch.utils.data import DataLoader\n", 13 | "from torchvision import utils\n", 14 | "import warnings\n", 15 | "from LFDatatest import LFDatatest\n", 16 | "from DeviceParameters import to_device\n", 17 | "from MainNet_pfe_test import MainNet\n", 18 | "from Functions import CropLF, MergeLF,ComptPSNR,rgb2ycbcr\n", 19 | "from skimage.measure import compare_ssim \n", 20 | "import numpy as np\n", 21 | "import scipy.io as scio \n", 22 | "import scipy.misc as scim\n", 23 | "import os\n", 24 | "import logging,argparse\n", 25 | "from datetime import datetime\n", 26 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "warnings.filterwarnings(\"ignore\")\n", 36 | "plt.ion()\n", 37 | "logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')\n", 38 | "log = logging.getLogger()\n", 39 | "fh = logging.FileHandler('Testing_original.log')\n", 40 | "log.addHandler(fh)\n", 41 | "\n", 42 | "# Testing settings\n", 43 | "parser = argparse.ArgumentParser(description=\"Light Field Denoising\")\n", 44 | "parser.add_argument(\"--stageNum\", type=int, default=6, help=\"The number of stages\")\n", 45 | "parser.add_argument(\"--sasLayerNum\", type=int, default=4, help=\"The number of stages\")\n", 46 | "parser.add_argument(\"--batchSize\", type=int, default=1, help=\"Batch size\")\n", 47 | "parser.add_argument(\"--patchSize\", type=int, default=32, help=\"The size of croped LF patch\")\n", 48 | "parser.add_argument(\"--overlap\", type=int, default=4, help=\"The size of croped LF patch\")\n", 49 | "parser.add_argument(\"--measurementNum\", type=int, default=4, help=\"The number of measurements\")\n", 50 | "parser.add_argument(\"--angResolution\", type=int, default=7, help=\"The angular resolution of original LF\")\n", 51 | "parser.add_argument(\"--channelNum\", type=int, default=1, help=\"The number of input channels\")\n", 52 | "parser.add_argument(\"--modelPath\", type=str, default='./model/***pfe model***', help=\"Path for loading trained model \")\n", 53 | "parser.add_argument(\"--dataPath\", type=str, default='path_to/test_LFCA_Kalantari_4-10.mat', help=\"Path for loading testing data \")\n", 54 | "parser.add_argument(\"--savePath\", type=str, default='./results/', help=\"Path for saving results \")\n", 55 | "\n", 56 | "\n", 57 | "opt = parser.parse_known_args()[0]\n", 58 | "logging.info(opt)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "lf_dataset = LFDatatest(opt)\n", 68 | "dataloader = DataLoader(lf_dataset, batch_size=opt.batchSize,shuffle=False)\n", 69 | "model=MainNet(opt)\n", 70 | "model.load_state_dict(torch.load(opt.modelPath))\n", 71 | "model.eval()\n", 72 | "model.cuda()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "with torch.no_grad():\n", 82 | " num = 0\n", 83 | " avg_psnr = 0\n", 84 | " avg_ssim = 0\n", 85 | " for _,sample in enumerate(dataloader):\n", 86 | " num=num+1\n", 87 | " LF=sample['LF'] #test lf [b,u,v,x,y,c]\n", 88 | " lfName=sample['lfName']\n", 89 | " b,u,v,x,y,c = LF.shape \n", 90 | "\n", 91 | " # Crop the input LF into patches \n", 92 | " LFStack,coordinate=CropLF(LF.permute(0,1,2,5,3,4),opt.patchSize, opt.overlap) #[b,n,u,v,c,x,y]\n", 93 | " n=LFStack.shape[1] \n", 94 | " estiLFStack=torch.zeros(b,n,u,v,c,opt.patchSize,opt.patchSize)#[b,n,u,v,c,x,y]\n", 95 | "\n", 96 | " # reconstruction\n", 97 | " for i in range(LFStack.shape[1]):\n", 98 | " codedImPatch=torch.zeros(b,opt.measurementNum,c,opt.patchSize,opt.patchSize)#[b,measurementNum,c,x,y]\n", 99 | " estiLFPatch=torch.zeros(b,u,v,c,opt.patchSize,opt.patchSize)#[b,u,v,c,x,y]\n", 100 | " for j in range(c):\n", 101 | " estiLFPatch[:,:,:,j:j+1,:,:]=model(LFStack[:,i,:,:,j:j+1,:,:].cuda()) #[b,measurementNum,c,x,y] [b,u,v,c,x,y].\n", 102 | " estiLFStack[:,i,:,:,:,:,:]=estiLFPatch #[b,n,u,v,c,x,y]\n", 103 | "\n", 104 | " # Merge the patches into LF\n", 105 | " estiLF=MergeLF(estiLFStack,coordinate,opt.overlap,x,y) \n", 106 | " b,u,v,c,xCrop,yCrop=estiLF.shape\n", 107 | " LF=LF[:,:,:, opt.overlap//2:opt.overlap//2+xCrop,opt.overlap//2:opt.overlap//2+yCrop,:]\n", 108 | " lf_psnr = 0\n", 109 | " lf_ssim = 0\n", 110 | "\n", 111 | " #evaluation\n", 112 | " for ind_uv in range(u*v):\n", 113 | " lf_psnr += ComptPSNR(rgb2ycbcr(estiLF.permute(0,1,2,4,5,3).reshape(b,u*v,xCrop,yCrop,c)[0,ind_uv].cpu().numpy())[:,:,0],\n", 114 | " rgb2ycbcr(LF.reshape(b,u*v,xCrop,yCrop,c)[0,ind_uv].cpu().numpy())[:,:,0]) / (u*v)\n", 115 | "\n", 116 | " lf_ssim += compare_ssim(rgb2ycbcr(estiLF.permute(0,1,2,4,5,3).reshape(b,u*v,xCrop,yCrop,c)[0,ind_uv].cpu().numpy()*255.0)[:,:,0].astype(np.uint8),\n", 117 | " rgb2ycbcr(LF.reshape(b,u*v,xCrop,yCrop,c)[0,ind_uv].cpu().numpy()*255.0)[:,:,0].astype(np.uint8),gaussian_weights=True,sigma=1.5,use_sample_covariance=False,multichannel=False) / (u*v)\n", 118 | "\n", 119 | " avg_psnr += lf_psnr / len(dataloader) \n", 120 | " avg_ssim += lf_ssim / len(dataloader)\n", 121 | " log.info('Index: %d Scene: %s PSNR: %.2f SSIM: %.3f'%(num,lfName[0],lf_psnr,lf_ssim))\n", 122 | " #save reconstructed LF\n", 123 | " scio.savemat(os.path.join(opt.savePath,lfName[0]+'.mat'),\n", 124 | " {'lf_recons':torch.squeeze(estiLF).numpy()})\n", 125 | "\n", 126 | "\n", 127 | " # #save coded mask\n", 128 | " if opt.measurementNum==1:\n", 129 | " plt.imsave(os.path.join(opt.savePath, 'mask.png'),\n", 130 | " torch.squeeze(255.0*model._modules['proj_init'].weight.data.reshape(-1,opt.angResolution,opt.angResolution).permute(1,2,0)).cpu().numpy())\n", 131 | " if opt.measurementNum==2:\n", 132 | " plt.imsave(os.path.join(opt.savePath, 'mask.png'),\n", 133 | " torch.squeeze(255.0*model._modules['proj_init'].weight.data.reshape(-1,opt.angResolution,opt.angResolution-1).permute(1,2,0)).cpu().numpy())\n", 134 | " if opt.measurementNum==4:\n", 135 | " plt.imsave(os.path.join(opt.savePath, 'mask.png'),\n", 136 | " torch.squeeze(255.0*model._modules['proj_init'].weight.data.reshape(-1,opt.angResolution,opt.angResolution).permute(1,2,0)).cpu().numpy())\n", 137 | "\n", 138 | " log.info('Average PSNR: %.2f SSIM: %.3f '%(avg_psnr,avg_ssim)) " 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [] 147 | } 148 | ], 149 | "metadata": { 150 | "kernelspec": { 151 | "display_name": "Python 3", 152 | "language": "python", 153 | "name": "python3" 154 | }, 155 | "language_info": { 156 | "codemirror_mode": { 157 | "name": "ipython", 158 | "version": 3 159 | }, 160 | "file_extension": ".py", 161 | "mimetype": "text/x-python", 162 | "name": "python", 163 | "nbconvert_exporter": "python", 164 | "pygments_lexer": "ipython3", 165 | "version": "3.8.5" 166 | } 167 | }, 168 | "nbformat": 4, 169 | "nbformat_minor": 4 170 | } 171 | -------------------------------------------------------------------------------- /LFCA/RefNet_pfe_pretrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import math 7 | 8 | class Conv_spa(nn.Module): 9 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 10 | super(Conv_spa, self).__init__() 11 | self.op = nn.Sequential( 12 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 13 | nn.ReLU(inplace = True) 14 | ) 15 | def forward(self,x): 16 | N,c,uv,h,w = x.shape 17 | x = x.permute(0,2,1,3,4).reshape(N*uv,c,h,w) 18 | out = self.op(x) 19 | out = out.reshape(N,uv,16,h,w).permute(0,2,1,3,4) 20 | return out 21 | 22 | class Conv_ang(nn.Module): 23 | def __init__(self, C_in, C_out, kernel_size, stride, padding, angular, bias): 24 | super(Conv_ang, self).__init__() 25 | self.angular = angular 26 | self.op = nn.Sequential( 27 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 28 | nn.ReLU(inplace = True) 29 | ) 30 | def forward(self,x): 31 | N,c,uv,h,w = x.shape 32 | x = x.permute(0,3,4,1,2).reshape(N*h*w,c,self.angular,self.angular) 33 | out = self.op(x) 34 | out = out.reshape(N,h,w,16,uv).permute(0,3,4,1,2) 35 | return out 36 | 37 | class Conv_epi_h(nn.Module): 38 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 39 | super(Conv_epi_h, self).__init__() 40 | self.op = nn.Sequential( 41 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 42 | nn.ReLU(inplace = True) 43 | ) 44 | def forward(self,x): 45 | N,c,uv,h,w = x.shape 46 | x = x.permute(0,3,1,2,4).reshape(N*h,c,uv,w) 47 | out = self.op(x) 48 | out = out.reshape(N,h,16,uv,w).permute(0,2,3,1,4) 49 | return out 50 | 51 | class Conv_epi_v(nn.Module): 52 | 53 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 54 | super(Conv_epi_v, self).__init__() 55 | self.op = nn.Sequential( 56 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 57 | nn.ReLU(inplace = True) 58 | ) 59 | def forward(self,x): 60 | N,c,uv,h,w = x.shape 61 | x = x.permute(0,4,1,2,3).reshape(N*w,c,uv,h) 62 | out = self.op(x) 63 | out = out.reshape(N,w,16,uv,h).permute(0,2,3,4,1) 64 | return out 65 | 66 | 67 | class Autocovnlayer(nn.Module): 68 | def __init__(self,dence_num,component_num,angular,bs): 69 | super(Autocovnlayer, self).__init__() 70 | self.dence_num = dence_num 71 | self.component_num = component_num 72 | self.dence_weight = nn.Parameter(torch.rand(dence_num),requires_grad=True) #[N,*,c,u,v,h,w] 73 | self.component_weight = nn.Parameter(torch.rand(component_num),requires_grad=True) #[N,*,c,,,] 74 | self.angular = angular 75 | self.kernel_size = 3 76 | 77 | self.naslayers = nn.ModuleList([ 78 | Conv_spa(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 79 | Conv_ang(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, angular = self.angular, bias = bs), 80 | Conv_epi_h(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 81 | Conv_epi_v(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs) 82 | ]) 83 | self.Conv_all = nn.Conv2d(in_channels = 32, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 84 | self.softmax1 = nn.Softmax(1) 85 | self.softmax0 = nn.Softmax(0) 86 | self.Conv_mixdence = nn.Conv2d(in_channels = 32*self.dence_num, out_channels=32, kernel_size=1, stride=1, padding=0, bias = False) 87 | self.Conv_mixnas = nn.Conv2d(in_channels = 16*4, out_channels=32, kernel_size=1, stride=1, padding=0, bias = False) ## 1*1 paddding!! 88 | self.relu = nn.ReLU(inplace=True) 89 | 90 | 91 | def forward(self,x,temperature_1,temperature_2): 92 | x = torch.stack(x,dim = 0) 93 | [fn, N, C, uv, h, w] = x.shape 94 | dence_weight_soft = torch.ones((self.dence_num,N))[:,:,None,None,None,None].cuda() 95 | component_weight_gumbel = torch.ones((self.component_num,N))[:,:,None,None,None,None].cuda() 96 | 97 | x = x * dence_weight_soft 98 | x = x.permute([1,3,0,2,4,5]).reshape([N*uv,fn*C,h,w]) 99 | x = self.relu(self.Conv_mixdence(x)) 100 | x_mix = x.reshape([N,uv,32,h,w]).permute([0,2,1,3,4]) 101 | layer_label = 0 102 | nas = [] 103 | for layer in self.naslayers: 104 | nas_ = layer(x_mix) 105 | nas.append(nas_) 106 | 107 | nas = torch.stack(nas,dim = 0) 108 | nas = nas * component_weight_gumbel 109 | #print("nas-shape:",nas.shape) 110 | nas = nas.permute([1,3,0,2,4,5]).reshape([N*uv,self.component_num*16,h,w]) 111 | nas = self.relu(self.Conv_mixnas(nas)) 112 | ####### add a spa conv 113 | nas = self.Conv_all(nas) 114 | #print("outshape0:",nas.shape) 115 | nas = nas.reshape(N,uv,32,h,w).permute(0,2,1,3,4) 116 | nas = self.relu(nas + x_mix) 117 | #print("outshape1:",nas.shape) 118 | return nas 119 | 120 | def make_autolayers(opt,bs): 121 | layers = [] 122 | for i in range( opt.sasLayerNum ): 123 | layers.append(Autocovnlayer(i+1, opt.component_num, opt.angResolution, bs)) 124 | return nn.Sequential(*layers) 125 | 126 | class RefNet(nn.Module): 127 | def __init__(self, opt, bs): 128 | super(RefNet, self).__init__() 129 | self.angResolution = opt.angResolution 130 | self.measurementNum = opt.measurementNum 131 | self.lfNum = opt.angResolution * opt.angResolution 132 | self.epochNum = opt.epochNum 133 | self.temperature_1 = opt.temperature_1 134 | self.temperature_2 = opt.temperature_2 135 | 136 | self.relu = nn.ReLU(inplace=True) 137 | self.conv0 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 138 | self.conv1 = nn.Conv2d(in_channels=self.measurementNum, out_channels=self.lfNum, kernel_size=3, stride=1, padding=1, bias = bs) 139 | 140 | #self.altblock = make_Altlayer(opt) 141 | self.dence_autolayers = make_autolayers(opt,bs) 142 | #self.syn_conv1 = nn.Conv3d(in_channels=64, out_channels=self.angResolution * self.angResolution, kernel_size=(self.measurementNum,3,3), stride=1, padding=(0,1,1)) 143 | self.syn_conv2 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1, bias = bs) 144 | 145 | 146 | def forward(self, input, epoch): 147 | N,uvInput,_,h,w = input.shape #[N,uvInput,1,h,w] 148 | if epoch <= 3800: # T 1 ==> 0.1 149 | temperature_1 = self.temperature_1 * (1 - epoch / 4000) 150 | temperature_2 = self.temperature_2 * (1 - epoch / 4000) 151 | else: 152 | temperature_1 = 0.05 153 | temperature_2 = 0.05 154 | 155 | # feature extraction 156 | feat = input.reshape(N*uvInput,1,h,w) 157 | feat = self.relu(self.conv0(feat)) 158 | 159 | # LF feature extration 160 | feat = feat.reshape(N,uvInput,32,h,w).permute(0,2,1,3,4).reshape(N*32,uvInput,h,w) 161 | feat = self.relu(self.conv1(feat)) 162 | feat = feat.reshape(N,32,self.lfNum,h,w) 163 | 164 | # autoConv 165 | feat = [feat] 166 | for index, layer in enumerate(self.dence_autolayers): 167 | #print("feat-shape:",feat.shape) 168 | feat_ = layer(feat,temperature_1,temperature_2) 169 | feat.append(feat_) 170 | feat = self.syn_conv2(feat[-1].permute(0,2,1,3,4).reshape(N*self.lfNum,32,h,w)) 171 | out = feat.reshape(N,self.lfNum,1,h,w) 172 | return out 173 | 174 | -------------------------------------------------------------------------------- /LFCA/RefNet_pfe_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import math 7 | 8 | class Conv_spa(nn.Module): 9 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 10 | super(Conv_spa, self).__init__() 11 | self.op = nn.Sequential( 12 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 13 | nn.ReLU(inplace = True) 14 | ) 15 | def forward(self,x): 16 | N,c,uv,h,w = x.shape 17 | x = x.permute(0,2,1,3,4).reshape(N*uv,c,h,w) 18 | out = self.op(x) 19 | out = out.reshape(N,uv,16,h,w).permute(0,2,1,3,4) 20 | return out 21 | 22 | class Conv_ang(nn.Module): 23 | def __init__(self, C_in, C_out, kernel_size, stride, padding, angular, bias): 24 | super(Conv_ang, self).__init__() 25 | self.angular = angular 26 | self.op = nn.Sequential( 27 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 28 | nn.ReLU(inplace = True) 29 | ) 30 | def forward(self,x): 31 | N,c,uv,h,w = x.shape 32 | x = x.permute(0,3,4,1,2).reshape(N*h*w,c,self.angular,self.angular) 33 | out = self.op(x) 34 | out = out.reshape(N,h,w,16,uv).permute(0,3,4,1,2) 35 | return out 36 | 37 | class Conv_epi_h(nn.Module): 38 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 39 | super(Conv_epi_h, self).__init__() 40 | self.op = nn.Sequential( 41 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 42 | nn.ReLU(inplace = True) 43 | ) 44 | def forward(self,x): 45 | N,c,uv,h,w = x.shape 46 | x = x.permute(0,3,1,2,4).reshape(N*h,c,uv,w) 47 | out = self.op(x) 48 | out = out.reshape(N,h,16,uv,w).permute(0,2,3,1,4) 49 | return out 50 | 51 | class Conv_epi_v(nn.Module): 52 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 53 | super(Conv_epi_v, self).__init__() 54 | self.op = nn.Sequential( 55 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 56 | nn.ReLU(inplace = True) 57 | ) 58 | def forward(self,x): 59 | N,c,uv,h,w = x.shape 60 | x = x.permute(0,4,1,2,3).reshape(N*w,c,uv,h) 61 | out = self.op(x) 62 | out = out.reshape(N,w,16,uv,h).permute(0,2,3,4,1) 63 | return out 64 | 65 | 66 | class Autocovnlayer(nn.Module): 67 | def __init__(self,dence_num,component_num,angular,bs): 68 | super(Autocovnlayer, self).__init__() 69 | self.dence_num = dence_num 70 | self.component_num = component_num 71 | self.dence_weight = nn.Parameter(torch.rand(dence_num),requires_grad=True) 72 | self.component_weight = nn.Parameter(torch.rand(component_num),requires_grad=True) 73 | self.angular = angular 74 | self.kernel_size = 3 75 | 76 | self.naslayers = nn.ModuleList([ 77 | Conv_spa(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 78 | Conv_ang(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, angular = self.angular, bias = bs), 79 | Conv_epi_h(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 80 | Conv_epi_v(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs) 81 | ]) 82 | self.Conv_all = nn.Conv2d(in_channels = 32, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 83 | self.softmax1 = nn.Softmax(1) 84 | self.softmax0 = nn.Softmax(0) 85 | self.Conv_mixdence = nn.Conv2d(in_channels = 32*self.dence_num, out_channels=32, kernel_size=1, stride=1, padding=0, bias = False) 86 | self.Conv_mixnas = nn.Conv2d(in_channels = 16*4, out_channels=32, kernel_size=1, stride=1, padding=0, bias = False) ## 1*1 paddding!! 87 | self.relu = nn.ReLU(inplace=True) 88 | 89 | 90 | def forward(self,x,temperature_1,temperature_2): 91 | x = torch.stack(x,dim = 0) 92 | [fn, N, C, uv, h, w] = x.shape 93 | 94 | dence_weight = self.dence_weight.clamp(0.02,0.98) 95 | dence_weight_soft = torch.zeros(dence_weight.shape) 96 | dence_weight_soft[dence_weight > 0.5] = 1 97 | dence_weight_soft = dence_weight_soft[:,None,None,None,None,None].cuda() 98 | 99 | component_weight = self.component_weight.clamp(0.02,0.98) 100 | component_weight_gumbel = torch.zeros(component_weight.shape) 101 | component_weight_gumbel[component_weight > 0.5] = 1 102 | component_weight_gumbel = component_weight_gumbel[:,None,None,None,None,None].cuda() 103 | 104 | x = x * dence_weight_soft 105 | x = x.permute([1,3,0,2,4,5]).reshape([N*uv,fn*C,h,w]) 106 | x = self.relu(self.Conv_mixdence(x)) 107 | x_mix = x.reshape([N,uv,32,h,w]).permute([0,2,1,3,4]) 108 | layer_label = 0 109 | nas = [] 110 | for layer in self.naslayers: 111 | nas_ = layer(x_mix) 112 | nas.append(nas_) 113 | 114 | nas = torch.stack(nas,dim = 0) 115 | nas = nas * component_weight_gumbel 116 | #print("nas-shape:",nas.shape) 117 | nas = nas.permute([1,3,0,2,4,5]).reshape([N*uv,self.component_num*16,h,w]) 118 | nas = self.relu(self.Conv_mixnas(nas)) 119 | ####### add a spa conv 120 | nas = self.Conv_all(nas) 121 | nas = nas.reshape(N,uv,32,h,w).permute(0,2,1,3,4) 122 | nas = self.relu(nas + x_mix) 123 | return nas 124 | 125 | def make_autolayers(opt,bs): 126 | layers = [] 127 | for i in range( opt.sasLayerNum ): 128 | layers.append(Autocovnlayer(i+1, opt.component_num, opt.angResolution, bs)) 129 | return nn.Sequential(*layers) 130 | 131 | class RefNet(nn.Module): 132 | 133 | def __init__(self, opt, bs): 134 | 135 | super(RefNet, self).__init__() 136 | 137 | self.angResolution = opt.angResolution 138 | self.measurementNum = opt.measurementNum 139 | self.lfNum = opt.angResolution * opt.angResolution 140 | self.epochNum = opt.epochNum 141 | self.temperature_1 = opt.temperature_1 142 | self.temperature_2 = opt.temperature_2 143 | 144 | self.relu = nn.ReLU(inplace=True) 145 | self.conv0 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 146 | self.conv1 = nn.Conv2d(in_channels=self.measurementNum, out_channels=self.lfNum, kernel_size=3, stride=1, padding=1, bias = bs) 147 | 148 | #self.altblock = make_Altlayer(opt) 149 | self.dence_autolayers = make_autolayers(opt,bs) 150 | #self.syn_conv1 = nn.Conv3d(in_channels=64, out_channels=self.angResolution * self.angResolution, kernel_size=(self.measurementNum,3,3), stride=1, padding=(0,1,1)) 151 | self.syn_conv2 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1, bias = bs) 152 | 153 | def forward(self, input, epoch): 154 | N,uvInput,_,h,w = input.shape 155 | if epoch <= 3800: # T 1 ==> 0.1 156 | temperature_1 = self.temperature_1 * (1 - epoch / 4000) 157 | temperature_2 = self.temperature_2 * (1 - epoch / 4000) 158 | else: 159 | temperature_1 = 0.05 160 | temperature_2 = 0.05 161 | 162 | # feature extraction 163 | feat = input.reshape(N*uvInput,1,h,w) 164 | feat = self.relu(self.conv0(feat)) 165 | 166 | # LF feature extration 167 | feat = feat.reshape(N,uvInput,32,h,w).permute(0,2,1,3,4).reshape(N*32,uvInput,h,w) 168 | feat = self.relu(self.conv1(feat)) 169 | feat = feat.reshape(N,32,self.lfNum,h,w) 170 | # autoConv 171 | feat = [feat] 172 | for index, layer in enumerate(self.dence_autolayers): 173 | #print("feat-shape:",feat.shape) 174 | feat_ = layer(feat,temperature_1,temperature_2) 175 | feat.append(feat_) 176 | feat = self.syn_conv2(feat[-1].permute(0,2,1,3,4).reshape(N*self.lfNum,32,h,w)) 177 | out = feat.reshape(N,self.lfNum,1,h,w) 178 | return out 179 | 180 | -------------------------------------------------------------------------------- /LFCA/LFCA-PFE-preTrain.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torch.utils.data import DataLoader\n", 11 | "from datetime import datetime\n", 12 | "import logging\n", 13 | "from LFDataset import LFDataset\n", 14 | "from LFDatatest import LFDatatest\n", 15 | "from Functions import weights_init,SetupSeed,CropLF, MergeLF,ComptPSNR,rgb2ycbcr\n", 16 | "from DeviceParameters import to_device\n", 17 | "from MainNet_pfe_pretrain import MainNet\n", 18 | "import itertools,argparse\n", 19 | "from skimage.metrics import structural_similarity\n", 20 | "import numpy as np\n", 21 | "import scipy.io as scio \n", 22 | "import scipy.misc as scim\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "from collections import defaultdict\n", 25 | "import os\n", 26 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# Training settings\n", 36 | "parser = argparse.ArgumentParser(description=\"Light Field Compressed Sensing\")\n", 37 | "parser.add_argument(\"--learningRate\", type=float, default=1e-3, help=\"Learning rate\")\n", 38 | "parser.add_argument(\"--step\", type=int, default=1000, help=\"Learning rate decay every n epochs\")\n", 39 | "parser.add_argument(\"--reduce\", type=float, default=0.7, help=\"Learning rate decay\")\n", 40 | "parser.add_argument(\"--stageNum\", type=int, default=6, help=\"The number of stages\")\n", 41 | "parser.add_argument(\"--sasLayerNum\", type=int, default=8, help=\"The number of stages\")\n", 42 | "parser.add_argument(\"--temperature_1\", type=float, default=1, help=\"The number of temperature_1\")\n", 43 | "parser.add_argument(\"--temperature_2\", type=float, default=1, help=\"The number of temperature_2\")\n", 44 | "parser.add_argument(\"--component_num\", type=int, default=4, help=\"The number of nas component\")\n", 45 | "parser.add_argument(\"--batchSize\", type=int, default=5, help=\"Batch size\")\n", 46 | "parser.add_argument(\"--sampleNum\", type=int, default=55, help=\"The number of LF in training set\")\n", 47 | "parser.add_argument(\"--patchSize\", type=int, default=32, help=\"The size of croped LF patch\")\n", 48 | "parser.add_argument(\"--num_cp\", type=int, default=1000, help=\"Number of epoches for saving checkpoint\")\n", 49 | "parser.add_argument(\"--measurementNum\", type=int, default=2, help=\"The number of measurements\")\n", 50 | "parser.add_argument(\"--angResolution\", type=int, default=5, help=\"The angular resolution of original LF\")\n", 51 | "parser.add_argument(\"--channelNum\", type=int, default=1, help=\"The channel number of input LF\")\n", 52 | "parser.add_argument(\"--epochNum\", type=int, default=10000, help=\"The number of epoches\")\n", 53 | "parser.add_argument(\"--overlap\", type=int, default=4, help=\"The size of croped LF patch\")\n", 54 | "parser.add_argument(\"--summaryPath\", type=str, default='./', help=\"Path for saving training log \")\n", 55 | "parser.add_argument(\"--dataName\", type=str, default='Synthetic', help=\"The name of dataset \")\n", 56 | "parser.add_argument(\"--testPath\", type=str, default='path_to/test_LFCA_synthetic_5.mat', help=\"Path for loading training data \")\n", 57 | "parser.add_argument(\"--dataPath\", type=str, default='path_to/train_LFCA_synthetic_5.mat', help=\"Path for loading training data \")\n", 58 | "\n", 59 | "opt = parser.parse_known_args()[0]\n", 60 | "\n", 61 | "logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')\n", 62 | "log = logging.getLogger()\n", 63 | "fh = logging.FileHandler('Training_pfe_{}_{}_{}_{}_pretrain.log'.format(opt.dataName, opt.measurementNum, opt.stageNum, opt.sasLayerNum))\n", 64 | "log.addHandler(fh)\n", 65 | "logging.info(opt)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "if __name__ == '__main__':\n", 75 | "\n", 76 | " SetupSeed(1)\n", 77 | " savePath = './model/lfca_{}_{}_{}_{}_{}_{}-pretrian'.format(opt.dataName, opt.measurementNum, opt.stageNum,opt.sasLayerNum, opt.epochNum, opt.learningRate)\n", 78 | " lfDataset = LFDataset(opt)\n", 79 | " dataloader = DataLoader(lfDataset, batch_size=opt.batchSize,shuffle=True)\n", 80 | "\n", 81 | " torch.backends.cudnn.deterministic = True\n", 82 | " torch.backends.cudnn.benchmark = False \n", 83 | "\n", 84 | " model=MainNet(opt)\n", 85 | " model._modules['proj_init'].weight.data[model._modules['proj_init'].weight.data<0.0]=0.0\n", 86 | " model._modules['proj_init'].weight.data[model._modules['proj_init'].weight.data>1.0]=1.0\n", 87 | " model = model.cuda()\n", 88 | " # total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 89 | " # log.info(\"Training parameters: %d\" %total_trainable_params)\n", 90 | "\n", 91 | " criterion = torch.nn.L1Loss() # Loss \n", 92 | " optimizer = torch.optim.Adam(itertools.chain(model.parameters()), lr=opt.learningRate) #optimizer\n", 93 | " scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt.step, gamma=opt.reduce)\n", 94 | " \n", 95 | " lossLogger = defaultdict(list)\n", 96 | " \n", 97 | " for epoch in range(opt.epochNum):\n", 98 | " batch = 0\n", 99 | " lossSum = 0\n", 100 | " for _,sample in enumerate(dataloader):\n", 101 | " batch = batch +1\n", 102 | " lf=sample['lf']\n", 103 | " lf = lf.cuda()\n", 104 | " estimatedLF=model(lf,epoch)\n", 105 | " loss = criterion(estimatedLF,lf)\n", 106 | " lossSum += loss.item()\n", 107 | " print(\"Epoch: %d Batch: %d Loss: %.6f\" %(epoch,batch,loss.item()))\n", 108 | " \n", 109 | " optimizer.zero_grad()\n", 110 | " loss.backward()\n", 111 | " optimizer.step()\n", 112 | " model._modules['proj_init'].weight.data[model._modules['proj_init'].weight.data<0.0]=0.0\n", 113 | " model._modules['proj_init'].weight.data[model._modules['proj_init'].weight.data>1.0]=1.0\n", 114 | " scheduler.step() #ONE\n", 115 | "\n", 116 | " if epoch % opt.num_cp == 0:\n", 117 | " model_save_path = join(savePath,\"pre_model_epoch_{}.pth\".format(epoch))\n", 118 | " state = {'epoch':epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(),\n", 119 | " 'scheduler': scheduler.state_dict()}\n", 120 | " torch.save(state,model_save_path)\n", 121 | " print(\"checkpoint saved to {}\".format(model_save_path))\n", 122 | " log.info(\"Epoch: %d Loss: %.6f\" %(epoch,lossSum/len(dataloader)))\n", 123 | "\n", 124 | " #Record the training loss\n", 125 | " lossLogger['Epoch'].append(epoch)\n", 126 | " lossLogger['Loss'].append(lossSum/len(dataloader))\n", 127 | " lossLogger['Lr'].append(optimizer.state_dict()['param_groups'][0]['lr'])\n", 128 | " #lossLogger['Psnr'].append(avg_psnr)\n", 129 | " plt.figure()\n", 130 | " plt.title('Loss')\n", 131 | " plt.plot(lossLogger['Epoch'],lossLogger['Loss'])\n", 132 | " plt.savefig('Training_{}_{}_{}_{}_{}_{}_pretrian.jpg'.format(opt.dataName, opt.measurementNum, opt.stageNum,opt.sasLayerNum, opt.epochNum, opt.learningRate))\n", 133 | " plt.close()" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python 3", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.8.8" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 4 165 | } 166 | -------------------------------------------------------------------------------- /LFCA/LFCA-PFE-Train-Orignal.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torch.utils.data import DataLoader\n", 11 | "from datetime import datetime\n", 12 | "import logging\n", 13 | "from LFDataset import LFDataset\n", 14 | "from LFDatatest import LFDatatest\n", 15 | "from Functions import weights_init,SetupSeed,CropLF, MergeLF,ComptPSNR,rgb2ycbcr\n", 16 | "from DeviceParameters import to_device\n", 17 | "from MainNet_pfe_ver0 import MainNet\n", 18 | "import itertools,argparse\n", 19 | "from skimage.metrics import structural_similarity\n", 20 | "import numpy as np\n", 21 | "import scipy.io as scio \n", 22 | "import scipy.misc as scim\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "from collections import defaultdict\n", 25 | "import os\n", 26 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# Training settings\n", 36 | "parser = argparse.ArgumentParser(description=\"Light Field Compressed Sensing\")\n", 37 | "parser.add_argument(\"--learningRate\", type=float, default=1e-3, help=\"Learning rate\")\n", 38 | "parser.add_argument(\"--step\", type=int, default=1000, help=\"Learning rate decay every n epochs\")\n", 39 | "parser.add_argument(\"--reduce\", type=float, default=0.5, help=\"Learning rate decay\")\n", 40 | "parser.add_argument(\"--stageNum\", type=int, default=6, help=\"The number of stages\")\n", 41 | "parser.add_argument(\"--sasLayerNum\", type=int, default=8, help=\"The number of stages\")\n", 42 | "parser.add_argument(\"--temperature_1\", type=float, default=1, help=\"The number of temperature_1\")\n", 43 | "parser.add_argument(\"--temperature_2\", type=float, default=1, help=\"The number of temperature_2\")\n", 44 | "parser.add_argument(\"--component_num\", type=int, default=4, help=\"The number of nas component\")\n", 45 | "parser.add_argument(\"--batchSize\", type=int, default=5, help=\"Batch size\")\n", 46 | "parser.add_argument(\"--sampleNum\", type=int, default=55, help=\"The number of LF in training set\")\n", 47 | "parser.add_argument(\"--patchSize\", type=int, default=32, help=\"The size of croped LF patch\")\n", 48 | "parser.add_argument(\"--num_cp\", type=int, default=1000, help=\"Number of epoches for saving checkpoint\")\n", 49 | "parser.add_argument(\"--measurementNum\", type=int, default=2, help=\"The number of measurements\")\n", 50 | "parser.add_argument(\"--angResolution\", type=int, default=5, help=\"The angular resolution of original LF\")\n", 51 | "parser.add_argument(\"--channelNum\", type=int, default=1, help=\"The channel number of input LF\")\n", 52 | "parser.add_argument(\"--epochNum\", type=int, default=10000, help=\"The number of epoches\")\n", 53 | "parser.add_argument(\"--overlap\", type=int, default=4, help=\"The size of croped LF patch\")\n", 54 | "parser.add_argument(\"--summaryPath\", type=str, default='./', help=\"Path for saving training log \")\n", 55 | "parser.add_argument(\"--dataName\", type=str, default='Synthetic', help=\"The name of dataset \")\n", 56 | "parser.add_argument(\"--preTrain\", type=str, default='./model/***pretrained model***', help=\"Path for loading pretrained model \")\n", 57 | "parser.add_argument(\"--testPath\", type=str, default='path_to/test_LFCA_synthetic_5.mat', help=\"Path for loading training data \")\n", 58 | "parser.add_argument(\"--dataPath\", type=str, default='path_to/train_LFCA_synthetic_5.mat', help=\"Path for loading training data \")\n", 59 | "\n", 60 | "opt = parser.parse_known_args()[0]\n", 61 | "\n", 62 | "logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')\n", 63 | "log = logging.getLogger()\n", 64 | "fh = logging.FileHandler('Training_pfe_{}_{}_{}_{}_pfe.log'.format(opt.dataName, opt.measurementNum, opt.stageNum, opt.sasLayerNum))\n", 65 | "log.addHandler(fh)\n", 66 | "logging.info(opt)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "if __name__ == '__main__':\n", 76 | "\n", 77 | " SetupSeed(1)\n", 78 | " savePath = './model/lfca_{}_{}_{}_{}_{}_{}-pfe'.format(opt.dataName, opt.measurementNum, opt.stageNum, opt.sasLayerNum, opt.epochNum, opt.learningRate)\n", 79 | " lfDataset = LFDataset(opt)\n", 80 | " dataloader = DataLoader(lfDataset, batch_size=opt.batchSize,shuffle=True)\n", 81 | "\n", 82 | " torch.backends.cudnn.deterministic = True\n", 83 | " torch.backends.cudnn.benchmark = False\n", 84 | "\n", 85 | " model=MainNet(opt)\n", 86 | " model.load_state_dict(torch.load(opt.preTrain)['model'])\n", 87 | " model = model.cuda()\n", 88 | " # total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 89 | " # log.info(\"Training parameters: %d\" %total_trainable_params)\n", 90 | "\n", 91 | " criterion = torch.nn.L1Loss() # Loss \n", 92 | " optimizer = torch.optim.Adam(itertools.chain(model.parameters()), lr=opt.learningRate) #optimizer\n", 93 | " scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr = opt.learningRate,steps_per_epoch=len(dataloader),\n", 94 | " epochs=opt.epochNum,pct_start = 0.2, div_factor = 10, final_div_factor = 10)\n", 95 | "\n", 96 | " lossLogger = defaultdict(list)\n", 97 | " for epoch in range(opt.epochNum):\n", 98 | " batch = 0\n", 99 | " lossSum = 0\n", 100 | " for _,sample in enumerate(dataloader):\n", 101 | " batch = batch +1\n", 102 | " lf=sample['lf']\n", 103 | " lf = lf.cuda()\n", 104 | " \n", 105 | " estimatedLF=model(lf,epoch)\n", 106 | " loss = criterion(estimatedLF,lf)\n", 107 | " lossSum += loss.item()\n", 108 | " print(\"Epoch: %d Batch: %d Loss: %.6f\" %(epoch,batch,loss.item()))\n", 109 | " \n", 110 | " optimizer.zero_grad()\n", 111 | " loss.backward()\n", 112 | " optimizer.step()\n", 113 | " model._modules['proj_init'].weight.data[model._modules['proj_init'].weight.data<0.0]=0.0\n", 114 | " model._modules['proj_init'].weight.data[model._modules['proj_init'].weight.data>1.0]=1.0\n", 115 | " scheduler.step() #ONE\n", 116 | " \n", 117 | " if epoch % opt.num_cp == 0:\n", 118 | " model_save_path = join(savePath,\"pfe_model_epoch_{}.pth\".format(epoch))\n", 119 | " state = {'epoch':epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(),\n", 120 | " 'scheduler': scheduler.state_dict()}\n", 121 | " torch.save(state,model_save_path)\n", 122 | " print(\"checkpoint saved to {}\".format(model_save_path))\n", 123 | " log.info(\"Epoch: %d Loss: %.6f\" %(epoch,lossSum/len(dataloader)))\n", 124 | "\n", 125 | " #Record the training loss\n", 126 | " lossLogger['Epoch'].append(epoch)\n", 127 | " lossLogger['Loss'].append(lossSum/len(dataloader))\n", 128 | " lossLogger['Lr'].append(optimizer.state_dict()['param_groups'][0]['lr'])\n", 129 | " #lossLogger['Psnr'].append(avg_psnr)\n", 130 | " plt.figure()\n", 131 | " plt.title('Loss')\n", 132 | " plt.plot(lossLogger['Epoch'],lossLogger['Loss'])\n", 133 | " plt.savefig('Training_{}_{}_{}_{}_{}_{}_pfe.jpg'.format(opt.dataName, opt.measurementNum, opt.stageNum,opt.sasLayerNum, opt.epochNum, opt.learningRate))\n", 134 | " plt.close()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [] 150 | } 151 | ], 152 | "metadata": { 153 | "kernelspec": { 154 | "display_name": "Python 3", 155 | "language": "python", 156 | "name": "python3" 157 | }, 158 | "language_info": { 159 | "codemirror_mode": { 160 | "name": "ipython", 161 | "version": 3 162 | }, 163 | "file_extension": ".py", 164 | "mimetype": "text/x-python", 165 | "name": "python", 166 | "nbconvert_exporter": "python", 167 | "pygments_lexer": "ipython3", 168 | "version": "3.8.8" 169 | } 170 | }, 171 | "nbformat": 4, 172 | "nbformat_minor": 4 173 | } 174 | -------------------------------------------------------------------------------- /LFDN/LFDN-PFE-preTrain.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torch.utils.data import DataLoader\n", 11 | "from datetime import datetime\n", 12 | "import logging\n", 13 | "from LFDataset import LFDataset\n", 14 | "from Functions import weights_init,SetupSeed,CropLF, MergeLF,ComptPSNR,rgb2ycbcr\n", 15 | "import itertools,argparse\n", 16 | "from skimage.metrics import structural_similarity\n", 17 | "import numpy as np\n", 18 | "import scipy.io as scio \n", 19 | "import scipy.misc as scim\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "from collections import defaultdict\n", 22 | "import os\n", 23 | "from os.path import join\n", 24 | "from MainNet_pfe_pretrain import MainNet\n", 25 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "# Training settings\n", 35 | "parser = argparse.ArgumentParser(description=\"Light Field Compressed Sensing\")\n", 36 | "parser.add_argument(\"--learningRate\", type=float, default=1e-3, help=\"Learning rate\")\n", 37 | "parser.add_argument(\"--step\", type=int, default=1000, help=\"Learning rate decay every n epochs\")\n", 38 | "parser.add_argument(\"--reduce\", type=float, default=0.5, help=\"Learning rate decay\")\n", 39 | "parser.add_argument(\"--stageNum\", type=int, default=2, help=\"The number of stages\")\n", 40 | "parser.add_argument(\"--sasLayerNum\", type=int, default=6, help=\"The number of stages\")\n", 41 | "parser.add_argument(\"--temperature_1\", type=float, default=1, help=\"The number of temperature_1\")\n", 42 | "parser.add_argument(\"--temperature_2\", type=float, default=1, help=\"The number of temperature_2\")\n", 43 | "parser.add_argument(\"--component_num\", type=int, default=4, help=\"The number of pfe component\")\n", 44 | "parser.add_argument(\"--noiselevel\", type=int, default=20, help=\"Noise level 10 20 50\")\n", 45 | "parser.add_argument(\"--batchSize\", type=int, default=5, help=\"Batch size\")\n", 46 | "parser.add_argument(\"--sampleNum\", type=int, default=100, help=\"The number of LF in training set\")\n", 47 | "parser.add_argument(\"--patchSize\", type=int, default=32, help=\"The size of croped LF patch\")\n", 48 | "parser.add_argument(\"--num_cp\", type=int, default=1000, help=\"Number of epoches for saving checkpoint\")\n", 49 | "parser.add_argument(\"--angResolution\", type=int, default=7, help=\"The angular resolution of original LF\")\n", 50 | "parser.add_argument(\"--channelNum\", type=int, default=1, help=\"The channel number of input LF\")\n", 51 | "parser.add_argument(\"--epochNum\", type=int, default=10010, help=\"The number of epoches\")\n", 52 | "parser.add_argument(\"--overlap\", type=int, default=4, help=\"The size of croped LF patch\")\n", 53 | "parser.add_argument(\"--summaryPath\", type=str, default='./', help=\"Path for saving training log \")\n", 54 | "parser.add_argument(\"--dataName\", type=str, default='Synthetic', help=\"The name of dataset \")\n", 55 | "parser.add_argument(\"--dataPath\", type=str, default='/***dataroot**/train_synthetic_noiselevel_10_20_50.mat', help=\"Path for loading training data \")\n", 56 | "parser.add_argument(\"--resume_epoch\", type=int, default=0, help=\"resume from checkpoint epoch\")\n", 57 | "\n", 58 | "opt = parser.parse_known_args()[0]\n", 59 | "\n", 60 | "logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')\n", 61 | "log = logging.getLogger()\n", 62 | "fh = logging.FileHandler('Training_pfe_{}_{}_{}_{}_{}_{}.log'.format(opt.dataName, opt.noiselevel, opt.stageNum, opt.sasLayerNum, opt.epochNum, opt.learningRate))\n", 63 | "log.addHandler(fh)\n", 64 | "logging.info(opt)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "if __name__ == '__main__':\n", 74 | "\n", 75 | " SetupSeed(50)\n", 76 | " savePath = './model/lfdn_{}_{}_{}_{}_{}_{}-pretrain'.format(opt.dataName, opt.noiselevel, opt.stageNum, opt.sasLayerNum, opt.epochNum, opt.learningRate)\n", 77 | " \n", 78 | " if not os.path.exists(savePath):\n", 79 | " os.makedirs(savePath)\n", 80 | " \n", 81 | " lfDataset = LFDataset(opt)\n", 82 | " dataloader = DataLoader(lfDataset, batch_size=opt.batchSize,shuffle=True)\n", 83 | " print('loaded {} LFIs from {}'.format(len(dataloader), opt.dataPath))\n", 84 | "\n", 85 | " device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 86 | " torch.backends.cudnn.deterministic = True\n", 87 | " torch.backends.cudnn.benchmark = False\n", 88 | "\n", 89 | " model=MainNet(opt)\n", 90 | " # model.load_state_dict(torch.load(opt.preTrain))\n", 91 | " model = model.cuda()\n", 92 | " total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 93 | " log.info(\"Training parameters: %d\" %total_trainable_params)\n", 94 | "\n", 95 | " criterion = torch.nn.L1Loss() # Loss \n", 96 | "\n", 97 | " optimizer = torch.optim.Adam(itertools.chain(model.parameters()), lr=opt.learningRate) #optimizer\n", 98 | " scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt.step, gamma=opt.reduce)\n", 99 | "\n", 100 | " if opt.resume_epoch:\n", 101 | " resume_path = join(savePath,'pre_model_epoch_{}.pth'.format(opt.resume_epoch))\n", 102 | " if os.path.isfile(resume_path):\n", 103 | " print(\"==>loading checkpoint 'epoch{}'\".format(resume_path))\n", 104 | " checkpoint = torch.load(resume_path)\n", 105 | " model.load_state_dict(checkpoint['model'])\n", 106 | " optimizer.load_state_dict(checkpoint['optimizer'])\n", 107 | " scheduler.load_state_dict(checkpoint['scheduler'])\n", 108 | " losslogger = checkpoint['losslogger']\n", 109 | " else:\n", 110 | " print(\"==> no model found at 'epoch{}'\".format(opt.resume_epoch)) \n", 111 | "\n", 112 | " lossLogger = defaultdict(list)\n", 113 | " \n", 114 | " for epoch in range(opt.epochNum):\n", 115 | " batch = 0\n", 116 | " lossSum = 0\n", 117 | " for _,sample in enumerate(dataloader):\n", 118 | " batch = batch +1\n", 119 | " lf=sample['lf']\n", 120 | " lf = lf.cuda()\n", 121 | "\n", 122 | " noiself=sample['noiself'].cuda()\n", 123 | " estimatedLF=model(noiself,epoch)\n", 124 | " loss = criterion(estimatedLF,lf)\n", 125 | " lossSum += loss.item()\n", 126 | "\n", 127 | " optimizer.zero_grad()\n", 128 | " loss.backward()\n", 129 | " optimizer.step()\n", 130 | " scheduler.step() #ONE\n", 131 | "\n", 132 | " if epoch % opt.num_cp == 0:\n", 133 | " model_save_path = join(savePath,\"pre_model_epoch_{}.pth\".format(epoch))\n", 134 | " state = {'epoch':epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(),\n", 135 | " 'scheduler': scheduler.state_dict()}\n", 136 | " torch.save(state,model_save_path)\n", 137 | " print(\"checkpoint saved to {}\".format(model_save_path))\n", 138 | "\n", 139 | " log.info(\"Epoch: %d Loss: %.6f\" %(epoch,lossSum/len(dataloader)))\n", 140 | "\n", 141 | " #Record the training loss\n", 142 | " lossLogger['Epoch'].append(epoch)\n", 143 | " lossLogger['Loss'].append(lossSum/len(dataloader))\n", 144 | " lossLogger['Lr'].append(optimizer.state_dict()['param_groups'][0]['lr'])\n", 145 | "\n", 146 | " plt.figure()\n", 147 | " plt.title('Loss')\n", 148 | " plt.plot(lossLogger['Epoch'],lossLogger['Loss'])\n", 149 | " plt.savefig('Training_{}_{}_{}_{}_{}_{}_ver.jpg'.format(opt.dataName, opt.noiselevel, opt.stageNum,opt.sasLayerNum, opt.epochNum, opt.learningRate))\n", 150 | " plt.close()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [] 159 | } 160 | ], 161 | "metadata": { 162 | "kernelspec": { 163 | "display_name": "Python 3 (ipykernel)", 164 | "language": "python", 165 | "name": "python3" 166 | }, 167 | "language_info": { 168 | "codemirror_mode": { 169 | "name": "ipython", 170 | "version": 3 171 | }, 172 | "file_extension": ".py", 173 | "mimetype": "text/x-python", 174 | "name": "python", 175 | "nbconvert_exporter": "python", 176 | "pygments_lexer": "ipython3", 177 | "version": "3.8.11" 178 | } 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 4 182 | } 183 | -------------------------------------------------------------------------------- /LFCA/RefNet_pfe_ver0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import math 7 | 8 | class Conv_spa(nn.Module): 9 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 10 | super(Conv_spa, self).__init__() 11 | self.op = nn.Sequential( 12 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 13 | nn.ReLU(inplace = True) 14 | ) 15 | def forward(self,x): 16 | N,c,uv,h,w = x.shape 17 | x = x.permute(0,2,1,3,4).reshape(N*uv,c,h,w) 18 | out = self.op(x) 19 | out = out.reshape(N,uv,16,h,w).permute(0,2,1,3,4) 20 | return out 21 | 22 | class Conv_ang(nn.Module): 23 | def __init__(self, C_in, C_out, kernel_size, stride, padding, angular, bias): 24 | super(Conv_ang, self).__init__() 25 | self.angular = angular 26 | self.op = nn.Sequential( 27 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 28 | nn.ReLU(inplace = True) 29 | ) 30 | def forward(self,x): 31 | N,c,uv,h,w = x.shape 32 | x = x.permute(0,3,4,1,2).reshape(N*h*w,c,self.angular,self.angular) 33 | out = self.op(x) 34 | out = out.reshape(N,h,w,16,uv).permute(0,3,4,1,2) 35 | return out 36 | 37 | class Conv_epi_h(nn.Module): 38 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 39 | super(Conv_epi_h, self).__init__() 40 | self.op = nn.Sequential( 41 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 42 | nn.ReLU(inplace = True) 43 | ) 44 | def forward(self,x): 45 | N,c,uv,h,w = x.shape 46 | x = x.permute(0,3,1,2,4).reshape(N*h,c,uv,w) 47 | out = self.op(x) 48 | out = out.reshape(N,h,16,uv,w).permute(0,2,3,1,4) 49 | return out 50 | 51 | class Conv_epi_v(nn.Module): 52 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 53 | super(Conv_epi_v, self).__init__() 54 | self.op = nn.Sequential( 55 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 56 | nn.ReLU(inplace = True) 57 | ) 58 | def forward(self,x): 59 | N,c,uv,h,w = x.shape 60 | x = x.permute(0,4,1,2,3).reshape(N*w,c,uv,h) 61 | out = self.op(x) 62 | out = out.reshape(N,w,16,uv,h).permute(0,2,3,4,1) 63 | return out 64 | 65 | 66 | class Autocovnlayer(nn.Module): 67 | def __init__(self,dence_num,component_num,angular,bs): 68 | super(Autocovnlayer, self).__init__() 69 | self.dence_num = dence_num 70 | self.component_num = component_num 71 | self.dence_weight = nn.Parameter(torch.rand(dence_num),requires_grad=True) 72 | self.component_weight = nn.Parameter(torch.rand(component_num),requires_grad=True) 73 | self.angular = angular 74 | self.kernel_size = 3 75 | 76 | self.naslayers = nn.ModuleList([ 77 | Conv_spa(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 78 | Conv_ang(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, angular = self.angular, bias = bs), 79 | Conv_epi_h(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 80 | Conv_epi_v(C_in = 32, C_out = 16, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs) 81 | ]) 82 | self.Conv_all = nn.Conv2d(in_channels = 32, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 83 | self.softmax1 = nn.Softmax(1) 84 | self.softmax0 = nn.Softmax(0) 85 | self.Conv_mixdence = nn.Conv2d(in_channels = 32*self.dence_num, out_channels=32, kernel_size=1, stride=1, padding=0, bias = False) 86 | self.Conv_mixnas = nn.Conv2d(in_channels = 16*4, out_channels=32, kernel_size=1, stride=1, padding=0, bias = False) ## 1*1 paddding!! 87 | self.relu = nn.ReLU(inplace=True) 88 | 89 | 90 | def forward(self,x,temperature_1,temperature_2): 91 | x = torch.stack(x,dim = 0) 92 | [fn, N, C, uv, h, w] = x.shape 93 | ## generate 2 noise dim of noise !!! 94 | dence_weight = self.dence_weight.clamp(0.02,0.98) 95 | dence_weight = dence_weight[:,None,None,None,None,None] 96 | component_weight = self.component_weight.clamp(0.02,0.98) 97 | component_weight = component_weight[:,None,None,None,None,None] 98 | 99 | noise_dence_r1 = torch.rand((self.dence_num,N))[:,:,None,None,None,None].cuda() ##[dence_num,N,1,1,1,1] 100 | noise_dence_r2 = torch.rand((self.dence_num,N))[:,:,None,None,None,None].cuda() 101 | noise_dence_logits = torch.log(torch.log(noise_dence_r1) / torch.log(noise_dence_r2)) 102 | dence_weight_soft = torch.sigmoid((torch.log(dence_weight / (1 - dence_weight)) + noise_dence_logits) / temperature_1) 103 | 104 | noise_component_r1 = torch.rand((self.component_num,N))[:,:,None,None,None,None].cuda() ##[dence_num,N,1,1,1,1] 105 | noise_component_r2 = torch.rand((self.component_num,N))[:,:,None,None,None,None].cuda() 106 | noise_component_logits = torch.log(torch.log(noise_component_r1) / torch.log(noise_component_r2)) 107 | component_weight_gumbel = torch.sigmoid((torch.log(component_weight / (1 - component_weight)) + noise_component_logits) / temperature_2) 108 | 109 | x = x * dence_weight_soft 110 | x = x.permute([1,3,0,2,4,5]).reshape([N*uv,fn*C,h,w]) 111 | x = self.relu(self.Conv_mixdence(x)) 112 | x_mix = x.reshape([N,uv,32,h,w]).permute([0,2,1,3,4]) 113 | layer_label = 0 114 | nas = [] 115 | for layer in self.naslayers: 116 | nas_ = layer(x_mix) 117 | nas.append(nas_) 118 | 119 | nas = torch.stack(nas,dim = 0) 120 | nas = nas * component_weight_gumbel 121 | #print("nas-shape:",nas.shape) 122 | nas = nas.permute([1,3,0,2,4,5]).reshape([N*uv,self.component_num*16,h,w]) 123 | nas = self.relu(self.Conv_mixnas(nas)) 124 | ####### add a spa conv 125 | nas = self.Conv_all(nas) 126 | nas = nas.reshape(N,uv,32,h,w).permute(0,2,1,3,4) 127 | nas = self.relu(nas + x_mix) 128 | return nas 129 | 130 | def make_autolayers(opt,bs): 131 | layers = [] 132 | for i in range( opt.sasLayerNum ): 133 | layers.append(Autocovnlayer(i+1, opt.component_num, opt.angResolution, bs)) 134 | return nn.Sequential(*layers) 135 | 136 | class RefNet(nn.Module): 137 | 138 | def __init__(self, opt, bs): 139 | 140 | super(RefNet, self).__init__() 141 | 142 | self.angResolution = opt.angResolution 143 | self.measurementNum = opt.measurementNum 144 | self.lfNum = opt.angResolution * opt.angResolution 145 | self.epochNum = opt.epochNum 146 | self.temperature_1 = opt.temperature_1 147 | self.temperature_2 = opt.temperature_2 148 | 149 | self.relu = nn.ReLU(inplace=True) 150 | self.conv0 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 151 | self.conv1 = nn.Conv2d(in_channels=self.measurementNum, out_channels=self.lfNum, kernel_size=3, stride=1, padding=1, bias = bs) 152 | 153 | #self.altblock = make_Altlayer(opt) 154 | self.dence_autolayers = make_autolayers(opt,bs) 155 | #self.syn_conv1 = nn.Conv3d(in_channels=64, out_channels=self.angResolution * self.angResolution, kernel_size=(self.measurementNum,3,3), stride=1, padding=(0,1,1)) 156 | self.syn_conv2 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1, bias = bs) 157 | 158 | def forward(self, input, epoch): 159 | N,uvInput,_,h,w = input.shape 160 | if epoch <= 3800: # T 1 ==> 0.1 161 | temperature_1 = self.temperature_1 * (1 - epoch / 4000) 162 | temperature_2 = self.temperature_2 * (1 - epoch / 4000) 163 | else: 164 | temperature_1 = 0.05 165 | temperature_2 = 0.05 166 | 167 | # feature extraction 168 | feat = input.reshape(N*uvInput,1,h,w) 169 | feat = self.relu(self.conv0(feat)) 170 | 171 | # LF feature extration 172 | feat = feat.reshape(N,uvInput,32,h,w).permute(0,2,1,3,4).reshape(N*32,uvInput,h,w) 173 | feat = self.relu(self.conv1(feat)) 174 | feat = feat.reshape(N,32,self.lfNum,h,w) 175 | # autoConv 176 | feat = [feat] 177 | for index, layer in enumerate(self.dence_autolayers): 178 | #print("feat-shape:",feat.shape) 179 | feat_ = layer(feat,temperature_1,temperature_2) 180 | feat.append(feat_) 181 | feat = self.syn_conv2(feat[-1].permute(0,2,1,3,4).reshape(N*self.lfNum,32,h,w)) 182 | out = feat.reshape(N,self.lfNum,1,h,w) 183 | return out 184 | 185 | -------------------------------------------------------------------------------- /LFDN/LFDN-Test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "97fd9fc2", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "from torch.utils.data import DataLoader\n", 12 | "from datetime import datetime\n", 13 | "import logging,argparse\n", 14 | "import warnings\n", 15 | "from LFDatatest import LFDataset\n", 16 | "from Functions import weights_init,SetupSeed,CropLF, MergeLF,ComptPSNR,rgb2ycbcr\n", 17 | "import itertools,argparse\n", 18 | "from skimage.metrics import structural_similarity\n", 19 | "import numpy as np\n", 20 | "import scipy.io as scio \n", 21 | "import scipy.misc as scim\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "from collections import defaultdict\n", 24 | "import os\n", 25 | "import time\n", 26 | "from os.path import join\n", 27 | "from MainNet_pfe_pretrain import MainNet\n", 28 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "72acff4a", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# Testing settings\n", 39 | "parser = argparse.ArgumentParser(description=\"Light Field Compressed Sensing\")\n", 40 | "parser.add_argument(\"--learningRate\", type=float, default=1e-3, help=\"Learning rate\")\n", 41 | "parser.add_argument(\"--step\", type=int, default=1000, help=\"Learning rate decay every n epochs\")\n", 42 | "parser.add_argument(\"--reduce\", type=float, default=0.5, help=\"Learning rate decay\")\n", 43 | "parser.add_argument(\"--stageNum\", type=int, default=2, help=\"The number of stages\")\n", 44 | "parser.add_argument(\"--sasLayerNum\", type=int, default=6, help=\"The number of stages\")\n", 45 | "parser.add_argument(\"--temperature_1\", type=float, default=1, help=\"The number of temperature_1\")\n", 46 | "parser.add_argument(\"--temperature_2\", type=float, default=1, help=\"The number of temperature_2\")\n", 47 | "parser.add_argument(\"--component_num\", type=int, default=4, help=\"The number of nas component\")\n", 48 | "parser.add_argument(\"--noiselevel\", type=int, default=20, help=\"Noise level 10 20 50\")\n", 49 | "parser.add_argument(\"--batchSize\", type=int, default=1, help=\"Batch size\")\n", 50 | "parser.add_argument(\"--sampleNum\", type=int, default=55, help=\"The number of LF in training set\")\n", 51 | "parser.add_argument(\"--patchSize\", type=int, default=32, help=\"The size of croped LF patch\")\n", 52 | "\n", 53 | "parser.add_argument(\"--angResolution\", type=int, default=7, help=\"The angular resolution of original LF\")\n", 54 | "parser.add_argument(\"--channelNum\", type=int, default=1, help=\"The channel number of input LF\")\n", 55 | "parser.add_argument(\"--epochNum\", type=int, default=11000, help=\"The number of epoches\")\n", 56 | "parser.add_argument(\"--overlap\", type=int, default=4, help=\"The size of croped LF patch\")\n", 57 | "parser.add_argument(\"--summaryPath\", type=str, default='./', help=\"Path for saving training log \")\n", 58 | "parser.add_argument(\"--dataName\", type=str, default='Synthetic', help=\"The name of dataset \")\n", 59 | "parser.add_argument(\"--modelPath\", type=str, default='./model/*** model path***', help=\"Path for loading trained model \")\n", 60 | "parser.add_argument(\"--dataPath\", type=str, default='/***dataroot***/test_synthetic_noiselevel_10_20_50.mat', help=\"Path for loading training data \")\n", 61 | "parser.add_argument(\"--savePath\", type=str, default='./results/', help=\"Path for saving results \")\n", 62 | "opt = parser.parse_known_args()[0]\n", 63 | "\n", 64 | "warnings.filterwarnings(\"ignore\")\n", 65 | "plt.ion()\n", 66 | "logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')\n", 67 | "log = logging.getLogger()\n", 68 | "fh = logging.FileHandler('Testing_original.log')\n", 69 | "log.addHandler(fh)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "a3d4987f", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "lf_dataset = LFDataset(opt)\n", 80 | "dataloader = DataLoader(lf_dataset, batch_size=opt.batchSize,shuffle=False)\n", 81 | "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n", 82 | "print(device) \n", 83 | "model=MainNet(opt)\n", 84 | "model.load_state_dict(torch.load(opt.modelPath)['model'])\n", 85 | "model.eval()\n", 86 | "model.cuda()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "4f8a6b3a", 93 | "metadata": { 94 | "scrolled": true 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "with torch.no_grad():\n", 99 | "# SetupSeed(50)\n", 100 | " num = 0\n", 101 | " avg_psnr = 0\n", 102 | " avg_ssim = 0\n", 103 | " for _,sample in enumerate(dataloader):\n", 104 | " num=num+1\n", 105 | " LF=sample['lf']\n", 106 | " noilf=sample['noiself']\n", 107 | " lfName=sample['lfname']\n", 108 | " b,u,v,x,y = LF.shape \n", 109 | " # Crop the input LF into patches \n", 110 | " LFStack,coordinate=CropLF(noilf,opt.patchSize, opt.overlap) #[b,n,u,v,c,x,y]\n", 111 | " n=LFStack.shape[1] \n", 112 | " estiLFStack=torch.zeros(b,n,u,v,opt.patchSize,opt.patchSize)#[b,n,u,v,c,x,y]\n", 113 | "\n", 114 | " for i in range(LFStack.shape[1]):\n", 115 | " estiLFStack[:,i,:,:,:,:] = model(LFStack[:,i,:,:,:,:].cuda(),opt.epochNum)\n", 116 | " estiLF=MergeLF(estiLFStack,coordinate,opt.overlap,x,y) #[b,u,v,c,x,y]\n", 117 | " b,u,v,xCrop,yCrop=estiLF.shape\n", 118 | " LF=LF[:,:,:, opt.overlap//2:opt.overlap//2+xCrop,opt.overlap//2:opt.overlap//2+yCrop]\n", 119 | " lf_psnr = 0\n", 120 | " lf_ssim = 0\n", 121 | " #evaluation\n", 122 | " for ind_uv in range(u*v):\n", 123 | " lf_psnr += ComptPSNR(estiLF.reshape(b,u*v,xCrop,yCrop)[0,ind_uv].cpu().numpy(),\n", 124 | " LF.reshape(b,u*v,xCrop,yCrop)[0,ind_uv].cpu().numpy()) / (u*v)\n", 125 | "\n", 126 | " lf_ssim += structural_similarity((estiLF.reshape(b,u*v,xCrop,yCrop)[0,ind_uv].cpu().numpy()*255.0).astype(np.uint8),\n", 127 | " (LF.reshape(b,u*v,xCrop,yCrop)[0,ind_uv].cpu().numpy()*255.0).astype(np.uint8),gaussian_weights=True,sigma=1.5,use_sample_covariance=False,multichannel=False) / (u*v)\n", 128 | " avg_psnr += lf_psnr / len(dataloader) \n", 129 | " avg_ssim += lf_ssim / len(dataloader)\n", 130 | " log.info('Index: %d Scene: %s PSNR: %.2f SSIM: %.3f'%(num,lfName[0],lf_psnr,lf_ssim))\n", 131 | " #save reconstructed LF\n", 132 | " scio.savemat(os.path.join(opt.savePath,lfName[0]+'.mat'),\n", 133 | " {'lf_recons':torch.squeeze(estiLF).numpy()})\n", 134 | " log.info('Average PSNR: %.2f SSIM: %.3f '%(avg_psnr,avg_ssim)) " 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "6d05c2f6", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "f7178735", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "885a3d11", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "id": "bbafae8e", 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "d49ee67e", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "624e0722", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "id": "78077dce", 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "id": "0e208787", 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "\n", 201 | "\n" 202 | ] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "Python 3 (ipykernel)", 208 | "language": "python", 209 | "name": "python3" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.8.11" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 5 226 | } 227 | -------------------------------------------------------------------------------- /LFDN/LFDN-PFE-Train-Orignal.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torch.utils.data import DataLoader\n", 11 | "from datetime import datetime\n", 12 | "import logging\n", 13 | "from LFDataset import LFDataset\n", 14 | "from Functions import weights_init,SetupSeed,CropLF, MergeLF,ComptPSNR,rgb2ycbcr\n", 15 | "import itertools,argparse\n", 16 | "from skimage.metrics import structural_similarity\n", 17 | "import numpy as np\n", 18 | "import scipy.io as scio \n", 19 | "import scipy.misc as scim\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "from collections import defaultdict\n", 22 | "import os\n", 23 | "from os.path import join\n", 24 | "from MainNet_pfe_ver0 import MainNet\n", 25 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "# Training settings\n", 35 | "parser = argparse.ArgumentParser(description=\"Light Field Compressed Sensing\")\n", 36 | "parser.add_argument(\"--learningRate\", type=float, default=1e-3, help=\"Learning rate\")\n", 37 | "parser.add_argument(\"--step\", type=int, default=1000, help=\"Learning rate decay every n epochs\")\n", 38 | "parser.add_argument(\"--reduce\", type=float, default=0.5, help=\"Learning rate decay\")\n", 39 | "parser.add_argument(\"--stageNum\", type=int, default=2, help=\"The number of stages\")\n", 40 | "parser.add_argument(\"--sasLayerNum\", type=int, default=6, help=\"The number of stages\")\n", 41 | "parser.add_argument(\"--temperature_1\", type=float, default=1, help=\"The number of temperature_1\")\n", 42 | "parser.add_argument(\"--temperature_2\", type=float, default=1, help=\"The number of temperature_2\")\n", 43 | "parser.add_argument(\"--component_num\", type=int, default=4, help=\"The number of pfe component\")\n", 44 | "parser.add_argument(\"--noiselevel\", type=int, default=20, help=\"Noise level 10 20 50\")\n", 45 | "parser.add_argument(\"--batchSize\", type=int, default=5, help=\"Batch size\")\n", 46 | "parser.add_argument(\"--sampleNum\", type=int, default=55, help=\"The number of LF in training set\")\n", 47 | "parser.add_argument(\"--patchSize\", type=int, default=32, help=\"The size of croped LF patch\")\n", 48 | "parser.add_argument(\"--num_cp\", type=int, default=1000, help=\"Number of epoches for saving checkpoint\")\n", 49 | "parser.add_argument(\"--angResolution\", type=int, default=7, help=\"The angular resolution of original LF\")\n", 50 | "parser.add_argument(\"--channelNum\", type=int, default=1, help=\"The channel number of input LF\")\n", 51 | "parser.add_argument(\"--epochNum\", type=int, default=10010, help=\"The number of epoches\")\n", 52 | "parser.add_argument(\"--overlap\", type=int, default=4, help=\"The size of croped LF patch\")\n", 53 | "parser.add_argument(\"--summaryPath\", type=str, default='./', help=\"Path for saving training log \")\n", 54 | "parser.add_argument(\"--dataName\", type=str, default='Synthetic', help=\"The name of dataset \")\n", 55 | "parser.add_argument(\"--preTrain\", type=str, default='./model/***pretrained model***', help=\"Path for loading pretrained model \")\n", 56 | "parser.add_argument(\"--dataPath\", type=str, default='/***dataroot***/train_synthetic_noiselevel_10_20_50.mat', help=\"Path for loading training data \")\n", 57 | "parser.add_argument(\"--resume_epoch\", type=int, default=0, help=\"resume from checkpoint epoch\")\n", 58 | "\n", 59 | "opt = parser.parse_known_args()[0]\n", 60 | "logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')\n", 61 | "log = logging.getLogger()\n", 62 | "fh = logging.FileHandler('Training_pfe_{}_{}_{}_{}_{}_{}.log'.format(opt.dataName, opt.noiselevel, opt.stageNum, opt.sasLayerNum, opt.epochNum, opt.learningRate))\n", 63 | "log.addHandler(fh)\n", 64 | "logging.info(opt)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "if __name__ == '__main__':\n", 74 | " SetupSeed(50)\n", 75 | " savePath = './model/lfdn_{}_{}_{}_{}_{}_{}_pfe-v1'.format(opt.dataName, opt.noiselevel, opt.stageNum, opt.sasLayerNum, opt.epochNum, opt.learningRate)\n", 76 | " if not os.path.exists(savePath):\n", 77 | " os.makedirs(savePath)\n", 78 | " \n", 79 | " lfDataset = LFDataset(opt)\n", 80 | " dataloader = DataLoader(lfDataset, batch_size=opt.batchSize,shuffle=True)\n", 81 | " print('loaded {} LFIs from {}'.format(len(dataloader), opt.dataPath))\n", 82 | " device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 83 | " torch.backends.cudnn.deterministic = True\n", 84 | " torch.backends.cudnn.benchmark = False\n", 85 | "\n", 86 | "\n", 87 | " model=MainNet(opt)\n", 88 | " model.load_state_dict(torch.load(opt.preTrain)['model'])\n", 89 | " model = model.cuda()\n", 90 | " total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 91 | " log.info(\"Training parameters: %d\" %total_trainable_params)\n", 92 | " criterion = torch.nn.L1Loss() # Loss \n", 93 | "\n", 94 | " optimizer = torch.optim.Adam(itertools.chain(model.parameters()), lr=opt.learningRate) #optimizer\n", 95 | " scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr = opt.learningRate,steps_per_epoch=len(dataloader),\n", 96 | " epochs=opt.epochNum,pct_start = 0.2, div_factor = 10, final_div_factor = 10)\n", 97 | "\n", 98 | " if opt.resume_epoch:\n", 99 | " resume_path = join(savePath,'pfe_model_epoch_{}.pth'.format(opt.resume_epoch))\n", 100 | " if os.path.isfile(resume_path):\n", 101 | " print(\"==>loading checkpoint 'epoch{}'\".format(resume_path))\n", 102 | " checkpoint = torch.load(resume_path)\n", 103 | " model.load_state_dict(checkpoint['model'])\n", 104 | " optimizer.load_state_dict(checkpoint['optimizer'])\n", 105 | " scheduler.load_state_dict(checkpoint['scheduler'])\n", 106 | " losslogger = checkpoint['losslogger']\n", 107 | " else:\n", 108 | " print(\"==> no model found at 'epoch{}'\".format(opt.resume_epoch)) \n", 109 | "\n", 110 | "\n", 111 | " lossLogger = defaultdict(list)\n", 112 | " for epoch in range(opt.epochNum):\n", 113 | " batch = 0\n", 114 | " lossSum = 0\n", 115 | " for _,sample in enumerate(dataloader):\n", 116 | " batch = batch +1\n", 117 | " lf=sample['lf']\n", 118 | " lf = lf.cuda()\n", 119 | " noiself=sample['noiself'].cuda()\n", 120 | " estimatedLF=model(noiself,epoch)\n", 121 | " loss = criterion(estimatedLF,lf)\n", 122 | " lossSum += loss.item()\n", 123 | "\n", 124 | " optimizer.zero_grad()\n", 125 | " loss.backward()\n", 126 | " optimizer.step()\n", 127 | " scheduler.step() #ONE\n", 128 | "\n", 129 | " if epoch % opt.num_cp == 0:\n", 130 | " model_save_path = join(savePath,\"pfe_model_epoch_{}.pth\".format(epoch))\n", 131 | " state = {'epoch':epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(),\n", 132 | " 'scheduler': scheduler.state_dict()}\n", 133 | " torch.save(state,model_save_path)\n", 134 | " print(\"checkpoint saved to {}\".format(model_save_path))\n", 135 | "\n", 136 | " log.info(\"Epoch: %d Loss: %.6f\" %(epoch,lossSum/len(dataloader)))\n", 137 | "\n", 138 | " #Record the training loss\n", 139 | " lossLogger['Epoch'].append(epoch)\n", 140 | " lossLogger['Loss'].append(lossSum/len(dataloader))\n", 141 | " lossLogger['Lr'].append(optimizer.state_dict()['param_groups'][0]['lr'])\n", 142 | " plt.figure()\n", 143 | " plt.title('Loss')\n", 144 | " plt.plot(lossLogger['Epoch'],lossLogger['Loss'])\n", 145 | " plt.savefig('Training_{}_{}_{}_{}_{}_{}_ver.jpg'.format(opt.dataName, opt.noiselevel, opt.stageNum,opt.sasLayerNum, opt.epochNum, opt.learningRate))\n", 146 | " plt.close()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python 3 (ipykernel)", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.8.11" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 4 178 | } 179 | -------------------------------------------------------------------------------- /LFDN/RefNet_pfe_pretrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import math 7 | 8 | class ChannelSELayer3D(nn.Module): 9 | def __init__(self, num_channels, reduction_ratio=2): 10 | """ 11 | :param num_channels: No of input channels 12 | :param reduction_ratio: By how much should the num_channels should be reduced 13 | """ 14 | super(ChannelSELayer3D, self).__init__() 15 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 16 | num_channels_reduced = num_channels // reduction_ratio 17 | self.reduction_ratio = reduction_ratio 18 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) 19 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) 20 | self.relu = nn.ReLU() 21 | self.sigmoid = nn.Sigmoid() 22 | 23 | def forward(self, input_tensor): 24 | """ 25 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W) 26 | :return: output tensor 27 | """ 28 | batch_size, num_channels, D, H, W = input_tensor.size() 29 | # Average along each channel 30 | squeeze_tensor = self.avg_pool(input_tensor) 31 | 32 | # channel excitation 33 | fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels))) 34 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 35 | 36 | output_tensor = torch.mul(input_tensor, fc_out_2.view(batch_size, num_channels, 1, 1, 1)) 37 | 38 | return output_tensor 39 | 40 | class ChannelSELayer(nn.Module): 41 | def __init__(self, num_channels, reduction_ratio=2): 42 | """ 43 | :param num_channels: No of input channels 44 | :param reduction_ratio: By how much should the num_channels should be reduced 45 | """ 46 | super(ChannelSELayer, self).__init__() 47 | num_channels_reduced = num_channels // reduction_ratio 48 | self.reduction_ratio = reduction_ratio 49 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) 50 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) 51 | self.relu = nn.ReLU() 52 | self.sigmoid = nn.Sigmoid() 53 | 54 | def forward(self, input_tensor): 55 | """ 56 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 57 | :return: output tensor 58 | """ 59 | batch_size, num_channels, H, W = input_tensor.size() 60 | # Average along each channel 61 | squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2) 62 | # channel excitation 63 | fc_out_1 = self.relu(self.fc1(squeeze_tensor)) 64 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 65 | 66 | a, b = squeeze_tensor.size() 67 | output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1)) 68 | return output_tensor 69 | 70 | 71 | class SpatialSELayer(nn.Module): 72 | def __init__(self, num_channels): 73 | """ 74 | :param num_channels: No of input channels 75 | """ 76 | super(SpatialSELayer, self).__init__() 77 | self.conv = nn.Conv2d(num_channels, 1, 1) 78 | self.sigmoid = nn.Sigmoid() 79 | 80 | def forward(self, input_tensor, weights=None): 81 | """ 82 | :param weights: weights for few shot learning 83 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 84 | :return: output_tensor 85 | """ 86 | # spatial squeeze 87 | batch_size, channel, a, b = input_tensor.size() 88 | 89 | if weights is not None: 90 | weights = torch.mean(weights, dim=0) 91 | weights = weights.view(1, channel, 1, 1) 92 | out = F.conv2d(input_tensor, weights) 93 | else: 94 | out = self.conv(input_tensor) 95 | squeeze_tensor = self.sigmoid(out) 96 | 97 | # spatial excitation 98 | # print(input_tensor.size(), squeeze_tensor.size()) 99 | squeeze_tensor = squeeze_tensor.view(batch_size, 1, a, b) 100 | output_tensor = torch.mul(input_tensor, squeeze_tensor) 101 | #output_tensor = torch.mul(input_tensor, squeeze_tensor) 102 | return output_tensor 103 | 104 | class Conv_spa(nn.Module): 105 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 106 | super(Conv_spa, self).__init__() 107 | self.op = nn.Sequential( 108 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 109 | nn.ReLU(inplace = True) 110 | ) 111 | def forward(self,x): 112 | N,c,uv,h,w = x.shape 113 | x = x.permute(0,2,1,3,4).reshape(N*uv,c,h,w) 114 | out = self.op(x) 115 | #print(out.shape) 116 | out = out.reshape(N,uv,32,h,w).permute(0,2,1,3,4) 117 | return out 118 | 119 | class Conv_ang(nn.Module): 120 | def __init__(self, C_in, C_out, kernel_size, stride, padding, angular, bias): 121 | super(Conv_ang, self).__init__() 122 | self.angular = angular 123 | self.op = nn.Sequential( 124 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 125 | nn.ReLU(inplace = True) 126 | ) 127 | def forward(self,x): 128 | N,c,uv,h,w = x.shape 129 | x = x.permute(0,3,4,1,2).reshape(N*h*w,c,self.angular,self.angular) 130 | out = self.op(x) 131 | out = out.reshape(N,h,w,32,uv).permute(0,3,4,1,2) 132 | return out 133 | 134 | class Conv_epi_h(nn.Module): 135 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 136 | super(Conv_epi_h, self).__init__() 137 | self.op = nn.Sequential( 138 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 139 | nn.ReLU(inplace = True) 140 | ) 141 | def forward(self,x): 142 | N,c,uv,h,w = x.shape 143 | x = x.permute(0,3,1,2,4).reshape(N*h,c,uv,w) 144 | out = self.op(x) 145 | out = out.reshape(N,h,32,uv,w).permute(0,2,3,1,4) 146 | return out 147 | 148 | class Conv_epi_v(nn.Module): 149 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 150 | super(Conv_epi_v, self).__init__() 151 | self.op = nn.Sequential( 152 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 153 | nn.ReLU(inplace = True) 154 | ) 155 | def forward(self,x): 156 | N,c,uv,h,w = x.shape 157 | x = x.permute(0,4,1,2,3).reshape(N*w,c,uv,h) 158 | out = self.op(x) 159 | out = out.reshape(N,w,32,uv,h).permute(0,2,3,4,1) 160 | return out 161 | 162 | 163 | class Autocovnlayer(nn.Module): 164 | def __init__(self,dence_num,component_num,angular,bs): 165 | super(Autocovnlayer, self).__init__() 166 | self.dence_num = dence_num 167 | self.component_num = component_num 168 | self.dence_weight = nn.Parameter(torch.rand(dence_num),requires_grad=True) 169 | self.component_weight = nn.Parameter(torch.rand(component_num),requires_grad=True) 170 | self.angular = angular 171 | self.kernel_size = 3 172 | 173 | self.naslayers = nn.ModuleList([ 174 | Conv_spa(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 175 | Conv_ang(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, angular = self.angular, bias = bs), 176 | Conv_epi_h(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 177 | Conv_epi_v(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs) 178 | ]) 179 | self.Conv_all = nn.Conv2d(in_channels = 64, out_channels=64, kernel_size=3, stride=1, padding=1, bias = bs) 180 | self.softmax1 = nn.Softmax(1) 181 | self.softmax0 = nn.Softmax(0) 182 | self.Conv_mixdence = nn.Conv2d(in_channels = 64*self.dence_num, out_channels=64, kernel_size=1, stride=1, padding=0, bias = False) 183 | self.Conv_mixnas = nn.Conv2d(in_channels = 32*4, out_channels=64, kernel_size=1, stride=1, padding=0, bias = False) ## 1*1 paddding!! 184 | self.relu = nn.ReLU(inplace=True) 185 | 186 | def forward(self,x,temperature_1,temperature_2): 187 | x = torch.stack(x,dim = 0) 188 | [fn, N, C, uv, h, w] = x.shape 189 | ## generate 2 noise dim of noise !!! 190 | dence_weight_soft = torch.ones((self.dence_num,N))[:,:,None,None,None,None].cuda() 191 | component_weight_gumbel = torch.ones((self.component_num,N))[:,:,None,None,None,None].cuda() 192 | 193 | x = x * dence_weight_soft 194 | x = x.permute([1,3,0,2,4,5]).reshape([N*uv,fn*C,h,w]) 195 | x = self.relu(self.Conv_mixdence(x)) 196 | x_mix = x.reshape([N,uv,C,h,w]).permute([0,2,1,3,4]) 197 | layer_label = 0 198 | nas = [] 199 | for layer in self.naslayers: 200 | nas_ = layer(x_mix) 201 | nas.append(nas_) 202 | 203 | nas = torch.stack(nas,dim = 0) 204 | nas = nas * component_weight_gumbel 205 | nas = nas.permute([1,3,0,2,4,5]).reshape([N*uv,self.component_num*32,h,w]) 206 | nas = self.relu(self.Conv_mixnas(nas)) 207 | ####### add a spa conv ####### 208 | nas = self.Conv_all(nas) 209 | nas = nas.reshape(N,uv,C,h,w).permute(0,2,1,3,4) 210 | nas = self.relu(nas + x_mix) 211 | return nas 212 | 213 | def make_autolayers(opt,bs): 214 | layers = [] 215 | for i in range( opt.sasLayerNum ): 216 | layers.append(Autocovnlayer(i+1, opt.component_num, opt.angResolution, bs)) 217 | return nn.Sequential(*layers) 218 | 219 | class RefNet(nn.Module): 220 | def __init__(self, opt, bs): 221 | super(RefNet, self).__init__() 222 | self.angResolution = opt.angResolution 223 | self.lfNum = opt.angResolution * opt.angResolution 224 | self.epochNum = opt.epochNum 225 | self.temperature_1 = opt.temperature_1 226 | self.temperature_2 = opt.temperature_2 227 | self.relu = nn.ReLU(inplace=True) 228 | 229 | self.conv0 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 230 | self.conv1 = nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 231 | self.dence_autolayers = make_autolayers(opt,bs) 232 | self.sptialSE = SpatialSELayer(32) 233 | self.channelSE = ChannelSELayer3D(32*2,2) 234 | self.syn_conv2 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias = bs) 235 | 236 | def forward(self,input,sampleLF,epoch): 237 | N,u,v,h,w = input.shape #[b,u,v,c,x,y] 238 | _,c,_,_ = sampleLF.shape 239 | if epoch <= 3800: # T 1 ==> 0.1 240 | temperature_1 = self.temperature_1 * (1 - epoch / 4000) 241 | temperature_2 = self.temperature_2 * (1 - epoch / 4000) 242 | else: 243 | temperature_1 = 0.05 244 | temperature_2 = 0.05 245 | # feature extraction sample 246 | feat1 = self.relu(self.conv1(sampleLF)) 247 | feat1 = self.sptialSE(feat1).unsqueeze(2) 248 | feat1 = feat1.expand(-1,-1,u*v,-1,-1) 249 | # feature extraction LF 250 | feat2 = input.reshape(N*u*v,1,h,w) 251 | feat2 = self.relu(self.conv0(feat2)) 252 | feat2 = feat2.reshape(N,u*v,32,h,w).permute(0,2,1,3,4) 253 | feat = torch.cat([feat2,feat1],1) 254 | feat = self.channelSE(feat) 255 | feat = [feat] 256 | for index, layer in enumerate(self.dence_autolayers): 257 | feat_ = layer(feat,temperature_1,temperature_2) 258 | feat.append(feat_) 259 | feat = self.syn_conv2(feat[-1].permute(0,2,1,3,4).reshape(N*self.lfNum,64,h,w)) 260 | out = feat.reshape(N,u,v,h,w) 261 | return out 262 | 263 | 264 | 265 | -------------------------------------------------------------------------------- /LFDN/RefNet_pfe_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import math 7 | 8 | class ChannelSELayer3D(nn.Module): 9 | def __init__(self, num_channels, reduction_ratio=2): 10 | """ 11 | :param num_channels: No of input channels 12 | :param reduction_ratio: By how much should the num_channels should be reduced 13 | """ 14 | super(ChannelSELayer3D, self).__init__() 15 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 16 | num_channels_reduced = num_channels // reduction_ratio 17 | self.reduction_ratio = reduction_ratio 18 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) 19 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) 20 | self.relu = nn.ReLU() 21 | self.sigmoid = nn.Sigmoid() 22 | 23 | def forward(self, input_tensor): 24 | """ 25 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W) 26 | :return: output tensor 27 | """ 28 | batch_size, num_channels, D, H, W = input_tensor.size() 29 | # Average along each channel 30 | squeeze_tensor = self.avg_pool(input_tensor) 31 | 32 | # channel excitation 33 | fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels))) 34 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 35 | 36 | output_tensor = torch.mul(input_tensor, fc_out_2.view(batch_size, num_channels, 1, 1, 1)) 37 | 38 | return output_tensor 39 | 40 | class ChannelSELayer(nn.Module): 41 | def __init__(self, num_channels, reduction_ratio=2): 42 | """ 43 | :param num_channels: No of input channels 44 | :param reduction_ratio: By how much should the num_channels should be reduced 45 | """ 46 | super(ChannelSELayer, self).__init__() 47 | num_channels_reduced = num_channels // reduction_ratio 48 | self.reduction_ratio = reduction_ratio 49 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) 50 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) 51 | self.relu = nn.ReLU() 52 | self.sigmoid = nn.Sigmoid() 53 | 54 | def forward(self, input_tensor): 55 | """ 56 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 57 | :return: output tensor 58 | """ 59 | batch_size, num_channels, H, W = input_tensor.size() 60 | # Average along each channel 61 | squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2) 62 | 63 | # channel excitation 64 | fc_out_1 = self.relu(self.fc1(squeeze_tensor)) 65 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 66 | 67 | a, b = squeeze_tensor.size() 68 | output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1)) 69 | return output_tensor 70 | 71 | 72 | class SpatialSELayer(nn.Module): 73 | def __init__(self, num_channels): 74 | """ 75 | :param num_channels: No of input channels 76 | """ 77 | super(SpatialSELayer, self).__init__() 78 | self.conv = nn.Conv2d(num_channels, 1, 1) 79 | self.sigmoid = nn.Sigmoid() 80 | 81 | def forward(self, input_tensor, weights=None): 82 | """ 83 | :param weights: weights for few shot learning 84 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 85 | :return: output_tensor 86 | """ 87 | # spatial squeeze 88 | batch_size, channel, a, b = input_tensor.size() 89 | 90 | if weights is not None: 91 | weights = torch.mean(weights, dim=0) 92 | weights = weights.view(1, channel, 1, 1) 93 | out = F.conv2d(input_tensor, weights) 94 | else: 95 | out = self.conv(input_tensor) 96 | squeeze_tensor = self.sigmoid(out) 97 | 98 | # spatial excitation 99 | # print(input_tensor.size(), squeeze_tensor.size()) 100 | squeeze_tensor = squeeze_tensor.view(batch_size, 1, a, b) 101 | output_tensor = torch.mul(input_tensor, squeeze_tensor) 102 | #output_tensor = torch.mul(input_tensor, squeeze_tensor) 103 | return output_tensor 104 | 105 | class Conv_spa(nn.Module): 106 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 107 | super(Conv_spa, self).__init__() 108 | self.op = nn.Sequential( 109 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 110 | nn.ReLU(inplace = True) 111 | ) 112 | def forward(self,x): 113 | N,c,uv,h,w = x.shape 114 | x = x.permute(0,2,1,3,4).reshape(N*uv,c,h,w) 115 | out = self.op(x) 116 | #print(out.shape) 117 | out = out.reshape(N,uv,32,h,w).permute(0,2,1,3,4) 118 | return out 119 | 120 | class Conv_ang(nn.Module): 121 | def __init__(self, C_in, C_out, kernel_size, stride, padding, angular, bias): 122 | super(Conv_ang, self).__init__() 123 | self.angular = angular 124 | self.op = nn.Sequential( 125 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 126 | nn.ReLU(inplace = True) 127 | ) 128 | def forward(self,x): 129 | N,c,uv,h,w = x.shape 130 | x = x.permute(0,3,4,1,2).reshape(N*h*w,c,self.angular,self.angular) 131 | out = self.op(x) 132 | out = out.reshape(N,h,w,32,uv).permute(0,3,4,1,2) 133 | return out 134 | 135 | class Conv_epi_h(nn.Module): 136 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 137 | super(Conv_epi_h, self).__init__() 138 | self.op = nn.Sequential( 139 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 140 | nn.ReLU(inplace = True) 141 | ) 142 | def forward(self,x): 143 | N,c,uv,h,w = x.shape 144 | x = x.permute(0,3,1,2,4).reshape(N*h,c,uv,w) 145 | out = self.op(x) 146 | out = out.reshape(N,h,32,uv,w).permute(0,2,3,1,4) 147 | return out 148 | 149 | class Conv_epi_v(nn.Module): 150 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 151 | super(Conv_epi_v, self).__init__() 152 | self.op = nn.Sequential( 153 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 154 | nn.ReLU(inplace = True) 155 | ) 156 | def forward(self,x): 157 | N,c,uv,h,w = x.shape 158 | x = x.permute(0,4,1,2,3).reshape(N*w,c,uv,h) 159 | out = self.op(x) 160 | out = out.reshape(N,w,32,uv,h).permute(0,2,3,4,1) 161 | return out 162 | 163 | 164 | class Autocovnlayer(nn.Module): 165 | def __init__(self,dence_num,component_num,angular,bs): 166 | super(Autocovnlayer, self).__init__() 167 | self.dence_num = dence_num 168 | self.component_num = component_num 169 | self.dence_weight = nn.Parameter(torch.rand(dence_num),requires_grad=True) 170 | self.component_weight = nn.Parameter(torch.rand(component_num),requires_grad=True) 171 | self.angular = angular 172 | self.kernel_size = 3 173 | 174 | self.naslayers = nn.ModuleList([ 175 | Conv_spa(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 176 | Conv_ang(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, angular = self.angular, bias = bs), 177 | Conv_epi_h(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 178 | Conv_epi_v(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs) 179 | ]) 180 | self.Conv_all = nn.Conv2d(in_channels = 64, out_channels=64, kernel_size=3, stride=1, padding=1, bias = bs) 181 | self.softmax1 = nn.Softmax(1) 182 | self.softmax0 = nn.Softmax(0) 183 | self.Conv_mixdence = nn.Conv2d(in_channels = 64*self.dence_num, out_channels=64, kernel_size=1, stride=1, padding=0, bias = False) 184 | self.Conv_mixnas = nn.Conv2d(in_channels = 32*4, out_channels=64, kernel_size=1, stride=1, padding=0, bias = False) ## 1*1 paddding!! 185 | self.relu = nn.ReLU(inplace=True) 186 | 187 | def forward(self,x,temperature_1,temperature_2): 188 | x = torch.stack(x,dim = 0) 189 | [fn, N, C, uv, h, w] = x.shape 190 | 191 | dence_weight = self.dence_weight.clamp(0.02,0.98) 192 | dence_weight_soft = torch.zeros(dence_weight.shape) 193 | dence_weight_soft[dence_weight > 0.1] = 1 194 | dence_weight_soft = dence_weight_soft[:,None,None,None,None,None].cuda() 195 | 196 | component_weight = self.component_weight.clamp(0.02,0.98) 197 | component_weight_gumbel = torch.zeros(component_weight.shape) 198 | component_weight_gumbel[component_weight > 0.1] = 1 199 | component_weight_gumbel = component_weight_gumbel[:,None,None,None,None,None].cuda() 200 | 201 | x = x * dence_weight_soft 202 | x = x.permute([1,3,0,2,4,5]).reshape([N*uv,fn*C,h,w]) 203 | x = self.relu(self.Conv_mixdence(x)) 204 | x_mix = x.reshape([N,uv,C,h,w]).permute([0,2,1,3,4]) 205 | layer_label = 0 206 | nas = [] 207 | for layer in self.naslayers: 208 | nas_ = layer(x_mix) 209 | nas.append(nas_) 210 | nas = torch.stack(nas,dim = 0) 211 | nas = nas * component_weight_gumbel 212 | nas = nas.permute([1,3,0,2,4,5]).reshape([N*uv,self.component_num*32,h,w]) 213 | nas = self.relu(self.Conv_mixnas(nas)) 214 | ####### add a spa conv ######## 215 | nas = self.Conv_all(nas) 216 | nas = nas.reshape(N,uv,C,h,w).permute(0,2,1,3,4) 217 | nas = self.relu(nas + x_mix) 218 | return nas 219 | 220 | def make_autolayers(opt,bs): 221 | layers = [] 222 | for i in range( opt.sasLayerNum ): 223 | layers.append(Autocovnlayer(i+1, opt.component_num, opt.angResolution, bs)) 224 | return nn.Sequential(*layers) 225 | 226 | class RefNet(nn.Module): 227 | def __init__(self, opt, bs): 228 | super(RefNet, self).__init__() 229 | self.angResolution = opt.angResolution 230 | self.lfNum = opt.angResolution * opt.angResolution 231 | self.epochNum = opt.epochNum 232 | self.temperature_1 = opt.temperature_1 233 | self.temperature_2 = opt.temperature_2 234 | self.relu = nn.ReLU(inplace=True) 235 | self.conv0 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 236 | self.conv1 = nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 237 | 238 | self.dence_autolayers = make_autolayers(opt,bs) 239 | self.sptialSE = SpatialSELayer(32) 240 | self.channelSE = ChannelSELayer3D(32*2,2) 241 | self.syn_conv2 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias = bs) 242 | 243 | def forward(self,input,sampleLF,epoch): 244 | N,u,v,h,w = input.shape 245 | _,c,_,_ = sampleLF.shape 246 | if epoch <= 3800: 247 | temperature_1 = self.temperature_1 * (1 - epoch / 4000) 248 | temperature_2 = self.temperature_2 * (1 - epoch / 4000) 249 | else: 250 | temperature_1 = 0.05 251 | temperature_2 = 0.05 252 | # feature extraction sample 253 | feat1 = self.relu(self.conv1(sampleLF)) 254 | feat1 = self.sptialSE(feat1).unsqueeze(2) 255 | feat1 = feat1.expand(-1,-1,u*v,-1,-1) 256 | # feature extraction LF 257 | feat2 = input.reshape(N*u*v,1,h,w) 258 | feat2 = self.relu(self.conv0(feat2)) 259 | feat2 = feat2.reshape(N,u*v,32,h,w).permute(0,2,1,3,4) 260 | feat = torch.cat([feat2,feat1],1) 261 | feat = self.channelSE(feat) 262 | feat = [feat] 263 | for index, layer in enumerate(self.dence_autolayers): 264 | feat_ = layer(feat,temperature_1,temperature_2) 265 | feat.append(feat_) 266 | feat = self.syn_conv2(feat[-1].permute(0,2,1,3,4).reshape(N*self.lfNum,64,h,w)) 267 | out = feat.reshape(N,u,v,h,w) 268 | return out 269 | 270 | 271 | 272 | 273 | -------------------------------------------------------------------------------- /LFDN/RefNet_pfe_ver0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import math 7 | 8 | class ChannelSELayer3D(nn.Module): 9 | def __init__(self, num_channels, reduction_ratio=2): 10 | """ 11 | :param num_channels: No of input channels 12 | :param reduction_ratio: By how much should the num_channels should be reduced 13 | """ 14 | super(ChannelSELayer3D, self).__init__() 15 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 16 | num_channels_reduced = num_channels // reduction_ratio 17 | self.reduction_ratio = reduction_ratio 18 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) 19 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) 20 | self.relu = nn.ReLU() 21 | self.sigmoid = nn.Sigmoid() 22 | 23 | def forward(self, input_tensor): 24 | """ 25 | :param input_tensor: X, shape = (batch_size, num_channels, D, H, W) 26 | :return: output tensor 27 | """ 28 | batch_size, num_channels, D, H, W = input_tensor.size() 29 | # Average along each channel 30 | squeeze_tensor = self.avg_pool(input_tensor) 31 | 32 | # channel excitation 33 | fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels))) 34 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 35 | 36 | output_tensor = torch.mul(input_tensor, fc_out_2.view(batch_size, num_channels, 1, 1, 1)) 37 | 38 | return output_tensor 39 | 40 | class ChannelSELayer(nn.Module): 41 | def __init__(self, num_channels, reduction_ratio=2): 42 | """ 43 | :param num_channels: No of input channels 44 | :param reduction_ratio: By how much should the num_channels should be reduced 45 | """ 46 | super(ChannelSELayer, self).__init__() 47 | num_channels_reduced = num_channels // reduction_ratio 48 | self.reduction_ratio = reduction_ratio 49 | self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) 50 | self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) 51 | self.relu = nn.ReLU() 52 | self.sigmoid = nn.Sigmoid() 53 | 54 | def forward(self, input_tensor): 55 | """ 56 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 57 | :return: output tensor 58 | """ 59 | batch_size, num_channels, H, W = input_tensor.size() 60 | # Average along each channel 61 | squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2) 62 | 63 | # channel excitation 64 | fc_out_1 = self.relu(self.fc1(squeeze_tensor)) 65 | fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) 66 | 67 | a, b = squeeze_tensor.size() 68 | output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1)) 69 | return output_tensor 70 | 71 | 72 | class SpatialSELayer(nn.Module): 73 | def __init__(self, num_channels): 74 | """ 75 | :param num_channels: No of input channels 76 | """ 77 | super(SpatialSELayer, self).__init__() 78 | self.conv = nn.Conv2d(num_channels, 1, 1) 79 | self.sigmoid = nn.Sigmoid() 80 | 81 | def forward(self, input_tensor, weights=None): 82 | """ 83 | :param weights: weights for few shot learning 84 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 85 | :return: output_tensor 86 | """ 87 | # spatial squeeze 88 | batch_size, channel, a, b = input_tensor.size() 89 | 90 | if weights is not None: 91 | weights = torch.mean(weights, dim=0) 92 | weights = weights.view(1, channel, 1, 1) 93 | out = F.conv2d(input_tensor, weights) 94 | else: 95 | out = self.conv(input_tensor) 96 | squeeze_tensor = self.sigmoid(out) 97 | 98 | # spatial excitation 99 | # print(input_tensor.size(), squeeze_tensor.size()) 100 | squeeze_tensor = squeeze_tensor.view(batch_size, 1, a, b) 101 | output_tensor = torch.mul(input_tensor, squeeze_tensor) 102 | #output_tensor = torch.mul(input_tensor, squeeze_tensor) 103 | return output_tensor 104 | 105 | class Conv_spa(nn.Module): 106 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 107 | super(Conv_spa, self).__init__() 108 | self.op = nn.Sequential( 109 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 110 | nn.ReLU(inplace = True) 111 | ) 112 | def forward(self,x): 113 | N,c,uv,h,w = x.shape 114 | x = x.permute(0,2,1,3,4).reshape(N*uv,c,h,w) 115 | out = self.op(x) 116 | #print(out.shape) 117 | out = out.reshape(N,uv,32,h,w).permute(0,2,1,3,4) 118 | return out 119 | 120 | class Conv_ang(nn.Module): 121 | def __init__(self, C_in, C_out, kernel_size, stride, padding, angular, bias): 122 | super(Conv_ang, self).__init__() 123 | self.angular = angular 124 | self.op = nn.Sequential( 125 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 126 | nn.ReLU(inplace = True) 127 | ) 128 | def forward(self,x): 129 | N,c,uv,h,w = x.shape 130 | x = x.permute(0,3,4,1,2).reshape(N*h*w,c,self.angular,self.angular) 131 | out = self.op(x) 132 | out = out.reshape(N,h,w,32,uv).permute(0,3,4,1,2) 133 | return out 134 | 135 | class Conv_epi_h(nn.Module): 136 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 137 | super(Conv_epi_h, self).__init__() 138 | self.op = nn.Sequential( 139 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 140 | nn.ReLU(inplace = True) 141 | ) 142 | def forward(self,x): 143 | N,c,uv,h,w = x.shape 144 | x = x.permute(0,3,1,2,4).reshape(N*h,c,uv,w) 145 | out = self.op(x) 146 | out = out.reshape(N,h,32,uv,w).permute(0,2,3,1,4) 147 | return out 148 | 149 | class Conv_epi_v(nn.Module): 150 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bias): 151 | super(Conv_epi_v, self).__init__() 152 | self.op = nn.Sequential( 153 | nn.Conv2d(C_in, C_out, kernel_size, stride = stride, padding = padding, bias = bias), 154 | nn.ReLU(inplace = True) 155 | ) 156 | def forward(self,x): 157 | N,c,uv,h,w = x.shape 158 | x = x.permute(0,4,1,2,3).reshape(N*w,c,uv,h) 159 | out = self.op(x) 160 | out = out.reshape(N,w,32,uv,h).permute(0,2,3,4,1) 161 | return out 162 | 163 | 164 | class Autocovnlayer(nn.Module): 165 | def __init__(self,dence_num,component_num,angular,bs): 166 | super(Autocovnlayer, self).__init__() 167 | self.dence_num = dence_num 168 | self.component_num = component_num 169 | self.dence_weight = nn.Parameter(torch.rand(dence_num),requires_grad=True) 170 | self.component_weight = nn.Parameter(torch.rand(component_num),requires_grad=True) 171 | self.angular = angular 172 | self.kernel_size = 3 173 | 174 | self.naslayers = nn.ModuleList([ 175 | Conv_spa(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 176 | Conv_ang(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, angular = self.angular, bias = bs), 177 | Conv_epi_h(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs), 178 | Conv_epi_v(C_in = 64, C_out = 32, kernel_size = self.kernel_size, stride = 1, padding = 1, bias = bs) 179 | ]) 180 | self.Conv_all = nn.Conv2d(in_channels = 64, out_channels=64, kernel_size=3, stride=1, padding=1, bias = bs) 181 | self.softmax1 = nn.Softmax(1) 182 | self.softmax0 = nn.Softmax(0) 183 | self.Conv_mixdence = nn.Conv2d(in_channels = 64*self.dence_num, out_channels=64, kernel_size=1, stride=1, padding=0, bias = False) 184 | self.Conv_mixnas = nn.Conv2d(in_channels = 32*4, out_channels=64, kernel_size=1, stride=1, padding=0, bias = False) ## 1*1 paddding!! 185 | self.relu = nn.ReLU(inplace=True) 186 | 187 | def forward(self,x,temperature_1,temperature_2): 188 | x = torch.stack(x,dim = 0) 189 | [fn, N, C, uv, h, w] = x.shape 190 | ## generate 2 noise dim of noise !!! 191 | dence_weight = self.dence_weight.clamp(0.02,0.98) 192 | dence_weight = dence_weight[:,None,None,None,None,None] 193 | component_weight = self.component_weight.clamp(0.02,0.98) 194 | component_weight = component_weight[:,None,None,None,None,None] 195 | 196 | noise_dence_r1 = torch.rand((self.dence_num,N))[:,:,None,None,None,None].cuda() 197 | noise_dence_r2 = torch.rand((self.dence_num,N))[:,:,None,None,None,None].cuda() 198 | noise_dence_logits = torch.log(torch.log(noise_dence_r1) / torch.log(noise_dence_r2)) 199 | dence_weight_soft = torch.sigmoid((torch.log(dence_weight / (1 - dence_weight)) + noise_dence_logits) / temperature_1) 200 | 201 | noise_component_r1 = torch.rand((self.component_num,N))[:,:,None,None,None,None].cuda() 202 | noise_component_r2 = torch.rand((self.component_num,N))[:,:,None,None,None,None].cuda() 203 | noise_component_logits = torch.log(torch.log(noise_component_r1) / torch.log(noise_component_r2)) 204 | component_weight_gumbel = torch.sigmoid((torch.log(component_weight / (1 - component_weight)) + noise_component_logits) / temperature_2) 205 | 206 | x = x * dence_weight_soft 207 | x = x.permute([1,3,0,2,4,5]).reshape([N*uv,fn*C,h,w]) 208 | x = self.relu(self.Conv_mixdence(x)) 209 | x_mix = x.reshape([N,uv,C,h,w]).permute([0,2,1,3,4]) 210 | layer_label = 0 211 | nas = [] 212 | for layer in self.naslayers: 213 | nas_ = layer(x_mix) 214 | nas.append(nas_) 215 | 216 | nas = torch.stack(nas,dim = 0) 217 | nas = nas * component_weight_gumbel 218 | nas = nas.permute([1,3,0,2,4,5]).reshape([N*uv,self.component_num*32,h,w]) 219 | nas = self.relu(self.Conv_mixnas(nas)) 220 | ####### add a spa conv ####### 221 | nas = self.Conv_all(nas) 222 | nas = nas.reshape(N,uv,C,h,w).permute(0,2,1,3,4) 223 | nas = self.relu(nas + x_mix) 224 | return nas 225 | 226 | def make_autolayers(opt,bs): 227 | layers = [] 228 | for i in range( opt.sasLayerNum ): 229 | layers.append(Autocovnlayer(i+1, opt.component_num, opt.angResolution, bs)) 230 | return nn.Sequential(*layers) 231 | 232 | class RefNet(nn.Module): 233 | def __init__(self, opt, bs): 234 | super(RefNet, self).__init__() 235 | self.angResolution = opt.angResolution 236 | self.lfNum = opt.angResolution * opt.angResolution 237 | self.epochNum = opt.epochNum 238 | self.temperature_1 = opt.temperature_1 239 | self.temperature_2 = opt.temperature_2 240 | 241 | self.relu = nn.ReLU(inplace=True) 242 | self.conv0 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 243 | self.conv1 = nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, stride=1, padding=1, bias = bs) 244 | self.dence_autolayers = make_autolayers(opt,bs) 245 | self.sptialSE = SpatialSELayer(32) 246 | self.channelSE = ChannelSELayer3D(32*2,2) 247 | self.syn_conv2 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias = bs) 248 | 249 | def forward(self,input,sampleLF,epoch): 250 | N,u,v,h,w = input.shape 251 | _,c,_,_ = sampleLF.shape 252 | if epoch <= 3800: 253 | temperature_1 = self.temperature_1 * (1 - epoch / 4000) 254 | temperature_2 = self.temperature_2 * (1 - epoch / 4000) 255 | else: 256 | temperature_1 = 0.05 257 | temperature_2 = 0.05 258 | # feature extraction sample 259 | feat1 = self.relu(self.conv1(sampleLF)) 260 | feat1 = self.sptialSE(feat1).unsqueeze(2) 261 | feat1 = feat1.expand(-1,-1,u*v,-1,-1) 262 | # feature extraction LF 263 | feat2 = input.reshape(N*u*v,1,h,w) 264 | feat2 = self.relu(self.conv0(feat2)) 265 | feat2 = feat2.reshape(N,u*v,32,h,w).permute(0,2,1,3,4) 266 | feat = torch.cat([feat2,feat1],1) 267 | feat = self.channelSE(feat) 268 | feat = [feat] 269 | for index, layer in enumerate(self.dence_autolayers): 270 | feat_ = layer(feat,temperature_1,temperature_2) 271 | feat.append(feat_) 272 | feat = self.syn_conv2(feat[-1].permute(0,2,1,3,4).reshape(N*self.lfNum,64,h,w)) 273 | out = feat.reshape(N,u,v,h,w) 274 | return out 275 | 276 | 277 | 278 | 279 | --------------------------------------------------------------------------------