├── 2D_random_points.py ├── 2D_separable_points.py ├── 3D_random_points_GD.py ├── 3D_separable_points_GD.py ├── LICENSE ├── MLPs.py ├── README.md ├── complex_encoding.ipynb ├── data.py ├── imgs └── simple_complex_encoding.png ├── trainer.py └── utils.py /2D_random_points.py: -------------------------------------------------------------------------------- 1 | import torch, os, logging 2 | from argparse import ArgumentParser 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | 6 | from trainer import * 7 | from utils import * 8 | from MLPs import * 9 | 10 | 11 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 12 | 13 | def get_logger(filename, verbosity=1, name=None): 14 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 15 | formatter = logging.Formatter( 16 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" 17 | ) 18 | logger = logging.getLogger(name) 19 | logger.setLevel(level_dict[verbosity]) 20 | 21 | fh = logging.FileHandler(filename, "w") 22 | fh.setFormatter(formatter) 23 | logger.addHandler(fh) 24 | 25 | sh = logging.StreamHandler() 26 | sh.setFormatter(formatter) 27 | logger.addHandler(sh) 28 | 29 | return logger 30 | 31 | 32 | def main(args): 33 | 34 | if os.path.exists(args.save_path): 35 | print('Path already exists!') 36 | return 1 37 | os.mkdir(args.save_path) 38 | logger = get_logger(args.save_path+args.logger) 39 | logger.info(args) 40 | 41 | # Set the CUDA flag 42 | device = "cuda" if torch.cuda.is_available() else "cpu" 43 | logger.info('device is: {}'.format(device)) 44 | 45 | 46 | buf = np.load(args.data_path) 47 | 48 | 49 | #mask = np.load(args.mask_path) 50 | # ########### generate random mask ########### 51 | image_size = buf['test_data'].shape[1] 52 | ratio = 0.25 53 | mask = [] 54 | for i in range(buf['test_data'].shape[0]): 55 | mask_tmp = [] 56 | for j in range(args.N_repeat): 57 | idx = torch.randperm(image_size**2)[:int(ratio*image_size**2)] 58 | mask_np_N2 = np.zeros((image_size**2)) 59 | mask_np_N2[idx] = 1 60 | mask_np_N2 = mask_np_N2==1 61 | mask_tmp.append(mask_np_N2) 62 | mask.append(np.stack(mask_tmp,0)) 63 | mask = np.stack(mask,0) 64 | #np.save('mask_2d_{}_{}_{}_{}.npy'.format(ratio,buf.shape[0],args.N_repeat,buf.shape[1]),mask) 65 | # ########################################### 66 | 67 | 68 | signals = torch.from_numpy(buf['test_data']/ 255.) 69 | 70 | logger.info('################ Simple Encoding ################') 71 | 72 | ez = 512 73 | rff_params = [8] 74 | linearf_params = [4] 75 | logf_params = [4.5] 76 | gaussian4_params = [0.006] 77 | linear4_params = [2.5/128] 78 | 79 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian4_params)+['Tri']*len(linear4_params) 80 | params = rff_params+linearf_params+logf_params+gaussian4_params+linear4_params 81 | 82 | print(encoding_methods) 83 | print(params) 84 | 85 | for depth in [4,1,0]: 86 | for lr in [5e-3]: 87 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr)) 88 | 89 | for em,param in zip(encoding_methods,params): 90 | 91 | ef = encoding_func_2D(em,[param,ez]) 92 | time_,trn_psnr_,tst_psnr_,rec_ = train_random_simple_2D(signals,ef,mask=mask,N_repeat=args.N_repeat,lr=lr,epochs=2000,depth=depth,device=device,logger=None) 93 | file_name = 'RD{}{}'.format(depth,em) 94 | if args.save_flag: 95 | # np.save(args.save_path+file_name+'_rec.npy',rec_) 96 | # np.save(args.save_path+file_name+'_time.npy',time_) 97 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_) 98 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_) 99 | for i in range(signals.shape[0]): 100 | plt.imshow(rec_[i]) 101 | plt.axis('off') 102 | plt.savefig(args.save_path+'I{}'.format(i)+file_name+'{}.pdf'.format(tst_psnr_[i,-1]), bbox_inches='tight', pad_inches=0) 103 | 104 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:]))) 105 | 106 | logger.info('################ Complex Encoding ################') 107 | 108 | ez = 256 109 | rff_params = [30] 110 | linearf_params = [4.5] 111 | logf_params = [6] 112 | gaussian_params = [0.005] 113 | linear_params = [2.5/256] 114 | 115 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params) 116 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params 117 | 118 | for depth in [0,1]: 119 | for lr in [1e-1]: 120 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr)) 121 | for em,param in zip(encoding_methods,params): 122 | 123 | ef = encoding_func_1D(em,[param,ez]) 124 | bl = blending_func_2D(ef) 125 | time_,trn_psnr_,tst_psnr_,rec_ = train_index_blend_kron_2D(signals,bl,ef,mask=mask,N_repeat=args.N_repeat,lr=lr,epochs=2000,depth=depth,device=device,logger=None) 126 | file_name = 'RKD{}{}'.format(depth,em) 127 | if args.save_flag: 128 | # np.save(args.save_path+file_name+'_rec.npy',rec_) 129 | # np.save(args.save_path+file_name+'_time.npy',time_) 130 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_) 131 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_) 132 | for i in range(signals.shape[0]): 133 | plt.imshow(rec_[i]) 134 | plt.axis('off') 135 | plt.savefig(args.save_path+'I{}'.format(i)+file_name+'{}.pdf'.format(tst_psnr_[i,-1]), bbox_inches='tight', pad_inches=0) 136 | 137 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:]))) 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | torch.set_default_dtype(torch.float32) 143 | torch.manual_seed(20220222) 144 | np.random.seed(20220222) 145 | 146 | 147 | parser = ArgumentParser() 148 | 149 | parser.add_argument("--data_path", type=str, default="data_div2k.npz") 150 | parser.add_argument("--mask_path", type=str, default="mask_2d_0.25_16_1_512.npy") 151 | parser.add_argument("--N_repeat", type=int, default=1) 152 | parser.add_argument("--save_path", type=str, default="2D_random_points/") 153 | parser.add_argument("--logger", type=str, default="log.log") 154 | parser.add_argument("--save_flag", type=int, default=0, choices=[0, 1]) 155 | 156 | 157 | args = parser.parse_args() 158 | 159 | main(args) -------------------------------------------------------------------------------- /2D_separable_points.py: -------------------------------------------------------------------------------- 1 | import torch, os, logging 2 | from argparse import ArgumentParser 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | 6 | from trainer import * 7 | from utils import * 8 | from MLPs import * 9 | 10 | 11 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 12 | 13 | def get_logger(filename, verbosity=1, name=None): 14 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 15 | formatter = logging.Formatter( 16 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" 17 | ) 18 | logger = logging.getLogger(name) 19 | logger.setLevel(level_dict[verbosity]) 20 | 21 | fh = logging.FileHandler(filename, "w") 22 | fh.setFormatter(formatter) 23 | logger.addHandler(fh) 24 | 25 | sh = logging.StreamHandler() 26 | sh.setFormatter(formatter) 27 | logger.addHandler(sh) 28 | 29 | return logger 30 | 31 | 32 | def main(args): 33 | 34 | if os.path.exists(args.save_path): 35 | print('Path already exists!') 36 | return 1 37 | os.mkdir(args.save_path) 38 | logger = get_logger(args.save_path+args.logger) 39 | logger.info(args) 40 | 41 | # Set the CUDA flag 42 | device = "cuda" if torch.cuda.is_available() else "cpu" 43 | logger.info('device is: {}'.format(device)) 44 | 45 | 46 | buf = np.load(args.data_path) 47 | 48 | signals = torch.from_numpy(buf['test_data']/ 255.) 49 | 50 | logger.info('################ Simple Encoding ################') 51 | 52 | ez = 512 53 | rff_params = [14] 54 | linearf_params = [5] 55 | logf_params = [5.5] 56 | gaussian4_params = [0.006] 57 | linear4_params = [2.5/128] 58 | 59 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian4_params)+['Tri']*len(linear4_params) 60 | params = rff_params+linearf_params+logf_params+gaussian4_params+linear4_params 61 | 62 | print(encoding_methods) 63 | print(params) 64 | 65 | for depth in [4,1,0]: 66 | for lr in [5e-3]: 67 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr)) 68 | 69 | for em,param in zip(encoding_methods,params): 70 | 71 | ef = encoding_func_2D(em,[param,ez]) 72 | time_,trn_psnr_,tst_psnr_,rec_ = train_simple_2D(signals,ef,N_repeat=args.N_repeat,lr=lr,epochs=2000,depth=depth,device=device,logger=None) 73 | file_name = 'D{}{}'.format(depth,em) 74 | if args.save_flag: 75 | # np.save(args.save_path+file_name+'_rec.npy',rec_) 76 | # np.save(args.save_path+file_name+'_time.npy',time_) 77 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_) 78 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_) 79 | for i in range(signals.shape[0]): 80 | plt.imshow(rec_[i]) 81 | plt.axis('off') 82 | plt.savefig(args.save_path+'I{}'.format(i)+file_name+'{}.pdf'.format(tst_psnr_[i,-1]), bbox_inches='tight', pad_inches=0) 83 | 84 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:]))) 85 | 86 | logger.info('################ Complex Encoding ################') 87 | 88 | ez = 256 89 | rff_params = [42] 90 | linearf_params = [4.5] 91 | logf_params = [6.5] 92 | gaussian_params = [0.004] 93 | linear_params = [2/256] 94 | 95 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params) 96 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params 97 | 98 | for depth in [0,1]: 99 | for lr in [1e-1]: 100 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr)) 101 | 102 | for em,param in zip(encoding_methods,params): 103 | 104 | ef = encoding_func_1D(em,[param,ez]) 105 | time_,trn_psnr_,tst_psnr_,rec_ = train_kron_2D(signals,ef,N_repeat=args.N_repeat,lr=lr,epochs=2000,depth=depth,device=device,logger=None) 106 | file_name = 'KD{}{}'.format(depth,em) 107 | if args.save_flag: 108 | # np.save(args.save_path+file_name+'_rec.npy',rec_) 109 | # np.save(args.save_path+file_name+'_time.npy',time_) 110 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_) 111 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_) 112 | for i in range(signals.shape[0]): 113 | plt.imshow(rec_[i]) 114 | plt.axis('off') 115 | plt.savefig(args.save_path+'I{}'.format(i)+file_name+'{}.pdf'.format(tst_psnr_[i,-1]), bbox_inches='tight', pad_inches=0) 116 | 117 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:]))) 118 | 119 | 120 | logger.info('################ Closed Form Complex Encoding ################') 121 | 122 | 123 | ez = 256 124 | rff_params = [42] 125 | linearf_params = [6] 126 | logf_params = [6.5] 127 | gaussian_params = [0.0025] 128 | linear_params = [1/256] 129 | 130 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params) 131 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params 132 | 133 | 134 | for em,param in zip(encoding_methods,params): 135 | 136 | if em == 'Gau' or em == 'Tri': em=em 137 | else: device = 'cpu' 138 | 139 | ef = encoding_func_1D(em,[param,ez]) 140 | time_,trn_psnr_,tst_psnr_,rec_ = train_closed_form_2D(signals,ef,N_repeat=args.N_repeat,device=device,logger=None) 141 | file_name = 'CKD{}{}'.format(depth,em) 142 | if args.save_flag: 143 | # np.save(args.save_path+file_name+'_rec.npy',rec_) 144 | # np.save(args.save_path+file_name+'_time.npy',time_) 145 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_) 146 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_) 147 | for i in range(signals.shape[0]): 148 | plt.imshow(rec_[i]) 149 | plt.axis('off') 150 | plt.savefig(args.save_path+'I{}'.format(i)+file_name+'{}.pdf'.format(tst_psnr_[i,-1]), bbox_inches='tight', pad_inches=0) 151 | 152 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:]))) 153 | 154 | 155 | 156 | if __name__ == "__main__": 157 | torch.set_default_dtype(torch.float32) 158 | torch.manual_seed(20220222) 159 | np.random.seed(20220222) 160 | 161 | 162 | parser = ArgumentParser() 163 | 164 | parser.add_argument("--data_path", type=str, default="data_div2k.npz") 165 | parser.add_argument("--N_repeat", type=int, default=1) 166 | parser.add_argument("--save_path", type=str, default="2D_separable_points/") 167 | parser.add_argument("--logger", type=str, default="log.log") 168 | parser.add_argument("--save_flag", type=int, default=0, choices=[0, 1]) 169 | 170 | 171 | args = parser.parse_args() 172 | main(args) -------------------------------------------------------------------------------- /3D_random_points_GD.py: -------------------------------------------------------------------------------- 1 | import torch, os, logging 2 | from argparse import ArgumentParser 3 | import numpy as np 4 | 5 | from trainer import * 6 | from utils import * 7 | from MLPs import * 8 | 9 | 10 | import imageio 11 | 12 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 13 | 14 | def get_logger(filename, verbosity=1, name=None): 15 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 16 | formatter = logging.Formatter( 17 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" 18 | ) 19 | logger = logging.getLogger(name) 20 | logger.setLevel(level_dict[verbosity]) 21 | 22 | fh = logging.FileHandler(filename, "w") 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | sh = logging.StreamHandler() 27 | sh.setFormatter(formatter) 28 | logger.addHandler(sh) 29 | 30 | return logger 31 | 32 | 33 | def main(args): 34 | 35 | if os.path.exists(args.save_path): 36 | print('Path already exists!') 37 | return 1 38 | os.mkdir(args.save_path) 39 | logger = get_logger(args.save_path+args.logger) 40 | logger.info(args) 41 | 42 | # Set the CUDA flag 43 | device = "cuda" if torch.cuda.is_available() else "cpu" 44 | logger.info('device is: {}'.format(device)) 45 | 46 | 47 | buf = np.load(args.data_path) 48 | 49 | signals = torch.from_numpy(buf['train']) 50 | 51 | #mask = np.load(args.mask_path) 52 | ########### generate random mask ########### 53 | image_size = signals.shape[1] 54 | ratio = 0.125 55 | mask = [] 56 | for i in range(signals.shape[0]): 57 | mask_tmp = [] 58 | for j in range(args.N_repeat): 59 | idx = torch.randperm(image_size**3)[:int(ratio*image_size**3)] 60 | mask_np_N2 = np.zeros((image_size**3)) 61 | mask_np_N2[idx] = 1 62 | mask_np_N2 = mask_np_N2==1 63 | mask_tmp.append(mask_np_N2) 64 | mask.append(np.stack(mask_tmp,0)) 65 | mask = np.stack(mask,0) 66 | np.save('mask_3d_{}_{}_{}_{}.npy'.format(ratio,signals.shape[0],args.N_repeat,signals.shape[1]),mask) 67 | ########################################### 68 | 69 | 70 | logger.info('################ Simple Encoding ################') 71 | 72 | ez = 192 73 | rff_params = [2] 74 | linearf_params = [3] 75 | logf_params = [3] 76 | gaussian_params = [0.02] 77 | linear_params = [3/64] 78 | 79 | 80 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params) 81 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params 82 | 83 | 84 | for depth in [5,1,0]: 85 | for lr in [5e-3]: 86 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr)) 87 | for em,param in zip(encoding_methods,params): 88 | 89 | ef = encoding_func_3D(em,[param,ez]) 90 | time_,trn_psnr_,tst_psnr_,rec_ = train_random_simple_3D(signals,ef,mask=mask,N_repeat=args.N_repeat,lr=lr,epochs=500,depth=depth,device=device,logger=None) 91 | file_name = 'RD{}{}'.format(depth,em) 92 | if args.save_flag: 93 | # np.save(args.save_path+file_name+'_rec.npy',rec_) 94 | # np.save(args.save_path+file_name+'_time.npy',time_) 95 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_) 96 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_) 97 | for i in range(signals.shape[0]): 98 | imageio.mimwrite(args.save_path+'V{}'.format(i)+file_name+ 'R{:.2f}.mp4'.format(tst_psnr_[i,-1]), to8b(rec_[i]), fps=10, quality=8) 99 | 100 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:]))) 101 | 102 | logger.info('################ Complex Encoding ################') 103 | 104 | ez = 64 105 | rff_params = [6] 106 | linearf_params = [8] 107 | logf_params = [8] 108 | gaussian_params = [0.02] 109 | linear_params = [3.5/64] 110 | 111 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params) 112 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params 113 | 114 | for depth in [0,1]: 115 | for lr in [1e-1]: 116 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr)) 117 | for em,param in zip(encoding_methods,params): 118 | 119 | ef = encoding_func_1D(em,[param,ez]) 120 | bl = blending_func_3D(ef) 121 | time_,trn_psnr_,tst_psnr_,rec_ = train_index_blend_kron_3D(signals,bl,ef,mask=mask,N_repeat=args.N_repeat,lr=lr,epochs=500,depth=depth,device=device,logger=None) 122 | file_name = 'RKD{}{}'.format(depth,em) 123 | if args.save_flag: 124 | # np.save(args.save_path+file_name+'_rec.npy',rec_) 125 | # np.save(args.save_path+file_name+'_time.npy',time_) 126 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_) 127 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_) 128 | for i in range(signals.shape[0]): 129 | imageio.mimwrite(args.save_path+'V{}'.format(i)+file_name+ 'R{:.2f}.mp4'.format(tst_psnr_[i,-1]), to8b(rec_[i]), fps=10, quality=8) 130 | 131 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:]))) 132 | 133 | 134 | 135 | if __name__ == "__main__": 136 | torch.set_default_dtype(torch.float32) 137 | torch.manual_seed(20220222) 138 | np.random.seed(20220222) 139 | 140 | 141 | parser = ArgumentParser() 142 | 143 | parser.add_argument("--data_path", type=str, default="video_16_128.npz") 144 | parser.add_argument("--mask_path", type=str, default="mask_3d_0.125_5_1_128.npy") 145 | parser.add_argument("--N_repeat", type=int, default=1) 146 | parser.add_argument("--save_path", type=str, default="3D_random_points/") 147 | parser.add_argument("--logger", type=str, default="log.log") 148 | parser.add_argument("--save_flag", type=int, default=0, choices=[0, 1]) 149 | 150 | 151 | args = parser.parse_args() 152 | main(args) -------------------------------------------------------------------------------- /3D_separable_points_GD.py: -------------------------------------------------------------------------------- 1 | import torch, os, logging 2 | from argparse import ArgumentParser 3 | import numpy as np 4 | 5 | from trainer import * 6 | from utils import * 7 | from MLPs import * 8 | 9 | 10 | import imageio 11 | 12 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 13 | 14 | def get_logger(filename, verbosity=1, name=None): 15 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} 16 | formatter = logging.Formatter( 17 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" 18 | ) 19 | logger = logging.getLogger(name) 20 | logger.setLevel(level_dict[verbosity]) 21 | 22 | fh = logging.FileHandler(filename, "w") 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | sh = logging.StreamHandler() 27 | sh.setFormatter(formatter) 28 | logger.addHandler(sh) 29 | 30 | return logger 31 | 32 | 33 | def main(args): 34 | 35 | if os.path.exists(args.save_path): 36 | print('Path already exists!') 37 | return 1 38 | os.mkdir(args.save_path) 39 | logger = get_logger(args.save_path+args.logger) 40 | logger.info(args) 41 | 42 | # Set the CUDA flag 43 | device = "cuda" if torch.cuda.is_available() else "cpu" 44 | logger.info('device is: {}'.format(device)) 45 | 46 | 47 | buf = np.load(args.data_path) 48 | 49 | signals = torch.from_numpy(buf['test']) 50 | 51 | logger.info('################ Simple Encoding ################') 52 | 53 | ez = 192 54 | rff_params = [3] 55 | linearf_params = [3] 56 | logf_params = [3] 57 | gaussian_params = [0.015] 58 | linear_params = [3/64] 59 | 60 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params) 61 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params 62 | 63 | for depth in [5,1,0]: 64 | for lr in [5e-3]: 65 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr)) 66 | for em,param in zip(encoding_methods,params): 67 | 68 | ef = encoding_func_3D(em,[param,ez]) 69 | time_,trn_psnr_,tst_psnr_,rec_ = train_simple_3D(signals,ef,N_repeat=args.N_repeat,lr=lr,epochs=500,depth=depth,device=device,logger=None) 70 | file_name = 'D{}{}'.format(depth,em) 71 | if args.save_flag: 72 | # np.save(args.save_path+file_name+'_rec.npy',rec_) 73 | # np.save(args.save_path+file_name+'_time.npy',time_) 74 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_) 75 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_) 76 | for i in range(signals.shape[0]): 77 | imageio.mimwrite(args.save_path+'V{}'.format(i)+file_name+ 'R{:.2f}.mp4'.format(tst_psnr_[i,-1]), to8b(rec_[i]), fps=10, quality=8) 78 | 79 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:]))) 80 | 81 | logger.info('################ Complex Encoding ################') 82 | 83 | ez = 64 84 | rff_params = [6] 85 | linearf_params = [3] 86 | logf_params = [3] 87 | gaussian_params = [0.015] 88 | linear_params = [2/64] 89 | 90 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params) 91 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params 92 | 93 | for depth in [0,1]: 94 | for lr in [1e-1]: 95 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr)) 96 | for em,param in zip(encoding_methods,params): 97 | 98 | ef = encoding_func_1D(em,[param,ez]) 99 | time_,trn_psnr_,tst_psnr_,rec_ = train_kron_3D(signals,ef,N_repeat=args.N_repeat,lr=lr,epochs=500,depth=depth,device=device,logger=None) 100 | file_name = 'KD{}{}'.format(depth,em) 101 | if args.save_flag: 102 | # np.save(args.save_path+file_name+'_rec.npy',rec_) 103 | # np.save(args.save_path+file_name+'_time.npy',time_) 104 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_) 105 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_) 106 | for i in range(signals.shape[0]): 107 | imageio.mimwrite(args.save_path+'V{}'.format(i)+file_name+ 'R{:.2f}.mp4'.format(tst_psnr_[i,-1]), to8b(rec_[i]), fps=10, quality=8) 108 | 109 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:]))) 110 | 111 | 112 | logger.info('################ Closed Form Complex Encoding ################') 113 | 114 | ez = 64 115 | rff_params = [6] 116 | linearf_params = [4] 117 | logf_params = [3] 118 | gaussian_params = [0.009] 119 | linear_params = [3/64] 120 | 121 | for em,param in zip(encoding_methods,params): 122 | 123 | if em == 'Gau' or em == 'Tri': em=em 124 | else: device = 'cpu' 125 | 126 | ef = encoding_func_1D(em,[param,ez]) 127 | time_,trn_psnr_,tst_psnr_,rec_ = train_closed_form_3D(signals,ef,N_repeat=args.N_repeat,device=device,logger=None) 128 | file_name = 'CKD{}{}'.format(depth,em) 129 | if args.save_flag: 130 | # np.save(args.save_path+file_name+'_rec.npy',rec_) 131 | # np.save(args.save_path+file_name+'_time.npy',time_) 132 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_) 133 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_) 134 | for i in range(signals.shape[0]): 135 | imageio.mimwrite(args.save_path+'V{}'.format(i)+file_name+ 'R{:.2f}.mp4'.format(tst_psnr_[i,-1]), to8b(rec_[i]), fps=10, quality=8) 136 | 137 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:]))) 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | torch.set_default_dtype(torch.float32) 143 | torch.manual_seed(20220222) 144 | np.random.seed(20220222) 145 | 146 | 147 | parser = ArgumentParser() 148 | 149 | parser.add_argument("--data_path", type=str, default="video_16_128.npz") 150 | parser.add_argument("--N_repeat", type=int, default=1) 151 | parser.add_argument("--save_path", type=str, default="3D_separable_points/") 152 | parser.add_argument("--logger", type=str, default="log.log") 153 | parser.add_argument("--save_flag", type=int, default=0, choices=[0, 1]) 154 | 155 | 156 | args = parser.parse_args() 157 | main(args) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jianqiao Zheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MLPs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | # regular MLP 8 | class MLP(nn.Module): 9 | def __init__(self,input_dim=2, output_dim=3, depth = 0, width= 256,bias=True,use_sigmoid = True): 10 | super(MLP, self).__init__() 11 | self.use_sigmoid = use_sigmoid 12 | self.mm = nn.ModuleList([]) 13 | if depth == 0: 14 | self.mm.append(nn.Linear(input_dim, output_dim,bias=bias)) 15 | else: 16 | self.mm.append(nn.Sequential(nn.Linear(input_dim, width,bias=bias),nn.ReLU(True))) 17 | for i in range(depth-1): 18 | self.mm.append(nn.Sequential(nn.Linear(width, width,bias=bias),nn.ReLU(True))) 19 | self.mm.append(nn.Sequential(nn.Linear(width, output_dim,bias=bias))) 20 | if use_sigmoid: self.mm.append(nn.Sigmoid()) 21 | def forward(self, x): 22 | for m in self.mm: 23 | x = m(x) 24 | return x 25 | def name(self): 26 | return "MLP" 27 | 28 | 29 | # If the input to MLP is in form of kron(x,y), this model can calculate it efficiently. 30 | class Kron_MLP(nn.Module): 31 | def __init__(self,input_dim=2, output_dim=3, depth=0, width0=256, width=256, bias=True, use_sigmoid=True): 32 | super(Kron_MLP, self).__init__() 33 | self.use_sigmoid = use_sigmoid 34 | 35 | if depth==0: width0 = output_dim 36 | 37 | self.mm = nn.ModuleList([]) 38 | self.first = nn.ParameterDict({ 39 | 'weight': nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim)-1/np.sqrt(width0))}) 40 | 41 | if depth == 1: 42 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim,bias=bias))) 43 | if depth > 1: 44 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width,bias=bias),nn.ReLU(True))) 45 | for i in range(depth-2): 46 | self.mm.append(nn.Sequential(nn.Linear(width, width,bias=bias),nn.ReLU(True))) 47 | self.mm.append(nn.Linear(width, output_dim,bias=bias)) 48 | 49 | if use_sigmoid: self.mm.append(nn.Sigmoid()) 50 | 51 | def forward(self, x, y=None): 52 | if y is None: y=x.detach().clone() 53 | 54 | ### calculate x*W*y.T intead of kron(x,y)*vec(W) ### 55 | # naive way 56 | x = x@self.first.weight 57 | x = x@(y.transpose(0,1)) 58 | 59 | # einsum implementation 60 | # x = torch.einsum('ij,wjk,lk->wil',x,self.first.weight,y) 61 | ########################################################## 62 | 63 | x = x.flatten(1,2).transpose(0,1) 64 | for m in self.mm: 65 | x = m(x) 66 | return x 67 | def name(self): 68 | return "Kron_MLP" 69 | 70 | 71 | # MLP with blending matrix in sparse matrix multiplication, not using because it's a bit slow 72 | class Blend_Kron_MLP(nn.Module): 73 | def __init__(self, input_dim=2, output_dim=3, depth=0, width0=256, width=256, bias=True, use_sigmoid=True): 74 | super(Blend_Kron_MLP, self).__init__() 75 | self.use_sigmoid = use_sigmoid 76 | 77 | if depth==0: width0 = output_dim 78 | 79 | self.mm = nn.ModuleList([]) 80 | self.first = nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim)-1/np.sqrt(width0)) 81 | 82 | if depth == 1: 83 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim,bias=bias))) 84 | if depth > 1: 85 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width,bias=bias),nn.ReLU(True))) 86 | for i in range(depth-2): 87 | self.mm.append(nn.Sequential(nn.Linear(width, width,bias=bias),nn.ReLU(True))) 88 | self.mm.append(nn.Linear(width, output_dim,bias=bias)) 89 | 90 | if use_sigmoid: self.mm.append(nn.Sigmoid()) 91 | 92 | def forward(self, B, x, y=None): 93 | if y is None: y=x.detach().clone() 94 | 95 | ### calculate x*W*y.T intead of kron(x,y)*vec(W) ### 96 | # naive way 97 | x = x@self.first 98 | x = x@(y.transpose(0,1)) 99 | 100 | # einsum implementation 101 | # x = torch.einsum('ij,wjk,lk->wil',x,self.first.weight,y) 102 | ########################################################## 103 | 104 | x = x.flatten(1,2).transpose(0,1) 105 | x = B@x 106 | for m in self.mm: 107 | x = m(x) 108 | return x 109 | def name(self): 110 | return "Blend_Kron_MLP" 111 | 112 | 113 | 114 | # MLP with blending matrix in indexing and weight sum implementation 115 | class Indexing_Blend_Kron_MLP(nn.Module): 116 | def __init__(self, input_dim=2, output_dim=3, depth=0, width0=256, width=256, bias=True, use_sigmoid=True): 117 | super(Indexing_Blend_Kron_MLP, self).__init__() 118 | self.use_sigmoid = use_sigmoid 119 | 120 | if depth==0: width0 = output_dim 121 | 122 | self.mm = nn.ModuleList([]) 123 | self.first = nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim)-1/np.sqrt(width0)) 124 | 125 | if depth == 1: 126 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim,bias=bias))) 127 | if depth > 1: 128 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width,bias=bias),nn.ReLU(True))) 129 | for i in range(depth-2): 130 | self.mm.append(nn.Sequential(nn.Linear(width, width,bias=bias),nn.ReLU(True))) 131 | self.mm.append(nn.Linear(width, output_dim,bias=bias)) 132 | 133 | if use_sigmoid: self.mm.append(nn.Sigmoid()) 134 | 135 | def forward(self, B, x, y=None): 136 | if y is None: y=x.detach().clone() 137 | 138 | ### calculate x*W*y.T intead of kron(x,y)*vec(W) and indexing weight sum ### 139 | # naive way 140 | x = x@self.first 141 | x = x@(y.transpose(0,1)) 142 | x = x.flatten(1,2).transpose(0,1) 143 | x = (x[B[0]]*(B[1].unsqueeze(-1))).sum(1) 144 | 145 | # einsum implementation 146 | # x = torch.einsum('ij,wjk,lk->wil',x,self.first.weight,y) 147 | # x = x.flatten(1,2).transpose(0,1) 148 | # x = torch.einsum('ijk,ij->ik',x[B[0]],B[1]) 149 | ########################################################## 150 | 151 | for m in self.mm: 152 | x = m(x) 153 | return x 154 | def name(self): 155 | return "Indexing_Blend_Kron_MLP" 156 | 157 | 158 | # Followings are 3D veisrions 159 | class Kron3_MLP(nn.Module): 160 | def __init__(self,input_dim=3, output_dim=3, depth=0, width0=256, width=256, bias=True, use_sigmoid=True): 161 | super(Kron3_MLP, self).__init__() 162 | self.use_sigmoid = use_sigmoid 163 | 164 | if depth==0: width0 = output_dim 165 | 166 | self.mm = nn.ModuleList([]) 167 | self.first = nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim, input_dim)-1/np.sqrt(width0)) 168 | 169 | if depth == 1: 170 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim, bias=bias))) 171 | if depth > 1: 172 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width, bias=bias),nn.ReLU(True))) 173 | for i in range(depth-2): 174 | self.mm.append(nn.Sequential(nn.Linear(width, width, bias=bias),nn.ReLU(True))) 175 | self.mm.append(nn.Linear(width, output_dim, bias=bias)) 176 | 177 | if use_sigmoid: self.mm.append(nn.Sigmoid()) 178 | def forward(self, x, y=None,z=None): 179 | if y is None: y=x.detach().clone() 180 | if z is None: z=x.detach().clone() 181 | 182 | ### calculate mode-n multiplication intead of kron(x,y,z)*vec(W) ### 183 | # naive way 184 | x = x@self.first 185 | x = x@(y.transpose(0,1)) 186 | x = z@(x.transpose(1,2)) 187 | 188 | # einsum implementation 189 | # x = torch.einsum('ai,bj,ck,wijk->wabc',x,y,z,self.first.weight) 190 | ########################################################## 191 | 192 | x = x.flatten(1,3).transpose(0,1) 193 | for m in self.mm: 194 | x = m(x) 195 | return x 196 | def name(self): 197 | return "Kron3_MLP" 198 | 199 | 200 | # class Blend_Kron3_MLP(nn.Module): 201 | # def __init__(self,input_dim=2, output_dim=3, depth=0,width0=256,width=256, use_sigmoid=True): 202 | # super(Blend_Kron3_MLP, self).__init__() 203 | # self.use_sigmoid = use_sigmoid 204 | 205 | # if depth==0: width0 = output_dim 206 | 207 | # self.mm = nn.ModuleList([]) 208 | # self.first = nn.ParameterDict({ 209 | # 'weight': nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim, input_dim)-1/np.sqrt(width0))}) 210 | # if depth == 1: 211 | # self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim))) 212 | # if depth > 1: 213 | # self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width),nn.ReLU(True))) 214 | # for i in range(depth-1): 215 | # self.mm.append(nn.Sequential(nn.Linear(width, width),nn.ReLU(True))) 216 | # self.mm.append(nn.Linear(width, output_dim)) 217 | # if use_sigmoid: self.mm.append(nn.Sigmoid()) 218 | # def forward(self, B, x, y=None,z=None): 219 | # if y is None: y=x.detach().clone() 220 | # if z is None: z=x.detach().clone() 221 | # x = x@self.first.weight 222 | # x = x@(y.transpose(0,1)) 223 | # x = z@(x.transpose(1,2)) 224 | # x = x.flatten(1,3).transpose(0,1) 225 | # x = B@x 226 | # for m in self.mm: 227 | # x = m(x) 228 | # return x 229 | # def name(self): 230 | # return "Blend_Kron3_MLP" 231 | 232 | 233 | 234 | class Indexing_Blend_Kron3_MLP(nn.Module): 235 | def __init__(self,input_dim=3, output_dim=3, depth=0, width0=256, width=256, bias=True, use_sigmoid=True): 236 | super(Indexing_Blend_Kron3_MLP, self).__init__() 237 | self.use_sigmoid = use_sigmoid 238 | 239 | if depth==0: width0 = output_dim 240 | 241 | self.mm = nn.ModuleList([]) 242 | self.first = nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim, input_dim)-1/np.sqrt(width0)) 243 | 244 | if depth == 1: 245 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim, bias=bias))) 246 | if depth > 1: 247 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width, bias=bias),nn.ReLU(True))) 248 | for i in range(depth-2): 249 | self.mm.append(nn.Sequential(nn.Linear(width, width, bias=bias),nn.ReLU(True))) 250 | self.mm.append(nn.Linear(width, output_dim, bias=bias)) 251 | 252 | if use_sigmoid: self.mm.append(nn.Sigmoid()) 253 | def forward(self, B, x, y=None,z=None): 254 | if y is None: y=x.detach().clone() 255 | if z is None: z=x.detach().clone() 256 | 257 | ### calculate mode-n multiplication intead of kron(x,y,z)*vec(W) ### 258 | # naive way 259 | x = x@self.first 260 | x = x@(y.transpose(0,1)) 261 | x = z@(x.transpose(1,2)) 262 | x = x.flatten(1,3).transpose(0,1) 263 | x = (x[B[0]]*(B[1].unsqueeze(-1))).sum(1) 264 | 265 | # einsum implementation 266 | # x = torch.einsum('ai,bj,ck,wijk->wabc',x,y,z,self.first.weight) 267 | # x = x.flatten(1,3).transpose(0,1) 268 | # x = torch.einsum('ijk,ij->ik',x[B[0]],B[1]) 269 | ########################################################## 270 | for m in self.mm: 271 | x = m(x) 272 | return x 273 | def name(self): 274 | return "Indexing_Blend_Kron3_MLP" 275 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Trading Positional Complexity vs Deepness in Coordinate Networks 2 | ### [Project Page](https://osiriszjq.github.io/complex_encoding) | [Paper](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136870142.pdf) 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | 6 | [Jianqiao Zheng](https://github.com/osiriszjq/), 7 | [Sameera Ramasinghe](https://scholar.google.pl/citations?user=-j0m9aMAAAAJ&hl=en), 8 | [Xueqian Li](https://lilac-lee.github.io/), 9 | [Simon Lucey](https://www.adelaide.edu.au/directory/simon.lucey)
10 | The University of Adelaide 11 | 12 | This is the official implementation of the paper "Trading Positional Complexity vs Deepness in Coordinate Networks", which has been accepted to ECCV 2022. 13 | 14 | #### Illustration of different methods to extend 1D encoding 15 | ![Illustration of different methods to extend 1D encoding](imgs/simple_complex_encoding.png) 16 | 17 | 18 | ## Google Colab 19 | [![Explore Siren in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/osiriszjq/complex_encoding/blob/main/complex_encoding.ipynb)
20 | If you want to try out our new complex encoding, we have written a [Colab](https://colab.research.google.com/github/osiriszjq/complex_encoding/blob/main/complex_encoding.ipynb) with the following experiments: 21 | * simple encoding for 2D image reconstuction with separable coordinates 22 | * complex encoding for 2D image reconstuction with separable coordinates 23 | * colsed form solution of complex encoding for 2D image reconstuction with separable coordinates. 24 | * simple encoding for 3D video reconstuction with non-separable coordinates 25 | * complex encoding for 3D video reconstuction with non-separable coordinates 26 | 27 | 28 | ## Dataset 29 | The Dataset used to reproduced can be found in [Google Drive](https://drive.google.com/drive/folders/1yLVG1WT5i9PxchNqAb84nJdHuLlNPCj4?usp=sharing). The image data is from [Random Fourier Frequency](https://github.com/tancik/fourier-feature-networks) and the video data is from [Youtube](https://research.google.com/youtube-bb/). 30 | 31 | 32 | ## Citation 33 | ```bibtex 34 | @inproceedings{zheng2022trading, 35 | title={Trading Positional Complexity vs Deepness in Coordinate Networks}, 36 | author={Zheng, Jianqiao and Ramasinghe, Sameera and Li, Xueqian and Lucey, Simon}, 37 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 38 | year={2022} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | 3 | datasets = { 4 | 'data_div2k.npz':'115l577qCHEbGN-GzE0MpnyHFoGJDy5Xw', 5 | 'video_16_128.npz':'1-5TOdbH4j6Fr4TV1bqppo9H5JgAhvJUs', 6 | } 7 | 8 | for name in datasets: 9 | gdown.download(id=datasets[name], output=name, quiet=False) -------------------------------------------------------------------------------- /imgs/simple_complex_encoding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osiriszjq/complex_encoding/5be131f47d3a6b7616ff2d9fa4eb5d055e280bb1/imgs/simple_complex_encoding.png -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time, torch, logging 3 | from MLPs import * 4 | from utils import * 5 | 6 | 7 | def train_simple_2D(signals,encoding_func,epochs=2000,lr=1e-2,criterion=nn.MSELoss(),device="cpu", 8 | depth=4,width=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None): 9 | 10 | ''' 11 | Trainer of simple encoding for 2D image reconstruction with separable coordinates. 12 | 13 | The input args are four parts: 14 | 15 | signals: a pytorch tensor shape in N_image x sample_N x sample_N x 3 16 | encoding_func: encoding function 17 | 18 | epochs,lr,criterion are optimization settings 19 | 20 | depth,width,bias,use_sigmoid are network cofiguration 21 | 22 | N_repeat: number of repeat times for each example 23 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones. 24 | logger: if is None, loggers will print out. 25 | 26 | Returns: 27 | 28 | running time: N_test x N_repeat 29 | train psnr: N_test x N_repeat 30 | test psnr: N_test x N_repeat 31 | reconstuction results: same size as input signals 32 | ''' 33 | 34 | 35 | sample_N =signals.shape[1] 36 | N_test = signals.shape[0] 37 | 38 | trn_psnr_ = np.zeros((N_test,N_repeat)) 39 | tst_psnr_ = np.zeros((N_test,N_repeat)) 40 | time_ = np.zeros((N_test,N_repeat)) 41 | rec_ = np.zeros((N_test,sample_N,sample_N,3)) 42 | 43 | # build the meshgrid for coordinates 44 | x1 = np.linspace(0, 1, sample_N+1)[:-1] 45 | all_data = np.stack(np.meshgrid(x1,x1), axis=-1) 46 | 47 | 48 | for i in range(N_test): 49 | # Prepare the targets 50 | all_target = signals[i].squeeze() 51 | train_label = all_target[::2,::2].reshape(-1,3).type(torch.FloatTensor) 52 | if test_mode == 'mid': 53 | test_label = all_target[1::2,1::2].reshape(-1,3).type(torch.FloatTensor) 54 | elif test_mode == 'rest': 55 | test_label1 = all_target[1::2,::2].reshape(-1,3).type(torch.FloatTensor) 56 | test_label2 = all_target[1::2,1::2].reshape(-1,3).type(torch.FloatTensor) 57 | test_label3 = all_target[::2,1::2].reshape(-1,3).type(torch.FloatTensor) 58 | test_label = torch.cat([test_label1,test_label2,test_label3],0) 59 | elif test_mode == 'all': 60 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor) 61 | 62 | 63 | for k in range(N_repeat): 64 | start_time = time.time() 65 | train_data = encoding_func(torch.from_numpy(all_data[::2,::2].reshape(-1,2)).type(torch.FloatTensor)) 66 | 67 | 68 | train_data, train_label = train_data.to(device),train_label.to(device) 69 | 70 | # regular MLP 71 | model = MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device) 72 | # Set the optimization 73 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8) 74 | 75 | 76 | for epoch in range(epochs): 77 | model.train() 78 | optimizer.zero_grad() 79 | 80 | out = model(train_data) 81 | loss = criterion(out, train_label) 82 | 83 | loss.backward() 84 | optimizer.step() 85 | 86 | 87 | time_[i,k] = time.time() - start_time 88 | 89 | 90 | if test_mode == 'mid': 91 | test_data = encoding_func(torch.from_numpy(all_data[1::2,1::2].reshape(-1,2)).type(torch.FloatTensor)) 92 | elif test_mode == 'rest': 93 | test_data1 = encoding_func(torch.from_numpy(all_data[1::2,::2].reshape(-1,2)).type(torch.FloatTensor)) 94 | test_data2 = encoding_func(torch.from_numpy(all_data[1::2,1::2].reshape(-1,2)).type(torch.FloatTensor)) 95 | test_data3 = encoding_func(torch.from_numpy(all_data[::2,1::2].reshape(-1,2)).type(torch.FloatTensor)) 96 | test_data = torch.cat([test_data1,test_data2,test_data3],0) 97 | elif test_mode == 'all': 98 | test_data = encoding_func(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor)) 99 | else: 100 | print('Wrong test_mode!') 101 | return -1 102 | 103 | 104 | test_data, test_label = test_data.to(device),test_label.to(device) 105 | 106 | model.eval() 107 | with torch.no_grad(): 108 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label) 109 | tst_psnr_[i,k] = psnr_func(model(test_data),test_label) 110 | if logger == None: 111 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 112 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 113 | else: 114 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 115 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 116 | 117 | rec_[i] = model(encoding_func(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,3).detach().cpu() 118 | return time_, trn_psnr_,tst_psnr_,rec_ 119 | 120 | 121 | 122 | 123 | def train_kron_2D(signals,encoding_func,epochs=2000,lr=1e-1,criterion=nn.MSELoss(),device="cpu", 124 | depth=0,width=256,width0=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None): 125 | 126 | ''' 127 | Trainer of complex encoding for 2D image reconstruction with separable coordinates. 128 | 129 | The input args are four parts: 130 | 131 | signals: a pytorch tensor shape in N_image x sample_N x sample_N x 3 132 | encoding_func: encoding function 133 | 134 | epochs,lr,criterion are optimization settings 135 | 136 | depth,width,width0,bias,use_sigmoid are network cofiguration 137 | 138 | N_repeat: number of repeat times for each example 139 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones. 140 | logger: if is None, loggers will print out. 141 | 142 | Returns: 143 | 144 | running time: N_test x N_repeat 145 | train psnr: N_test x N_repeat 146 | test psnr: N_test x N_repeat 147 | reconstuction results: same size as input signals 148 | ''' 149 | 150 | 151 | sample_N =signals.shape[1] 152 | N_test = signals.shape[0] 153 | 154 | trn_psnr_ = np.zeros((N_test,N_repeat)) 155 | tst_psnr_ = np.zeros((N_test,N_repeat)) 156 | time_ = np.zeros((N_test,N_repeat)) 157 | rec_ = np.zeros((N_test,sample_N,sample_N,3)) 158 | 159 | # Here we only use 1D grids 160 | all_data = np.linspace(0, 1, sample_N+1)[:-1] 161 | 162 | 163 | for i in range(N_test): 164 | # Prepare the targets 165 | all_target = signals[i].squeeze() 166 | train_label = all_target[::2,::2].reshape(-1,3).type(torch.FloatTensor).to(device) 167 | if test_mode == 'mid': 168 | test_label = all_target[1::2,::2].reshape(-1,3).type(torch.FloatTensor).to(device) 169 | test_mask = torch.tensor([[True]],device=device).repeat(int(sample_N/2),int(sample_N/2)).reshape(-1) 170 | elif test_mode == 'rest': 171 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device) 172 | test_mask = torch.tensor([[False,True],[True,True]],device=device).repeat(int(sample_N/2),int(sample_N/2)).reshape(-1) 173 | elif test_mode == 'all': 174 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device) 175 | test_mask = torch.tensor([[True]],device=device).repeat(sample_N,sample_N).reshape(-1) 176 | 177 | 178 | for k in range(N_repeat): 179 | start_time = time.time() 180 | train_data = torch.from_numpy(all_data[::2].reshape(-1,1)).type(torch.FloatTensor) 181 | 182 | # Initialize classification model to learn 183 | model = Kron_MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width0=width0,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device) 184 | # Set the optimization 185 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8) 186 | 187 | train_data = encoding_func(train_data).to(device) 188 | 189 | for epoch in range(epochs): 190 | model.train() 191 | optimizer.zero_grad() 192 | 193 | out = model(train_data) 194 | loss = criterion(out, train_label) 195 | 196 | loss.backward() 197 | optimizer.step() 198 | 199 | 200 | time_[i,k] = time.time() - start_time 201 | model.eval() 202 | 203 | if test_mode == 'mid': 204 | test_data = torch.from_numpy(all_data[1::2].reshape(-1,1)).type(torch.FloatTensor) 205 | elif test_mode == 'rest': 206 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor) 207 | elif test_mode == 'all': 208 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor) 209 | else: 210 | print('Wrong test_mode!') 211 | return -1 212 | 213 | 214 | test_data = encoding_func(test_data).to(device) 215 | 216 | with torch.no_grad(): 217 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label) 218 | tst_psnr_[i,k] = psnr_func(model(test_data)[test_mask],test_label[test_mask]) 219 | 220 | if logger == None: 221 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 222 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 223 | else: 224 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 225 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 226 | rec_[i] = model(encoding_func(torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,3).detach().cpu() 227 | return time_,trn_psnr_,tst_psnr_,rec_ 228 | 229 | 230 | 231 | 232 | def train_closed_form_2D(signals,encoding_func,device="cpu",test_mode='rest',N_repeat=1,logger=None): 233 | 234 | ''' 235 | Closed form solution of complex encoding for 2D image reconstruction with separable coordinates. No training. 236 | 237 | The input args are two parts: 238 | 239 | signals: a pytorch tensor shape in N_image x sample_N x sample_N x 3 240 | encoding_func: encoding function 241 | 242 | N_repeat: number of repeat times for each example 243 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones. 244 | logger: if is None, loggers will print out. 245 | 246 | Returns: 247 | 248 | running time: N_test x N_repeat 249 | train psnr: N_test x N_repeat 250 | test psnr: N_test x N_repeat 251 | reconstuction results: same size as input signals 252 | ''' 253 | 254 | 255 | sample_N =signals.shape[1] 256 | N_test = signals.shape[0] 257 | 258 | trn_psnr_ = np.zeros((N_test,N_repeat)) 259 | tst_psnr_ = np.zeros((N_test,N_repeat)) 260 | time_ = np.zeros((N_test,N_repeat)) 261 | rec_ = np.zeros((N_test,sample_N,sample_N,3)) 262 | 263 | # Here we only use 1D grids 264 | all_data = np.linspace(0, 1, sample_N+1)[:-1] 265 | 266 | 267 | for i in range(N_test): 268 | # Prepare the targets 269 | all_target = signals[i].squeeze() 270 | train_label = all_target[::2,::2].type(torch.FloatTensor).to(device) 271 | if test_mode == 'mid': 272 | test_label = all_target[1::2,1::2].reshape(-1,3).type(torch.FloatTensor).to(device) 273 | test_mask = torch.tensor([[True]],device=device).repeat(int(sample_N/2),int(sample_N/2)).reshape(-1) 274 | elif test_mode == 'rest': 275 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device) 276 | test_mask = torch.tensor([[False,True],[True,True]],device=device).repeat(int(sample_N/2),int(sample_N/2)).reshape(-1) 277 | elif test_mode == 'all': 278 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device) 279 | test_mask = torch.tensor([[True]],device=device).repeat(sample_N,sample_N).reshape(-1) 280 | 281 | 282 | for k in range(N_repeat): 283 | start_time = time.time() 284 | train_data = torch.from_numpy(all_data[::2].reshape(-1,1)).type(torch.FloatTensor) 285 | 286 | train_data = encoding_func(train_data).to(device) 287 | 288 | # two ways to calculate the inverse of a matrix 289 | #ix = torch.linalg.pinv(train_data) 290 | ix = torch.linalg.lstsq(train_data, torch.eye(train_data.shape[0]).to(device)).solution 291 | W= ix@(train_label.transpose(0,2))@ix.T 292 | 293 | time_[i,k] = time.time() - start_time 294 | 295 | if test_mode == 'mid': 296 | test_data = torch.from_numpy(all_data[1::2].reshape(-1,1)).type(torch.FloatTensor) 297 | elif test_mode == 'rest': 298 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor) 299 | elif test_mode == 'all': 300 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor) 301 | else: 302 | print('Wrong test_mode!') 303 | return -1 304 | 305 | 306 | test_data = encoding_func(test_data).to(device) 307 | 308 | trn_psnr_[i,k] = psnr_func((train_data@W@train_data.T).transpose(0,2),train_label).detach().cpu().numpy() 309 | tst_psnr_[i,k] = psnr_func((test_data@W@test_data.T).transpose(0,2).reshape(-1,3)[test_mask],test_label[test_mask]).detach().cpu().numpy() 310 | 311 | if logger == None: 312 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 313 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 314 | else: 315 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 316 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 317 | rec_[i] = (test_data@W@test_data.T).transpose(0,2).detach().cpu() 318 | return time_,trn_psnr_,tst_psnr_,rec_ 319 | 320 | 321 | 322 | 323 | def train_random_simple_2D(signals,encoding_func,ratio=0.25,mask=None,epochs=2000,lr=1e-2,criterion=nn.MSELoss(),device="cpu", 324 | depth=4,width=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None): 325 | 326 | ''' 327 | Trainer of simple encoding for 2D image reconstruction with randomly sampled points. 328 | 329 | The input args are four parts: 330 | 331 | signals: a torch tensor shape in N_image x sample_N x sample_N x 3 332 | encoding_func: encoding function 333 | ratio: if mask is None, randomly sample points for training with this ratio of number of all points 334 | mask: a boolean matrix shape in N_image x N_repeat x (sample_N x sample_N) 335 | 336 | epochs,lr,criterion are optimization settings 337 | 338 | depth,width,bias,use_sigmoid are network cofiguration 339 | 340 | N_repeat: number of repeat times for each example 341 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones. 342 | logger: if is None, loggers will print out. 343 | 344 | Returns: 345 | 346 | running time: N_test x N_repeat 347 | train psnr: N_test x N_repeat 348 | test psnr: N_test x N_repeat 349 | reconstuction results: same size as input signals 350 | ''' 351 | 352 | 353 | sample_N =signals.shape[1] 354 | N_test = signals.shape[0] 355 | 356 | trn_psnr_ = np.zeros((N_test,N_repeat)) 357 | tst_psnr_ = np.zeros((N_test,N_repeat)) 358 | time_ = np.zeros((N_test,N_repeat)) 359 | rec_ = np.zeros((N_test,sample_N,sample_N,3)) 360 | 361 | # build the meshgrid for coordinates 362 | x1 = np.linspace(0, 1, sample_N+1)[:-1] 363 | all_data = np.stack(np.meshgrid(x1,x1), axis=-1) 364 | 365 | 366 | 367 | for i in range(N_test): 368 | 369 | for k in range(N_repeat): 370 | 371 | if mask is None: 372 | idx = torch.randperm(sample_N**2)[:int(ratio*sample_N**2)] 373 | mask_np_N2 = np.zeros((sample_N**2)) 374 | mask_np_N2[idx] = 1 375 | mask_np_N2 = mask_np_N2==1 376 | else: 377 | mask_np_N2 = mask[i,k] 378 | 379 | # Prepare the targets 380 | all_target = signals[i].squeeze() 381 | train_label = all_target.reshape(-1,3)[mask_np_N2].type(torch.FloatTensor) 382 | 383 | start_time = time.time() 384 | train_data = encoding_func(torch.from_numpy(all_data.reshape(-1,2)[mask_np_N2]).type(torch.FloatTensor)) 385 | train_data, train_label = train_data.to(device),train_label.to(device) 386 | 387 | 388 | # Initialize classification model to learn 389 | model = MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device) 390 | # Set the optimization 391 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8) 392 | 393 | 394 | for epoch in range(epochs): 395 | model.train() 396 | optimizer.zero_grad() 397 | 398 | out = model(train_data) 399 | loss = criterion(out, train_label) 400 | 401 | loss.backward() 402 | optimizer.step() 403 | 404 | 405 | time_[i,k] = time.time() - start_time 406 | 407 | 408 | if test_mode == 'rest': 409 | test_label = all_target.reshape(-1,3)[~mask_np_N2].type(torch.FloatTensor) 410 | test_data = encoding_func(torch.from_numpy(all_data.reshape(-1,2)[~mask_np_N2]).type(torch.FloatTensor)) 411 | elif test_mode == 'all': 412 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor) 413 | test_data = encoding_func(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor)) 414 | else: 415 | print('Wrong test_mode!') 416 | return -1 417 | 418 | test_data, test_label = test_data.to(device),test_label.to(device) 419 | 420 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label) 421 | tst_psnr_[i,k] = psnr_func(model(test_data),test_label) 422 | 423 | if logger == None: 424 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 425 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 426 | else: 427 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 428 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 429 | rec_[i] = model(encoding_func(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,3).detach().cpu() 430 | 431 | return time_, trn_psnr_,tst_psnr_,rec_ 432 | 433 | 434 | 435 | 436 | def train_index_blend_kron_2D(signals,outter_blending,inner_encoding,ratio=0.25,mask=None,sm=0.0,epochs=2000,lr=1e-1,criterion=torch.nn.MSELoss(),device="cpu", 437 | depth=0,width=256,width0=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None): 438 | 439 | ''' 440 | Trainer of complex encoding for 2D image reconstruction with randomly sampled points. 441 | 442 | The input args are four parts: 443 | 444 | signals: a pytorch tensor shape in N_image x sample_N x sample_N x 3 445 | inner encoding: encoding function function of the virtual grid points, so the encoded results is static 446 | outter_blending: blending funcion for interpolation index and weights of random samples 447 | ratio: if mask is None, randomly sample points for training with this ratio of number of all points 448 | mask: a boolean matrix shape in N_image x N_repeat x (sample_N x sample_N) 449 | sm: how much we punish on total variation loss 450 | 451 | epochs,lr,criterion are optimization settings 452 | 453 | depth,width,width0,bias,use_sigmoid are network cofiguration 454 | 455 | N_repeat: number of repeat times for each example 456 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones. 457 | logger: if is None, loggers will print out. 458 | 459 | Returns: 460 | 461 | running time: N_test x N_repeat 462 | train psnr: N_test x N_repeat 463 | test psnr: N_test x N_repeat 464 | reconstuction results: same size as input signals 465 | ''' 466 | 467 | 468 | sample_N =signals.shape[1] 469 | N_test = signals.shape[0] 470 | 471 | trn_psnr_ = np.zeros((N_test,N_repeat)) 472 | tst_psnr_ = np.zeros((N_test,N_repeat)) 473 | time_ = np.zeros((N_test,N_repeat)) 474 | rec_ = np.zeros((N_test,sample_N,sample_N,3)) 475 | 476 | # the length of all_data is same as 1D signal length 477 | x1 = np.linspace(0, 1, sample_N+1)[:-1] 478 | all_data = np.stack(np.meshgrid(x1,x1), axis=-1) 479 | all_grid = torch.linspace(0, 1, outter_blending.dim+1)[:-1] 480 | encoded_grid = inner_encoding(all_grid.reshape(-1,1)) 481 | 482 | filter=torch.tensor([[[1.0,-1.0]]]).to(device) 483 | 484 | 485 | for i in range(N_test): 486 | 487 | for k in range(N_repeat): 488 | 489 | if mask is None: 490 | idx = torch.randperm(sample_N**2)[:int(ratio*sample_N**2)] 491 | mask_np_N2 = np.zeros((sample_N**2)) 492 | mask_np_N2[idx] = 1 493 | mask_np_N2 = mask_np_N2==1 494 | else: 495 | mask_np_N2 = mask[i,k] 496 | 497 | 498 | # Prepare the targets 499 | all_target = signals[i].squeeze() 500 | train_label = all_target.reshape(-1,3)[mask_np_N2].type(torch.FloatTensor) 501 | 502 | start_time = time.time() 503 | train_idx, train_data = outter_blending(torch.from_numpy(all_data.reshape(-1,2)[mask_np_N2]).type(torch.FloatTensor)) 504 | train_idx,train_data, train_label = train_idx.to(device),train_data.to(device),train_label.to(device) 505 | 506 | encoded_grid = encoded_grid.to(device) 507 | 508 | # Initialize classification model to learn 509 | model = Indexing_Blend_Kron_MLP(input_dim=inner_encoding.dim,output_dim=3,depth=depth,width0=width0,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device) 510 | # Set the optimization 511 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8) 512 | 513 | for epoch in range(epochs): 514 | model.train() 515 | optimizer.zero_grad() 516 | 517 | out = model([train_idx,train_data],encoded_grid) 518 | loss = criterion(out, train_label) 519 | if sm>0: 520 | for oo in range(model.first.weight.shape[0]): 521 | loss += sm*smooth(model.first.weight[oo],filter) 522 | 523 | loss.backward() 524 | optimizer.step() 525 | 526 | model.eval() 527 | 528 | time_[i,k] = time.time() - start_time 529 | 530 | 531 | model.eval() 532 | if test_mode == 'rest': 533 | test_label = all_target.reshape(-1,3)[~mask_np_N2].type(torch.FloatTensor) 534 | test_idx,test_data = outter_blending(torch.from_numpy(all_data.reshape(-1,2)[~mask_np_N2]).type(torch.FloatTensor)) 535 | elif test_mode == 'all': 536 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor) 537 | test_idx,test_data = outter_blending(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor)) 538 | else: 539 | print('Wrong test_mode!') 540 | return -1 541 | 542 | test_idx,test_data, test_label = test_idx.to(device),test_data.to(device),test_label.to(device) 543 | with torch.no_grad(): 544 | trn_psnr_[i,k] = psnr_func(model([train_idx,train_data],encoded_grid),train_label) 545 | 546 | tst_psnr_[i,k] = psnr_func(model([test_idx,test_data],encoded_grid),test_label) 547 | 548 | if logger == None: 549 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 550 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 551 | else: 552 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 553 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 554 | 555 | rec_[i] = model([outter_blending(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor))[0].to(device), 556 | outter_blending(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor))[1].to(device)],encoded_grid).reshape(sample_N,sample_N,3).detach().cpu() 557 | return time_,trn_psnr_,tst_psnr_,rec_ 558 | 559 | 560 | 561 | 562 | def train_simple_3D(signals,encoding_func,epochs=500,lr=5e-3,criterion=nn.MSELoss(),device="cpu", 563 | depth=5,width=512,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None): 564 | 565 | ''' 566 | Trainer of simple encoding for 3D video reconstruction with separable coordinates. 567 | 568 | The input args are four parts: 569 | 570 | signals: a pytorch tensor shape in N_video x sample_N x sample_N x sample_N x 3 571 | encoding_func: encoding function 572 | 573 | epochs,lr,criterion are optimization settings 574 | 575 | depth,width,bias,use_sigmoid are network cofiguration 576 | 577 | N_repeat: number of repeat times for each example 578 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones. 579 | logger: if is None, loggers will print out. 580 | 581 | Returns: 582 | 583 | running time: N_test x N_repeat 584 | train psnr: N_test x N_repeat 585 | test psnr: N_test x N_repeat 586 | reconstuction results: same size as input signals 587 | ''' 588 | 589 | 590 | sample_N =signals.shape[1] 591 | N_test = signals.shape[0] 592 | 593 | trn_psnr_ = np.zeros((N_test,N_repeat)) 594 | tst_psnr_ = np.zeros((N_test,N_repeat)) 595 | time_ = np.zeros((N_test,N_repeat)) 596 | rec_ = np.zeros((N_test,sample_N,sample_N,sample_N,3)) 597 | 598 | # build the meshgrid for 3D coordinates 599 | x1 = torch.linspace(0, 1, sample_N+1)[:-1] 600 | all_data = torch.stack(torch.meshgrid(x1,x1,x1), axis=-1) 601 | 602 | 603 | for i in range(N_test): 604 | # Prepare the targets 605 | all_target = signals[i].squeeze() 606 | train_label = all_target[::2,::2,::2].reshape(-1,3).type(torch.FloatTensor) 607 | if test_mode == 'mid': 608 | test_label = all_target[1::2,1::2,1::2].reshape(-1,3).type(torch.FloatTensor) 609 | elif test_mode == 'rest': 610 | test_label1 = all_target[::2,1::2,::2].reshape(-1,3).type(torch.FloatTensor) 611 | test_label2 = all_target[::2,1::2,1::2].reshape(-1,3).type(torch.FloatTensor) 612 | test_label3 = all_target[::2,::2,1::2].reshape(-1,3).type(torch.FloatTensor) 613 | test_label4 = all_target[1::2,::2,::2].reshape(-1,3).type(torch.FloatTensor) 614 | test_label5 = all_target[1::2,1::2,::2].reshape(-1,3).type(torch.FloatTensor) 615 | test_label6 = all_target[1::2,1::2,1::2].reshape(-1,3).type(torch.FloatTensor) 616 | test_label7 = all_target[1::2,::2,1::2].reshape(-1,3).type(torch.FloatTensor) 617 | test_label = torch.cat([test_label1,test_label2,test_label3,test_label4,test_label5,test_label6,test_label7],0) 618 | elif test_mode == 'all': 619 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor) 620 | 621 | for k in range(N_repeat): 622 | start_time = time.time() 623 | train_data = all_data[::2,::2,::2].reshape(-1,3) 624 | 625 | 626 | if test_mode == 'mid': 627 | test_data = all_data[1::2,1::2,1::2].reshape(-1,3) 628 | elif test_mode == 'rest': 629 | test_data1 = all_data[::2,1::2,::2].reshape(-1,3) 630 | test_data2 = all_data[::2,1::2,1::2].reshape(-1,3) 631 | test_data3 = all_data[::2,::2,1::2].reshape(-1,3) 632 | test_data4 = all_data[1::2,::2,::2].reshape(-1,3) 633 | test_data5 = all_data[1::2,1::2,::2].reshape(-1,3) 634 | test_data6 = all_data[1::2,1::2,1::2].reshape(-1,3) 635 | test_data7 = all_data[1::2,::2,1::2].reshape(-1,3) 636 | test_data = torch.cat([test_data1,test_data2,test_data3,test_data4,test_data5,test_data6,test_data7],0) 637 | elif test_mode == 'all': 638 | test_data = all_data.reshape(-1,3) 639 | 640 | else: 641 | print('Wrong test_mode!') 642 | return -1 643 | 644 | train_data, train_label = encoding_func(train_data).to(device),train_label.to(device) 645 | test_data, test_label = encoding_func(test_data).to(device),test_label.to(device) 646 | 647 | # Initialize classification model to learn 648 | model = MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device) 649 | # Set the optimization 650 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8) 651 | 652 | 653 | for epoch in range(epochs): 654 | model.train() 655 | optimizer.zero_grad() 656 | 657 | out = model(train_data) 658 | loss = criterion(out, train_label) 659 | 660 | loss.backward() 661 | optimizer.step() 662 | 663 | 664 | time_[i,k] = time.time() - start_time 665 | model.eval() 666 | with torch.no_grad(): 667 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label) 668 | tst_psnr_[i,k] = psnr_func(model(test_data),test_label) 669 | if logger == None: 670 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 671 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 672 | else: 673 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 674 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 675 | 676 | with torch.no_grad(): 677 | for ff in range(128): 678 | rec_[i,ff] = model(encoding_func((all_data[ff,:,:].reshape(-1,3)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,3).detach().cpu() 679 | 680 | return time_, trn_psnr_,tst_psnr_,rec_ 681 | 682 | 683 | 684 | 685 | def train_kron_3D(signals,encoding_func,epochs=500,lr=1e-1,criterion=nn.MSELoss(),device="cpu", 686 | depth=0,width=256,width0=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None): 687 | 688 | ''' 689 | Trainer of complex encoding for 3D video reconstruction with separable coordinates. 690 | 691 | The input args are four parts: 692 | 693 | signals: a pytorch tensor shape in N_video x sample_N x sample_N x sample_N x 3 694 | encoding_func: encoding function 695 | 696 | epochs,lr,criterion are optimization settings 697 | 698 | depth,width,width0,bias,use_sigmoid are network cofiguration 699 | 700 | N_repeat: number of repeat times for each example 701 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones. 702 | logger: if is None, loggers will print out. 703 | 704 | Returns: 705 | 706 | running time: N_test x N_repeat 707 | train psnr: N_test x N_repeat 708 | test psnr: N_test x N_repeat 709 | reconstuction results: same size as input signals 710 | ''' 711 | 712 | 713 | sample_N =signals.shape[1] 714 | N_test = signals.shape[0] 715 | 716 | trn_psnr_ = np.zeros((N_test,N_repeat)) 717 | tst_psnr_ = np.zeros((N_test,N_repeat)) 718 | time_ = np.zeros((N_test,N_repeat)) 719 | rec_ = np.zeros((N_test,sample_N,sample_N,sample_N,3)) 720 | 721 | # Here we only use 1D grids 722 | all_data = np.linspace(0, 1, sample_N+1)[:-1] 723 | 724 | 725 | for i in range(N_test): 726 | # Prepare the targets 727 | all_target = signals[i].squeeze() 728 | train_label = all_target[::2,::2,::2].reshape(-1,3).type(torch.FloatTensor).to(device) 729 | if test_mode == 'mid': 730 | test_label = all_target[1::2,1::2,1::2].reshape(-1,3).type(torch.FloatTensor).to(device) 731 | test_mask = torch.tensor([[True]],device=device).repeat(int(sample_N/2),int(sample_N/2),int(sample_N/2)).reshape(-1) 732 | elif test_mode == 'rest': 733 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device) 734 | test_mask = torch.tensor([[[False,True],[True,True]],[[True,True],[True,True]]],device=device).repeat(int(sample_N/2),int(sample_N/2),int(sample_N/2)).reshape(-1) 735 | elif test_mode == 'all': 736 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device) 737 | test_mask = torch.tensor([[True]],device=device).repeat(sample_N,sample_N,sample_N).reshape(-1) 738 | 739 | 740 | for k in range(N_repeat): 741 | start_time = time.time() 742 | # train data are sampled by , test data is either in middle or all data, both are encoded by encoding func 743 | train_data = torch.from_numpy(all_data[::2].reshape(-1,1)).type(torch.FloatTensor) 744 | 745 | if test_mode == 'mid': 746 | test_data = torch.from_numpy(all_data[1::2].reshape(-1,1)).type(torch.FloatTensor) 747 | elif test_mode == 'rest': 748 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor) 749 | elif test_mode == 'all': 750 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor) 751 | else: 752 | print('Wrong test_mode!') 753 | return -1 754 | 755 | 756 | # Initialize classification model to learn 757 | model = Kron3_MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width0=width0,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device) 758 | # Set the optimization 759 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8) 760 | 761 | train_data = encoding_func(train_data).to(device) 762 | test_data = encoding_func(test_data).to(device) 763 | 764 | for epoch in range(epochs): 765 | model.train() 766 | optimizer.zero_grad() 767 | 768 | out = model(train_data) 769 | loss = criterion(out, train_label) 770 | 771 | loss.backward() 772 | optimizer.step() 773 | 774 | 775 | time_[i,k] = time.time() - start_time 776 | model.eval() 777 | with torch.no_grad(): 778 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label) 779 | 780 | tst_psnr_[i,k] = psnr_func(model(test_data)[test_mask],test_label[test_mask]) 781 | 782 | if logger == None: 783 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 784 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 785 | else: 786 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 787 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 788 | with torch.no_grad(): 789 | rec_[i] = model(encoding_func(torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,sample_N,3).detach().cpu() 790 | return time_,trn_psnr_,tst_psnr_,rec_ 791 | 792 | 793 | 794 | 795 | def train_closed_form_3D(signals,encoding_func,device="cpu",test_mode='rest',N_repeat=1,logger=None): 796 | 797 | ''' 798 | Closed form solution of complex encoding for 3D video reconstruction with separable coordinates. No training. 799 | 800 | The input args are two parts: 801 | 802 | signals: a pytorch tensor shape in N_video x sample_N x sample_N x sample_N x 3 803 | encoding_func: encoding function 804 | 805 | N_repeat: number of repeat times for each example 806 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones. 807 | logger: if is None, loggers will print out. 808 | 809 | Returns: 810 | 811 | running time: N_test x N_repeat 812 | train psnr: N_test x N_repeat 813 | test psnr: N_test x N_repeat 814 | reconstuction results: same size as input signals 815 | ''' 816 | 817 | 818 | sample_N =signals.shape[1] 819 | N_test = signals.shape[0] 820 | 821 | trn_psnr_ = np.zeros((N_test,N_repeat)) 822 | tst_psnr_ = np.zeros((N_test,N_repeat)) 823 | time_ = np.zeros((N_test,N_repeat)) 824 | rec_ = np.zeros((N_test,sample_N,sample_N,sample_N,3)) 825 | 826 | # Here we only use 1D grids 827 | all_data = np.linspace(0, 1, sample_N+1)[:-1] 828 | 829 | 830 | for i in range(N_test): 831 | # Prepare the targets 832 | all_target = signals[i].squeeze() 833 | train_label = all_target[::2,::2,::2].type(torch.FloatTensor).to(device) 834 | if test_mode == 'mid': 835 | test_label = all_target[1::2,1::2,1::2].reshape(-1,3).type(torch.FloatTensor).to(device) 836 | test_mask = torch.tensor([[True]],device=device).repeat(int(sample_N/2),int(sample_N/2),int(sample_N/2)).reshape(-1) 837 | elif test_mode == 'rest': 838 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device) 839 | test_mask = torch.tensor([[[False,True],[True,True]],[[True,True],[True,True]]],device=device).repeat(int(sample_N/2),int(sample_N/2),int(sample_N/2)).reshape(-1) 840 | elif test_mode == 'all': 841 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device) 842 | test_mask = torch.tensor([[True]],device=device).repeat(sample_N,sample_N,sample_N).reshape(-1) 843 | 844 | 845 | for k in range(N_repeat): 846 | start_time = time.time() 847 | # train data are sampled by , test data is either in middle or all data, both are encoded by encoding func 848 | train_data = torch.from_numpy(all_data[::2].reshape(-1,1)).type(torch.FloatTensor) 849 | 850 | if test_mode == 'mid': 851 | test_data = torch.from_numpy(all_data[1::2].reshape(-1,1)).type(torch.FloatTensor) 852 | elif test_mode == 'rest': 853 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor) 854 | elif test_mode == 'all': 855 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor) 856 | else: 857 | print('Wrong test_mode!') 858 | return -1 859 | 860 | 861 | train_data = encoding_func(train_data).to(device) 862 | 863 | test_data = encoding_func(test_data).to(device) 864 | 865 | #ix = torch.linalg.pinv(train_data) 866 | ix = torch.linalg.lstsq(train_data, torch.eye(train_data.shape[0]).to(device)).solution 867 | W= (ix@(train_label.transpose(0,3))@ix.T).transpose(1,3)@ix.T 868 | 869 | 870 | time_[i,k] = time.time() - start_time 871 | trn_psnr_[i,k] = psnr_func((train_data@((W@train_data.T).transpose(1,3))@train_data.T).transpose(0,3),train_label).detach().cpu().numpy() 872 | tst_psnr_[i,k] = psnr_func((test_data@((W@test_data.T).transpose(1,3))@test_data.T).transpose(0,3).reshape(-1,3)[test_mask],test_label[test_mask]).detach().cpu().numpy() 873 | 874 | if logger == None: 875 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 876 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 877 | else: 878 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 879 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 880 | rec_[i] = (test_data@((W@test_data.T).transpose(1,3))@test_data.T).transpose(0,3).detach().cpu() 881 | return time_,trn_psnr_,tst_psnr_,rec_ 882 | 883 | 884 | 885 | 886 | def train_random_simple_3D(signals,encoding_func,ratio=0.125,mask=None,epochs=500,lr=5e-3,criterion=nn.MSELoss(),device="cpu", 887 | depth=5,width=512,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None): 888 | 889 | ''' 890 | Trainer of simple encoding for 3D video reconstruction with randomly sampled points. 891 | 892 | The input args are four parts: 893 | 894 | signals: a torch tensor shape in N_video x sample_N x sample_N x sample_N x 3 895 | encoding_func: encoding function 896 | ratio: if mask is None, randomly sample points for training with this ratio of number of all points 897 | mask: a boolean matrix shape in N_video x N_repeat x (sample_N x sample_N x sample_N) 898 | 899 | epochs,lr,criterion are optimization settings 900 | 901 | depth,width,bias,use_sigmoid are network cofiguration 902 | 903 | N_repeat: number of repeat times for each example 904 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones. 905 | logger: if is None, loggers will print out. 906 | 907 | Returns: 908 | 909 | running time: N_test x N_repeat 910 | train psnr: N_test x N_repeat 911 | test psnr: N_test x N_repeat 912 | reconstuction results: same size as input signals 913 | ''' 914 | 915 | 916 | sample_N =signals.shape[1] 917 | N_test = signals.shape[0] 918 | 919 | 920 | trn_psnr_ = np.zeros((N_test,N_repeat)) 921 | tst_psnr_ = np.zeros((N_test,N_repeat)) 922 | time_ = np.zeros((N_test,N_repeat)) 923 | rec_ = np.zeros((N_test,sample_N,sample_N,sample_N,3)) 924 | 925 | # build 3D coordinate meshgrid 926 | x1 = np.linspace(0, 1, sample_N+1)[:-1] 927 | all_data = np.stack(np.meshgrid(x1,x1,x1), axis=-1) 928 | 929 | 930 | for i in range(N_test): 931 | for k in range(N_repeat): 932 | if mask is None: 933 | idx = torch.randperm(sample_N**3)[:int(ratio*sample_N**3)] 934 | mask_np_N2 = np.zeros((sample_N**3)) 935 | mask_np_N2[idx] = 1 936 | mask_np_N2 = mask_np_N2==1 937 | else: 938 | mask_np_N2 = mask[i,k] 939 | 940 | # Prepare the targets 941 | all_target = signals[i].squeeze() 942 | train_label = all_target.reshape(-1,3)[mask_np_N2].type(torch.FloatTensor) 943 | if test_mode == 'rest': 944 | test_label = all_target.reshape(-1,3)[~mask_np_N2].type(torch.FloatTensor) 945 | elif test_mode == 'all': 946 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor) 947 | 948 | 949 | start_time = time.time() 950 | train_data = encoding_func(torch.from_numpy(all_data.reshape(-1,3)[mask_np_N2]).type(torch.FloatTensor)) 951 | 952 | 953 | if test_mode == 'rest': 954 | test_data = encoding_func(torch.from_numpy(all_data.reshape(-1,3)[~mask_np_N2]).type(torch.FloatTensor)) 955 | elif test_mode == 'all': 956 | test_data = encoding_func(torch.from_numpy(all_data.reshape(-1,3)).type(torch.FloatTensor)) 957 | 958 | else: 959 | print('Wrong test_mode!') 960 | return -1 961 | 962 | train_data, train_label = train_data.to(device),train_label.to(device) 963 | test_data, test_label = test_data.to(device),test_label.to(device) 964 | 965 | 966 | # Initialize classification model to learn 967 | model = MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device) 968 | # Set the optimization 969 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8) 970 | 971 | 972 | for epoch in range(epochs): 973 | model.train() 974 | optimizer.zero_grad() 975 | 976 | out = model(train_data) 977 | loss = criterion(out, train_label) 978 | 979 | loss.backward() 980 | optimizer.step() 981 | time_[i,k] = time.time() - start_time 982 | model.eval() 983 | with torch.no_grad(): 984 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label) 985 | tst_psnr_[i,k] = psnr_func(model(test_data),test_label) 986 | 987 | 988 | if logger == None: 989 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 990 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 991 | else: 992 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 993 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 994 | with torch.no_grad(): 995 | for ff in range(128): 996 | rec_[i,ff] = model(encoding_func(torch.from_numpy(all_data[ff,:,:].reshape(-1,3)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,3).detach().cpu() 997 | return time_, trn_psnr_,tst_psnr_,rec_ 998 | 999 | 1000 | 1001 | 1002 | def train_index_blend_kron_3D(signals,outter_blending,inner_encoding,ratio=0.125,mask=None,sm=0.0,epochs=500,lr=1e-1,criterion=torch.nn.MSELoss(),device="cpu", 1003 | depth=0,width=256,width0=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None): 1004 | 1005 | ''' 1006 | Trainer of complex encoding for 3D video reconstruction with randomly sampled points. 1007 | 1008 | The input args are four parts: 1009 | 1010 | signals: a torch tensor shape in N_video x sample_N x sample_N x sample_N x 3 1011 | encoding_func: encoding function 1012 | ratio: if mask is None, randomly sample points for training with this ratio of number of all points 1013 | mask: a boolean matrix shape in N_video x N_repeat x (sample_N x sample_N x sample_N) 1014 | sm: how much we punish on total variation loss 1015 | 1016 | epochs,lr,criterion are optimization settings 1017 | 1018 | depth,width,width0,bias,use_sigmoid are network cofiguration 1019 | 1020 | N_repeat: number of repeat times for each example 1021 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones. 1022 | logger: if is None, loggers will print out. 1023 | 1024 | Returns: 1025 | 1026 | running time: N_test x N_repeat 1027 | train psnr: N_test x N_repeat 1028 | test psnr: N_test x N_repeat 1029 | reconstuction results: same size as input signals 1030 | ''' 1031 | 1032 | 1033 | sample_N =signals.shape[1] 1034 | N_test = signals.shape[0] 1035 | 1036 | 1037 | trn_psnr_ = np.zeros((N_test,N_repeat)) 1038 | tst_psnr_ = np.zeros((N_test,N_repeat)) 1039 | time_ = np.zeros((N_test,N_repeat)) 1040 | rec_ = np.zeros((N_test,sample_N,sample_N,sample_N,3)) 1041 | 1042 | # the length of all_data is same as 1D signal length 1043 | x1 = np.linspace(0, 1, sample_N+1)[:-1] 1044 | all_data = np.stack(np.meshgrid(x1,x1,x1), axis=-1) 1045 | all_grid = torch.linspace(0, 1, outter_blending.dim+1)[:-1] 1046 | encoded_grid = inner_encoding(all_grid.reshape(-1,1)) 1047 | 1048 | filter=torch.tensor([[[1.0,-1.0]]]).to(device) 1049 | 1050 | 1051 | for i in range(N_test): 1052 | 1053 | for k in range(N_repeat): 1054 | 1055 | if mask is None: 1056 | idx = torch.randperm(sample_N**3)[:int(ratio*sample_N**3)] 1057 | mask_np_N2 = np.zeros((sample_N**3)) 1058 | mask_np_N2[idx] = 1 1059 | mask_np_N2 = mask_np_N2==1 1060 | else: 1061 | mask_np_N2 = mask[i,k] 1062 | 1063 | 1064 | # Prepare the targets 1065 | all_target = signals[i].squeeze() 1066 | train_label = all_target.reshape(-1,3)[mask_np_N2].type(torch.FloatTensor) 1067 | if test_mode == 'rest': 1068 | test_label = all_target.reshape(-1,3)[~mask_np_N2].type(torch.FloatTensor) 1069 | elif test_mode == 'all': 1070 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor) 1071 | 1072 | 1073 | start_time = time.time() 1074 | train_idx,train_data = outter_blending(torch.from_numpy(all_data.reshape(-1,3)[mask_np_N2]).type(torch.FloatTensor)) 1075 | 1076 | 1077 | if test_mode == 'rest': 1078 | test_idx, test_data = outter_blending(torch.from_numpy(all_data.reshape(-1,3)[~mask_np_N2]).type(torch.FloatTensor)) 1079 | elif test_mode == 'all': 1080 | test_idx, test_data = outter_blending(torch.from_numpy(all_data.reshape(-1,3)).type(torch.FloatTensor)) 1081 | 1082 | else: 1083 | print('Wrong test_mode!') 1084 | return -1 1085 | 1086 | train_idx, train_data, train_label = train_idx.to(device), train_data.to(device),train_label.to(device) 1087 | test_idx, test_data, test_label = test_idx.to(device), test_data.to(device),test_label.to(device) 1088 | encoded_grid = encoded_grid.to(device) 1089 | 1090 | # Initialize classification model to learn 1091 | model = Indexing_Blend_Kron3_MLP(input_dim=inner_encoding.dim,output_dim=3,depth=depth,width0=width0,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device) 1092 | # Set the optimization 1093 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8) 1094 | 1095 | for epoch in range(epochs): 1096 | model.train() 1097 | optimizer.zero_grad() 1098 | 1099 | out = model([train_idx,train_data],encoded_grid) 1100 | loss = criterion(out, train_label) 1101 | if sm>0: 1102 | for oo in range(model.first.weight.shape[0]): 1103 | loss += sm*smooth_3D(model.first.weight[oo],filter) 1104 | 1105 | loss.backward() 1106 | optimizer.step() 1107 | 1108 | time_[i,k] = time.time() - start_time 1109 | model.eval() 1110 | with torch.no_grad(): 1111 | trn_psnr_[i,k] = psnr_func(model([train_idx,train_data],encoded_grid),train_label) 1112 | tst_psnr_[i,k] = psnr_func(model([test_idx,test_data],encoded_grid),test_label) 1113 | 1114 | if logger == None: 1115 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 1116 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 1117 | else: 1118 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---" 1119 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time)) 1120 | with torch.no_grad(): 1121 | rec_[i] = model([outter_blending(torch.from_numpy(all_data.reshape(-1,3)).type(torch.FloatTensor))[0].to(device), 1122 | outter_blending(torch.from_numpy(all_data.reshape(-1,3)).type(torch.FloatTensor))[1].to(device)],encoded_grid).reshape(sample_N,sample_N,sample_N,3).detach().cpu() 1123 | return time_,trn_psnr_,tst_psnr_,rec_ 1124 | 1125 | 1126 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | # encoding funcs for 1D, 2D and 3D 8 | class encoding_func_1D: 9 | def __init__(self, name, param=None): 10 | # param should be in form of [sigma,feature_dimension] 11 | self.name = name 12 | 13 | if name == 'none': self.dim=1 14 | elif name == 'basic': self.dim=2 15 | else: 16 | self.dim = param[1] 17 | if name == 'RFF': 18 | self.sig = param[0] 19 | self.b = param[0]*torch.randn((int(param[1]/2),1)) 20 | elif name == 'rffb': 21 | self.b = param[0] 22 | elif name == 'LinF': 23 | self.b = torch.linspace(2.**0., 2.**param[0], steps=int(param[1]/2)).reshape(-1,1) 24 | elif name == 'LogF': 25 | self.b = 2.**torch.linspace(0., param[0], steps=int(param[1]/2)).reshape(-1,1) 26 | elif name == 'Gau': 27 | self.dic = torch.linspace(0., 1, steps=param[1]+1)[:-1].reshape(1,-1) 28 | self.sig = param[0] 29 | elif name == 'Tri': 30 | self.dic = torch.linspace(0., 1, steps=param[1]+1)[:-1].reshape(1,-1) 31 | if param[0] is None: self.d = 1/param[1] 32 | else: self.d = param[0] 33 | else: 34 | print('Undifined encoding!') 35 | def __call__(self, x): 36 | if self.name == 'none': 37 | return x 38 | elif self.name == 'basic': 39 | emb = torch.cat((torch.sin((2.*np.pi*x)),torch.cos((2.*np.pi*x))),1) 40 | emb = emb/(emb.norm(dim=1).max()) 41 | return emb 42 | elif (self.name == 'RFF')|(self.name == 'rffb')|(self.name == 'LinF')|(self.name == 'LogF'): 43 | emb = torch.cat((torch.sin((2.*np.pi*x) @ self.b.T),torch.cos((2.*np.pi*x) @ self.b.T)),1) 44 | emb = emb/(emb.norm(dim=1).max()) 45 | return emb 46 | elif self.name == 'Gau': 47 | emb = (-0.5*(x-self.dic)**2/(self.sig**2)).exp() 48 | emb = emb/(emb.norm(dim=1).max()) 49 | return emb 50 | elif self.name == 'Tri': 51 | emb = (1-(x-self.dic).abs()/self.d) 52 | emb = emb*(emb>0) 53 | emb = emb/(emb.norm(dim=1).max()) 54 | return emb 55 | 56 | 57 | # simple 2D encoding. For Fourier based methods we use 2 directions. For shifted verision encoder we use 4 directions. The total feature dimension is same. 58 | class encoding_func_2D: 59 | def __init__(self, name, param=None): 60 | self.name = name 61 | 62 | if name == 'none': self.dim=2 63 | elif name == 'basic': self.dim=4 64 | else: 65 | self.dim = param[1] 66 | if name == 'RFF': 67 | self.b = param[0]*torch.randn((int(param[1]/2),2)) 68 | elif name == 'rffb': 69 | self.b = param[0] 70 | elif name == 'LinF': 71 | self.b = torch.linspace(2.**0., 2.**param[0], steps=int(param[1]/4)).reshape(-1,1) 72 | elif name == 'LogF': 73 | self.b = 2.**torch.linspace(0., param[0], steps=int(param[1]/4)).reshape(-1,1) 74 | elif name == 'Gau2': 75 | self.dic = torch.linspace(0., 1, steps=int(param[1]/2)+1)[:-1].reshape(1,-1) 76 | self.sig = param[0] 77 | elif (name == 'Gau4')|(name == 'Gau'): 78 | self.dic = torch.linspace(0., 1, steps=int(param[1]/4)+1)[:-1].reshape(1,-1) 79 | self.sig = param[0] 80 | elif name == 'Tri2': 81 | self.dic = torch.linspace(0., 1, steps=int(param[1]/2)+1)[:-1].reshape(1,-1) 82 | if param[0] is None: self.d = 1/param[1] 83 | else: self.d = param[0] 84 | elif (name == 'Tri4')|(name == 'Tri'): 85 | self.dic = torch.linspace(0., 1, steps=int(param[1]/4)+1)[:-1].reshape(1,-1) 86 | if param[0] is None: self.d = 1/param[1] 87 | else: self.d = param[0] 88 | else: 89 | print('Undifined encoding!') 90 | def __call__(self, x): 91 | if self.name == 'none': 92 | return x 93 | elif self.name == 'basic': 94 | emb = torch.cat((torch.sin((2.*np.pi*x)),torch.cos((2.*np.pi*x))),1) 95 | emb = emb/(emb.norm(dim=1).max()) 96 | return emb 97 | elif (self.name == 'RFF')|(self.name == 'rffb'): 98 | emb = torch.cat((torch.sin((2.*np.pi*x) @ self.b.T),torch.cos((2.*np.pi*x) @ self.b.T)),1) 99 | emb = emb/(emb.norm(dim=1).max()) 100 | return emb 101 | elif (self.name == 'LinF')|(self.name == 'LogF'): 102 | emb1 = torch.cat((torch.sin((2.*np.pi*x[:,:1]) @ self.b.T),torch.cos((2.*np.pi*x[:,:1]) @ self.b.T)),1) 103 | emb2 = torch.cat((torch.sin((2.*np.pi*x[:,1:2]) @ self.b.T),torch.cos((2.*np.pi*x[:,1:2]) @ self.b.T)),1) 104 | emb = torch.cat([emb1,emb2],1) 105 | emb = emb/(emb.norm(dim=1).max()) 106 | return emb 107 | elif self.name == 'Gau2': 108 | emb1 = (-0.5*(x[:,:1]-self.dic)**2/(self.sig**2)).exp() 109 | emb2 = (-0.5*(x[:,1:2]-self.dic)**2/(self.sig**2)).exp() 110 | emb = torch.cat([emb1,emb2],1) 111 | emb = emb/(emb.norm(dim=1).max()) 112 | return emb 113 | elif (self.name == 'Gau4')|(self.name == 'Gau'): 114 | emb1 = (-0.5*(x[:,:1]-self.dic)**2/(self.sig**2)).exp() 115 | emb2 = (-0.5*(x[:,1:2]-self.dic)**2/(self.sig**2)).exp() 116 | emb3 = (-0.5*(0.5*(x[:,:1]+x[:,1:2])-self.dic)**2/(self.sig**2)).exp() 117 | emb4 = (-0.5*(0.5*(x[:,:1]-x[:,1:2]+1)-self.dic)**2/(self.sig**2)).exp() 118 | emb = torch.cat([emb1,emb2,emb3,emb4],1) 119 | emb = emb/(emb.norm(dim=1).max()) 120 | return emb 121 | elif self.name == 'Tri2': 122 | emb1 = (1-(x[:,:1]-self.dic).abs()/self.d) 123 | emb1 = emb1*(emb1>0) 124 | emb2 = (1-(x[:,1:2]-self.dic).abs()/self.d) 125 | emb2 = emb2*(emb2>0) 126 | emb = torch.cat([emb1,emb2],1) 127 | emb = emb/(emb.norm(dim=1).max()) 128 | return emb 129 | elif (self.name == 'Tri4')|(self.name == 'Tri'): 130 | emb1 = (1-(x[:,:1]-self.dic).abs()/self.d) 131 | emb1 = emb1*(emb1>0) 132 | emb2 = (1-(x[:,1:2]-self.dic).abs()/self.d) 133 | emb2 = emb2*(emb2>0) 134 | emb3 = (1-(0.5*(x[:,:1]+x[:,1:2])-self.dic).abs()/self.d) 135 | emb3 = emb3*(emb3>0) 136 | emb4 = (1-(0.5*(x[:,:1]-x[:,1:2]+1)-self.dic).abs()/self.d) 137 | emb4 = emb4*(emb4>0) 138 | emb = torch.cat([emb1,emb2,emb3,emb4],1) 139 | emb = emb/(emb.norm(dim=1).max()) 140 | return emb 141 | 142 | 143 | # simple 3D encoding. All method use 3 directions. 144 | class encoding_func_3D: 145 | def __init__(self, name, param=None): 146 | self.name = name 147 | 148 | if name == 'none': self.dim=2 149 | elif name == 'basic': self.dim=4 150 | else: 151 | self.dim = param[1] 152 | if name == 'RFF': 153 | self.b = param[0]*torch.randn((int(param[1]/2),3)) 154 | elif name == 'rffb': 155 | self.b = param[0] 156 | elif name == 'LinF': 157 | self.b = torch.linspace(2.**0., 2.**param[0], steps=int(param[1]/6)).reshape(-1,1) 158 | elif name == 'LogF': 159 | self.b = 2.**torch.linspace(0., param[0], steps=int(param[1]/6)).reshape(-1,1) 160 | elif name == 'Gau': 161 | self.dic = torch.linspace(0., 1, steps=int(param[1]/3)+1)[:-1].reshape(1,-1) 162 | self.sig = param[0] 163 | elif name == 'Tri': 164 | self.dic = torch.linspace(0., 1, steps=int(param[1]/3)+1)[:-1].reshape(1,-1) 165 | if param[0] is None: self.d = 1/param[1] 166 | else: self.d = param[0] 167 | else: 168 | print('Undifined encoding!') 169 | def __call__(self, x): 170 | if self.name == 'none': 171 | return x 172 | elif self.name == 'basic': 173 | emb = torch.cat((torch.sin((2.*np.pi*x)),torch.cos((2.*np.pi*x))),1) 174 | emb = emb/(emb.norm(dim=1).max()) 175 | return emb 176 | elif (self.name == 'RFF')|(self.name == 'rffb'): 177 | emb = torch.cat((torch.sin((2.*np.pi*x) @ self.b.T),torch.cos((2.*np.pi*x) @ self.b.T)),1) 178 | emb = emb/(emb.norm(dim=1).max()) 179 | return emb 180 | elif (self.name == 'LinF')|(self.name == 'LogF'): 181 | emb1 = torch.cat((torch.sin((2.*np.pi*x[:,:1]) @ self.b.T),torch.cos((2.*np.pi*x[:,:1]) @ self.b.T)),1) 182 | emb2 = torch.cat((torch.sin((2.*np.pi*x[:,1:2]) @ self.b.T),torch.cos((2.*np.pi*x[:,1:2]) @ self.b.T)),1) 183 | emb3 = torch.cat((torch.sin((2.*np.pi*x[:,2:3]) @ self.b.T),torch.cos((2.*np.pi*x[:,2:3]) @ self.b.T)),1) 184 | emb = torch.cat([emb1,emb2,emb3],1) 185 | emb = emb/(emb.norm(dim=1).max()) 186 | return emb 187 | elif self.name == 'Gau': 188 | emb1 = (-0.5*(x[:,:1]-self.dic)**2/(self.sig**2)).exp() 189 | emb2 = (-0.5*(x[:,1:2]-self.dic)**2/(self.sig**2)).exp() 190 | emb3 = (-0.5*(x[:,2:3]-self.dic)**2/(self.sig**2)).exp() 191 | emb = torch.cat([emb1,emb2,emb3],1) 192 | emb = emb/(emb.norm(dim=1).max()) 193 | return emb 194 | elif self.name == 'Tri': 195 | emb1 = (1-(x[:,:1]-self.dic).abs()/self.d) 196 | emb1 = emb1*(emb1>0) 197 | emb2 = (1-(x[:,1:2]-self.dic).abs()/self.d) 198 | emb2 = emb2*(emb2>0) 199 | emb3 = (1-(x[:,2:3]-self.dic).abs()/self.d) 200 | emb3 = emb3*(emb3>0) 201 | emb = torch.cat([emb1,emb2,emb3],1) 202 | emb = emb/(emb.norm(dim=1).max()) 203 | return emb 204 | 205 | 206 | # blending matrix for 2D random samples 207 | class blending_func_2D: 208 | ''' 209 | encoding_func : inner encoding func. The name will be used to choose corresponding closed form distance func, which will be much faster than calculating experimentally. 210 | dim : number of inner encoing func to inerpolate, usually equals to the feature dimension of inner encoding 211 | indexing: defalt as True to return grid point index and weights. Set False to return a sparse marix but will be sloer in the future. 212 | ''' 213 | def __init__(self, encoding_func, dim=256,indexing=True): 214 | self.name = encoding_func.name 215 | self.indexing = indexing 216 | 217 | if dim is None: self.dim = encoding_func.dim 218 | else: self.dim = dim 219 | 220 | if self.name == 'RFF': 221 | self.D = lambda x1,x2: (-2*(np.pi*(x1-x2)/(self.dim-1)*encoding_func.sig)**2).exp() 222 | elif self.name == 'Gau': 223 | self.D = lambda x1,x2: (-0.25*((x1-x2)/(self.dim-1))**2/(encoding_func.sig**2)).exp() 224 | elif self.name == 'Tri': 225 | self.D = lambda x1,x2: 0.25*torch.maximum(2*encoding_func.d-(x1-x2).abs()/(self.dim-1),torch.tensor(0))**2 226 | else: 227 | self.D = lambda x1,x2: (encoding_func(x1/(self.dim-1))*encoding_func(x2/(self.dim-1))).sum(-1).unsqueeze(1) 228 | 229 | def __call__(self, x): 230 | # make x in the grid 231 | x = x.clamp(0,1-1e-3) 232 | x = (self.dim-1)*x 233 | y = x[:,1:2] 234 | x = x[:,:1] 235 | xmin = torch.floor(x) 236 | ymin = torch.floor(y) 237 | 238 | 239 | xd0 = self.D(xmin,xmin) 240 | xdd = self.D(xmin,xmin+1) 241 | xd1 = self.D(xmin+1,xmin+1) 242 | xda = self.D(xmin,x) 243 | xdb = self.D(xmin+1,x) 244 | xff = xd0*xd1-xdd**2 245 | 246 | xa = (xda*xd1-xdb*xdd)/xff 247 | xb = (xdb*xd0-xda*xdd)/xff 248 | 249 | yd0 = self.D(ymin,ymin) 250 | ydd = self.D(ymin,ymin+1) 251 | yd1 = self.D(ymin+1,ymin+1) 252 | yda = self.D(ymin,y) 253 | ydb = self.D(ymin+1,y) 254 | yff = yd0*yd1-ydd**2 255 | 256 | ya = (yda*yd1-ydb*ydd)/yff 257 | yb = (ydb*yd0-yda*ydd)/yff 258 | 259 | xs = xa+xb 260 | xa = xa/xs 261 | xb = xb/xs 262 | 263 | ys = ya+yb 264 | ya = ya/ys 265 | yb = yb/ys 266 | 267 | 268 | if self.indexing: 269 | return [torch.cat([xmin*self.dim+ymin,xmin*self.dim+ymin+1,(xmin+1)*self.dim+ymin,(xmin+1)*self.dim+ymin+1],1).type(torch.LongTensor),torch.cat([xa*ya,xa*yb,xb*ya,xb*yb],1)] 270 | else: 271 | c = torch.cat([xa*ya,xa*yb,xb*ya,xb*yb],0) 272 | 273 | y = torch.cat([xmin*self.dim+ymin,xmin*self.dim+ymin+1,(xmin+1)*self.dim+ymin,(xmin+1)*self.dim+ymin+1],0).type(torch.IntTensor) 274 | x = torch.linspace(0,x.shape[0]-1,x.shape[0],dtype=int).reshape(-1,1).repeat(4,1) 275 | return torch.sparse_coo_tensor(torch.cat([x,y],1).T, c.reshape(-1), (int(x.shape[0]/4), self.dim**2)) 276 | 277 | 278 | 279 | 280 | 281 | class blending_func_3D: 282 | ''' 283 | encoding_func : inner encoding func. The name will be used to choose corresponding closed form distance func, which will be much faster than calculating experimentally. 284 | dim : number of inner encoing func to inerpolate, usually equals to the feature dimension of inner encoding 285 | indexing: defalt as True to return grid point index and weights. Set False to return a sparse marix but will be sloer in the future. 286 | ''' 287 | def __init__(self, encoding_func, dim=None,indexing=True): 288 | self.name = encoding_func.name 289 | self.indexing = indexing 290 | 291 | if dim is None: self.dim = encoding_func.dim 292 | else: self.dim = dim 293 | 294 | if self.name == 'RFF': 295 | self.D = lambda x1,x2: (-2*(np.pi*(x1-x2)/(self.dim-1)*encoding_func.sig)**2).exp() 296 | elif self.name == 'Gau': 297 | self.D = lambda x1,x2: (-0.25*((x1-x2)/(self.dim-1))**2/(encoding_func.sig**2)).exp() 298 | elif self.name == 'Tri': 299 | self.D = lambda x1,x2: 0.25*torch.maximum(2*encoding_func.d-(x1-x2).abs()/(self.dim-1),torch.tensor(0))**2 300 | else: 301 | self.D = lambda x1,x2: (encoding_func(x1/(self.dim-1))*encoding_func(x2/(self.dim-1))).sum(-1).unsqueeze(1) 302 | 303 | def __call__(self, x): 304 | # make x in the grid 305 | x = x.clamp(0,1-1e-3) 306 | x = (self.dim-1)*x 307 | y = x[:,1:2] 308 | z = x[:,2:3] 309 | x = x[:,:1] 310 | xmin = torch.floor(x) 311 | ymin = torch.floor(y) 312 | zmin = torch.floor(z) 313 | 314 | if self.name=='RFF' or self.name=='Gau' or self.name=='Tri': 315 | d0 = self.D(torch.tensor([0]),torch.tensor([0])) 316 | dd = self.D(torch.tensor([0]),torch.tensor([1])) 317 | ff = d0**2-dd**2 318 | 319 | xda = self.D(xmin,x) 320 | xdb = self.D(xmin+1,x) 321 | 322 | xa = (xda*d0-xdb*dd)/ff 323 | xb = (xdb*d0-xda*dd)/ff 324 | 325 | yda = self.D(ymin,y) 326 | ydb = self.D(ymin+1,y) 327 | 328 | ya = (yda*d0-ydb*dd)/ff 329 | yb = (ydb*d0-yda*dd)/ff 330 | 331 | zda = self.D(zmin,z) 332 | zdb = self.D(zmin+1,z) 333 | 334 | za = (zda*d0-zdb*dd)/ff 335 | zb = (zdb*d0-zda*dd)/ff 336 | 337 | xs = xa+xb 338 | xa = xa/xs 339 | xb = xb/xs 340 | 341 | ys = ya+yb 342 | ya = ya/ys 343 | yb = yb/ys 344 | 345 | zs = za+zb 346 | za = za/zs 347 | zb = zb/zs 348 | 349 | N=self.dim 350 | Ns = x.shape[0] 351 | if self.indexing: 352 | c = torch.cat([xa*ya*za,xa*ya*zb,xa*yb*za,xa*yb*zb,xb*ya*za,xb*ya*zb,xb*yb*za,xb*yb*zb],1) 353 | y = torch.cat([xmin*N**2+ymin*N+zmin,xmin*N**2+ymin*N+zmin+1,xmin*N**2+(ymin+1)*N+zmin,xmin*N**2+(ymin+1)*N+zmin+1, 354 | (xmin+1)*N**2+ymin*N+zmin,(xmin+1)*N**2+ymin*N+zmin+1,(xmin+1)*N**2+(ymin+1)*N+zmin,(xmin+1)*N**2+(ymin+1)*N+zmin+1],1).type(torch.LongTensor) 355 | return [y,c] 356 | else: 357 | c = torch.cat([xa*ya*za,xa*ya*zb,xa*yb*za,xa*yb*zb,xb*ya*za,xb*ya*zb,xb*yb*za,xb*yb*zb],0) 358 | 359 | 360 | y = torch.cat([xmin*N**2+ymin*N+zmin,xmin*N**2+ymin*N+zmin+1,xmin*N**2+(ymin+1)*N+zmin,xmin*N**2+(ymin+1)*N+zmin+1, 361 | (xmin+1)*N**2+ymin*N+zmin,(xmin+1)*N**2+ymin*N+zmin+1,(xmin+1)*N**2+(ymin+1)*N+zmin,(xmin+1)*N**2+(ymin+1)*N+zmin+1],0).type(torch.IntTensor) 362 | x = torch.range(0,Ns-1,dtype=int).reshape(-1,1).repeat(8,1) 363 | return torch.sparse_coo_tensor(torch.cat([x,y],1).T, c.reshape(-1), (Ns, self.dim**3)) 364 | 365 | else: 366 | xd0 = self.D(xmin,xmin) 367 | xdd = self.D(xmin,xmin+1) 368 | xd1 = self.D(xmin+1,xmin+1) 369 | xda = self.D(xmin,x) 370 | xdb = self.D(xmin+1,x) 371 | xff = xd0*xd1-xdd**2 372 | 373 | xa = (xda*xd1-xdb*xdd)/xff 374 | xb = (xdb*xd0-xda*xdd)/xff 375 | 376 | yd0 = self.D(ymin,ymin) 377 | ydd = self.D(ymin,ymin+1) 378 | yd1 = self.D(ymin+1,ymin+1) 379 | yda = self.D(ymin,y) 380 | ydb = self.D(ymin+1,y) 381 | yff = yd0*yd1-ydd**2 382 | 383 | ya = (yda*yd1-ydb*ydd)/yff 384 | yb = (ydb*yd0-yda*ydd)/yff 385 | 386 | 387 | zd0 = self.D(zmin,zmin) 388 | zdd = self.D(zmin,zmin+1) 389 | zd1 = self.D(zmin+1,zmin+1) 390 | zda = self.D(zmin,z) 391 | zdb = self.D(zmin+1,z) 392 | zff = zd0*zd1-zdd**2 393 | 394 | za = (zda*zd1-zdb*zdd)/zff 395 | zb = (zdb*zd0-zda*zdd)/zff 396 | 397 | xs = xa+xb 398 | xa = xa/xs 399 | xb = xb/xs 400 | 401 | ys = ya+yb 402 | ya = ya/ys 403 | yb = yb/ys 404 | 405 | zs = za+zb 406 | za = za/zs 407 | zb = zb/zs 408 | 409 | N=self.dim 410 | Ns = x.shape[0] 411 | if self.indexing: 412 | c = torch.cat([xa*ya*za,xa*ya*zb,xa*yb*za,xa*yb*zb,xb*ya*za,xb*ya*zb,xb*yb*za,xb*yb*zb],1) 413 | y = torch.cat([xmin*N**2+ymin*N+zmin,xmin*N**2+ymin*N+zmin+1,xmin*N**2+(ymin+1)*N+zmin,xmin*N**2+(ymin+1)*N+zmin+1, 414 | (xmin+1)*N**2+ymin*N+zmin,(xmin+1)*N**2+ymin*N+zmin+1,(xmin+1)*N**2+(ymin+1)*N+zmin,(xmin+1)*N**2+(ymin+1)*N+zmin+1],1).type(torch.LongTensor) 415 | return [y,c] 416 | else: 417 | c = torch.cat([xa*ya*za,xa*ya*zb,xa*yb*za,xa*yb*zb,xb*ya*za,xb*ya*zb,xb*yb*za,xb*yb*zb],0) 418 | 419 | 420 | y = torch.cat([xmin*N**2+ymin*N+zmin,xmin*N**2+ymin*N+zmin+1,xmin*N**2+(ymin+1)*N+zmin,xmin*N**2+(ymin+1)*N+zmin+1, 421 | (xmin+1)*N**2+ymin*N+zmin,(xmin+1)*N**2+ymin*N+zmin+1,(xmin+1)*N**2+(ymin+1)*N+zmin,(xmin+1)*N**2+(ymin+1)*N+zmin+1],0).type(torch.IntTensor) 422 | x = torch.range(0,Ns-1,dtype=int).reshape(-1,1).repeat(8,1) 423 | return torch.sparse_coo_tensor(torch.cat([x,y],1).T, c.reshape(-1), (Ns, self.dim**3)) 424 | 425 | 426 | 427 | class fast_blending_func_3D: 428 | ''' 429 | A fast version (may not really) without many options for common use. 430 | encoding_func : inner encoding func. The name will be used to choose corresponding closed form distance func, which will be much faster than calculating experimentally. 431 | dim : number of inner encoing func to inerpolate, usually equals to the feature dimension of inner encoding 432 | ''' 433 | def __init__(self, encoding_func, dim=None): 434 | self.name = encoding_func.name 435 | 436 | if dim is None: self.dim = encoding_func.dim 437 | else: self.dim = dim 438 | 439 | if self.name == 'RFF': 440 | self.D = lambda t: (-2*(np.pi*(t)/(self.dim-1)*encoding_func.sig)**2).exp() 441 | elif self.name == 'Gau': 442 | self.D = lambda t: (-0.25*((t)/(self.dim-1))**2/(encoding_func.sig**2)).exp() 443 | elif self.name == 'Tri': 444 | self.D = lambda t: 0.25*torch.maximum(2*encoding_func.d-(t).abs()/(self.dim-1),torch.tensor(0))**2 445 | else: 446 | print('No fast version!') 447 | 448 | # constant coefficients 449 | d0 = self.D(torch.tensor([0])) 450 | dd = self.D(torch.tensor([1])) 451 | ff = d0**2-dd**2 452 | self.coe = torch.tensor([[d0,-dd],[-dd,d0]])/ff 453 | 454 | self.dim2 = self.dim**2 455 | self.idx_matrix = torch.tensor([self.dim**2,self.dim,1.0]).reshape(3,1) 456 | 457 | def __call__(self, x): 458 | # make x in the grid 459 | x = x.clamp(0,1-1e-3) 460 | x = (self.dim-1)*x 461 | xmin = torch.floor(x) 462 | 463 | d_ratio = torch.stack([self.D(xmin-x),self.D(xmin+1-x)],-1) 464 | itp_weights = d_ratio@self.coe 465 | 466 | # calcualte mixing product 467 | xy = itp_weights[:,0,0]*itp_weights[:,1,0] 468 | xy1 = itp_weights[:,0,0]*itp_weights[:,1,1] 469 | x1y = itp_weights[:,0,1]*itp_weights[:,1,0] 470 | x1y1 = itp_weights[:,0,1]*itp_weights[:,1,1] 471 | z = itp_weights[:,2,0] 472 | z1 = itp_weights[:,2,1] 473 | c = torch.stack([xy*z,xy*z1,xy1*z,xy1*z1,x1y*z,x1y*z1,x1y1*z,x1y1*z1],1) 474 | 475 | x_idx = xmin@self.idx_matrix 476 | y = torch.cat([x_idx,x_idx+1],1) 477 | y = torch.cat([y,y+self.dim],1) 478 | y = torch.cat([y,y+self.dim2],1) 479 | return [y.type(torch.LongTensor),c] 480 | 481 | 482 | 483 | def srank_func(X,return_l1=False): 484 | (_,s,_) = torch.svd(X) 485 | sr2 = (s*s).sum()/s[0]/s[0] 486 | sr1 = s.sum()/s[0] 487 | if return_l1: 488 | return sr1,sr2 489 | else: 490 | return sr2 491 | 492 | 493 | def psnr_func(x,y,return_mse=False): 494 | diff = x - y 495 | mse = (diff*diff).flatten().mean() 496 | if return_mse: 497 | return -10*(mse.log10()),mse 498 | else: 499 | return -10*(mse.log10()) 500 | 501 | 502 | def smooth(X,filter): 503 | sx = (torch.nn.functional.conv1d(X.unsqueeze(1),filter)**2).mean() 504 | sy = (torch.nn.functional.conv1d(X.T.unsqueeze(1),filter)**2).mean() 505 | return sx+sy 506 | 507 | 508 | def smooth_3D(X,filter): 509 | sx = (torch.nn.functional.conv1d(X.flatten(0,1).unsqueeze(1),filter)**2).mean() 510 | sy = (torch.nn.functional.conv1d(X.transpose(1,2).flatten(0,1).unsqueeze(1),filter)**2).mean() 511 | sz = (torch.nn.functional.conv1d(X.transpose(0,2).flatten(0,1).unsqueeze(1),filter)**2).mean() 512 | return sx+sy+sz 513 | --------------------------------------------------------------------------------