├── .DS_Store ├── LICENSE ├── README.md ├── demo_decode.py ├── demo_eval.py ├── encoded_weights ├── .DS_Store ├── model_lego.pth └── model_tarot.pth └── training_scripts ├── README.txt ├── eval_net.py ├── network.py ├── preprocess.py └── train_net.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AugmentariumLab/SIGNET/1c04c87c1148e9070cc586effe9333ce35c942fd/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 The Augmentarium 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SIGNET: Efficient Neural Representations for Light Fields 2 | This repository contains the demo code for SIGNET: Efficient Neural Representations for Light Fields, published at ICCV 2021. We provide the Python implementation of Gegenbauer embedding as well as the network used to encode the light fields. 3 | 4 | ## Requirements 5 | * CUDA 6 | * PyTorch 7 | * Numpy 8 | * PIL 9 | 10 | ## Demo 11 | 12 | To decode an image at light field view point (u, v), please run 13 | * `python demo_decode.py -u [u] -v [v] --scene [scene_name]` 14 | * u and v are integers within the range [0, 16], specifying the viewpoint coordinates in the original light field 15 | We provide the pretrained weights for scenes "lego" and "tarot" in the `encoded_weights` folder. 16 | 17 | ## Related Publication 18 | 19 | Please refer to for our paper published in ICCV 2021: "SIGNET: Efficient Neural Representations for Light Fields". 20 | 21 | ## References 22 | 23 | If you use this in your research, please reference it as: 24 | 25 | @inproceedings{Feng2021SIGNET, 26 | author={Feng, Brandon Y. and Varshney, Amitabh}, 27 | booktitle={Proceedings of the International Conference on Computer Vision (ICCV 2021)}, 28 | title={SIGNET: Efficient Neural Representations for Light Fields}, 29 | year={2021}, 30 | } 31 | 32 | or 33 | 34 | Brandon Y. Feng and Amitabh Varshney. 2021. SIGNET: Efficient Neural Representations for Light Fields. International Conference on Computer Vision (ICCV 2021). 35 | -------------------------------------------------------------------------------- /demo_decode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | def carte_to_geg(x, N_freqs, alpha=0.5): 11 | n = x.shape[0] 12 | c = np.zeros(( n, N_freqs + 1 )) 13 | c[:, 0] = 1.0 14 | c[:, 1] = 2.0 * alpha * x 15 | for i in range(2, N_freqs + 1 ): 16 | c[:, i] = ( ( 2 * i - 2 + 2.0 * alpha ) * x * c[:, i-1] + (- i + 2 - 2.0 * alpha ) * c[:, i-2] ) / i 17 | return c[:, 1:] 18 | 19 | class Embedding(nn.Module): 20 | def __init__(self, N_freqs, n_size, alpha=0.5): 21 | super(Embedding, self).__init__() 22 | self.N_freqs = N_freqs 23 | self.alpha = alpha 24 | 25 | x = np.linspace(-0.5, 0.5, n_size) 26 | #x = np.linspace(-1, 1, n_size) 27 | self.cache_geg = nn.Parameter(torch.from_numpy(carte_to_geg(x, N_freqs, alpha)).float(), requires_grad=False) 28 | 29 | def forward(self, x): 30 | return self.cache_geg[x.long()] 31 | 32 | class SineLayer(nn.Module): 33 | def __init__(self, in_features, out_features, bias=True, is_first=False, is_res=False, omega_0=30): 34 | super().__init__() 35 | self.omega_0 = omega_0 36 | self.is_first = is_first 37 | self.is_res = is_res 38 | self.in_features = in_features 39 | self.linear = nn.Linear(in_features, out_features, bias=bias) 40 | self.init_weights(self.linear) 41 | 42 | def init_weights(self, layer): 43 | with torch.no_grad(): 44 | if self.is_first: 45 | layer.weight.uniform_(-1 / self.in_features, 46 | 1 / self.in_features) 47 | else: 48 | layer.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 49 | np.sqrt(6 / self.in_features) / self.omega_0) 50 | 51 | def forward(self, input): 52 | if self.is_res: 53 | return input + torch.sin(self.omega_0 * self.linear(input)) 54 | else: 55 | return torch.sin(self.omega_0 * self.linear(input)) 56 | 57 | class SIGNET(nn.Module): 58 | def __init__(self, first_omega_0=30, hidden_omega_0=30., hidden_layers=8, in_feature_ratio=1, out_features=3, 59 | hidden_features=512, alpha=0.5, with_res=False, with_norm=False): 60 | super().__init__() 61 | 62 | in_features = int(in_feature_ratio * 512) 63 | self.with_res = with_res 64 | self.D = hidden_layers + 2 65 | 66 | for i in range(hidden_layers+1): 67 | if i == 0: 68 | layer = SineLayer(in_features, hidden_features, is_first=True, is_res=False, omega_0=first_omega_0) 69 | else: 70 | layer = SineLayer(hidden_features, hidden_features, is_first=False, is_res=self.with_res, omega_0=hidden_omega_0) 71 | if with_norm: 72 | layer = nn.Sequential(layer, nn.LayerNorm(hidden_features, elementwise_affine=True)) 73 | setattr(self, f"encoding_{i+1}", layer) 74 | 75 | final_linear = nn.Linear(hidden_features, out_features) 76 | 77 | with torch.no_grad(): 78 | final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, np.sqrt(6 / hidden_features) / hidden_omega_0) 79 | setattr(self, f"encoding_{hidden_layers+2}", final_linear) 80 | 81 | self.N_xy = int(in_feature_ratio * (240)) 82 | self.N_uv = int(in_feature_ratio * (16)) 83 | 84 | self.xy_embedd = Embedding(self.N_xy, 1024, alpha) 85 | self.uv_embedd = Embedding(self.N_uv, 17, alpha) 86 | 87 | def forward(self, x): 88 | emb_x = torch.cat( [self.uv_embedd(x[:, 0]), self.uv_embedd(x[:, 1]), self.xy_embedd(x[:, 2]), self.xy_embedd(x[:, 3]) ], axis=1).to(x.device) 89 | out = emb_x 90 | for i in range(self.D): 91 | out = getattr(self, f"encoding_{i+1}")(out) 92 | return out 93 | 94 | def get_LF_val(u, v, width=1024, height=1024): 95 | x = np.linspace(0, width-1, width) 96 | y = np.linspace(0, height-1, height) 97 | 98 | xv, yv = np.meshgrid(y, x) 99 | img_grid = torch.from_numpy(np.stack([yv, xv], axis=-1)) 100 | 101 | uv_grid = torch.ones_like(img_grid) 102 | uv_grid[:, :, 0], uv_grid[:, :, 1] = u, v 103 | 104 | val_inp_t = torch.cat([uv_grid, img_grid], dim = -1).float() 105 | 106 | del img_grid, xv, yv 107 | return val_inp_t.view(-1, val_inp_t.shape[-1]) 108 | 109 | def eval_im(val_inp_t, batches, device): 110 | b_size = val_inp_t.shape[0] // batches 111 | with torch.no_grad(): 112 | out = [] 113 | for b in range(batches): 114 | out.append(model(val_inp_t[b_size*b:b_size*(b+1)].to(device))) 115 | out = torch.cat(out, dim = 0) 116 | out = torch.clamp(out, 0, 1) 117 | out_np = out.view(1024, 1024, 3).cpu().numpy() * 255 118 | return out_np 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("-u", type=int, default=0, help="angular dimension u") 123 | parser.add_argument("-v", type=int, default=0, help="angular dimension v") 124 | parser.add_argument("-b", type=int, default=4, help="batch size in inference") 125 | parser.add_argument("--scene", type=str, default="lego", help="lego or tarot") 126 | args = parser.parse_args() 127 | 128 | OUT_DIR = f'./decoded_images/{args.scene}' 129 | if not os.path.exists(OUT_DIR): 130 | os.makedirs(OUT_DIR) 131 | 132 | args = parser.parse_args() 133 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 134 | 135 | model = SIGNET(hidden_layers=8, alpha=0.5, hidden_features=512, in_feature_ratio=1, with_norm=True, with_res=True) 136 | m_state_dict = torch.load(f'./encoded_weights/model_{args.scene}.pth') 137 | model.load_state_dict(m_state_dict) 138 | model.eval() 139 | model = model.to(device) 140 | val_inp_t = get_LF_val(u=args.u, v=args.v).to(device) 141 | out_np = eval_im(val_inp_t, args.b, device) 142 | Image.fromarray(np.uint8(out_np)).save(f'{OUT_DIR}/%s_u%d_v%d.png' % (args.scene, args.u, args.v)) 143 | -------------------------------------------------------------------------------- /demo_eval.py: -------------------------------------------------------------------------------- 1 | import os, tqdm, itertools, argparse 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from skimage.metrics import structural_similarity as ssim 6 | 7 | from demo_decode import SIGNET, get_LF_val, eval_im 8 | 9 | def compute_psnr(img1, img2): 10 | img1 = img1.astype(np.float64) 11 | img2 = img2.astype(np.float64) 12 | mse = np.mean((img1 - img2) ** 2) 13 | return 10 * np.log10(255**2 / mse) 14 | 15 | def compute_ssim(img1, img2): 16 | img1 = img1.astype(np.float64) 17 | img2 = img2.astype(np.float64) 18 | score = ssim(img1, img2, multichannel=True, data_range=255) 19 | return score 20 | 21 | def read_uv_view(u, v, img_dir): 22 | for file in sorted(os.listdir(img_dir)): 23 | _u, _v = file.split('_')[1], file.split('_')[2] 24 | if f'{u:02d}' == _u and f'{v:02d}' == _v: 25 | img = np.asarray(Image.open(f'{img_dir}/{file}')).astype(np.uint8) 26 | return img 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("-b", type=int, default=4, help="batch size in inference") 31 | parser.add_argument("--scene", type=str, default="lego", help="lego or tarot") 32 | parser.add_argument("--img_dir", type=str, default="./data/lego", help="path to folder with all ground truth images") 33 | 34 | args = parser.parse_args() 35 | 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | model = SIGNET(hidden_layers=8, alpha=0.5, hidden_features=512, in_feature_ratio=1, with_norm=True, with_res=True) 39 | m_state_dict = torch.load(f'./encoded_weights/model_{args.scene}.pth') 40 | model.load_state_dict(m_state_dict) 41 | model.eval() 42 | model = model.to(device) 43 | val_inp_t = get_LF_val(u=0, v=0).to(device) 44 | 45 | uv_range = 17 46 | tbar = tqdm.tqdm(list(itertools.product(range(uv_range), range(uv_range)))) 47 | p, s = 0, 0 48 | ct = 0 49 | for (u, v) in tbar: 50 | val_inp_t[..., 0] = v 51 | val_inp_t[..., 1] = u 52 | out_np = eval_im(model, val_inp_t, batches=args.b, device=device) 53 | img_gt = read_uv_view(u, v, args.img_dir) 54 | 55 | p_ = compute_psnr(img_gt, out_np) 56 | s_ = compute_ssim(img_gt, out_np) 57 | ct += 1 58 | p += p_ 59 | s += s_ 60 | tbar.set_postfix(PSNR = p_, AvePSNR = p / ct, SSIM = s_, AveSSIM = s / ct) 61 | 62 | print(f'PSNR: {p / ct} | SSIM: {s / ct}') 63 | -------------------------------------------------------------------------------- /encoded_weights/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AugmentariumLab/SIGNET/1c04c87c1148e9070cc586effe9333ce35c942fd/encoded_weights/.DS_Store -------------------------------------------------------------------------------- /encoded_weights/model_lego.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AugmentariumLab/SIGNET/1c04c87c1148e9070cc586effe9333ce35c942fd/encoded_weights/model_lego.pth -------------------------------------------------------------------------------- /encoded_weights/model_tarot.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AugmentariumLab/SIGNET/1c04c87c1148e9070cc586effe9333ce35c942fd/encoded_weights/model_tarot.pth -------------------------------------------------------------------------------- /training_scripts/README.txt: -------------------------------------------------------------------------------- 1 | Here we provide demo scripts to process the Stanford Light Field dataset (http://lightfield.stanford.edu/lfs.html) and train the neural representation with SIGNET. 2 | 3 | Step 1: 4 | Download specific datasets to your local directory. For example, put the unzipped Lego Knights images into a folder called "{ROOT}/data/lego". 5 | 6 | Step 2: 7 | Run "python preprocess.py --img_dir {ROOT}/data/lego/ --save_dir {ROOT}/patch_data/lego_patches" 8 | 9 | This step is performed to preprocess the light field dataset into small batches convenient for network training. 10 | You may adjust the batch size within this script, or increase the batch size in the dataloader during training. 11 | Here the validation image is hardcoded as the "01_01" view. Feel free to adjust according to your needs. 12 | 13 | Step 3: 14 | Run "python train_net.py --root_dir {ROOT} --exp_name lego_test --trainset_dir {ROOT}/patch_data/lego_patches" 15 | 16 | Training should begin following this command. 17 | You should find the trained weights and validation output at folder "{ROOT}/{exp_name}" during training. 18 | The image resolution is assumed to be less than 1024x1024. If you work on data with higher resolution, please use the "--img_W" and "--img_H" arguments to adjust accordingly. 19 | 20 | To decode light field view at (u, v) from trained weights, please run "python eval_net.py --exp_dir {ROOT}/{exp_name} -u u -v v" For the demo scenes, the u/v coordinates are restricted to integers between 0 and 16. 21 | 22 | The purpose of these scripts is to provide a simple implementation that helps you kickstart your own experiments. 23 | If you encounter any error or have any suggestion, please don't hesitate to reach out to yfeng97@umd.edu. Thank you! 24 | -------------------------------------------------------------------------------- /training_scripts/eval_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from network import SIGNET_static 11 | 12 | def get_LF_val(u, v, width=1024, height=1024): 13 | x = np.linspace(0, width-1, width) 14 | y = np.linspace(0, height-1, height) 15 | 16 | xv, yv = np.meshgrid(y, x) 17 | img_grid = torch.from_numpy(np.stack([yv, xv], axis=-1)) 18 | 19 | uv_grid = torch.ones_like(img_grid) 20 | uv_grid[:, :, 0], uv_grid[:, :, 1] = u, v 21 | 22 | val_inp_t = torch.cat([uv_grid, img_grid], dim = -1).float() 23 | 24 | val_inp_t[..., :2] /= 17 25 | val_inp_t[..., 2] /= width 26 | val_inp_t[..., 3] /= height 27 | 28 | del img_grid, xv, yv 29 | return val_inp_t.view(-1, val_inp_t.shape[-1]) 30 | 31 | def eval_im(val_inp_t, batches, device): 32 | b_size = val_inp_t.shape[0] // batches 33 | with torch.no_grad(): 34 | out = [] 35 | for b in range(batches): 36 | out.append(model(val_inp_t[b_size*b:b_size*(b+1)].to(device))) 37 | out = torch.cat(out, dim = 0) 38 | out = torch.clamp(out, 0, 1) 39 | out_np = out.view(1024, 1024, 3).cpu().numpy() * 255 40 | return out_np 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("-u", type=int, default=0, help="angular dimension u") 45 | parser.add_argument("-v", type=int, default=0, help="angular dimension v") 46 | parser.add_argument("-b", type=int, default=4, help="batch size in inference") 47 | parser.add_argument("--exp_dir", type=str, help="directory to trained weights") 48 | args = parser.parse_args() 49 | 50 | OUT_DIR = f'./{args.exp_dir}/eval_output' 51 | if not os.path.exists(OUT_DIR): 52 | os.makedirs(OUT_DIR) 53 | 54 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 55 | 56 | model = SIGNET_static(hidden_layers=8, alpha=0.5, skips=[], hidden_features=512, with_norm=True, with_res=True) 57 | m_state_dict = torch.load(f'{args.exp_dir}/model.pth') 58 | model.load_state_dict(m_state_dict, strict=False) 59 | model.eval() 60 | model = model.to(device) 61 | val_inp_t = get_LF_val(u=args.u, v=args.v).to(device) 62 | out_np = eval_im(val_inp_t, args.b, device) 63 | Image.fromarray(np.uint8(out_np)).save(f'{OUT_DIR}/u{args.u}_v{args.v}.png') 64 | -------------------------------------------------------------------------------- /training_scripts/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | class GegEmbedding(nn.Module): 7 | def __init__(self, N_freqs, alpha=0.5): 8 | super(GegEmbedding, self).__init__() 9 | self.N_freqs = N_freqs 10 | self.alpha = alpha 11 | 12 | def forward(self, x): 13 | n = len(x) 14 | x_in = x.squeeze() 15 | c = torch.zeros(( n, self.N_freqs + 1 ), device=x.device) 16 | c[..., 0] = 1.0 17 | c[..., 1] = 2.0 * self.alpha * x_in 18 | for i in range(2, self.N_freqs + 1 ): 19 | c[..., i] = ( ( 2 * i - 2 + 2.0 * self.alpha ) * x_in * c[..., i-1] + (- i + 2 - 2.0 * self.alpha ) * c[..., i-2] ) / i 20 | 21 | return c[..., 1:].contiguous().view(n, -1) 22 | 23 | class SineLayer(nn.Module): 24 | def __init__(self, in_features, out_features, bias=False, is_first=False, is_res=False, omega_0=30): 25 | super().__init__() 26 | self.omega_0 = omega_0 27 | self.is_first = is_first 28 | self.is_res = is_res 29 | 30 | self.in_features = in_features 31 | self.linear = nn.Linear(in_features, out_features, bias=bias) 32 | self.init_weights(self.linear) 33 | 34 | def init_weights(self, layer): 35 | with torch.no_grad(): 36 | if self.is_first: 37 | layer.weight.uniform_(-1 / self.in_features, 1 / self.in_features) 38 | else: 39 | layer.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, np.sqrt(6 / self.in_features) / self.omega_0) 40 | 41 | def forward(self, input): 42 | if self.is_res: 43 | return input + torch.sin(self.omega_0 * self.linear(input)) 44 | else: 45 | return torch.sin(self.omega_0 * self.linear(input)) 46 | 47 | class SIGNET_static(nn.Module): 48 | def __init__(self, first_omega_0=30, hidden_omega_0=30., hidden_layers=8, in_feature_ratio=1, out_features=3, skips=[3], hidden_features=512, alpha=0.5, with_res=False, with_sigmoid=False, with_norm=False): 49 | super().__init__() 50 | 51 | in_features = int(in_feature_ratio * 512) 52 | 53 | self.with_res = with_res 54 | self.with_sigmoid = with_sigmoid 55 | self.D = hidden_layers + 2 56 | self.skips = skips 57 | 58 | for i in range(hidden_layers+1): 59 | if i == 0: 60 | layer = SineLayer(in_features, hidden_features, bias=True, is_first=True, is_res=False, omega_0=first_omega_0) 61 | elif i in skips: 62 | layer = SineLayer(hidden_features + in_features, hidden_features, bias=True, is_first=False, is_res=False, omega_0=hidden_omega_0) 63 | else: 64 | layer = SineLayer(hidden_features, hidden_features, is_first=False, bias=True, is_res=self.with_res, omega_0=hidden_omega_0) 65 | if with_norm: 66 | layer = nn.Sequential(layer, nn.LayerNorm(hidden_features, elementwise_affine=True)) 67 | setattr(self, f"encoding_{i+1}", layer) 68 | 69 | final_linear = nn.Linear(hidden_features, out_features, bias=True) 70 | 71 | with torch.no_grad(): 72 | final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, np.sqrt(6 / hidden_features) / hidden_omega_0) 73 | setattr(self, f"encoding_{hidden_layers+2}", final_linear) 74 | 75 | self.N_xy = int(in_feature_ratio * 240) 76 | self.N_uv = int(in_feature_ratio * 16) 77 | 78 | self.xy_embedd = GegEmbedding(self.N_xy, alpha) 79 | self.uv_embedd = GegEmbedding(self.N_uv, alpha) 80 | 81 | def forward(self, x): 82 | # x: [B, 4] 83 | emb_x = torch.cat( [self.uv_embedd(x[:, 0]), self.uv_embedd(x[:, 1]), self.xy_embedd(x[:, 2]), self.xy_embedd(x[:, 3]) ], axis=1) 84 | out = emb_x 85 | for i in range(self.D): 86 | if i in self.skips: 87 | out = torch.cat([emb_x, out], -1) 88 | out = getattr(self, f"encoding_{i+1}")(out) 89 | return out 90 | -------------------------------------------------------------------------------- /training_scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from PIL import Image 5 | import scipy 6 | 7 | import pickle 8 | from tqdm import tqdm 9 | from models import Embedding 10 | import torch 11 | 12 | import argparse 13 | 14 | def save_dict(di_, filename_): 15 | with open(filename_, 'wb') as f: 16 | pickle.dump(di_, f) 17 | 18 | def load_dict(filename_): 19 | with open(filename_, 'rb') as f: 20 | ret_di = pickle.load(f) 21 | return ret_di 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--img_dir", type=str, help="Dataset directory") 26 | parser.add_argument("--save_dir", type=str, help="Directory to store the preprocessed training patches") 27 | args = parser.parse_args() 28 | 29 | img_dir = args.img_dir 30 | save_dir = args.save_dir 31 | 32 | if not os.path.isdir(save_dir): 33 | os.makedirs(save_dir) 34 | os.makedirs(save_dir+'/train') 35 | 36 | if img_dir == './data/bracelet/': 37 | lf_arr = np.zeros((17, 17, 640, 1024, 3)) 38 | width, height, u_dim, v_dim = 1024, 640, 17, 17 39 | val_img_path = './data/bracelet/out_01_01_-1512.372681_883.201477_.png' 40 | else: 41 | lf_arr = np.zeros((17, 17, 1024, 1024, 3)) 42 | width, height, u_dim, v_dim = 1024, 1024, 17, 17 43 | if img_dir.endswith('tarot/'): 44 | val_img_path = './data/tarot/out_01_01_-863.245667_1019.015076.png' 45 | else: 46 | val_img_path = './data/lego/out_01_01_-390.668457_1095.314209.png' 47 | 48 | batch_size = width * 2 49 | i, j = 0, 0 50 | for file in sorted(os.listdir(img_dir)): 51 | if j == u_dim: 52 | i += 1 53 | j = 0 54 | if file.endswith(".png"): 55 | img_ = np.asarray(Image.open(img_dir + file).convert('RGB')) / 255. 56 | lf_arr[i, j] = img_ 57 | j += 1 58 | print('Light field loaded') 59 | 60 | img_data = lf_arr.reshape(-1, 3) 61 | del lf_arr 62 | 63 | x = np.linspace(0, width-1, width) 64 | y = np.linspace(0, height-1, height) 65 | u = np.linspace(0, u_dim-1, u_dim) 66 | v = np.linspace(0, v_dim-1, v_dim) 67 | i = [u, v, x, y] 68 | uv, vv, xv, yv = np.meshgrid(*i) 69 | img_grid = np.stack([uv, vv, xv, yv], axis=-1) 70 | 71 | val_x = img_grid[1, 1] 72 | val_img_grid = val_x.reshape(-1, 4) 73 | 74 | img_ = np.asarray(Image.open(val_img_path).convert('RGB')) / 255. 75 | lf_arr = np.zeros((height, width, 3)) 76 | lf_arr[:, :, :] = img_ 77 | 78 | val_img_data = lf_arr.reshape(-1, 3).astype(np.float32) 79 | img_grid = img_grid.reshape(-1, 4) 80 | print('Linspace grid created') 81 | 82 | del uv, vv, xv, yv, x, y, u, v 83 | p_num = width * height * u_dim * v_dim 84 | batch_num = p_num // batch_size 85 | idx_list = np.split(np.random.permutation(np.arange(0, p_num)), batch_num) 86 | 87 | print(f'Writing patches to {save_dir}') 88 | val_data = {'x':val_img_grid,'y':val_img_data} 89 | save_dict(val_data, f'{save_dir}/patch_val.pkl') 90 | for p in tqdm(range(batch_num)): 91 | slt_idx = idx_list[p] 92 | x_p = img_grid[slt_idx] 93 | y_p = img_data[slt_idx] 94 | save_data = {'x':x_p,'y':y_p} 95 | save_dict(save_data, os.path.join(save_dir, 'train', f'patch_{p:05d}.pkl')) 96 | -------------------------------------------------------------------------------- /training_scripts/train_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tqdm 4 | from PIL import Image 5 | import argparse 6 | import pickle 7 | 8 | import torch 9 | torch.manual_seed(0) 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = False 12 | 13 | from torch.utils.data import DataLoader 14 | 15 | from network import SIGNET_static 16 | 17 | class LFPatchDataset(torch.utils.data.Dataset): 18 | def __init__(self, is_train=True, file_dir = './patch_data/Lego/'): 19 | if is_train: 20 | self.file_dir = f'{file_dir}/train' 21 | self.file_list = [] 22 | for f in sorted(os.listdir(self.file_dir)): 23 | self.file_list.append(f'{file_dir}/train/{f}') 24 | self.batch_num = len(self.file_list) 25 | else: 26 | self.batch_num = 1 27 | self.file_list = [f'{file_dir}/patch_val.pkl']*1 28 | def __len__(self): 29 | return self.batch_num 30 | 31 | def __getitem__(self, idx): 32 | filename_ = self.file_list[idx] 33 | with open(filename_, 'rb') as f: 34 | ret_di = pickle.load(f) 35 | 36 | lab_t = torch.from_numpy(ret_di['y']).float() 37 | inp_G_t = torch.from_numpy(ret_di['x']).float() 38 | 39 | return inp_G_t, lab_t 40 | 41 | def compute_psnr(img1, img2): 42 | img1 = img1.astype(np.float64) 43 | img2 = img2.astype(np.float64) 44 | mse = np.mean((img1 - img2) ** 2) 45 | return 10 * np.log10(255**2 / mse) 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--root_dir", type=str, help="Root directory") 50 | parser.add_argument("--exp_name", type=str, default="test", help="Experiment name") 51 | parser.add_argument("--trainset_dir", type=str, default="lego") 52 | parser.add_argument("--num_epochs", type=int, default=30) 53 | parser.add_argument("--img_W", type=int, default=1024) 54 | parser.add_argument("--img_H", type=int, default=1024) 55 | 56 | args = parser.parse_args() 57 | 58 | device = ("cuda:0" if torch.cuda.is_available() else "cpu" ) 59 | 60 | root_dir = args.root_dir 61 | exp_dir = f'{root_dir}/{args.exp_name}' 62 | print(f'Current experiment directory is: {exp_dir}') 63 | trainset_dir = f'{root_dir}/{args.trainset_dir}' 64 | 65 | num_epochs = args.num_epochs 66 | 67 | if not os.path.isdir(exp_dir): 68 | os.makedirs(exp_dir) 69 | os.makedirs(f'{exp_dir}/valout') 70 | 71 | val_im_shape = [1024, 1024] 72 | 73 | model = SIGNET_static(hidden_layers=8, alpha=0.5, skips=[], hidden_features=512, with_norm=True, with_res=True) 74 | model = model.to(device) 75 | 76 | trainset = LFPatchDataset(is_train=True, file_dir = trainset_dir) 77 | valset = LFPatchDataset(is_train=False, file_dir = trainset_dir) 78 | val_inp_t, _ = valset[0] 79 | 80 | bsize = 1 81 | train_loader = DataLoader(trainset, batch_size=bsize, drop_last=False, num_workers=8, pin_memory=True) 82 | iters = len(train_loader) 83 | 84 | # Frequency to save validation image 85 | val_freq = 200#iters * 2 86 | # Frequency to save the checkpoint 87 | save_freq = 5 88 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5) 89 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-7) 90 | 91 | print('Starts training') 92 | mse_losses, psnrs = [], [] 93 | 94 | for epoch in range(num_epochs): 95 | e_psnr, e_loss, it = 0, 0, 0 96 | t = tqdm.tqdm(train_loader) 97 | 98 | for batch_idx, (inp_G_t, lab_t) in enumerate(t): 99 | optimizer.zero_grad() 100 | inp_G_t, lab_t = inp_G_t.view(-1, inp_G_t.shape[-1]).to(device), lab_t.view(-1, 3).to(device) 101 | 102 | # scale the input coordinates from integers to floats 103 | inp_G_t[..., :2] /= 17 104 | inp_G_t[..., 2] /= args.img_W 105 | inp_G_t[..., 3] /= args.img_H 106 | 107 | out = model(inp_G_t) 108 | mse_loss = torch.nn.functional.mse_loss(out, lab_t) 109 | loss = mse_loss 110 | loss.backward() 111 | optimizer.step() 112 | 113 | psnr = 10 * np.log10(1 / mse_loss.item()) 114 | e_psnr += psnr 115 | e_loss += mse_loss.item() 116 | 117 | if it % val_freq == 0: 118 | val_inp_t = val_inp_t.view(-1, val_inp_t.shape[-1]) 119 | val_inp_t[..., :2] /= 17 120 | val_inp_t[..., 2] /= args.img_W 121 | val_inp_t[..., 3] /= args.img_H 122 | b_size = val_inp_t.shape[0] // 16 123 | model.eval() 124 | with torch.no_grad(): 125 | out = [] 126 | for b in range(16): 127 | out.append(model(val_inp_t[b_size*b:b_size*(b+1)].to(device))) 128 | out = torch.cat(out, dim = 0) 129 | out = torch.clamp(out, 0, 1) 130 | out_np = out.view(val_im_shape[0], val_im_shape[1], 3).cpu().numpy() * 255 131 | out_im = Image.fromarray(np.uint8(out_np)) 132 | out_name = f'valout/valout_e_{epoch}_it_{it}.png' 133 | out_im.save(f'{exp_dir}/{out_name}') 134 | model.train() 135 | it += 1 136 | t.set_postfix(PSNR = psnr, EpochPSNR = e_psnr / it, EpochLoss = e_loss / it) 137 | 138 | scheduler.step() 139 | 140 | print('Epoch: %s Ave PSNR: %s Ave Loss: %s'%(epoch, e_psnr / it, e_loss / it)) 141 | psnrs.append(e_psnr / it); mse_losses.append(e_loss / it) 142 | 143 | if epoch % save_freq == 0: 144 | torch.save(model.state_dict(), f'{exp_dir}/model.pth') 145 | 146 | torch.save(model.state_dict(), f'{exp_dir}/model.pth') 147 | 148 | np.savetxt(f'{exp_dir}/mse_stats.txt', mse_losses, delimiter=',') 149 | np.savetxt(f'{exp_dir}/psnr_stats.txt', psnrs, delimiter=',') 150 | --------------------------------------------------------------------------------