├── 2D_random_points.py
├── 2D_separable_points.py
├── 3D_random_points_GD.py
├── 3D_separable_points_GD.py
├── LICENSE
├── MLPs.py
├── README.md
├── complex_encoding.ipynb
├── data.py
├── imgs
└── simple_complex_encoding.png
├── trainer.py
└── utils.py
/2D_random_points.py:
--------------------------------------------------------------------------------
1 | import torch, os, logging
2 | from argparse import ArgumentParser
3 | import numpy as np
4 | from matplotlib import pyplot as plt
5 |
6 | from trainer import *
7 | from utils import *
8 | from MLPs import *
9 |
10 |
11 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
12 |
13 | def get_logger(filename, verbosity=1, name=None):
14 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
15 | formatter = logging.Formatter(
16 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
17 | )
18 | logger = logging.getLogger(name)
19 | logger.setLevel(level_dict[verbosity])
20 |
21 | fh = logging.FileHandler(filename, "w")
22 | fh.setFormatter(formatter)
23 | logger.addHandler(fh)
24 |
25 | sh = logging.StreamHandler()
26 | sh.setFormatter(formatter)
27 | logger.addHandler(sh)
28 |
29 | return logger
30 |
31 |
32 | def main(args):
33 |
34 | if os.path.exists(args.save_path):
35 | print('Path already exists!')
36 | return 1
37 | os.mkdir(args.save_path)
38 | logger = get_logger(args.save_path+args.logger)
39 | logger.info(args)
40 |
41 | # Set the CUDA flag
42 | device = "cuda" if torch.cuda.is_available() else "cpu"
43 | logger.info('device is: {}'.format(device))
44 |
45 |
46 | buf = np.load(args.data_path)
47 |
48 |
49 | #mask = np.load(args.mask_path)
50 | # ########### generate random mask ###########
51 | image_size = buf['test_data'].shape[1]
52 | ratio = 0.25
53 | mask = []
54 | for i in range(buf['test_data'].shape[0]):
55 | mask_tmp = []
56 | for j in range(args.N_repeat):
57 | idx = torch.randperm(image_size**2)[:int(ratio*image_size**2)]
58 | mask_np_N2 = np.zeros((image_size**2))
59 | mask_np_N2[idx] = 1
60 | mask_np_N2 = mask_np_N2==1
61 | mask_tmp.append(mask_np_N2)
62 | mask.append(np.stack(mask_tmp,0))
63 | mask = np.stack(mask,0)
64 | #np.save('mask_2d_{}_{}_{}_{}.npy'.format(ratio,buf.shape[0],args.N_repeat,buf.shape[1]),mask)
65 | # ###########################################
66 |
67 |
68 | signals = torch.from_numpy(buf['test_data']/ 255.)
69 |
70 | logger.info('################ Simple Encoding ################')
71 |
72 | ez = 512
73 | rff_params = [8]
74 | linearf_params = [4]
75 | logf_params = [4.5]
76 | gaussian4_params = [0.006]
77 | linear4_params = [2.5/128]
78 |
79 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian4_params)+['Tri']*len(linear4_params)
80 | params = rff_params+linearf_params+logf_params+gaussian4_params+linear4_params
81 |
82 | print(encoding_methods)
83 | print(params)
84 |
85 | for depth in [4,1,0]:
86 | for lr in [5e-3]:
87 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr))
88 |
89 | for em,param in zip(encoding_methods,params):
90 |
91 | ef = encoding_func_2D(em,[param,ez])
92 | time_,trn_psnr_,tst_psnr_,rec_ = train_random_simple_2D(signals,ef,mask=mask,N_repeat=args.N_repeat,lr=lr,epochs=2000,depth=depth,device=device,logger=None)
93 | file_name = 'RD{}{}'.format(depth,em)
94 | if args.save_flag:
95 | # np.save(args.save_path+file_name+'_rec.npy',rec_)
96 | # np.save(args.save_path+file_name+'_time.npy',time_)
97 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_)
98 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_)
99 | for i in range(signals.shape[0]):
100 | plt.imshow(rec_[i])
101 | plt.axis('off')
102 | plt.savefig(args.save_path+'I{}'.format(i)+file_name+'{}.pdf'.format(tst_psnr_[i,-1]), bbox_inches='tight', pad_inches=0)
103 |
104 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:])))
105 |
106 | logger.info('################ Complex Encoding ################')
107 |
108 | ez = 256
109 | rff_params = [30]
110 | linearf_params = [4.5]
111 | logf_params = [6]
112 | gaussian_params = [0.005]
113 | linear_params = [2.5/256]
114 |
115 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params)
116 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params
117 |
118 | for depth in [0,1]:
119 | for lr in [1e-1]:
120 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr))
121 | for em,param in zip(encoding_methods,params):
122 |
123 | ef = encoding_func_1D(em,[param,ez])
124 | bl = blending_func_2D(ef)
125 | time_,trn_psnr_,tst_psnr_,rec_ = train_index_blend_kron_2D(signals,bl,ef,mask=mask,N_repeat=args.N_repeat,lr=lr,epochs=2000,depth=depth,device=device,logger=None)
126 | file_name = 'RKD{}{}'.format(depth,em)
127 | if args.save_flag:
128 | # np.save(args.save_path+file_name+'_rec.npy',rec_)
129 | # np.save(args.save_path+file_name+'_time.npy',time_)
130 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_)
131 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_)
132 | for i in range(signals.shape[0]):
133 | plt.imshow(rec_[i])
134 | plt.axis('off')
135 | plt.savefig(args.save_path+'I{}'.format(i)+file_name+'{}.pdf'.format(tst_psnr_[i,-1]), bbox_inches='tight', pad_inches=0)
136 |
137 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:])))
138 |
139 |
140 |
141 | if __name__ == "__main__":
142 | torch.set_default_dtype(torch.float32)
143 | torch.manual_seed(20220222)
144 | np.random.seed(20220222)
145 |
146 |
147 | parser = ArgumentParser()
148 |
149 | parser.add_argument("--data_path", type=str, default="data_div2k.npz")
150 | parser.add_argument("--mask_path", type=str, default="mask_2d_0.25_16_1_512.npy")
151 | parser.add_argument("--N_repeat", type=int, default=1)
152 | parser.add_argument("--save_path", type=str, default="2D_random_points/")
153 | parser.add_argument("--logger", type=str, default="log.log")
154 | parser.add_argument("--save_flag", type=int, default=0, choices=[0, 1])
155 |
156 |
157 | args = parser.parse_args()
158 |
159 | main(args)
--------------------------------------------------------------------------------
/2D_separable_points.py:
--------------------------------------------------------------------------------
1 | import torch, os, logging
2 | from argparse import ArgumentParser
3 | import numpy as np
4 | from matplotlib import pyplot as plt
5 |
6 | from trainer import *
7 | from utils import *
8 | from MLPs import *
9 |
10 |
11 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
12 |
13 | def get_logger(filename, verbosity=1, name=None):
14 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
15 | formatter = logging.Formatter(
16 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
17 | )
18 | logger = logging.getLogger(name)
19 | logger.setLevel(level_dict[verbosity])
20 |
21 | fh = logging.FileHandler(filename, "w")
22 | fh.setFormatter(formatter)
23 | logger.addHandler(fh)
24 |
25 | sh = logging.StreamHandler()
26 | sh.setFormatter(formatter)
27 | logger.addHandler(sh)
28 |
29 | return logger
30 |
31 |
32 | def main(args):
33 |
34 | if os.path.exists(args.save_path):
35 | print('Path already exists!')
36 | return 1
37 | os.mkdir(args.save_path)
38 | logger = get_logger(args.save_path+args.logger)
39 | logger.info(args)
40 |
41 | # Set the CUDA flag
42 | device = "cuda" if torch.cuda.is_available() else "cpu"
43 | logger.info('device is: {}'.format(device))
44 |
45 |
46 | buf = np.load(args.data_path)
47 |
48 | signals = torch.from_numpy(buf['test_data']/ 255.)
49 |
50 | logger.info('################ Simple Encoding ################')
51 |
52 | ez = 512
53 | rff_params = [14]
54 | linearf_params = [5]
55 | logf_params = [5.5]
56 | gaussian4_params = [0.006]
57 | linear4_params = [2.5/128]
58 |
59 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian4_params)+['Tri']*len(linear4_params)
60 | params = rff_params+linearf_params+logf_params+gaussian4_params+linear4_params
61 |
62 | print(encoding_methods)
63 | print(params)
64 |
65 | for depth in [4,1,0]:
66 | for lr in [5e-3]:
67 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr))
68 |
69 | for em,param in zip(encoding_methods,params):
70 |
71 | ef = encoding_func_2D(em,[param,ez])
72 | time_,trn_psnr_,tst_psnr_,rec_ = train_simple_2D(signals,ef,N_repeat=args.N_repeat,lr=lr,epochs=2000,depth=depth,device=device,logger=None)
73 | file_name = 'D{}{}'.format(depth,em)
74 | if args.save_flag:
75 | # np.save(args.save_path+file_name+'_rec.npy',rec_)
76 | # np.save(args.save_path+file_name+'_time.npy',time_)
77 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_)
78 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_)
79 | for i in range(signals.shape[0]):
80 | plt.imshow(rec_[i])
81 | plt.axis('off')
82 | plt.savefig(args.save_path+'I{}'.format(i)+file_name+'{}.pdf'.format(tst_psnr_[i,-1]), bbox_inches='tight', pad_inches=0)
83 |
84 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:])))
85 |
86 | logger.info('################ Complex Encoding ################')
87 |
88 | ez = 256
89 | rff_params = [42]
90 | linearf_params = [4.5]
91 | logf_params = [6.5]
92 | gaussian_params = [0.004]
93 | linear_params = [2/256]
94 |
95 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params)
96 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params
97 |
98 | for depth in [0,1]:
99 | for lr in [1e-1]:
100 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr))
101 |
102 | for em,param in zip(encoding_methods,params):
103 |
104 | ef = encoding_func_1D(em,[param,ez])
105 | time_,trn_psnr_,tst_psnr_,rec_ = train_kron_2D(signals,ef,N_repeat=args.N_repeat,lr=lr,epochs=2000,depth=depth,device=device,logger=None)
106 | file_name = 'KD{}{}'.format(depth,em)
107 | if args.save_flag:
108 | # np.save(args.save_path+file_name+'_rec.npy',rec_)
109 | # np.save(args.save_path+file_name+'_time.npy',time_)
110 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_)
111 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_)
112 | for i in range(signals.shape[0]):
113 | plt.imshow(rec_[i])
114 | plt.axis('off')
115 | plt.savefig(args.save_path+'I{}'.format(i)+file_name+'{}.pdf'.format(tst_psnr_[i,-1]), bbox_inches='tight', pad_inches=0)
116 |
117 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:])))
118 |
119 |
120 | logger.info('################ Closed Form Complex Encoding ################')
121 |
122 |
123 | ez = 256
124 | rff_params = [42]
125 | linearf_params = [6]
126 | logf_params = [6.5]
127 | gaussian_params = [0.0025]
128 | linear_params = [1/256]
129 |
130 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params)
131 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params
132 |
133 |
134 | for em,param in zip(encoding_methods,params):
135 |
136 | if em == 'Gau' or em == 'Tri': em=em
137 | else: device = 'cpu'
138 |
139 | ef = encoding_func_1D(em,[param,ez])
140 | time_,trn_psnr_,tst_psnr_,rec_ = train_closed_form_2D(signals,ef,N_repeat=args.N_repeat,device=device,logger=None)
141 | file_name = 'CKD{}{}'.format(depth,em)
142 | if args.save_flag:
143 | # np.save(args.save_path+file_name+'_rec.npy',rec_)
144 | # np.save(args.save_path+file_name+'_time.npy',time_)
145 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_)
146 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_)
147 | for i in range(signals.shape[0]):
148 | plt.imshow(rec_[i])
149 | plt.axis('off')
150 | plt.savefig(args.save_path+'I{}'.format(i)+file_name+'{}.pdf'.format(tst_psnr_[i,-1]), bbox_inches='tight', pad_inches=0)
151 |
152 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:])))
153 |
154 |
155 |
156 | if __name__ == "__main__":
157 | torch.set_default_dtype(torch.float32)
158 | torch.manual_seed(20220222)
159 | np.random.seed(20220222)
160 |
161 |
162 | parser = ArgumentParser()
163 |
164 | parser.add_argument("--data_path", type=str, default="data_div2k.npz")
165 | parser.add_argument("--N_repeat", type=int, default=1)
166 | parser.add_argument("--save_path", type=str, default="2D_separable_points/")
167 | parser.add_argument("--logger", type=str, default="log.log")
168 | parser.add_argument("--save_flag", type=int, default=0, choices=[0, 1])
169 |
170 |
171 | args = parser.parse_args()
172 | main(args)
--------------------------------------------------------------------------------
/3D_random_points_GD.py:
--------------------------------------------------------------------------------
1 | import torch, os, logging
2 | from argparse import ArgumentParser
3 | import numpy as np
4 |
5 | from trainer import *
6 | from utils import *
7 | from MLPs import *
8 |
9 |
10 | import imageio
11 |
12 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
13 |
14 | def get_logger(filename, verbosity=1, name=None):
15 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
16 | formatter = logging.Formatter(
17 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
18 | )
19 | logger = logging.getLogger(name)
20 | logger.setLevel(level_dict[verbosity])
21 |
22 | fh = logging.FileHandler(filename, "w")
23 | fh.setFormatter(formatter)
24 | logger.addHandler(fh)
25 |
26 | sh = logging.StreamHandler()
27 | sh.setFormatter(formatter)
28 | logger.addHandler(sh)
29 |
30 | return logger
31 |
32 |
33 | def main(args):
34 |
35 | if os.path.exists(args.save_path):
36 | print('Path already exists!')
37 | return 1
38 | os.mkdir(args.save_path)
39 | logger = get_logger(args.save_path+args.logger)
40 | logger.info(args)
41 |
42 | # Set the CUDA flag
43 | device = "cuda" if torch.cuda.is_available() else "cpu"
44 | logger.info('device is: {}'.format(device))
45 |
46 |
47 | buf = np.load(args.data_path)
48 |
49 | signals = torch.from_numpy(buf['train'])
50 |
51 | #mask = np.load(args.mask_path)
52 | ########### generate random mask ###########
53 | image_size = signals.shape[1]
54 | ratio = 0.125
55 | mask = []
56 | for i in range(signals.shape[0]):
57 | mask_tmp = []
58 | for j in range(args.N_repeat):
59 | idx = torch.randperm(image_size**3)[:int(ratio*image_size**3)]
60 | mask_np_N2 = np.zeros((image_size**3))
61 | mask_np_N2[idx] = 1
62 | mask_np_N2 = mask_np_N2==1
63 | mask_tmp.append(mask_np_N2)
64 | mask.append(np.stack(mask_tmp,0))
65 | mask = np.stack(mask,0)
66 | np.save('mask_3d_{}_{}_{}_{}.npy'.format(ratio,signals.shape[0],args.N_repeat,signals.shape[1]),mask)
67 | ###########################################
68 |
69 |
70 | logger.info('################ Simple Encoding ################')
71 |
72 | ez = 192
73 | rff_params = [2]
74 | linearf_params = [3]
75 | logf_params = [3]
76 | gaussian_params = [0.02]
77 | linear_params = [3/64]
78 |
79 |
80 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params)
81 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params
82 |
83 |
84 | for depth in [5,1,0]:
85 | for lr in [5e-3]:
86 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr))
87 | for em,param in zip(encoding_methods,params):
88 |
89 | ef = encoding_func_3D(em,[param,ez])
90 | time_,trn_psnr_,tst_psnr_,rec_ = train_random_simple_3D(signals,ef,mask=mask,N_repeat=args.N_repeat,lr=lr,epochs=500,depth=depth,device=device,logger=None)
91 | file_name = 'RD{}{}'.format(depth,em)
92 | if args.save_flag:
93 | # np.save(args.save_path+file_name+'_rec.npy',rec_)
94 | # np.save(args.save_path+file_name+'_time.npy',time_)
95 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_)
96 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_)
97 | for i in range(signals.shape[0]):
98 | imageio.mimwrite(args.save_path+'V{}'.format(i)+file_name+ 'R{:.2f}.mp4'.format(tst_psnr_[i,-1]), to8b(rec_[i]), fps=10, quality=8)
99 |
100 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:])))
101 |
102 | logger.info('################ Complex Encoding ################')
103 |
104 | ez = 64
105 | rff_params = [6]
106 | linearf_params = [8]
107 | logf_params = [8]
108 | gaussian_params = [0.02]
109 | linear_params = [3.5/64]
110 |
111 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params)
112 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params
113 |
114 | for depth in [0,1]:
115 | for lr in [1e-1]:
116 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr))
117 | for em,param in zip(encoding_methods,params):
118 |
119 | ef = encoding_func_1D(em,[param,ez])
120 | bl = blending_func_3D(ef)
121 | time_,trn_psnr_,tst_psnr_,rec_ = train_index_blend_kron_3D(signals,bl,ef,mask=mask,N_repeat=args.N_repeat,lr=lr,epochs=500,depth=depth,device=device,logger=None)
122 | file_name = 'RKD{}{}'.format(depth,em)
123 | if args.save_flag:
124 | # np.save(args.save_path+file_name+'_rec.npy',rec_)
125 | # np.save(args.save_path+file_name+'_time.npy',time_)
126 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_)
127 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_)
128 | for i in range(signals.shape[0]):
129 | imageio.mimwrite(args.save_path+'V{}'.format(i)+file_name+ 'R{:.2f}.mp4'.format(tst_psnr_[i,-1]), to8b(rec_[i]), fps=10, quality=8)
130 |
131 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:])))
132 |
133 |
134 |
135 | if __name__ == "__main__":
136 | torch.set_default_dtype(torch.float32)
137 | torch.manual_seed(20220222)
138 | np.random.seed(20220222)
139 |
140 |
141 | parser = ArgumentParser()
142 |
143 | parser.add_argument("--data_path", type=str, default="video_16_128.npz")
144 | parser.add_argument("--mask_path", type=str, default="mask_3d_0.125_5_1_128.npy")
145 | parser.add_argument("--N_repeat", type=int, default=1)
146 | parser.add_argument("--save_path", type=str, default="3D_random_points/")
147 | parser.add_argument("--logger", type=str, default="log.log")
148 | parser.add_argument("--save_flag", type=int, default=0, choices=[0, 1])
149 |
150 |
151 | args = parser.parse_args()
152 | main(args)
--------------------------------------------------------------------------------
/3D_separable_points_GD.py:
--------------------------------------------------------------------------------
1 | import torch, os, logging
2 | from argparse import ArgumentParser
3 | import numpy as np
4 |
5 | from trainer import *
6 | from utils import *
7 | from MLPs import *
8 |
9 |
10 | import imageio
11 |
12 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
13 |
14 | def get_logger(filename, verbosity=1, name=None):
15 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
16 | formatter = logging.Formatter(
17 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
18 | )
19 | logger = logging.getLogger(name)
20 | logger.setLevel(level_dict[verbosity])
21 |
22 | fh = logging.FileHandler(filename, "w")
23 | fh.setFormatter(formatter)
24 | logger.addHandler(fh)
25 |
26 | sh = logging.StreamHandler()
27 | sh.setFormatter(formatter)
28 | logger.addHandler(sh)
29 |
30 | return logger
31 |
32 |
33 | def main(args):
34 |
35 | if os.path.exists(args.save_path):
36 | print('Path already exists!')
37 | return 1
38 | os.mkdir(args.save_path)
39 | logger = get_logger(args.save_path+args.logger)
40 | logger.info(args)
41 |
42 | # Set the CUDA flag
43 | device = "cuda" if torch.cuda.is_available() else "cpu"
44 | logger.info('device is: {}'.format(device))
45 |
46 |
47 | buf = np.load(args.data_path)
48 |
49 | signals = torch.from_numpy(buf['test'])
50 |
51 | logger.info('################ Simple Encoding ################')
52 |
53 | ez = 192
54 | rff_params = [3]
55 | linearf_params = [3]
56 | logf_params = [3]
57 | gaussian_params = [0.015]
58 | linear_params = [3/64]
59 |
60 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params)
61 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params
62 |
63 | for depth in [5,1,0]:
64 | for lr in [5e-3]:
65 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr))
66 | for em,param in zip(encoding_methods,params):
67 |
68 | ef = encoding_func_3D(em,[param,ez])
69 | time_,trn_psnr_,tst_psnr_,rec_ = train_simple_3D(signals,ef,N_repeat=args.N_repeat,lr=lr,epochs=500,depth=depth,device=device,logger=None)
70 | file_name = 'D{}{}'.format(depth,em)
71 | if args.save_flag:
72 | # np.save(args.save_path+file_name+'_rec.npy',rec_)
73 | # np.save(args.save_path+file_name+'_time.npy',time_)
74 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_)
75 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_)
76 | for i in range(signals.shape[0]):
77 | imageio.mimwrite(args.save_path+'V{}'.format(i)+file_name+ 'R{:.2f}.mp4'.format(tst_psnr_[i,-1]), to8b(rec_[i]), fps=10, quality=8)
78 |
79 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:])))
80 |
81 | logger.info('################ Complex Encoding ################')
82 |
83 | ez = 64
84 | rff_params = [6]
85 | linearf_params = [3]
86 | logf_params = [3]
87 | gaussian_params = [0.015]
88 | linear_params = [2/64]
89 |
90 | encoding_methods = ['RFF']*len(rff_params)+['LinF']*len(linearf_params)+['LogF']*len(logf_params)+['Gau']*len(gaussian_params)+['Tri']*len(linear_params)
91 | params = rff_params+linearf_params+logf_params+gaussian_params+linear_params
92 |
93 | for depth in [0,1]:
94 | for lr in [1e-1]:
95 | logger.info('######## Network Depth = {}, Learning Rate = {} ########'.format(depth,lr))
96 | for em,param in zip(encoding_methods,params):
97 |
98 | ef = encoding_func_1D(em,[param,ez])
99 | time_,trn_psnr_,tst_psnr_,rec_ = train_kron_3D(signals,ef,N_repeat=args.N_repeat,lr=lr,epochs=500,depth=depth,device=device,logger=None)
100 | file_name = 'KD{}{}'.format(depth,em)
101 | if args.save_flag:
102 | # np.save(args.save_path+file_name+'_rec.npy',rec_)
103 | # np.save(args.save_path+file_name+'_time.npy',time_)
104 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_)
105 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_)
106 | for i in range(signals.shape[0]):
107 | imageio.mimwrite(args.save_path+'V{}'.format(i)+file_name+ 'R{:.2f}.mp4'.format(tst_psnr_[i,-1]), to8b(rec_[i]), fps=10, quality=8)
108 |
109 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:])))
110 |
111 |
112 | logger.info('################ Closed Form Complex Encoding ################')
113 |
114 | ez = 64
115 | rff_params = [6]
116 | linearf_params = [4]
117 | logf_params = [3]
118 | gaussian_params = [0.009]
119 | linear_params = [3/64]
120 |
121 | for em,param in zip(encoding_methods,params):
122 |
123 | if em == 'Gau' or em == 'Tri': em=em
124 | else: device = 'cpu'
125 |
126 | ef = encoding_func_1D(em,[param,ez])
127 | time_,trn_psnr_,tst_psnr_,rec_ = train_closed_form_3D(signals,ef,N_repeat=args.N_repeat,device=device,logger=None)
128 | file_name = 'CKD{}{}'.format(depth,em)
129 | if args.save_flag:
130 | # np.save(args.save_path+file_name+'_rec.npy',rec_)
131 | # np.save(args.save_path+file_name+'_time.npy',time_)
132 | # np.save(args.save_path+file_name+'_trn.npy',trn_psnr_)
133 | # np.save(args.save_path+file_name+'_tst.npy',tst_psnr_)
134 | for i in range(signals.shape[0]):
135 | imageio.mimwrite(args.save_path+'V{}'.format(i)+file_name+ 'R{:.2f}.mp4'.format(tst_psnr_[i,-1]), to8b(rec_[i]), fps=10, quality=8)
136 |
137 | logger.info('embedding method:{}, param:{}, psnr:{}, std:{}, time:{}.'.format(em,param,np.mean(tst_psnr_[:,:]),np.std(tst_psnr_[:,:]),np.mean(time_[:,:])))
138 |
139 |
140 |
141 | if __name__ == "__main__":
142 | torch.set_default_dtype(torch.float32)
143 | torch.manual_seed(20220222)
144 | np.random.seed(20220222)
145 |
146 |
147 | parser = ArgumentParser()
148 |
149 | parser.add_argument("--data_path", type=str, default="video_16_128.npz")
150 | parser.add_argument("--N_repeat", type=int, default=1)
151 | parser.add_argument("--save_path", type=str, default="3D_separable_points/")
152 | parser.add_argument("--logger", type=str, default="log.log")
153 | parser.add_argument("--save_flag", type=int, default=0, choices=[0, 1])
154 |
155 |
156 | args = parser.parse_args()
157 | main(args)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Jianqiao Zheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MLPs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | # regular MLP
8 | class MLP(nn.Module):
9 | def __init__(self,input_dim=2, output_dim=3, depth = 0, width= 256,bias=True,use_sigmoid = True):
10 | super(MLP, self).__init__()
11 | self.use_sigmoid = use_sigmoid
12 | self.mm = nn.ModuleList([])
13 | if depth == 0:
14 | self.mm.append(nn.Linear(input_dim, output_dim,bias=bias))
15 | else:
16 | self.mm.append(nn.Sequential(nn.Linear(input_dim, width,bias=bias),nn.ReLU(True)))
17 | for i in range(depth-1):
18 | self.mm.append(nn.Sequential(nn.Linear(width, width,bias=bias),nn.ReLU(True)))
19 | self.mm.append(nn.Sequential(nn.Linear(width, output_dim,bias=bias)))
20 | if use_sigmoid: self.mm.append(nn.Sigmoid())
21 | def forward(self, x):
22 | for m in self.mm:
23 | x = m(x)
24 | return x
25 | def name(self):
26 | return "MLP"
27 |
28 |
29 | # If the input to MLP is in form of kron(x,y), this model can calculate it efficiently.
30 | class Kron_MLP(nn.Module):
31 | def __init__(self,input_dim=2, output_dim=3, depth=0, width0=256, width=256, bias=True, use_sigmoid=True):
32 | super(Kron_MLP, self).__init__()
33 | self.use_sigmoid = use_sigmoid
34 |
35 | if depth==0: width0 = output_dim
36 |
37 | self.mm = nn.ModuleList([])
38 | self.first = nn.ParameterDict({
39 | 'weight': nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim)-1/np.sqrt(width0))})
40 |
41 | if depth == 1:
42 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim,bias=bias)))
43 | if depth > 1:
44 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width,bias=bias),nn.ReLU(True)))
45 | for i in range(depth-2):
46 | self.mm.append(nn.Sequential(nn.Linear(width, width,bias=bias),nn.ReLU(True)))
47 | self.mm.append(nn.Linear(width, output_dim,bias=bias))
48 |
49 | if use_sigmoid: self.mm.append(nn.Sigmoid())
50 |
51 | def forward(self, x, y=None):
52 | if y is None: y=x.detach().clone()
53 |
54 | ### calculate x*W*y.T intead of kron(x,y)*vec(W) ###
55 | # naive way
56 | x = x@self.first.weight
57 | x = x@(y.transpose(0,1))
58 |
59 | # einsum implementation
60 | # x = torch.einsum('ij,wjk,lk->wil',x,self.first.weight,y)
61 | ##########################################################
62 |
63 | x = x.flatten(1,2).transpose(0,1)
64 | for m in self.mm:
65 | x = m(x)
66 | return x
67 | def name(self):
68 | return "Kron_MLP"
69 |
70 |
71 | # MLP with blending matrix in sparse matrix multiplication, not using because it's a bit slow
72 | class Blend_Kron_MLP(nn.Module):
73 | def __init__(self, input_dim=2, output_dim=3, depth=0, width0=256, width=256, bias=True, use_sigmoid=True):
74 | super(Blend_Kron_MLP, self).__init__()
75 | self.use_sigmoid = use_sigmoid
76 |
77 | if depth==0: width0 = output_dim
78 |
79 | self.mm = nn.ModuleList([])
80 | self.first = nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim)-1/np.sqrt(width0))
81 |
82 | if depth == 1:
83 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim,bias=bias)))
84 | if depth > 1:
85 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width,bias=bias),nn.ReLU(True)))
86 | for i in range(depth-2):
87 | self.mm.append(nn.Sequential(nn.Linear(width, width,bias=bias),nn.ReLU(True)))
88 | self.mm.append(nn.Linear(width, output_dim,bias=bias))
89 |
90 | if use_sigmoid: self.mm.append(nn.Sigmoid())
91 |
92 | def forward(self, B, x, y=None):
93 | if y is None: y=x.detach().clone()
94 |
95 | ### calculate x*W*y.T intead of kron(x,y)*vec(W) ###
96 | # naive way
97 | x = x@self.first
98 | x = x@(y.transpose(0,1))
99 |
100 | # einsum implementation
101 | # x = torch.einsum('ij,wjk,lk->wil',x,self.first.weight,y)
102 | ##########################################################
103 |
104 | x = x.flatten(1,2).transpose(0,1)
105 | x = B@x
106 | for m in self.mm:
107 | x = m(x)
108 | return x
109 | def name(self):
110 | return "Blend_Kron_MLP"
111 |
112 |
113 |
114 | # MLP with blending matrix in indexing and weight sum implementation
115 | class Indexing_Blend_Kron_MLP(nn.Module):
116 | def __init__(self, input_dim=2, output_dim=3, depth=0, width0=256, width=256, bias=True, use_sigmoid=True):
117 | super(Indexing_Blend_Kron_MLP, self).__init__()
118 | self.use_sigmoid = use_sigmoid
119 |
120 | if depth==0: width0 = output_dim
121 |
122 | self.mm = nn.ModuleList([])
123 | self.first = nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim)-1/np.sqrt(width0))
124 |
125 | if depth == 1:
126 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim,bias=bias)))
127 | if depth > 1:
128 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width,bias=bias),nn.ReLU(True)))
129 | for i in range(depth-2):
130 | self.mm.append(nn.Sequential(nn.Linear(width, width,bias=bias),nn.ReLU(True)))
131 | self.mm.append(nn.Linear(width, output_dim,bias=bias))
132 |
133 | if use_sigmoid: self.mm.append(nn.Sigmoid())
134 |
135 | def forward(self, B, x, y=None):
136 | if y is None: y=x.detach().clone()
137 |
138 | ### calculate x*W*y.T intead of kron(x,y)*vec(W) and indexing weight sum ###
139 | # naive way
140 | x = x@self.first
141 | x = x@(y.transpose(0,1))
142 | x = x.flatten(1,2).transpose(0,1)
143 | x = (x[B[0]]*(B[1].unsqueeze(-1))).sum(1)
144 |
145 | # einsum implementation
146 | # x = torch.einsum('ij,wjk,lk->wil',x,self.first.weight,y)
147 | # x = x.flatten(1,2).transpose(0,1)
148 | # x = torch.einsum('ijk,ij->ik',x[B[0]],B[1])
149 | ##########################################################
150 |
151 | for m in self.mm:
152 | x = m(x)
153 | return x
154 | def name(self):
155 | return "Indexing_Blend_Kron_MLP"
156 |
157 |
158 | # Followings are 3D veisrions
159 | class Kron3_MLP(nn.Module):
160 | def __init__(self,input_dim=3, output_dim=3, depth=0, width0=256, width=256, bias=True, use_sigmoid=True):
161 | super(Kron3_MLP, self).__init__()
162 | self.use_sigmoid = use_sigmoid
163 |
164 | if depth==0: width0 = output_dim
165 |
166 | self.mm = nn.ModuleList([])
167 | self.first = nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim, input_dim)-1/np.sqrt(width0))
168 |
169 | if depth == 1:
170 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim, bias=bias)))
171 | if depth > 1:
172 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width, bias=bias),nn.ReLU(True)))
173 | for i in range(depth-2):
174 | self.mm.append(nn.Sequential(nn.Linear(width, width, bias=bias),nn.ReLU(True)))
175 | self.mm.append(nn.Linear(width, output_dim, bias=bias))
176 |
177 | if use_sigmoid: self.mm.append(nn.Sigmoid())
178 | def forward(self, x, y=None,z=None):
179 | if y is None: y=x.detach().clone()
180 | if z is None: z=x.detach().clone()
181 |
182 | ### calculate mode-n multiplication intead of kron(x,y,z)*vec(W) ###
183 | # naive way
184 | x = x@self.first
185 | x = x@(y.transpose(0,1))
186 | x = z@(x.transpose(1,2))
187 |
188 | # einsum implementation
189 | # x = torch.einsum('ai,bj,ck,wijk->wabc',x,y,z,self.first.weight)
190 | ##########################################################
191 |
192 | x = x.flatten(1,3).transpose(0,1)
193 | for m in self.mm:
194 | x = m(x)
195 | return x
196 | def name(self):
197 | return "Kron3_MLP"
198 |
199 |
200 | # class Blend_Kron3_MLP(nn.Module):
201 | # def __init__(self,input_dim=2, output_dim=3, depth=0,width0=256,width=256, use_sigmoid=True):
202 | # super(Blend_Kron3_MLP, self).__init__()
203 | # self.use_sigmoid = use_sigmoid
204 |
205 | # if depth==0: width0 = output_dim
206 |
207 | # self.mm = nn.ModuleList([])
208 | # self.first = nn.ParameterDict({
209 | # 'weight': nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim, input_dim)-1/np.sqrt(width0))})
210 | # if depth == 1:
211 | # self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim)))
212 | # if depth > 1:
213 | # self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width),nn.ReLU(True)))
214 | # for i in range(depth-1):
215 | # self.mm.append(nn.Sequential(nn.Linear(width, width),nn.ReLU(True)))
216 | # self.mm.append(nn.Linear(width, output_dim))
217 | # if use_sigmoid: self.mm.append(nn.Sigmoid())
218 | # def forward(self, B, x, y=None,z=None):
219 | # if y is None: y=x.detach().clone()
220 | # if z is None: z=x.detach().clone()
221 | # x = x@self.first.weight
222 | # x = x@(y.transpose(0,1))
223 | # x = z@(x.transpose(1,2))
224 | # x = x.flatten(1,3).transpose(0,1)
225 | # x = B@x
226 | # for m in self.mm:
227 | # x = m(x)
228 | # return x
229 | # def name(self):
230 | # return "Blend_Kron3_MLP"
231 |
232 |
233 |
234 | class Indexing_Blend_Kron3_MLP(nn.Module):
235 | def __init__(self,input_dim=3, output_dim=3, depth=0, width0=256, width=256, bias=True, use_sigmoid=True):
236 | super(Indexing_Blend_Kron3_MLP, self).__init__()
237 | self.use_sigmoid = use_sigmoid
238 |
239 | if depth==0: width0 = output_dim
240 |
241 | self.mm = nn.ModuleList([])
242 | self.first = nn.Parameter(2/np.sqrt(width0)*torch.rand(width0, input_dim, input_dim, input_dim)-1/np.sqrt(width0))
243 |
244 | if depth == 1:
245 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, output_dim, bias=bias)))
246 | if depth > 1:
247 | self.mm.append(nn.Sequential(nn.ReLU(True),nn.Linear(width0, width, bias=bias),nn.ReLU(True)))
248 | for i in range(depth-2):
249 | self.mm.append(nn.Sequential(nn.Linear(width, width, bias=bias),nn.ReLU(True)))
250 | self.mm.append(nn.Linear(width, output_dim, bias=bias))
251 |
252 | if use_sigmoid: self.mm.append(nn.Sigmoid())
253 | def forward(self, B, x, y=None,z=None):
254 | if y is None: y=x.detach().clone()
255 | if z is None: z=x.detach().clone()
256 |
257 | ### calculate mode-n multiplication intead of kron(x,y,z)*vec(W) ###
258 | # naive way
259 | x = x@self.first
260 | x = x@(y.transpose(0,1))
261 | x = z@(x.transpose(1,2))
262 | x = x.flatten(1,3).transpose(0,1)
263 | x = (x[B[0]]*(B[1].unsqueeze(-1))).sum(1)
264 |
265 | # einsum implementation
266 | # x = torch.einsum('ai,bj,ck,wijk->wabc',x,y,z,self.first.weight)
267 | # x = x.flatten(1,3).transpose(0,1)
268 | # x = torch.einsum('ijk,ij->ik',x[B[0]],B[1])
269 | ##########################################################
270 | for m in self.mm:
271 | x = m(x)
272 | return x
273 | def name(self):
274 | return "Indexing_Blend_Kron3_MLP"
275 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Trading Positional Complexity vs Deepness in Coordinate Networks
2 | ### [Project Page](https://osiriszjq.github.io/complex_encoding) | [Paper](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136870142.pdf)
3 | [](https://opensource.org/licenses/MIT)
4 |
5 |
6 | [Jianqiao Zheng](https://github.com/osiriszjq/),
7 | [Sameera Ramasinghe](https://scholar.google.pl/citations?user=-j0m9aMAAAAJ&hl=en),
8 | [Xueqian Li](https://lilac-lee.github.io/),
9 | [Simon Lucey](https://www.adelaide.edu.au/directory/simon.lucey)
10 | The University of Adelaide
11 |
12 | This is the official implementation of the paper "Trading Positional Complexity vs Deepness in Coordinate Networks", which has been accepted to ECCV 2022.
13 |
14 | #### Illustration of different methods to extend 1D encoding
15 | 
16 |
17 |
18 | ## Google Colab
19 | [](https://colab.research.google.com/github/osiriszjq/complex_encoding/blob/main/complex_encoding.ipynb)
20 | If you want to try out our new complex encoding, we have written a [Colab](https://colab.research.google.com/github/osiriszjq/complex_encoding/blob/main/complex_encoding.ipynb) with the following experiments:
21 | * simple encoding for 2D image reconstuction with separable coordinates
22 | * complex encoding for 2D image reconstuction with separable coordinates
23 | * colsed form solution of complex encoding for 2D image reconstuction with separable coordinates.
24 | * simple encoding for 3D video reconstuction with non-separable coordinates
25 | * complex encoding for 3D video reconstuction with non-separable coordinates
26 |
27 |
28 | ## Dataset
29 | The Dataset used to reproduced can be found in [Google Drive](https://drive.google.com/drive/folders/1yLVG1WT5i9PxchNqAb84nJdHuLlNPCj4?usp=sharing). The image data is from [Random Fourier Frequency](https://github.com/tancik/fourier-feature-networks) and the video data is from [Youtube](https://research.google.com/youtube-bb/).
30 |
31 |
32 | ## Citation
33 | ```bibtex
34 | @inproceedings{zheng2022trading,
35 | title={Trading Positional Complexity vs Deepness in Coordinate Networks},
36 | author={Zheng, Jianqiao and Ramasinghe, Sameera and Li, Xueqian and Lucey, Simon},
37 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
38 | year={2022}
39 | }
40 | ```
41 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import gdown
2 |
3 | datasets = {
4 | 'data_div2k.npz':'115l577qCHEbGN-GzE0MpnyHFoGJDy5Xw',
5 | 'video_16_128.npz':'1-5TOdbH4j6Fr4TV1bqppo9H5JgAhvJUs',
6 | }
7 |
8 | for name in datasets:
9 | gdown.download(id=datasets[name], output=name, quiet=False)
--------------------------------------------------------------------------------
/imgs/simple_complex_encoding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/osiriszjq/complex_encoding/5be131f47d3a6b7616ff2d9fa4eb5d055e280bb1/imgs/simple_complex_encoding.png
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time, torch, logging
3 | from MLPs import *
4 | from utils import *
5 |
6 |
7 | def train_simple_2D(signals,encoding_func,epochs=2000,lr=1e-2,criterion=nn.MSELoss(),device="cpu",
8 | depth=4,width=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None):
9 |
10 | '''
11 | Trainer of simple encoding for 2D image reconstruction with separable coordinates.
12 |
13 | The input args are four parts:
14 |
15 | signals: a pytorch tensor shape in N_image x sample_N x sample_N x 3
16 | encoding_func: encoding function
17 |
18 | epochs,lr,criterion are optimization settings
19 |
20 | depth,width,bias,use_sigmoid are network cofiguration
21 |
22 | N_repeat: number of repeat times for each example
23 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones.
24 | logger: if is None, loggers will print out.
25 |
26 | Returns:
27 |
28 | running time: N_test x N_repeat
29 | train psnr: N_test x N_repeat
30 | test psnr: N_test x N_repeat
31 | reconstuction results: same size as input signals
32 | '''
33 |
34 |
35 | sample_N =signals.shape[1]
36 | N_test = signals.shape[0]
37 |
38 | trn_psnr_ = np.zeros((N_test,N_repeat))
39 | tst_psnr_ = np.zeros((N_test,N_repeat))
40 | time_ = np.zeros((N_test,N_repeat))
41 | rec_ = np.zeros((N_test,sample_N,sample_N,3))
42 |
43 | # build the meshgrid for coordinates
44 | x1 = np.linspace(0, 1, sample_N+1)[:-1]
45 | all_data = np.stack(np.meshgrid(x1,x1), axis=-1)
46 |
47 |
48 | for i in range(N_test):
49 | # Prepare the targets
50 | all_target = signals[i].squeeze()
51 | train_label = all_target[::2,::2].reshape(-1,3).type(torch.FloatTensor)
52 | if test_mode == 'mid':
53 | test_label = all_target[1::2,1::2].reshape(-1,3).type(torch.FloatTensor)
54 | elif test_mode == 'rest':
55 | test_label1 = all_target[1::2,::2].reshape(-1,3).type(torch.FloatTensor)
56 | test_label2 = all_target[1::2,1::2].reshape(-1,3).type(torch.FloatTensor)
57 | test_label3 = all_target[::2,1::2].reshape(-1,3).type(torch.FloatTensor)
58 | test_label = torch.cat([test_label1,test_label2,test_label3],0)
59 | elif test_mode == 'all':
60 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor)
61 |
62 |
63 | for k in range(N_repeat):
64 | start_time = time.time()
65 | train_data = encoding_func(torch.from_numpy(all_data[::2,::2].reshape(-1,2)).type(torch.FloatTensor))
66 |
67 |
68 | train_data, train_label = train_data.to(device),train_label.to(device)
69 |
70 | # regular MLP
71 | model = MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device)
72 | # Set the optimization
73 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8)
74 |
75 |
76 | for epoch in range(epochs):
77 | model.train()
78 | optimizer.zero_grad()
79 |
80 | out = model(train_data)
81 | loss = criterion(out, train_label)
82 |
83 | loss.backward()
84 | optimizer.step()
85 |
86 |
87 | time_[i,k] = time.time() - start_time
88 |
89 |
90 | if test_mode == 'mid':
91 | test_data = encoding_func(torch.from_numpy(all_data[1::2,1::2].reshape(-1,2)).type(torch.FloatTensor))
92 | elif test_mode == 'rest':
93 | test_data1 = encoding_func(torch.from_numpy(all_data[1::2,::2].reshape(-1,2)).type(torch.FloatTensor))
94 | test_data2 = encoding_func(torch.from_numpy(all_data[1::2,1::2].reshape(-1,2)).type(torch.FloatTensor))
95 | test_data3 = encoding_func(torch.from_numpy(all_data[::2,1::2].reshape(-1,2)).type(torch.FloatTensor))
96 | test_data = torch.cat([test_data1,test_data2,test_data3],0)
97 | elif test_mode == 'all':
98 | test_data = encoding_func(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor))
99 | else:
100 | print('Wrong test_mode!')
101 | return -1
102 |
103 |
104 | test_data, test_label = test_data.to(device),test_label.to(device)
105 |
106 | model.eval()
107 | with torch.no_grad():
108 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label)
109 | tst_psnr_[i,k] = psnr_func(model(test_data),test_label)
110 | if logger == None:
111 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
112 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
113 | else:
114 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
115 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
116 |
117 | rec_[i] = model(encoding_func(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,3).detach().cpu()
118 | return time_, trn_psnr_,tst_psnr_,rec_
119 |
120 |
121 |
122 |
123 | def train_kron_2D(signals,encoding_func,epochs=2000,lr=1e-1,criterion=nn.MSELoss(),device="cpu",
124 | depth=0,width=256,width0=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None):
125 |
126 | '''
127 | Trainer of complex encoding for 2D image reconstruction with separable coordinates.
128 |
129 | The input args are four parts:
130 |
131 | signals: a pytorch tensor shape in N_image x sample_N x sample_N x 3
132 | encoding_func: encoding function
133 |
134 | epochs,lr,criterion are optimization settings
135 |
136 | depth,width,width0,bias,use_sigmoid are network cofiguration
137 |
138 | N_repeat: number of repeat times for each example
139 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones.
140 | logger: if is None, loggers will print out.
141 |
142 | Returns:
143 |
144 | running time: N_test x N_repeat
145 | train psnr: N_test x N_repeat
146 | test psnr: N_test x N_repeat
147 | reconstuction results: same size as input signals
148 | '''
149 |
150 |
151 | sample_N =signals.shape[1]
152 | N_test = signals.shape[0]
153 |
154 | trn_psnr_ = np.zeros((N_test,N_repeat))
155 | tst_psnr_ = np.zeros((N_test,N_repeat))
156 | time_ = np.zeros((N_test,N_repeat))
157 | rec_ = np.zeros((N_test,sample_N,sample_N,3))
158 |
159 | # Here we only use 1D grids
160 | all_data = np.linspace(0, 1, sample_N+1)[:-1]
161 |
162 |
163 | for i in range(N_test):
164 | # Prepare the targets
165 | all_target = signals[i].squeeze()
166 | train_label = all_target[::2,::2].reshape(-1,3).type(torch.FloatTensor).to(device)
167 | if test_mode == 'mid':
168 | test_label = all_target[1::2,::2].reshape(-1,3).type(torch.FloatTensor).to(device)
169 | test_mask = torch.tensor([[True]],device=device).repeat(int(sample_N/2),int(sample_N/2)).reshape(-1)
170 | elif test_mode == 'rest':
171 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device)
172 | test_mask = torch.tensor([[False,True],[True,True]],device=device).repeat(int(sample_N/2),int(sample_N/2)).reshape(-1)
173 | elif test_mode == 'all':
174 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device)
175 | test_mask = torch.tensor([[True]],device=device).repeat(sample_N,sample_N).reshape(-1)
176 |
177 |
178 | for k in range(N_repeat):
179 | start_time = time.time()
180 | train_data = torch.from_numpy(all_data[::2].reshape(-1,1)).type(torch.FloatTensor)
181 |
182 | # Initialize classification model to learn
183 | model = Kron_MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width0=width0,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device)
184 | # Set the optimization
185 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8)
186 |
187 | train_data = encoding_func(train_data).to(device)
188 |
189 | for epoch in range(epochs):
190 | model.train()
191 | optimizer.zero_grad()
192 |
193 | out = model(train_data)
194 | loss = criterion(out, train_label)
195 |
196 | loss.backward()
197 | optimizer.step()
198 |
199 |
200 | time_[i,k] = time.time() - start_time
201 | model.eval()
202 |
203 | if test_mode == 'mid':
204 | test_data = torch.from_numpy(all_data[1::2].reshape(-1,1)).type(torch.FloatTensor)
205 | elif test_mode == 'rest':
206 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)
207 | elif test_mode == 'all':
208 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)
209 | else:
210 | print('Wrong test_mode!')
211 | return -1
212 |
213 |
214 | test_data = encoding_func(test_data).to(device)
215 |
216 | with torch.no_grad():
217 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label)
218 | tst_psnr_[i,k] = psnr_func(model(test_data)[test_mask],test_label[test_mask])
219 |
220 | if logger == None:
221 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
222 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
223 | else:
224 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
225 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
226 | rec_[i] = model(encoding_func(torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,3).detach().cpu()
227 | return time_,trn_psnr_,tst_psnr_,rec_
228 |
229 |
230 |
231 |
232 | def train_closed_form_2D(signals,encoding_func,device="cpu",test_mode='rest',N_repeat=1,logger=None):
233 |
234 | '''
235 | Closed form solution of complex encoding for 2D image reconstruction with separable coordinates. No training.
236 |
237 | The input args are two parts:
238 |
239 | signals: a pytorch tensor shape in N_image x sample_N x sample_N x 3
240 | encoding_func: encoding function
241 |
242 | N_repeat: number of repeat times for each example
243 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones.
244 | logger: if is None, loggers will print out.
245 |
246 | Returns:
247 |
248 | running time: N_test x N_repeat
249 | train psnr: N_test x N_repeat
250 | test psnr: N_test x N_repeat
251 | reconstuction results: same size as input signals
252 | '''
253 |
254 |
255 | sample_N =signals.shape[1]
256 | N_test = signals.shape[0]
257 |
258 | trn_psnr_ = np.zeros((N_test,N_repeat))
259 | tst_psnr_ = np.zeros((N_test,N_repeat))
260 | time_ = np.zeros((N_test,N_repeat))
261 | rec_ = np.zeros((N_test,sample_N,sample_N,3))
262 |
263 | # Here we only use 1D grids
264 | all_data = np.linspace(0, 1, sample_N+1)[:-1]
265 |
266 |
267 | for i in range(N_test):
268 | # Prepare the targets
269 | all_target = signals[i].squeeze()
270 | train_label = all_target[::2,::2].type(torch.FloatTensor).to(device)
271 | if test_mode == 'mid':
272 | test_label = all_target[1::2,1::2].reshape(-1,3).type(torch.FloatTensor).to(device)
273 | test_mask = torch.tensor([[True]],device=device).repeat(int(sample_N/2),int(sample_N/2)).reshape(-1)
274 | elif test_mode == 'rest':
275 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device)
276 | test_mask = torch.tensor([[False,True],[True,True]],device=device).repeat(int(sample_N/2),int(sample_N/2)).reshape(-1)
277 | elif test_mode == 'all':
278 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device)
279 | test_mask = torch.tensor([[True]],device=device).repeat(sample_N,sample_N).reshape(-1)
280 |
281 |
282 | for k in range(N_repeat):
283 | start_time = time.time()
284 | train_data = torch.from_numpy(all_data[::2].reshape(-1,1)).type(torch.FloatTensor)
285 |
286 | train_data = encoding_func(train_data).to(device)
287 |
288 | # two ways to calculate the inverse of a matrix
289 | #ix = torch.linalg.pinv(train_data)
290 | ix = torch.linalg.lstsq(train_data, torch.eye(train_data.shape[0]).to(device)).solution
291 | W= ix@(train_label.transpose(0,2))@ix.T
292 |
293 | time_[i,k] = time.time() - start_time
294 |
295 | if test_mode == 'mid':
296 | test_data = torch.from_numpy(all_data[1::2].reshape(-1,1)).type(torch.FloatTensor)
297 | elif test_mode == 'rest':
298 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)
299 | elif test_mode == 'all':
300 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)
301 | else:
302 | print('Wrong test_mode!')
303 | return -1
304 |
305 |
306 | test_data = encoding_func(test_data).to(device)
307 |
308 | trn_psnr_[i,k] = psnr_func((train_data@W@train_data.T).transpose(0,2),train_label).detach().cpu().numpy()
309 | tst_psnr_[i,k] = psnr_func((test_data@W@test_data.T).transpose(0,2).reshape(-1,3)[test_mask],test_label[test_mask]).detach().cpu().numpy()
310 |
311 | if logger == None:
312 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
313 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
314 | else:
315 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
316 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
317 | rec_[i] = (test_data@W@test_data.T).transpose(0,2).detach().cpu()
318 | return time_,trn_psnr_,tst_psnr_,rec_
319 |
320 |
321 |
322 |
323 | def train_random_simple_2D(signals,encoding_func,ratio=0.25,mask=None,epochs=2000,lr=1e-2,criterion=nn.MSELoss(),device="cpu",
324 | depth=4,width=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None):
325 |
326 | '''
327 | Trainer of simple encoding for 2D image reconstruction with randomly sampled points.
328 |
329 | The input args are four parts:
330 |
331 | signals: a torch tensor shape in N_image x sample_N x sample_N x 3
332 | encoding_func: encoding function
333 | ratio: if mask is None, randomly sample points for training with this ratio of number of all points
334 | mask: a boolean matrix shape in N_image x N_repeat x (sample_N x sample_N)
335 |
336 | epochs,lr,criterion are optimization settings
337 |
338 | depth,width,bias,use_sigmoid are network cofiguration
339 |
340 | N_repeat: number of repeat times for each example
341 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones.
342 | logger: if is None, loggers will print out.
343 |
344 | Returns:
345 |
346 | running time: N_test x N_repeat
347 | train psnr: N_test x N_repeat
348 | test psnr: N_test x N_repeat
349 | reconstuction results: same size as input signals
350 | '''
351 |
352 |
353 | sample_N =signals.shape[1]
354 | N_test = signals.shape[0]
355 |
356 | trn_psnr_ = np.zeros((N_test,N_repeat))
357 | tst_psnr_ = np.zeros((N_test,N_repeat))
358 | time_ = np.zeros((N_test,N_repeat))
359 | rec_ = np.zeros((N_test,sample_N,sample_N,3))
360 |
361 | # build the meshgrid for coordinates
362 | x1 = np.linspace(0, 1, sample_N+1)[:-1]
363 | all_data = np.stack(np.meshgrid(x1,x1), axis=-1)
364 |
365 |
366 |
367 | for i in range(N_test):
368 |
369 | for k in range(N_repeat):
370 |
371 | if mask is None:
372 | idx = torch.randperm(sample_N**2)[:int(ratio*sample_N**2)]
373 | mask_np_N2 = np.zeros((sample_N**2))
374 | mask_np_N2[idx] = 1
375 | mask_np_N2 = mask_np_N2==1
376 | else:
377 | mask_np_N2 = mask[i,k]
378 |
379 | # Prepare the targets
380 | all_target = signals[i].squeeze()
381 | train_label = all_target.reshape(-1,3)[mask_np_N2].type(torch.FloatTensor)
382 |
383 | start_time = time.time()
384 | train_data = encoding_func(torch.from_numpy(all_data.reshape(-1,2)[mask_np_N2]).type(torch.FloatTensor))
385 | train_data, train_label = train_data.to(device),train_label.to(device)
386 |
387 |
388 | # Initialize classification model to learn
389 | model = MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device)
390 | # Set the optimization
391 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8)
392 |
393 |
394 | for epoch in range(epochs):
395 | model.train()
396 | optimizer.zero_grad()
397 |
398 | out = model(train_data)
399 | loss = criterion(out, train_label)
400 |
401 | loss.backward()
402 | optimizer.step()
403 |
404 |
405 | time_[i,k] = time.time() - start_time
406 |
407 |
408 | if test_mode == 'rest':
409 | test_label = all_target.reshape(-1,3)[~mask_np_N2].type(torch.FloatTensor)
410 | test_data = encoding_func(torch.from_numpy(all_data.reshape(-1,2)[~mask_np_N2]).type(torch.FloatTensor))
411 | elif test_mode == 'all':
412 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor)
413 | test_data = encoding_func(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor))
414 | else:
415 | print('Wrong test_mode!')
416 | return -1
417 |
418 | test_data, test_label = test_data.to(device),test_label.to(device)
419 |
420 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label)
421 | tst_psnr_[i,k] = psnr_func(model(test_data),test_label)
422 |
423 | if logger == None:
424 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
425 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
426 | else:
427 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
428 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
429 | rec_[i] = model(encoding_func(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,3).detach().cpu()
430 |
431 | return time_, trn_psnr_,tst_psnr_,rec_
432 |
433 |
434 |
435 |
436 | def train_index_blend_kron_2D(signals,outter_blending,inner_encoding,ratio=0.25,mask=None,sm=0.0,epochs=2000,lr=1e-1,criterion=torch.nn.MSELoss(),device="cpu",
437 | depth=0,width=256,width0=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None):
438 |
439 | '''
440 | Trainer of complex encoding for 2D image reconstruction with randomly sampled points.
441 |
442 | The input args are four parts:
443 |
444 | signals: a pytorch tensor shape in N_image x sample_N x sample_N x 3
445 | inner encoding: encoding function function of the virtual grid points, so the encoded results is static
446 | outter_blending: blending funcion for interpolation index and weights of random samples
447 | ratio: if mask is None, randomly sample points for training with this ratio of number of all points
448 | mask: a boolean matrix shape in N_image x N_repeat x (sample_N x sample_N)
449 | sm: how much we punish on total variation loss
450 |
451 | epochs,lr,criterion are optimization settings
452 |
453 | depth,width,width0,bias,use_sigmoid are network cofiguration
454 |
455 | N_repeat: number of repeat times for each example
456 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones.
457 | logger: if is None, loggers will print out.
458 |
459 | Returns:
460 |
461 | running time: N_test x N_repeat
462 | train psnr: N_test x N_repeat
463 | test psnr: N_test x N_repeat
464 | reconstuction results: same size as input signals
465 | '''
466 |
467 |
468 | sample_N =signals.shape[1]
469 | N_test = signals.shape[0]
470 |
471 | trn_psnr_ = np.zeros((N_test,N_repeat))
472 | tst_psnr_ = np.zeros((N_test,N_repeat))
473 | time_ = np.zeros((N_test,N_repeat))
474 | rec_ = np.zeros((N_test,sample_N,sample_N,3))
475 |
476 | # the length of all_data is same as 1D signal length
477 | x1 = np.linspace(0, 1, sample_N+1)[:-1]
478 | all_data = np.stack(np.meshgrid(x1,x1), axis=-1)
479 | all_grid = torch.linspace(0, 1, outter_blending.dim+1)[:-1]
480 | encoded_grid = inner_encoding(all_grid.reshape(-1,1))
481 |
482 | filter=torch.tensor([[[1.0,-1.0]]]).to(device)
483 |
484 |
485 | for i in range(N_test):
486 |
487 | for k in range(N_repeat):
488 |
489 | if mask is None:
490 | idx = torch.randperm(sample_N**2)[:int(ratio*sample_N**2)]
491 | mask_np_N2 = np.zeros((sample_N**2))
492 | mask_np_N2[idx] = 1
493 | mask_np_N2 = mask_np_N2==1
494 | else:
495 | mask_np_N2 = mask[i,k]
496 |
497 |
498 | # Prepare the targets
499 | all_target = signals[i].squeeze()
500 | train_label = all_target.reshape(-1,3)[mask_np_N2].type(torch.FloatTensor)
501 |
502 | start_time = time.time()
503 | train_idx, train_data = outter_blending(torch.from_numpy(all_data.reshape(-1,2)[mask_np_N2]).type(torch.FloatTensor))
504 | train_idx,train_data, train_label = train_idx.to(device),train_data.to(device),train_label.to(device)
505 |
506 | encoded_grid = encoded_grid.to(device)
507 |
508 | # Initialize classification model to learn
509 | model = Indexing_Blend_Kron_MLP(input_dim=inner_encoding.dim,output_dim=3,depth=depth,width0=width0,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device)
510 | # Set the optimization
511 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8)
512 |
513 | for epoch in range(epochs):
514 | model.train()
515 | optimizer.zero_grad()
516 |
517 | out = model([train_idx,train_data],encoded_grid)
518 | loss = criterion(out, train_label)
519 | if sm>0:
520 | for oo in range(model.first.weight.shape[0]):
521 | loss += sm*smooth(model.first.weight[oo],filter)
522 |
523 | loss.backward()
524 | optimizer.step()
525 |
526 | model.eval()
527 |
528 | time_[i,k] = time.time() - start_time
529 |
530 |
531 | model.eval()
532 | if test_mode == 'rest':
533 | test_label = all_target.reshape(-1,3)[~mask_np_N2].type(torch.FloatTensor)
534 | test_idx,test_data = outter_blending(torch.from_numpy(all_data.reshape(-1,2)[~mask_np_N2]).type(torch.FloatTensor))
535 | elif test_mode == 'all':
536 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor)
537 | test_idx,test_data = outter_blending(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor))
538 | else:
539 | print('Wrong test_mode!')
540 | return -1
541 |
542 | test_idx,test_data, test_label = test_idx.to(device),test_data.to(device),test_label.to(device)
543 | with torch.no_grad():
544 | trn_psnr_[i,k] = psnr_func(model([train_idx,train_data],encoded_grid),train_label)
545 |
546 | tst_psnr_[i,k] = psnr_func(model([test_idx,test_data],encoded_grid),test_label)
547 |
548 | if logger == None:
549 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
550 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
551 | else:
552 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
553 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
554 |
555 | rec_[i] = model([outter_blending(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor))[0].to(device),
556 | outter_blending(torch.from_numpy(all_data.reshape(-1,2)).type(torch.FloatTensor))[1].to(device)],encoded_grid).reshape(sample_N,sample_N,3).detach().cpu()
557 | return time_,trn_psnr_,tst_psnr_,rec_
558 |
559 |
560 |
561 |
562 | def train_simple_3D(signals,encoding_func,epochs=500,lr=5e-3,criterion=nn.MSELoss(),device="cpu",
563 | depth=5,width=512,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None):
564 |
565 | '''
566 | Trainer of simple encoding for 3D video reconstruction with separable coordinates.
567 |
568 | The input args are four parts:
569 |
570 | signals: a pytorch tensor shape in N_video x sample_N x sample_N x sample_N x 3
571 | encoding_func: encoding function
572 |
573 | epochs,lr,criterion are optimization settings
574 |
575 | depth,width,bias,use_sigmoid are network cofiguration
576 |
577 | N_repeat: number of repeat times for each example
578 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones.
579 | logger: if is None, loggers will print out.
580 |
581 | Returns:
582 |
583 | running time: N_test x N_repeat
584 | train psnr: N_test x N_repeat
585 | test psnr: N_test x N_repeat
586 | reconstuction results: same size as input signals
587 | '''
588 |
589 |
590 | sample_N =signals.shape[1]
591 | N_test = signals.shape[0]
592 |
593 | trn_psnr_ = np.zeros((N_test,N_repeat))
594 | tst_psnr_ = np.zeros((N_test,N_repeat))
595 | time_ = np.zeros((N_test,N_repeat))
596 | rec_ = np.zeros((N_test,sample_N,sample_N,sample_N,3))
597 |
598 | # build the meshgrid for 3D coordinates
599 | x1 = torch.linspace(0, 1, sample_N+1)[:-1]
600 | all_data = torch.stack(torch.meshgrid(x1,x1,x1), axis=-1)
601 |
602 |
603 | for i in range(N_test):
604 | # Prepare the targets
605 | all_target = signals[i].squeeze()
606 | train_label = all_target[::2,::2,::2].reshape(-1,3).type(torch.FloatTensor)
607 | if test_mode == 'mid':
608 | test_label = all_target[1::2,1::2,1::2].reshape(-1,3).type(torch.FloatTensor)
609 | elif test_mode == 'rest':
610 | test_label1 = all_target[::2,1::2,::2].reshape(-1,3).type(torch.FloatTensor)
611 | test_label2 = all_target[::2,1::2,1::2].reshape(-1,3).type(torch.FloatTensor)
612 | test_label3 = all_target[::2,::2,1::2].reshape(-1,3).type(torch.FloatTensor)
613 | test_label4 = all_target[1::2,::2,::2].reshape(-1,3).type(torch.FloatTensor)
614 | test_label5 = all_target[1::2,1::2,::2].reshape(-1,3).type(torch.FloatTensor)
615 | test_label6 = all_target[1::2,1::2,1::2].reshape(-1,3).type(torch.FloatTensor)
616 | test_label7 = all_target[1::2,::2,1::2].reshape(-1,3).type(torch.FloatTensor)
617 | test_label = torch.cat([test_label1,test_label2,test_label3,test_label4,test_label5,test_label6,test_label7],0)
618 | elif test_mode == 'all':
619 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor)
620 |
621 | for k in range(N_repeat):
622 | start_time = time.time()
623 | train_data = all_data[::2,::2,::2].reshape(-1,3)
624 |
625 |
626 | if test_mode == 'mid':
627 | test_data = all_data[1::2,1::2,1::2].reshape(-1,3)
628 | elif test_mode == 'rest':
629 | test_data1 = all_data[::2,1::2,::2].reshape(-1,3)
630 | test_data2 = all_data[::2,1::2,1::2].reshape(-1,3)
631 | test_data3 = all_data[::2,::2,1::2].reshape(-1,3)
632 | test_data4 = all_data[1::2,::2,::2].reshape(-1,3)
633 | test_data5 = all_data[1::2,1::2,::2].reshape(-1,3)
634 | test_data6 = all_data[1::2,1::2,1::2].reshape(-1,3)
635 | test_data7 = all_data[1::2,::2,1::2].reshape(-1,3)
636 | test_data = torch.cat([test_data1,test_data2,test_data3,test_data4,test_data5,test_data6,test_data7],0)
637 | elif test_mode == 'all':
638 | test_data = all_data.reshape(-1,3)
639 |
640 | else:
641 | print('Wrong test_mode!')
642 | return -1
643 |
644 | train_data, train_label = encoding_func(train_data).to(device),train_label.to(device)
645 | test_data, test_label = encoding_func(test_data).to(device),test_label.to(device)
646 |
647 | # Initialize classification model to learn
648 | model = MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device)
649 | # Set the optimization
650 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8)
651 |
652 |
653 | for epoch in range(epochs):
654 | model.train()
655 | optimizer.zero_grad()
656 |
657 | out = model(train_data)
658 | loss = criterion(out, train_label)
659 |
660 | loss.backward()
661 | optimizer.step()
662 |
663 |
664 | time_[i,k] = time.time() - start_time
665 | model.eval()
666 | with torch.no_grad():
667 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label)
668 | tst_psnr_[i,k] = psnr_func(model(test_data),test_label)
669 | if logger == None:
670 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
671 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
672 | else:
673 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
674 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
675 |
676 | with torch.no_grad():
677 | for ff in range(128):
678 | rec_[i,ff] = model(encoding_func((all_data[ff,:,:].reshape(-1,3)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,3).detach().cpu()
679 |
680 | return time_, trn_psnr_,tst_psnr_,rec_
681 |
682 |
683 |
684 |
685 | def train_kron_3D(signals,encoding_func,epochs=500,lr=1e-1,criterion=nn.MSELoss(),device="cpu",
686 | depth=0,width=256,width0=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None):
687 |
688 | '''
689 | Trainer of complex encoding for 3D video reconstruction with separable coordinates.
690 |
691 | The input args are four parts:
692 |
693 | signals: a pytorch tensor shape in N_video x sample_N x sample_N x sample_N x 3
694 | encoding_func: encoding function
695 |
696 | epochs,lr,criterion are optimization settings
697 |
698 | depth,width,width0,bias,use_sigmoid are network cofiguration
699 |
700 | N_repeat: number of repeat times for each example
701 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones.
702 | logger: if is None, loggers will print out.
703 |
704 | Returns:
705 |
706 | running time: N_test x N_repeat
707 | train psnr: N_test x N_repeat
708 | test psnr: N_test x N_repeat
709 | reconstuction results: same size as input signals
710 | '''
711 |
712 |
713 | sample_N =signals.shape[1]
714 | N_test = signals.shape[0]
715 |
716 | trn_psnr_ = np.zeros((N_test,N_repeat))
717 | tst_psnr_ = np.zeros((N_test,N_repeat))
718 | time_ = np.zeros((N_test,N_repeat))
719 | rec_ = np.zeros((N_test,sample_N,sample_N,sample_N,3))
720 |
721 | # Here we only use 1D grids
722 | all_data = np.linspace(0, 1, sample_N+1)[:-1]
723 |
724 |
725 | for i in range(N_test):
726 | # Prepare the targets
727 | all_target = signals[i].squeeze()
728 | train_label = all_target[::2,::2,::2].reshape(-1,3).type(torch.FloatTensor).to(device)
729 | if test_mode == 'mid':
730 | test_label = all_target[1::2,1::2,1::2].reshape(-1,3).type(torch.FloatTensor).to(device)
731 | test_mask = torch.tensor([[True]],device=device).repeat(int(sample_N/2),int(sample_N/2),int(sample_N/2)).reshape(-1)
732 | elif test_mode == 'rest':
733 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device)
734 | test_mask = torch.tensor([[[False,True],[True,True]],[[True,True],[True,True]]],device=device).repeat(int(sample_N/2),int(sample_N/2),int(sample_N/2)).reshape(-1)
735 | elif test_mode == 'all':
736 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device)
737 | test_mask = torch.tensor([[True]],device=device).repeat(sample_N,sample_N,sample_N).reshape(-1)
738 |
739 |
740 | for k in range(N_repeat):
741 | start_time = time.time()
742 | # train data are sampled by , test data is either in middle or all data, both are encoded by encoding func
743 | train_data = torch.from_numpy(all_data[::2].reshape(-1,1)).type(torch.FloatTensor)
744 |
745 | if test_mode == 'mid':
746 | test_data = torch.from_numpy(all_data[1::2].reshape(-1,1)).type(torch.FloatTensor)
747 | elif test_mode == 'rest':
748 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)
749 | elif test_mode == 'all':
750 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)
751 | else:
752 | print('Wrong test_mode!')
753 | return -1
754 |
755 |
756 | # Initialize classification model to learn
757 | model = Kron3_MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width0=width0,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device)
758 | # Set the optimization
759 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8)
760 |
761 | train_data = encoding_func(train_data).to(device)
762 | test_data = encoding_func(test_data).to(device)
763 |
764 | for epoch in range(epochs):
765 | model.train()
766 | optimizer.zero_grad()
767 |
768 | out = model(train_data)
769 | loss = criterion(out, train_label)
770 |
771 | loss.backward()
772 | optimizer.step()
773 |
774 |
775 | time_[i,k] = time.time() - start_time
776 | model.eval()
777 | with torch.no_grad():
778 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label)
779 |
780 | tst_psnr_[i,k] = psnr_func(model(test_data)[test_mask],test_label[test_mask])
781 |
782 | if logger == None:
783 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
784 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
785 | else:
786 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
787 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
788 | with torch.no_grad():
789 | rec_[i] = model(encoding_func(torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,sample_N,3).detach().cpu()
790 | return time_,trn_psnr_,tst_psnr_,rec_
791 |
792 |
793 |
794 |
795 | def train_closed_form_3D(signals,encoding_func,device="cpu",test_mode='rest',N_repeat=1,logger=None):
796 |
797 | '''
798 | Closed form solution of complex encoding for 3D video reconstruction with separable coordinates. No training.
799 |
800 | The input args are two parts:
801 |
802 | signals: a pytorch tensor shape in N_video x sample_N x sample_N x sample_N x 3
803 | encoding_func: encoding function
804 |
805 | N_repeat: number of repeat times for each example
806 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones.
807 | logger: if is None, loggers will print out.
808 |
809 | Returns:
810 |
811 | running time: N_test x N_repeat
812 | train psnr: N_test x N_repeat
813 | test psnr: N_test x N_repeat
814 | reconstuction results: same size as input signals
815 | '''
816 |
817 |
818 | sample_N =signals.shape[1]
819 | N_test = signals.shape[0]
820 |
821 | trn_psnr_ = np.zeros((N_test,N_repeat))
822 | tst_psnr_ = np.zeros((N_test,N_repeat))
823 | time_ = np.zeros((N_test,N_repeat))
824 | rec_ = np.zeros((N_test,sample_N,sample_N,sample_N,3))
825 |
826 | # Here we only use 1D grids
827 | all_data = np.linspace(0, 1, sample_N+1)[:-1]
828 |
829 |
830 | for i in range(N_test):
831 | # Prepare the targets
832 | all_target = signals[i].squeeze()
833 | train_label = all_target[::2,::2,::2].type(torch.FloatTensor).to(device)
834 | if test_mode == 'mid':
835 | test_label = all_target[1::2,1::2,1::2].reshape(-1,3).type(torch.FloatTensor).to(device)
836 | test_mask = torch.tensor([[True]],device=device).repeat(int(sample_N/2),int(sample_N/2),int(sample_N/2)).reshape(-1)
837 | elif test_mode == 'rest':
838 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device)
839 | test_mask = torch.tensor([[[False,True],[True,True]],[[True,True],[True,True]]],device=device).repeat(int(sample_N/2),int(sample_N/2),int(sample_N/2)).reshape(-1)
840 | elif test_mode == 'all':
841 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor).to(device)
842 | test_mask = torch.tensor([[True]],device=device).repeat(sample_N,sample_N,sample_N).reshape(-1)
843 |
844 |
845 | for k in range(N_repeat):
846 | start_time = time.time()
847 | # train data are sampled by , test data is either in middle or all data, both are encoded by encoding func
848 | train_data = torch.from_numpy(all_data[::2].reshape(-1,1)).type(torch.FloatTensor)
849 |
850 | if test_mode == 'mid':
851 | test_data = torch.from_numpy(all_data[1::2].reshape(-1,1)).type(torch.FloatTensor)
852 | elif test_mode == 'rest':
853 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)
854 | elif test_mode == 'all':
855 | test_data = torch.from_numpy(all_data.reshape(-1,1)).type(torch.FloatTensor)
856 | else:
857 | print('Wrong test_mode!')
858 | return -1
859 |
860 |
861 | train_data = encoding_func(train_data).to(device)
862 |
863 | test_data = encoding_func(test_data).to(device)
864 |
865 | #ix = torch.linalg.pinv(train_data)
866 | ix = torch.linalg.lstsq(train_data, torch.eye(train_data.shape[0]).to(device)).solution
867 | W= (ix@(train_label.transpose(0,3))@ix.T).transpose(1,3)@ix.T
868 |
869 |
870 | time_[i,k] = time.time() - start_time
871 | trn_psnr_[i,k] = psnr_func((train_data@((W@train_data.T).transpose(1,3))@train_data.T).transpose(0,3),train_label).detach().cpu().numpy()
872 | tst_psnr_[i,k] = psnr_func((test_data@((W@test_data.T).transpose(1,3))@test_data.T).transpose(0,3).reshape(-1,3)[test_mask],test_label[test_mask]).detach().cpu().numpy()
873 |
874 | if logger == None:
875 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
876 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
877 | else:
878 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
879 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
880 | rec_[i] = (test_data@((W@test_data.T).transpose(1,3))@test_data.T).transpose(0,3).detach().cpu()
881 | return time_,trn_psnr_,tst_psnr_,rec_
882 |
883 |
884 |
885 |
886 | def train_random_simple_3D(signals,encoding_func,ratio=0.125,mask=None,epochs=500,lr=5e-3,criterion=nn.MSELoss(),device="cpu",
887 | depth=5,width=512,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None):
888 |
889 | '''
890 | Trainer of simple encoding for 3D video reconstruction with randomly sampled points.
891 |
892 | The input args are four parts:
893 |
894 | signals: a torch tensor shape in N_video x sample_N x sample_N x sample_N x 3
895 | encoding_func: encoding function
896 | ratio: if mask is None, randomly sample points for training with this ratio of number of all points
897 | mask: a boolean matrix shape in N_video x N_repeat x (sample_N x sample_N x sample_N)
898 |
899 | epochs,lr,criterion are optimization settings
900 |
901 | depth,width,bias,use_sigmoid are network cofiguration
902 |
903 | N_repeat: number of repeat times for each example
904 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones.
905 | logger: if is None, loggers will print out.
906 |
907 | Returns:
908 |
909 | running time: N_test x N_repeat
910 | train psnr: N_test x N_repeat
911 | test psnr: N_test x N_repeat
912 | reconstuction results: same size as input signals
913 | '''
914 |
915 |
916 | sample_N =signals.shape[1]
917 | N_test = signals.shape[0]
918 |
919 |
920 | trn_psnr_ = np.zeros((N_test,N_repeat))
921 | tst_psnr_ = np.zeros((N_test,N_repeat))
922 | time_ = np.zeros((N_test,N_repeat))
923 | rec_ = np.zeros((N_test,sample_N,sample_N,sample_N,3))
924 |
925 | # build 3D coordinate meshgrid
926 | x1 = np.linspace(0, 1, sample_N+1)[:-1]
927 | all_data = np.stack(np.meshgrid(x1,x1,x1), axis=-1)
928 |
929 |
930 | for i in range(N_test):
931 | for k in range(N_repeat):
932 | if mask is None:
933 | idx = torch.randperm(sample_N**3)[:int(ratio*sample_N**3)]
934 | mask_np_N2 = np.zeros((sample_N**3))
935 | mask_np_N2[idx] = 1
936 | mask_np_N2 = mask_np_N2==1
937 | else:
938 | mask_np_N2 = mask[i,k]
939 |
940 | # Prepare the targets
941 | all_target = signals[i].squeeze()
942 | train_label = all_target.reshape(-1,3)[mask_np_N2].type(torch.FloatTensor)
943 | if test_mode == 'rest':
944 | test_label = all_target.reshape(-1,3)[~mask_np_N2].type(torch.FloatTensor)
945 | elif test_mode == 'all':
946 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor)
947 |
948 |
949 | start_time = time.time()
950 | train_data = encoding_func(torch.from_numpy(all_data.reshape(-1,3)[mask_np_N2]).type(torch.FloatTensor))
951 |
952 |
953 | if test_mode == 'rest':
954 | test_data = encoding_func(torch.from_numpy(all_data.reshape(-1,3)[~mask_np_N2]).type(torch.FloatTensor))
955 | elif test_mode == 'all':
956 | test_data = encoding_func(torch.from_numpy(all_data.reshape(-1,3)).type(torch.FloatTensor))
957 |
958 | else:
959 | print('Wrong test_mode!')
960 | return -1
961 |
962 | train_data, train_label = train_data.to(device),train_label.to(device)
963 | test_data, test_label = test_data.to(device),test_label.to(device)
964 |
965 |
966 | # Initialize classification model to learn
967 | model = MLP(input_dim=encoding_func.dim,output_dim=3,depth=depth,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device)
968 | # Set the optimization
969 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8)
970 |
971 |
972 | for epoch in range(epochs):
973 | model.train()
974 | optimizer.zero_grad()
975 |
976 | out = model(train_data)
977 | loss = criterion(out, train_label)
978 |
979 | loss.backward()
980 | optimizer.step()
981 | time_[i,k] = time.time() - start_time
982 | model.eval()
983 | with torch.no_grad():
984 | trn_psnr_[i,k] = psnr_func(model(train_data),train_label)
985 | tst_psnr_[i,k] = psnr_func(model(test_data),test_label)
986 |
987 |
988 | if logger == None:
989 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
990 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
991 | else:
992 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
993 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
994 | with torch.no_grad():
995 | for ff in range(128):
996 | rec_[i,ff] = model(encoding_func(torch.from_numpy(all_data[ff,:,:].reshape(-1,3)).type(torch.FloatTensor)).to(device)).reshape(sample_N,sample_N,3).detach().cpu()
997 | return time_, trn_psnr_,tst_psnr_,rec_
998 |
999 |
1000 |
1001 |
1002 | def train_index_blend_kron_3D(signals,outter_blending,inner_encoding,ratio=0.125,mask=None,sm=0.0,epochs=500,lr=1e-1,criterion=torch.nn.MSELoss(),device="cpu",
1003 | depth=0,width=256,width0=256,bias=True,use_sigmoid=False,N_repeat=1,test_mode='rest',logger=None):
1004 |
1005 | '''
1006 | Trainer of complex encoding for 3D video reconstruction with randomly sampled points.
1007 |
1008 | The input args are four parts:
1009 |
1010 | signals: a torch tensor shape in N_video x sample_N x sample_N x sample_N x 3
1011 | encoding_func: encoding function
1012 | ratio: if mask is None, randomly sample points for training with this ratio of number of all points
1013 | mask: a boolean matrix shape in N_video x N_repeat x (sample_N x sample_N x sample_N)
1014 | sm: how much we punish on total variation loss
1015 |
1016 | epochs,lr,criterion are optimization settings
1017 |
1018 | depth,width,width0,bias,use_sigmoid are network cofiguration
1019 |
1020 | N_repeat: number of repeat times for each example
1021 | test_mode: default as 'rest' to use all points except training ones for test. 'mid' will use a shift grid of training points. 'all' will use all ponts including training ones.
1022 | logger: if is None, loggers will print out.
1023 |
1024 | Returns:
1025 |
1026 | running time: N_test x N_repeat
1027 | train psnr: N_test x N_repeat
1028 | test psnr: N_test x N_repeat
1029 | reconstuction results: same size as input signals
1030 | '''
1031 |
1032 |
1033 | sample_N =signals.shape[1]
1034 | N_test = signals.shape[0]
1035 |
1036 |
1037 | trn_psnr_ = np.zeros((N_test,N_repeat))
1038 | tst_psnr_ = np.zeros((N_test,N_repeat))
1039 | time_ = np.zeros((N_test,N_repeat))
1040 | rec_ = np.zeros((N_test,sample_N,sample_N,sample_N,3))
1041 |
1042 | # the length of all_data is same as 1D signal length
1043 | x1 = np.linspace(0, 1, sample_N+1)[:-1]
1044 | all_data = np.stack(np.meshgrid(x1,x1,x1), axis=-1)
1045 | all_grid = torch.linspace(0, 1, outter_blending.dim+1)[:-1]
1046 | encoded_grid = inner_encoding(all_grid.reshape(-1,1))
1047 |
1048 | filter=torch.tensor([[[1.0,-1.0]]]).to(device)
1049 |
1050 |
1051 | for i in range(N_test):
1052 |
1053 | for k in range(N_repeat):
1054 |
1055 | if mask is None:
1056 | idx = torch.randperm(sample_N**3)[:int(ratio*sample_N**3)]
1057 | mask_np_N2 = np.zeros((sample_N**3))
1058 | mask_np_N2[idx] = 1
1059 | mask_np_N2 = mask_np_N2==1
1060 | else:
1061 | mask_np_N2 = mask[i,k]
1062 |
1063 |
1064 | # Prepare the targets
1065 | all_target = signals[i].squeeze()
1066 | train_label = all_target.reshape(-1,3)[mask_np_N2].type(torch.FloatTensor)
1067 | if test_mode == 'rest':
1068 | test_label = all_target.reshape(-1,3)[~mask_np_N2].type(torch.FloatTensor)
1069 | elif test_mode == 'all':
1070 | test_label = all_target.reshape(-1,3).type(torch.FloatTensor)
1071 |
1072 |
1073 | start_time = time.time()
1074 | train_idx,train_data = outter_blending(torch.from_numpy(all_data.reshape(-1,3)[mask_np_N2]).type(torch.FloatTensor))
1075 |
1076 |
1077 | if test_mode == 'rest':
1078 | test_idx, test_data = outter_blending(torch.from_numpy(all_data.reshape(-1,3)[~mask_np_N2]).type(torch.FloatTensor))
1079 | elif test_mode == 'all':
1080 | test_idx, test_data = outter_blending(torch.from_numpy(all_data.reshape(-1,3)).type(torch.FloatTensor))
1081 |
1082 | else:
1083 | print('Wrong test_mode!')
1084 | return -1
1085 |
1086 | train_idx, train_data, train_label = train_idx.to(device), train_data.to(device),train_label.to(device)
1087 | test_idx, test_data, test_label = test_idx.to(device), test_data.to(device),test_label.to(device)
1088 | encoded_grid = encoded_grid.to(device)
1089 |
1090 | # Initialize classification model to learn
1091 | model = Indexing_Blend_Kron3_MLP(input_dim=inner_encoding.dim,output_dim=3,depth=depth,width0=width0,width=width,bias=bias,use_sigmoid=use_sigmoid).to(device)
1092 | # Set the optimization
1093 | optimizer = torch.optim.Adam(model.parameters(), lr, betas=(0.9, 0.999),weight_decay=1e-8)
1094 |
1095 | for epoch in range(epochs):
1096 | model.train()
1097 | optimizer.zero_grad()
1098 |
1099 | out = model([train_idx,train_data],encoded_grid)
1100 | loss = criterion(out, train_label)
1101 | if sm>0:
1102 | for oo in range(model.first.weight.shape[0]):
1103 | loss += sm*smooth_3D(model.first.weight[oo],filter)
1104 |
1105 | loss.backward()
1106 | optimizer.step()
1107 |
1108 | time_[i,k] = time.time() - start_time
1109 | model.eval()
1110 | with torch.no_grad():
1111 | trn_psnr_[i,k] = psnr_func(model([train_idx,train_data],encoded_grid),train_label)
1112 | tst_psnr_[i,k] = psnr_func(model([test_idx,test_data],encoded_grid),test_label)
1113 |
1114 | if logger == None:
1115 | print("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
1116 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
1117 | else:
1118 | logger.info("==>>> E: %g, T: %g, train psnr: %g--- , test psnr: %g--- , time: %g seconds ---"
1119 | % (i,k, np.mean(trn_psnr_[i,k]),np.mean(tst_psnr_[i,k]),time.time() - start_time))
1120 | with torch.no_grad():
1121 | rec_[i] = model([outter_blending(torch.from_numpy(all_data.reshape(-1,3)).type(torch.FloatTensor))[0].to(device),
1122 | outter_blending(torch.from_numpy(all_data.reshape(-1,3)).type(torch.FloatTensor))[1].to(device)],encoded_grid).reshape(sample_N,sample_N,sample_N,3).detach().cpu()
1123 | return time_,trn_psnr_,tst_psnr_,rec_
1124 |
1125 |
1126 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | # encoding funcs for 1D, 2D and 3D
8 | class encoding_func_1D:
9 | def __init__(self, name, param=None):
10 | # param should be in form of [sigma,feature_dimension]
11 | self.name = name
12 |
13 | if name == 'none': self.dim=1
14 | elif name == 'basic': self.dim=2
15 | else:
16 | self.dim = param[1]
17 | if name == 'RFF':
18 | self.sig = param[0]
19 | self.b = param[0]*torch.randn((int(param[1]/2),1))
20 | elif name == 'rffb':
21 | self.b = param[0]
22 | elif name == 'LinF':
23 | self.b = torch.linspace(2.**0., 2.**param[0], steps=int(param[1]/2)).reshape(-1,1)
24 | elif name == 'LogF':
25 | self.b = 2.**torch.linspace(0., param[0], steps=int(param[1]/2)).reshape(-1,1)
26 | elif name == 'Gau':
27 | self.dic = torch.linspace(0., 1, steps=param[1]+1)[:-1].reshape(1,-1)
28 | self.sig = param[0]
29 | elif name == 'Tri':
30 | self.dic = torch.linspace(0., 1, steps=param[1]+1)[:-1].reshape(1,-1)
31 | if param[0] is None: self.d = 1/param[1]
32 | else: self.d = param[0]
33 | else:
34 | print('Undifined encoding!')
35 | def __call__(self, x):
36 | if self.name == 'none':
37 | return x
38 | elif self.name == 'basic':
39 | emb = torch.cat((torch.sin((2.*np.pi*x)),torch.cos((2.*np.pi*x))),1)
40 | emb = emb/(emb.norm(dim=1).max())
41 | return emb
42 | elif (self.name == 'RFF')|(self.name == 'rffb')|(self.name == 'LinF')|(self.name == 'LogF'):
43 | emb = torch.cat((torch.sin((2.*np.pi*x) @ self.b.T),torch.cos((2.*np.pi*x) @ self.b.T)),1)
44 | emb = emb/(emb.norm(dim=1).max())
45 | return emb
46 | elif self.name == 'Gau':
47 | emb = (-0.5*(x-self.dic)**2/(self.sig**2)).exp()
48 | emb = emb/(emb.norm(dim=1).max())
49 | return emb
50 | elif self.name == 'Tri':
51 | emb = (1-(x-self.dic).abs()/self.d)
52 | emb = emb*(emb>0)
53 | emb = emb/(emb.norm(dim=1).max())
54 | return emb
55 |
56 |
57 | # simple 2D encoding. For Fourier based methods we use 2 directions. For shifted verision encoder we use 4 directions. The total feature dimension is same.
58 | class encoding_func_2D:
59 | def __init__(self, name, param=None):
60 | self.name = name
61 |
62 | if name == 'none': self.dim=2
63 | elif name == 'basic': self.dim=4
64 | else:
65 | self.dim = param[1]
66 | if name == 'RFF':
67 | self.b = param[0]*torch.randn((int(param[1]/2),2))
68 | elif name == 'rffb':
69 | self.b = param[0]
70 | elif name == 'LinF':
71 | self.b = torch.linspace(2.**0., 2.**param[0], steps=int(param[1]/4)).reshape(-1,1)
72 | elif name == 'LogF':
73 | self.b = 2.**torch.linspace(0., param[0], steps=int(param[1]/4)).reshape(-1,1)
74 | elif name == 'Gau2':
75 | self.dic = torch.linspace(0., 1, steps=int(param[1]/2)+1)[:-1].reshape(1,-1)
76 | self.sig = param[0]
77 | elif (name == 'Gau4')|(name == 'Gau'):
78 | self.dic = torch.linspace(0., 1, steps=int(param[1]/4)+1)[:-1].reshape(1,-1)
79 | self.sig = param[0]
80 | elif name == 'Tri2':
81 | self.dic = torch.linspace(0., 1, steps=int(param[1]/2)+1)[:-1].reshape(1,-1)
82 | if param[0] is None: self.d = 1/param[1]
83 | else: self.d = param[0]
84 | elif (name == 'Tri4')|(name == 'Tri'):
85 | self.dic = torch.linspace(0., 1, steps=int(param[1]/4)+1)[:-1].reshape(1,-1)
86 | if param[0] is None: self.d = 1/param[1]
87 | else: self.d = param[0]
88 | else:
89 | print('Undifined encoding!')
90 | def __call__(self, x):
91 | if self.name == 'none':
92 | return x
93 | elif self.name == 'basic':
94 | emb = torch.cat((torch.sin((2.*np.pi*x)),torch.cos((2.*np.pi*x))),1)
95 | emb = emb/(emb.norm(dim=1).max())
96 | return emb
97 | elif (self.name == 'RFF')|(self.name == 'rffb'):
98 | emb = torch.cat((torch.sin((2.*np.pi*x) @ self.b.T),torch.cos((2.*np.pi*x) @ self.b.T)),1)
99 | emb = emb/(emb.norm(dim=1).max())
100 | return emb
101 | elif (self.name == 'LinF')|(self.name == 'LogF'):
102 | emb1 = torch.cat((torch.sin((2.*np.pi*x[:,:1]) @ self.b.T),torch.cos((2.*np.pi*x[:,:1]) @ self.b.T)),1)
103 | emb2 = torch.cat((torch.sin((2.*np.pi*x[:,1:2]) @ self.b.T),torch.cos((2.*np.pi*x[:,1:2]) @ self.b.T)),1)
104 | emb = torch.cat([emb1,emb2],1)
105 | emb = emb/(emb.norm(dim=1).max())
106 | return emb
107 | elif self.name == 'Gau2':
108 | emb1 = (-0.5*(x[:,:1]-self.dic)**2/(self.sig**2)).exp()
109 | emb2 = (-0.5*(x[:,1:2]-self.dic)**2/(self.sig**2)).exp()
110 | emb = torch.cat([emb1,emb2],1)
111 | emb = emb/(emb.norm(dim=1).max())
112 | return emb
113 | elif (self.name == 'Gau4')|(self.name == 'Gau'):
114 | emb1 = (-0.5*(x[:,:1]-self.dic)**2/(self.sig**2)).exp()
115 | emb2 = (-0.5*(x[:,1:2]-self.dic)**2/(self.sig**2)).exp()
116 | emb3 = (-0.5*(0.5*(x[:,:1]+x[:,1:2])-self.dic)**2/(self.sig**2)).exp()
117 | emb4 = (-0.5*(0.5*(x[:,:1]-x[:,1:2]+1)-self.dic)**2/(self.sig**2)).exp()
118 | emb = torch.cat([emb1,emb2,emb3,emb4],1)
119 | emb = emb/(emb.norm(dim=1).max())
120 | return emb
121 | elif self.name == 'Tri2':
122 | emb1 = (1-(x[:,:1]-self.dic).abs()/self.d)
123 | emb1 = emb1*(emb1>0)
124 | emb2 = (1-(x[:,1:2]-self.dic).abs()/self.d)
125 | emb2 = emb2*(emb2>0)
126 | emb = torch.cat([emb1,emb2],1)
127 | emb = emb/(emb.norm(dim=1).max())
128 | return emb
129 | elif (self.name == 'Tri4')|(self.name == 'Tri'):
130 | emb1 = (1-(x[:,:1]-self.dic).abs()/self.d)
131 | emb1 = emb1*(emb1>0)
132 | emb2 = (1-(x[:,1:2]-self.dic).abs()/self.d)
133 | emb2 = emb2*(emb2>0)
134 | emb3 = (1-(0.5*(x[:,:1]+x[:,1:2])-self.dic).abs()/self.d)
135 | emb3 = emb3*(emb3>0)
136 | emb4 = (1-(0.5*(x[:,:1]-x[:,1:2]+1)-self.dic).abs()/self.d)
137 | emb4 = emb4*(emb4>0)
138 | emb = torch.cat([emb1,emb2,emb3,emb4],1)
139 | emb = emb/(emb.norm(dim=1).max())
140 | return emb
141 |
142 |
143 | # simple 3D encoding. All method use 3 directions.
144 | class encoding_func_3D:
145 | def __init__(self, name, param=None):
146 | self.name = name
147 |
148 | if name == 'none': self.dim=2
149 | elif name == 'basic': self.dim=4
150 | else:
151 | self.dim = param[1]
152 | if name == 'RFF':
153 | self.b = param[0]*torch.randn((int(param[1]/2),3))
154 | elif name == 'rffb':
155 | self.b = param[0]
156 | elif name == 'LinF':
157 | self.b = torch.linspace(2.**0., 2.**param[0], steps=int(param[1]/6)).reshape(-1,1)
158 | elif name == 'LogF':
159 | self.b = 2.**torch.linspace(0., param[0], steps=int(param[1]/6)).reshape(-1,1)
160 | elif name == 'Gau':
161 | self.dic = torch.linspace(0., 1, steps=int(param[1]/3)+1)[:-1].reshape(1,-1)
162 | self.sig = param[0]
163 | elif name == 'Tri':
164 | self.dic = torch.linspace(0., 1, steps=int(param[1]/3)+1)[:-1].reshape(1,-1)
165 | if param[0] is None: self.d = 1/param[1]
166 | else: self.d = param[0]
167 | else:
168 | print('Undifined encoding!')
169 | def __call__(self, x):
170 | if self.name == 'none':
171 | return x
172 | elif self.name == 'basic':
173 | emb = torch.cat((torch.sin((2.*np.pi*x)),torch.cos((2.*np.pi*x))),1)
174 | emb = emb/(emb.norm(dim=1).max())
175 | return emb
176 | elif (self.name == 'RFF')|(self.name == 'rffb'):
177 | emb = torch.cat((torch.sin((2.*np.pi*x) @ self.b.T),torch.cos((2.*np.pi*x) @ self.b.T)),1)
178 | emb = emb/(emb.norm(dim=1).max())
179 | return emb
180 | elif (self.name == 'LinF')|(self.name == 'LogF'):
181 | emb1 = torch.cat((torch.sin((2.*np.pi*x[:,:1]) @ self.b.T),torch.cos((2.*np.pi*x[:,:1]) @ self.b.T)),1)
182 | emb2 = torch.cat((torch.sin((2.*np.pi*x[:,1:2]) @ self.b.T),torch.cos((2.*np.pi*x[:,1:2]) @ self.b.T)),1)
183 | emb3 = torch.cat((torch.sin((2.*np.pi*x[:,2:3]) @ self.b.T),torch.cos((2.*np.pi*x[:,2:3]) @ self.b.T)),1)
184 | emb = torch.cat([emb1,emb2,emb3],1)
185 | emb = emb/(emb.norm(dim=1).max())
186 | return emb
187 | elif self.name == 'Gau':
188 | emb1 = (-0.5*(x[:,:1]-self.dic)**2/(self.sig**2)).exp()
189 | emb2 = (-0.5*(x[:,1:2]-self.dic)**2/(self.sig**2)).exp()
190 | emb3 = (-0.5*(x[:,2:3]-self.dic)**2/(self.sig**2)).exp()
191 | emb = torch.cat([emb1,emb2,emb3],1)
192 | emb = emb/(emb.norm(dim=1).max())
193 | return emb
194 | elif self.name == 'Tri':
195 | emb1 = (1-(x[:,:1]-self.dic).abs()/self.d)
196 | emb1 = emb1*(emb1>0)
197 | emb2 = (1-(x[:,1:2]-self.dic).abs()/self.d)
198 | emb2 = emb2*(emb2>0)
199 | emb3 = (1-(x[:,2:3]-self.dic).abs()/self.d)
200 | emb3 = emb3*(emb3>0)
201 | emb = torch.cat([emb1,emb2,emb3],1)
202 | emb = emb/(emb.norm(dim=1).max())
203 | return emb
204 |
205 |
206 | # blending matrix for 2D random samples
207 | class blending_func_2D:
208 | '''
209 | encoding_func : inner encoding func. The name will be used to choose corresponding closed form distance func, which will be much faster than calculating experimentally.
210 | dim : number of inner encoing func to inerpolate, usually equals to the feature dimension of inner encoding
211 | indexing: defalt as True to return grid point index and weights. Set False to return a sparse marix but will be sloer in the future.
212 | '''
213 | def __init__(self, encoding_func, dim=256,indexing=True):
214 | self.name = encoding_func.name
215 | self.indexing = indexing
216 |
217 | if dim is None: self.dim = encoding_func.dim
218 | else: self.dim = dim
219 |
220 | if self.name == 'RFF':
221 | self.D = lambda x1,x2: (-2*(np.pi*(x1-x2)/(self.dim-1)*encoding_func.sig)**2).exp()
222 | elif self.name == 'Gau':
223 | self.D = lambda x1,x2: (-0.25*((x1-x2)/(self.dim-1))**2/(encoding_func.sig**2)).exp()
224 | elif self.name == 'Tri':
225 | self.D = lambda x1,x2: 0.25*torch.maximum(2*encoding_func.d-(x1-x2).abs()/(self.dim-1),torch.tensor(0))**2
226 | else:
227 | self.D = lambda x1,x2: (encoding_func(x1/(self.dim-1))*encoding_func(x2/(self.dim-1))).sum(-1).unsqueeze(1)
228 |
229 | def __call__(self, x):
230 | # make x in the grid
231 | x = x.clamp(0,1-1e-3)
232 | x = (self.dim-1)*x
233 | y = x[:,1:2]
234 | x = x[:,:1]
235 | xmin = torch.floor(x)
236 | ymin = torch.floor(y)
237 |
238 |
239 | xd0 = self.D(xmin,xmin)
240 | xdd = self.D(xmin,xmin+1)
241 | xd1 = self.D(xmin+1,xmin+1)
242 | xda = self.D(xmin,x)
243 | xdb = self.D(xmin+1,x)
244 | xff = xd0*xd1-xdd**2
245 |
246 | xa = (xda*xd1-xdb*xdd)/xff
247 | xb = (xdb*xd0-xda*xdd)/xff
248 |
249 | yd0 = self.D(ymin,ymin)
250 | ydd = self.D(ymin,ymin+1)
251 | yd1 = self.D(ymin+1,ymin+1)
252 | yda = self.D(ymin,y)
253 | ydb = self.D(ymin+1,y)
254 | yff = yd0*yd1-ydd**2
255 |
256 | ya = (yda*yd1-ydb*ydd)/yff
257 | yb = (ydb*yd0-yda*ydd)/yff
258 |
259 | xs = xa+xb
260 | xa = xa/xs
261 | xb = xb/xs
262 |
263 | ys = ya+yb
264 | ya = ya/ys
265 | yb = yb/ys
266 |
267 |
268 | if self.indexing:
269 | return [torch.cat([xmin*self.dim+ymin,xmin*self.dim+ymin+1,(xmin+1)*self.dim+ymin,(xmin+1)*self.dim+ymin+1],1).type(torch.LongTensor),torch.cat([xa*ya,xa*yb,xb*ya,xb*yb],1)]
270 | else:
271 | c = torch.cat([xa*ya,xa*yb,xb*ya,xb*yb],0)
272 |
273 | y = torch.cat([xmin*self.dim+ymin,xmin*self.dim+ymin+1,(xmin+1)*self.dim+ymin,(xmin+1)*self.dim+ymin+1],0).type(torch.IntTensor)
274 | x = torch.linspace(0,x.shape[0]-1,x.shape[0],dtype=int).reshape(-1,1).repeat(4,1)
275 | return torch.sparse_coo_tensor(torch.cat([x,y],1).T, c.reshape(-1), (int(x.shape[0]/4), self.dim**2))
276 |
277 |
278 |
279 |
280 |
281 | class blending_func_3D:
282 | '''
283 | encoding_func : inner encoding func. The name will be used to choose corresponding closed form distance func, which will be much faster than calculating experimentally.
284 | dim : number of inner encoing func to inerpolate, usually equals to the feature dimension of inner encoding
285 | indexing: defalt as True to return grid point index and weights. Set False to return a sparse marix but will be sloer in the future.
286 | '''
287 | def __init__(self, encoding_func, dim=None,indexing=True):
288 | self.name = encoding_func.name
289 | self.indexing = indexing
290 |
291 | if dim is None: self.dim = encoding_func.dim
292 | else: self.dim = dim
293 |
294 | if self.name == 'RFF':
295 | self.D = lambda x1,x2: (-2*(np.pi*(x1-x2)/(self.dim-1)*encoding_func.sig)**2).exp()
296 | elif self.name == 'Gau':
297 | self.D = lambda x1,x2: (-0.25*((x1-x2)/(self.dim-1))**2/(encoding_func.sig**2)).exp()
298 | elif self.name == 'Tri':
299 | self.D = lambda x1,x2: 0.25*torch.maximum(2*encoding_func.d-(x1-x2).abs()/(self.dim-1),torch.tensor(0))**2
300 | else:
301 | self.D = lambda x1,x2: (encoding_func(x1/(self.dim-1))*encoding_func(x2/(self.dim-1))).sum(-1).unsqueeze(1)
302 |
303 | def __call__(self, x):
304 | # make x in the grid
305 | x = x.clamp(0,1-1e-3)
306 | x = (self.dim-1)*x
307 | y = x[:,1:2]
308 | z = x[:,2:3]
309 | x = x[:,:1]
310 | xmin = torch.floor(x)
311 | ymin = torch.floor(y)
312 | zmin = torch.floor(z)
313 |
314 | if self.name=='RFF' or self.name=='Gau' or self.name=='Tri':
315 | d0 = self.D(torch.tensor([0]),torch.tensor([0]))
316 | dd = self.D(torch.tensor([0]),torch.tensor([1]))
317 | ff = d0**2-dd**2
318 |
319 | xda = self.D(xmin,x)
320 | xdb = self.D(xmin+1,x)
321 |
322 | xa = (xda*d0-xdb*dd)/ff
323 | xb = (xdb*d0-xda*dd)/ff
324 |
325 | yda = self.D(ymin,y)
326 | ydb = self.D(ymin+1,y)
327 |
328 | ya = (yda*d0-ydb*dd)/ff
329 | yb = (ydb*d0-yda*dd)/ff
330 |
331 | zda = self.D(zmin,z)
332 | zdb = self.D(zmin+1,z)
333 |
334 | za = (zda*d0-zdb*dd)/ff
335 | zb = (zdb*d0-zda*dd)/ff
336 |
337 | xs = xa+xb
338 | xa = xa/xs
339 | xb = xb/xs
340 |
341 | ys = ya+yb
342 | ya = ya/ys
343 | yb = yb/ys
344 |
345 | zs = za+zb
346 | za = za/zs
347 | zb = zb/zs
348 |
349 | N=self.dim
350 | Ns = x.shape[0]
351 | if self.indexing:
352 | c = torch.cat([xa*ya*za,xa*ya*zb,xa*yb*za,xa*yb*zb,xb*ya*za,xb*ya*zb,xb*yb*za,xb*yb*zb],1)
353 | y = torch.cat([xmin*N**2+ymin*N+zmin,xmin*N**2+ymin*N+zmin+1,xmin*N**2+(ymin+1)*N+zmin,xmin*N**2+(ymin+1)*N+zmin+1,
354 | (xmin+1)*N**2+ymin*N+zmin,(xmin+1)*N**2+ymin*N+zmin+1,(xmin+1)*N**2+(ymin+1)*N+zmin,(xmin+1)*N**2+(ymin+1)*N+zmin+1],1).type(torch.LongTensor)
355 | return [y,c]
356 | else:
357 | c = torch.cat([xa*ya*za,xa*ya*zb,xa*yb*za,xa*yb*zb,xb*ya*za,xb*ya*zb,xb*yb*za,xb*yb*zb],0)
358 |
359 |
360 | y = torch.cat([xmin*N**2+ymin*N+zmin,xmin*N**2+ymin*N+zmin+1,xmin*N**2+(ymin+1)*N+zmin,xmin*N**2+(ymin+1)*N+zmin+1,
361 | (xmin+1)*N**2+ymin*N+zmin,(xmin+1)*N**2+ymin*N+zmin+1,(xmin+1)*N**2+(ymin+1)*N+zmin,(xmin+1)*N**2+(ymin+1)*N+zmin+1],0).type(torch.IntTensor)
362 | x = torch.range(0,Ns-1,dtype=int).reshape(-1,1).repeat(8,1)
363 | return torch.sparse_coo_tensor(torch.cat([x,y],1).T, c.reshape(-1), (Ns, self.dim**3))
364 |
365 | else:
366 | xd0 = self.D(xmin,xmin)
367 | xdd = self.D(xmin,xmin+1)
368 | xd1 = self.D(xmin+1,xmin+1)
369 | xda = self.D(xmin,x)
370 | xdb = self.D(xmin+1,x)
371 | xff = xd0*xd1-xdd**2
372 |
373 | xa = (xda*xd1-xdb*xdd)/xff
374 | xb = (xdb*xd0-xda*xdd)/xff
375 |
376 | yd0 = self.D(ymin,ymin)
377 | ydd = self.D(ymin,ymin+1)
378 | yd1 = self.D(ymin+1,ymin+1)
379 | yda = self.D(ymin,y)
380 | ydb = self.D(ymin+1,y)
381 | yff = yd0*yd1-ydd**2
382 |
383 | ya = (yda*yd1-ydb*ydd)/yff
384 | yb = (ydb*yd0-yda*ydd)/yff
385 |
386 |
387 | zd0 = self.D(zmin,zmin)
388 | zdd = self.D(zmin,zmin+1)
389 | zd1 = self.D(zmin+1,zmin+1)
390 | zda = self.D(zmin,z)
391 | zdb = self.D(zmin+1,z)
392 | zff = zd0*zd1-zdd**2
393 |
394 | za = (zda*zd1-zdb*zdd)/zff
395 | zb = (zdb*zd0-zda*zdd)/zff
396 |
397 | xs = xa+xb
398 | xa = xa/xs
399 | xb = xb/xs
400 |
401 | ys = ya+yb
402 | ya = ya/ys
403 | yb = yb/ys
404 |
405 | zs = za+zb
406 | za = za/zs
407 | zb = zb/zs
408 |
409 | N=self.dim
410 | Ns = x.shape[0]
411 | if self.indexing:
412 | c = torch.cat([xa*ya*za,xa*ya*zb,xa*yb*za,xa*yb*zb,xb*ya*za,xb*ya*zb,xb*yb*za,xb*yb*zb],1)
413 | y = torch.cat([xmin*N**2+ymin*N+zmin,xmin*N**2+ymin*N+zmin+1,xmin*N**2+(ymin+1)*N+zmin,xmin*N**2+(ymin+1)*N+zmin+1,
414 | (xmin+1)*N**2+ymin*N+zmin,(xmin+1)*N**2+ymin*N+zmin+1,(xmin+1)*N**2+(ymin+1)*N+zmin,(xmin+1)*N**2+(ymin+1)*N+zmin+1],1).type(torch.LongTensor)
415 | return [y,c]
416 | else:
417 | c = torch.cat([xa*ya*za,xa*ya*zb,xa*yb*za,xa*yb*zb,xb*ya*za,xb*ya*zb,xb*yb*za,xb*yb*zb],0)
418 |
419 |
420 | y = torch.cat([xmin*N**2+ymin*N+zmin,xmin*N**2+ymin*N+zmin+1,xmin*N**2+(ymin+1)*N+zmin,xmin*N**2+(ymin+1)*N+zmin+1,
421 | (xmin+1)*N**2+ymin*N+zmin,(xmin+1)*N**2+ymin*N+zmin+1,(xmin+1)*N**2+(ymin+1)*N+zmin,(xmin+1)*N**2+(ymin+1)*N+zmin+1],0).type(torch.IntTensor)
422 | x = torch.range(0,Ns-1,dtype=int).reshape(-1,1).repeat(8,1)
423 | return torch.sparse_coo_tensor(torch.cat([x,y],1).T, c.reshape(-1), (Ns, self.dim**3))
424 |
425 |
426 |
427 | class fast_blending_func_3D:
428 | '''
429 | A fast version (may not really) without many options for common use.
430 | encoding_func : inner encoding func. The name will be used to choose corresponding closed form distance func, which will be much faster than calculating experimentally.
431 | dim : number of inner encoing func to inerpolate, usually equals to the feature dimension of inner encoding
432 | '''
433 | def __init__(self, encoding_func, dim=None):
434 | self.name = encoding_func.name
435 |
436 | if dim is None: self.dim = encoding_func.dim
437 | else: self.dim = dim
438 |
439 | if self.name == 'RFF':
440 | self.D = lambda t: (-2*(np.pi*(t)/(self.dim-1)*encoding_func.sig)**2).exp()
441 | elif self.name == 'Gau':
442 | self.D = lambda t: (-0.25*((t)/(self.dim-1))**2/(encoding_func.sig**2)).exp()
443 | elif self.name == 'Tri':
444 | self.D = lambda t: 0.25*torch.maximum(2*encoding_func.d-(t).abs()/(self.dim-1),torch.tensor(0))**2
445 | else:
446 | print('No fast version!')
447 |
448 | # constant coefficients
449 | d0 = self.D(torch.tensor([0]))
450 | dd = self.D(torch.tensor([1]))
451 | ff = d0**2-dd**2
452 | self.coe = torch.tensor([[d0,-dd],[-dd,d0]])/ff
453 |
454 | self.dim2 = self.dim**2
455 | self.idx_matrix = torch.tensor([self.dim**2,self.dim,1.0]).reshape(3,1)
456 |
457 | def __call__(self, x):
458 | # make x in the grid
459 | x = x.clamp(0,1-1e-3)
460 | x = (self.dim-1)*x
461 | xmin = torch.floor(x)
462 |
463 | d_ratio = torch.stack([self.D(xmin-x),self.D(xmin+1-x)],-1)
464 | itp_weights = d_ratio@self.coe
465 |
466 | # calcualte mixing product
467 | xy = itp_weights[:,0,0]*itp_weights[:,1,0]
468 | xy1 = itp_weights[:,0,0]*itp_weights[:,1,1]
469 | x1y = itp_weights[:,0,1]*itp_weights[:,1,0]
470 | x1y1 = itp_weights[:,0,1]*itp_weights[:,1,1]
471 | z = itp_weights[:,2,0]
472 | z1 = itp_weights[:,2,1]
473 | c = torch.stack([xy*z,xy*z1,xy1*z,xy1*z1,x1y*z,x1y*z1,x1y1*z,x1y1*z1],1)
474 |
475 | x_idx = xmin@self.idx_matrix
476 | y = torch.cat([x_idx,x_idx+1],1)
477 | y = torch.cat([y,y+self.dim],1)
478 | y = torch.cat([y,y+self.dim2],1)
479 | return [y.type(torch.LongTensor),c]
480 |
481 |
482 |
483 | def srank_func(X,return_l1=False):
484 | (_,s,_) = torch.svd(X)
485 | sr2 = (s*s).sum()/s[0]/s[0]
486 | sr1 = s.sum()/s[0]
487 | if return_l1:
488 | return sr1,sr2
489 | else:
490 | return sr2
491 |
492 |
493 | def psnr_func(x,y,return_mse=False):
494 | diff = x - y
495 | mse = (diff*diff).flatten().mean()
496 | if return_mse:
497 | return -10*(mse.log10()),mse
498 | else:
499 | return -10*(mse.log10())
500 |
501 |
502 | def smooth(X,filter):
503 | sx = (torch.nn.functional.conv1d(X.unsqueeze(1),filter)**2).mean()
504 | sy = (torch.nn.functional.conv1d(X.T.unsqueeze(1),filter)**2).mean()
505 | return sx+sy
506 |
507 |
508 | def smooth_3D(X,filter):
509 | sx = (torch.nn.functional.conv1d(X.flatten(0,1).unsqueeze(1),filter)**2).mean()
510 | sy = (torch.nn.functional.conv1d(X.transpose(1,2).flatten(0,1).unsqueeze(1),filter)**2).mean()
511 | sz = (torch.nn.functional.conv1d(X.transpose(0,2).flatten(0,1).unsqueeze(1),filter)**2).mean()
512 | return sx+sy+sz
513 |
--------------------------------------------------------------------------------