├── checkpoints ├── Arena │ └── finalModel.pkl ├── Lewis │ └── finalModel.pkl ├── Subway │ └── finalModel.pkl └── SunTemple │ └── finalModel.pkl ├── LICENSE ├── utils ├── gen_video.py ├── metrics.py └── matlab_metric.py ├── benchmark.py ├── .gitignore ├── README.md ├── eval.py ├── train.py ├── model.py └── dataloaders.py /checkpoints/Arena/finalModel.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanhe312/STSSNet-AAAI2024/HEAD/checkpoints/Arena/finalModel.pkl -------------------------------------------------------------------------------- /checkpoints/Lewis/finalModel.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanhe312/STSSNet-AAAI2024/HEAD/checkpoints/Lewis/finalModel.pkl -------------------------------------------------------------------------------- /checkpoints/Subway/finalModel.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanhe312/STSSNet-AAAI2024/HEAD/checkpoints/Subway/finalModel.pkl -------------------------------------------------------------------------------- /checkpoints/SunTemple/finalModel.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanhe312/STSSNet-AAAI2024/HEAD/checkpoints/SunTemple/finalModel.pkl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ruian He 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/gen_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "true" 3 | 4 | import cv2 5 | import tqdm 6 | import numpy as np 7 | 8 | from PIL import Image 9 | 10 | def inference(save_dir, res_format, gt_path, min_id, max_id, offset): 11 | cnt = 0 12 | 13 | if gt_path is not None: 14 | gt_writer = None 15 | pred_writer = None 16 | 17 | for idx in tqdm.tqdm(range(min_id, max_id+1)): 18 | img = cv2.imread(os.path.join(save_dir,res_format%(idx+offset))) 19 | h,w,c = img.shape 20 | if h != 1080 or w != 1920: 21 | img = np.asarray(Image.fromarray(img).resize((1920, 1080), Image.BICUBIC)) 22 | if gt_path is not None: 23 | gt = np.load(gt_path%idx).astype(np.float32) 24 | if img.shape != gt.shape: 25 | print(img.shape) 26 | print(gt.shape) 27 | print(f"Unmatched resolution at frame idx {idx}!!!") 28 | break 29 | 30 | gt = (np.clip(gt,0,1) * 255.0).astype(np.uint8) 31 | 32 | if gt_writer is None: 33 | h,w,c = gt.shape 34 | gt_writer = cv2.VideoWriter() 35 | gt_writer.open(os.path.join(save_dir, 'gt.avi'), cv2.VideoWriter_fourcc('p', 'n', 'g', ' '), 60, (w, h), True) 36 | gt_writer.write(gt) 37 | 38 | if pred_writer is None: 39 | h,w,c = img.shape 40 | pred_writer = cv2.VideoWriter() 41 | pred_writer.open(os.path.join(save_dir, 'pred.avi'), cv2.VideoWriter_fourcc('p', 'n', 'g', ' '), 60, (w, h), True) 42 | pred_writer.write(img) 43 | 44 | cnt += 1 45 | 46 | if gt_path is not None: 47 | gt_writer.release() 48 | pred_writer.release() 49 | 50 | if __name__ == '__main__': 51 | inference('output/Lewis', 'res.%04d.png', '/home/user2/dataset/rendering/Lewis/test/HR/compressedHR.%04d.npy', 5, 1000, 0) -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | import torch_tensorrt 8 | 9 | from thop import profile, clever_format 10 | from model import STSSNet 11 | from dataloaders import * 12 | 13 | def benchmark(dataLoaderIns, model, half=False, use_trt=False, scale=1): 14 | model = model.to('cuda:0') 15 | model.eval() 16 | with torch.no_grad(): 17 | for index, (input,features,mask, hisBuffer, label) in dataLoaderIns: 18 | input=input.cuda() 19 | hisBuffer=hisBuffer.cuda() 20 | mask=mask.cuda() 21 | features=features.cuda() 22 | 23 | B, C, H, W = input.shape 24 | 25 | if scale != 1: 26 | input = F.interpolate(input, scale_factor=scale, mode='bilinear') 27 | hisBuffer = F.interpolate(hisBuffer.reshape(B,-1,H,W), scale_factor=scale, mode='bilinear').reshape(3*B,-1,int(scale*H),int(scale*W)) 28 | mask = F.interpolate(mask, scale_factor=scale, mode='bilinear') 29 | features = F.interpolate(features, scale_factor=scale, mode='bilinear') 30 | 31 | print('Input Shape:', list(input.shape[-2:])) 32 | 33 | B, C, H, W = input.shape 34 | input=F.pad(input,(0, 8-W%8,0,8-H%8)) 35 | hisBuffer=F.pad(hisBuffer,(0, 8-W%8,0,8-H%8)) 36 | mask=F.pad(mask,(0, 8-W%8,0,8-H%8)) 37 | features=F.pad(features,(0, 8-W%8,0,8-H%8)) 38 | 39 | if half: 40 | model = model.half() 41 | input = input.half() 42 | hisBuffer = hisBuffer.half() 43 | mask = mask.half() 44 | features = features.half() 45 | 46 | # calculate flops 47 | macs, params = profile(model, inputs=(input, features, mask, hisBuffer)) 48 | macs, params = clever_format([macs, params], "%.3f") 49 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 50 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 51 | 52 | # compile trt model 53 | if use_trt: 54 | print('Compiling trt model...') 55 | traced_model = torch.jit.trace(model,(input, features, mask, hisBuffer)) 56 | inputs = [] 57 | for tensor in [input, features, mask, hisBuffer]: 58 | inputs.append(torch_tensorrt.Input(list(tensor.shape), dtype=torch.half if half else torch.float)) 59 | model = torch_tensorrt.compile(traced_model, 60 | inputs= inputs, 61 | enabled_precisions= {torch.half if half else torch.float} # Run with FP16 62 | ) 63 | else: 64 | print('Using traced model...') 65 | model = torch.jit.trace(model,(input, features, mask, hisBuffer)) 66 | 67 | times = [] 68 | 69 | # warm up 70 | for i in range(10): 71 | res=model(input, features, mask, hisBuffer) 72 | 73 | # benchmark 74 | for i in range(101): 75 | torch.cuda.synchronize() 76 | tt = time.time() 77 | res = model(input, features, mask, hisBuffer) 78 | torch.cuda.synchronize() 79 | times.append(time.time()-tt) 80 | 81 | print("Time: %.3f"%(np.mean(times[1:])*1000),'ms') 82 | break 83 | 84 | if __name__ =="__main__": 85 | dataset = get_Lewis_test_data() 86 | testLoader = data.DataLoader(dataset,1,shuffle=False,num_workers=1, pin_memory=False) 87 | 88 | model = STSSNet(6,3,9,4) 89 | 90 | benchmark(testLoader, model, half=True, use_trt=True) 91 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # accelerated pytorch implementation for training 2 | 3 | import torch 4 | import numpy as np 5 | import pytorch_msssim 6 | 7 | class cvtColor: 8 | def __init__(self) -> None: 9 | scale1 = 1.0/255.0 10 | offset1 = 1.0/255.0 11 | self.rgb2ycbcr_coeffs = [ 12 | 65.481 * scale1, 128.553 * scale1, 24.966 * scale1, 16.0 * offset1, 13 | -37.797 * scale1, -74.203 * scale1, 112.0 * scale1, 128.0 * offset1, 14 | 112.0 * scale1, -93.786 * scale1, -18.214 * scale1, 128.0 * offset1] 15 | scale2 = 255.0 16 | offset2 = 1.0 17 | self.ycbcr2rgb_coeffs = [ 18 | 0.0045662 * scale2, 0, 0.0062589 * scale2, -0.8742024 * offset2, 19 | 0.0045662 * scale2, -0.0015363 * scale2, -0.0031881 * scale2, 0.5316682 * offset2, 20 | 0.0045662 * scale2, 0.0079107 * scale2, 0 , -1.0856326 * offset2 21 | ] 22 | def rgb2ycbcr(self, tensor): 23 | """ 24 | tensor = B x C x H x W 25 | """ 26 | R = tensor[:,0:1] 27 | G = tensor[:,1:2] 28 | B = tensor[:,2:3] 29 | 30 | Y = self.rgb2ycbcr_coeffs[0] * R + self.rgb2ycbcr_coeffs[1] * G + self.rgb2ycbcr_coeffs[2] * B + self.rgb2ycbcr_coeffs[3] 31 | Cb = self.rgb2ycbcr_coeffs[4] * R + self.rgb2ycbcr_coeffs[5] * G + self.rgb2ycbcr_coeffs[6] * B + self.rgb2ycbcr_coeffs[7] 32 | Cr = self.rgb2ycbcr_coeffs[8] * R + self.rgb2ycbcr_coeffs[9] * G + self.rgb2ycbcr_coeffs[10] * B + self.rgb2ycbcr_coeffs[11] 33 | 34 | return torch.cat([Y,Cb,Cr],dim=1) 35 | 36 | def ycrcb2rgb(self, tensor): 37 | """ 38 | tensor = B x C x H x W 39 | """ 40 | 41 | Y = tensor[:,0:1] 42 | Cb = tensor[:,1:2] 43 | Cr = tensor[:,2:3] 44 | 45 | R = self.ycbcr2rgb_coeffs[0] * Y + self.ycbcr2rgb_coeffs[1] * Cb + self.ycbcr2rgb_coeffs[2] * Cr + self.ycbcr2rgb_coeffs[3] 46 | G = self.ycbcr2rgb_coeffs[4] * Y + self.ycbcr2rgb_coeffs[5] * Cb + self.ycbcr2rgb_coeffs[6] * Cr + self.ycbcr2rgb_coeffs[7] 47 | B = self.ycbcr2rgb_coeffs[8] * Y + self.ycbcr2rgb_coeffs[9] * Cb + self.ycbcr2rgb_coeffs[10] * Cr + self.ycbcr2rgb_coeffs[11] 48 | 49 | return torch.cat([R,G,B],dim=1) 50 | 51 | cvtColor = cvtColor() 52 | 53 | def accuracy(output, target): 54 | with torch.no_grad(): 55 | pred = torch.argmax(output, dim=1) 56 | # print(target.shape) 57 | # print(pred.shape) 58 | # print(output.shape) 59 | assert pred.shape[0] == len(target) 60 | correct = 0 61 | correct += torch.sum(pred == target).item() 62 | # print("c:", correct) 63 | # print('len(target):', len(target)) 64 | return correct / len(target) 65 | 66 | 67 | def top_k_acc(output, target, k=3): 68 | with torch.no_grad(): 69 | pred = torch.topk(output, k, dim=1)[1] 70 | assert pred.shape[0] == len(target) 71 | correct = 0 72 | for i in range(k): 73 | correct += torch.sum(pred[:, i] == target).item() 74 | return correct / len(target) 75 | 76 | def mse(output, target): 77 | with torch.no_grad(): 78 | mse = (output - target).square().mean() 79 | return mse 80 | 81 | def psnr(output, target, only_y=False): 82 | output = torch.clamp(output, 0.0, 1.0) 83 | target = torch.clamp(target, 0.0, 1.0) 84 | if only_y: 85 | output = cvtColor.rgb2ycbcr(output) 86 | target = cvtColor.rgb2ycbcr(target) 87 | output = output[:,0:1] 88 | target = target[:,0:1] 89 | with torch.no_grad(): 90 | mse = (output * 255.0 - target * 255.0).square().mean() 91 | psnr = 20.0 * torch.log10(255.0/mse.sqrt()) 92 | return psnr 93 | 94 | def ssim(output, target, only_y=False): 95 | output = torch.clamp(output, 0.0, 1.0) 96 | target = torch.clamp(target, 0.0, 1.0) 97 | if only_y: 98 | output = cvtColor.rgb2ycbcr(output) 99 | target = cvtColor.rgb2ycbcr(target) 100 | output = output[:,0:1] 101 | target = target[:,0:1] 102 | # print(output.dtype,target.dtype) 103 | ssim = pytorch_msssim.ssim(output, target, data_range=1) 104 | return ssim -------------------------------------------------------------------------------- /utils/matlab_metric.py: -------------------------------------------------------------------------------- 1 | ''' 2 | calculate the PSNR and SSIM. 3 | same as MATLAB's results 4 | ''' 5 | 6 | import os 7 | import math 8 | import numpy as np 9 | import cv2 10 | import glob 11 | import os 12 | 13 | def rgb2ycbcr(img, only_y=True): 14 | '''same as matlab rgb2ycbcr 15 | only_y: only return Y channel 16 | Input: 17 | uint8, [0, 255] 18 | float, [0, 1] 19 | ''' 20 | in_img_type = img.dtype 21 | img.astype(np.float32) 22 | if in_img_type != np.uint8: 23 | img *= 255. 24 | # convert 25 | if only_y: 26 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 27 | else: 28 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 29 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 30 | if in_img_type == np.uint8: 31 | rlt = rlt.round() 32 | else: 33 | rlt /= 255. 34 | return rlt.astype(in_img_type) 35 | #########################calc_metrics############################# 36 | def calc_metrics(img1, img2, crop_border, test_Y=True, norm=False, mask=None): 37 | if norm: 38 | img1 = (np.clip(img1,0,1) * 255.0).astype(np.uint8) 39 | img2 = (np.clip(img2,0,1) * 255.0).astype(np.uint8) 40 | 41 | img1 = img1 / 255. 42 | img2 = img2 / 255. 43 | 44 | if test_Y and img1.shape[2] == 3: # evaluate on Y channel in YCbCr color space 45 | im1_in = rgb2ycbcr(img1) 46 | im2_in = rgb2ycbcr(img2) 47 | else: 48 | im1_in = img1 49 | im2_in = img2 50 | 51 | if crop_border != 0: 52 | if im1_in.ndim == 3: 53 | cropped_im1 = im1_in[crop_border:-crop_border, crop_border:-crop_border, :] 54 | cropped_im2 = im2_in[crop_border:-crop_border, crop_border:-crop_border, :] 55 | elif im1_in.ndim == 2: 56 | cropped_im1 = im1_in[crop_border:-crop_border, crop_border:-crop_border] 57 | cropped_im2 = im2_in[crop_border:-crop_border, crop_border:-crop_border] 58 | else: 59 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im1_in.ndim)) 60 | else: 61 | cropped_im1 = im1_in 62 | cropped_im2 = im2_in 63 | 64 | psnr = calc_psnr(cropped_im1 * 255, cropped_im2 * 255, mask=mask) 65 | ssim = calc_ssim(cropped_im1 * 255, cropped_im2 * 255, mask=mask) 66 | return psnr, ssim 67 | 68 | def calc_metrics_y(img1, img2, crop_border, test_Y=True): 69 | img1 = img1 / 255. 70 | img2 = img2 / 255. 71 | im1_in = img1 72 | im2_in = img2 73 | if im1_in.ndim == 3: 74 | cropped_im1 = im1_in[crop_border:-crop_border, crop_border:-crop_border, :] 75 | cropped_im2 = im2_in[crop_border:-crop_border, crop_border:-crop_border, :] 76 | elif im1_in.ndim == 2: 77 | cropped_im1 = im1_in[crop_border:-crop_border, crop_border:-crop_border] 78 | cropped_im2 = im2_in[crop_border:-crop_border, crop_border:-crop_border] 79 | else: 80 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im1_in.ndim)) 81 | 82 | psnr = calc_psnr(cropped_im1 * 255, cropped_im2 * 255) 83 | ssim = calc_ssim(cropped_im1 * 255, cropped_im2 * 255) 84 | return psnr, ssim 85 | 86 | def calc_psnr(img1, img2, mask=None): 87 | # img1 and img2 have range [0, 255] 88 | img1 = img1.astype(np.float64) 89 | img2 = img2.astype(np.float64) 90 | if mask is not None: 91 | mse = np.sum((img1 - img2)**2 * mask) / (np.sum(mask) + 1e-5) 92 | else: 93 | mse = np.mean((img1 - img2)**2) 94 | if mse == 0: 95 | return float('inf') 96 | return 20 * math.log10(255.0 / math.sqrt(mse)) 97 | 98 | def ssim(img1, img2, mask=None): 99 | C1 = (0.01 * 255)**2 100 | C2 = (0.03 * 255)**2 101 | 102 | img1 = img1.astype(np.float64) 103 | img2 = img2.astype(np.float64) 104 | kernel = cv2.getGaussianKernel(11, 1.5) 105 | window = np.outer(kernel, kernel.transpose()) 106 | 107 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 108 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 109 | mu1_sq = mu1**2 110 | mu2_sq = mu2**2 111 | mu1_mu2 = mu1 * mu2 112 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 113 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 114 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 115 | 116 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 117 | (sigma1_sq + sigma2_sq + C2)) 118 | if mask is not None: 119 | return np.sum(ssim_map * mask[5:-5, 5:-5]) / (np.sum(mask[5:-5, 5:-5]) + 1e-5) 120 | else: 121 | return ssim_map.mean() 122 | 123 | 124 | def calc_ssim(img1, img2, mask=None): 125 | '''calculate SSIM 126 | the same outputs as MATLAB's 127 | img1, img2: [0, 255] 128 | ''' 129 | if not img1.shape == img2.shape: 130 | raise ValueError('Input images must have the same dimensions.') 131 | if img1.ndim == 2: 132 | return ssim(img1, img2) 133 | elif img1.ndim == 3: 134 | if img1.shape[2] == 3: 135 | ssims = [] 136 | for i in range(3): 137 | ssims.append(ssim(img1, img2, mask=mask)) 138 | return np.array(ssims).mean() 139 | elif img1.shape[2] == 1: 140 | return ssim(np.squeeze(img1), np.squeeze(img2), mask=mask) 141 | else: 142 | raise ValueError('Wrong input image dimensions.') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STSSNet-AAAI2024 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2312.10890-b31b1b.svg)](https://arxiv.org/abs/2312.10890) 4 | 5 | Official Implementation of "Low-latency Space-time Supersampling for Real-time Rendering" (AAAI 2024). 6 | 7 | 8 | [![](https://markdown-videos-api.jorgenkh.no/youtube/8aPu2ECwVLk)](https://youtu.be/8aPu2ECwVLk) 9 | 10 | ## Environment 11 | 12 | We use Torch-TensorRT 1.1.0, PyTorch 1.11, CUDA 11.4, cuDNN 8.2 and TensorRT 8.2.5.1. 13 | 14 | Please download the corresponding version of [CUDA](https://developer.nvidia.com/cuda-11-4-1-download-archive), [cuDNN](https://developer.nvidia.com/rdp/cudnn-archive), and [TensorRT](https://developer.nvidia.com/nvidia-tensorrt-download). Then set the environment variables as follows: 15 | 16 | ```bash 17 | export TRT_RELEASE=~/project/TensorRT-8.2.5.1 18 | export PATH="/usr/local/cuda-11.4/bin:$PATH" 19 | export CUDA_HOME="/usr/local/cuda-11.4" 20 | export LD_LIBRARY_PATH="$TRT_RELEASE/lib:/usr/local/cuda-11.4/lib64:$LD_LIBRARY_PATH" 21 | ``` 22 | 23 | Next, create a conda environment and install PyTorch and Torch-TensorRT: 24 | 25 | ```bash 26 | conda create -n tensorrt python=3.7 pytorch=1.11 torchvision torchaudio cudatoolkit=11.3 -c pytorch -y 27 | conda activate tensorrt 28 | pip3 install $TRT_RELEASE/python/tensorrt-8.2.5.1-cp37-none-linux_x86_64.whl 29 | pip3 install torch-tensorrt==1.1.0 -f https://github.com/pytorch/TensorRT/releases/download/v1.1.0/torch_tensorrt-1.1.0-cp37-cp37m-linux_x86_64.whl 30 | pip3 install opencv-python tqdm thop matplotlib scikit-image lpips visdom numpy pytorch_msssim 31 | ``` 32 | 33 | ## Dataset 34 | 35 | We release the dataset used in our paper at [ModelScope](https://www.modelscope.cn/datasets/ryanhe312/STSSNet-AAAI2024). The dataset contains four scenes, Lewis, SunTemple, Subway, and Arena, each with around 6000 frames for training and 1000 for testing. Every frame is a compressed numpy array with 16-bit float type. 36 | 37 | You can install ModelScope by running: 38 | 39 | ```bash 40 | pip install modelscope 41 | ``` 42 | 43 | Then you can download the dataset by running the following code in Python: 44 | 45 | ```python 46 | from modelscope.msdatasets import MsDataset 47 | ds = MsDataset.load('ryanhe312/STSSNet-AAAI2024', subset_name='Lewis', split='test') 48 | # ds = MsDataset.load('ryanhe312/STSSNet-AAAI2024', subset_name='SunTemple', split='test') 49 | # ds = MsDataset.load('ryanhe312/STSSNet-AAAI2024', subset_name='Subway', split='test') 50 | # ds = MsDataset.load('ryanhe312/STSSNet-AAAI2024', subset_name='Arena', split='test') 51 | ``` 52 | 53 | Note that the dataset is around 40GB per test scene. It may take a while to download the dataset. 54 | 55 | Please modify the path in `dataloaders.py` to your own path before next step. 56 | 57 | ## Evaluation 58 | 59 | You can modify the `dataset` and `mode` in `eval.py` to evaluate different scenes and modes. 60 | 61 | `all` mode means evaluating all the pixels, `edge` mode means evaluating the pixels on the canny edge of the HR frame, and `hole` mode means evaluating the pixels in warping holes in the LR frame. 62 | 63 | Run the following command to evaluate the model for PSNR, SSIM and LPIPS: 64 | 65 | ```bash 66 | python eval.py 67 | ``` 68 | 69 | To evaluate the VMAF, you need to: 70 | 71 | 1. Set `save_img` to `True` in `eval.py` and run it. 72 | 2. Run generate `utils/video.py` to generate `gt.avi` and `pred.avi`. 73 | 3. Install ffmpeg, and add its path to environment variable "PATH". 74 | 4. Follow the instructions of [VMAF](https://github.com/Netflix/vmaf) to use ffmpeg to compute VMAF metric between `gt.avi` and `pred.avi`. 75 | 76 | ## Benchmark 77 | 78 | You can test model size, FLOPs, and the inference speed of our model by running: 79 | 80 | ```bash 81 | python benchmark.py 82 | ``` 83 | 84 | You should get the following results: 85 | 86 | ``` 87 | Computational complexity: 31.502G 88 | Number of parameters: 417.241K 89 | Time: 4.350 ms 90 | ``` 91 | 92 | Inference speed is tested on a single RTX 3090 GPU and may vary on different machines. 93 | 94 | ## Training 95 | 96 | You can download the training dataset by running the following code in Python: 97 | 98 | ```python 99 | from modelscope.msdatasets import MsDataset 100 | ds = MsDataset.load('ryanhe312/STSSNet-AAAI2024', subset_name='Lewis', split='train') 101 | ds = MsDataset.load('ryanhe312/STSSNet-AAAI2024', subset_name='Lewis', split='validation') 102 | ``` 103 | 104 | It will download two sequence train1 and train2. And you can modify the `subset_name` for different scenes (one of 'Lewis', 'SunTemple' and 'Subway'). Each sequence is around 150GB. It may take a while to download the dataset. 105 | 106 | Please modify the path in `dataloaders.py` to your own path, and run `train.py` to train for different scenes. 107 | 108 | Visdom is used for visualization. You can run `python -m visdom.server` to start a visdom server, and then open `http://localhost:8097/` in your browser to see the training process. 109 | 110 | ## Acknowledgement 111 | 112 | We thank the authors of [ExtraNet](https://github.com/fuxihao66/ExtraNet) for their great work and data generation pipeline. 113 | 114 | ## Citation 115 | 116 | If you find our work useful in your research, please consider citing: 117 | 118 | ```bibtex 119 | @misc{he2023lowlatency, 120 | title={Low-latency Space-time Supersampling for Real-time Rendering}, 121 | author={Ruian He and Shili Zhou and Yuqi Sun and Ri Cheng and Weimin Tan and Bo Yan}, 122 | year={2023}, 123 | eprint={2312.10890}, 124 | archivePrefix={arXiv}, 125 | primaryClass={cs.CV} 126 | } 127 | ``` -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import lpips 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torch.utils.data as data 8 | 9 | from tqdm import tqdm 10 | 11 | from utils import matlab_metric, metrics 12 | from dataloaders import * 13 | 14 | def ImgWrite(mPath,prefix,idx,img): 15 | cv2.imwrite(os.path.join(mPath,prefix+"."+str(idx).zfill(4)+".png"),img) 16 | 17 | @torch.no_grad() 18 | def save_res(dataLoaderIns, model, modelPath, save_dir, save_img=True, mode='all'): 19 | if not os.path.exists(save_dir): 20 | os.makedirs(save_dir) 21 | 22 | if modelPath.endswith(".tar"): 23 | model_CKPT = torch.load(modelPath, map_location="cuda:0")["state_dict"] 24 | elif modelPath.endswith(".ckpt"): 25 | model_CKPT = {k[6:]:v for k,v in torch.load(modelPath, map_location="cuda:0")["state_dict"].items() if 'vgg' not in k} 26 | else: 27 | model_CKPT = torch.load(modelPath, map_location="cuda:0") 28 | model.load_state_dict(model_CKPT) 29 | model = model.to("cuda:0") 30 | model.eval() 31 | 32 | all_PSNR_SF = [] 33 | all_ssim_SF = [] 34 | all_lpips_SF = [] 35 | 36 | all_PSNR_IF = [] 37 | all_ssim_IF = [] 38 | all_lpips_IF = [] 39 | loss_fn_alex = lpips.LPIPS(net='alex').cuda() 40 | 41 | 42 | print('saving to ',save_dir) 43 | f = open(os.path.join(save_dir, 'metrics.csv'), 'w') 44 | print('frame,psnr,ssim,lpips', file=f) 45 | for index, (input,features,mask,hisBuffer,label) in tqdm(dataLoaderIns): 46 | index = index[0].item() 47 | input=input.cuda() 48 | hisBuffer=hisBuffer.cuda() 49 | mask=mask.cuda() 50 | features=features.cuda() 51 | label=label.cuda() 52 | 53 | B,C,H,W = input.size() 54 | 55 | input = F.pad(input,(0,0,0,4),'replicate') 56 | mask = F.pad(mask,(0,0,0,4),'replicate') 57 | features = F.pad(features,(0,0,0,4),'replicate') 58 | hisBuffer = F.pad(hisBuffer.reshape(B,-1,H,W),(0,0,0,4),'replicate').reshape(B,3,4,H+4,W) 59 | 60 | res=model(input, features, mask, hisBuffer) 61 | res = res[:,:,:-8] 62 | 63 | ## mask 64 | if mode == 'edge': 65 | gray = cv2.cvtColor((label[0].permute(1,2,0).detach().cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) 66 | mask = cv2.Canny(gray, 100, 200) 67 | elif mode == 'hole': 68 | mask = 1 - mask[:, :, :-4] 69 | mask = F.interpolate(mask, scale_factor=2, mode='bilinear').squeeze().cpu().numpy() 70 | else: 71 | mask = None 72 | 73 | ## calculate metrics 74 | psnr, ssim = matlab_metric.calc_metrics(res[0].permute(1,2,0).detach().cpu().numpy(), label[0].permute(1,2,0).detach().cpu().numpy(), 0, norm=True, mask=mask) 75 | with torch.no_grad(): 76 | lpips_ = loss_fn_alex(res, label).item() 77 | 78 | if index % 2 == 0: 79 | all_PSNR_SF.append(psnr) 80 | all_ssim_SF.append(ssim) 81 | all_lpips_SF.append(lpips_) 82 | else: 83 | all_PSNR_IF.append(psnr) 84 | all_ssim_IF.append(ssim) 85 | all_lpips_IF.append(lpips_) 86 | 87 | print(index, psnr, ssim, lpips_, file=f, sep=',', flush=True) 88 | 89 | ## save res 90 | if save_img: 91 | res=res.squeeze(0).cpu().numpy().transpose([1,2,0]) 92 | res=cv2.cvtColor(res,cv2.COLOR_RGB2BGR) 93 | res = (np.clip(res, 0, 1) * 255).astype(np.uint8) 94 | ImgWrite(save_dir,"res",index,res) 95 | 96 | psnr_sf = np.mean(all_PSNR_SF) 97 | ssim_sf = np.mean(all_ssim_SF) 98 | lpips_sf = np.mean(all_lpips_SF) 99 | 100 | psnr_if = np.mean(all_PSNR_IF) 101 | ssim_if = np.mean(all_ssim_IF) 102 | lpips_if = np.mean(all_lpips_IF) 103 | 104 | print('SF', psnr_sf, ssim_sf, lpips_sf, file=f, sep=',') 105 | print('IF', psnr_if, ssim_if, lpips_if, file=f, sep=',') 106 | print('MEAN', (psnr_sf+psnr_if)/2, (ssim_sf+ssim_if)/2, (lpips_if+lpips_sf)/2, file=f, sep=',') 107 | 108 | f.close() 109 | 110 | def plot_res(save_dir): 111 | import matplotlib.pyplot as plt 112 | with open(os.path.join(save_dir, 'metrics.csv'), 'r') as f: 113 | data = np.loadtxt(f, delimiter=',', skiprows=1, usecols=(1,2,3))[:-3] 114 | all_PSNR_SF, all_ssim_SF, all_lpips_SF = data[::2,0], data[::2,1], data[::2,2] 115 | all_PSNR_IF, all_ssim_IF, all_lpips_IF = data[1::2,0], data[1::2,1], data[1::2,2] 116 | 117 | plt.plot(np.arange(len(all_PSNR_SF)), np.array(all_PSNR_SF)) 118 | plt.plot(np.arange(len(all_PSNR_IF)), np.array(all_PSNR_IF)) 119 | plt.legend(['PSNR-SF', 'PSNR-IF']) 120 | plt.savefig(os.path.join(save_dir, 'PSNR-curve.jpg')) 121 | 122 | if __name__ =="__main__": 123 | # Lewis 124 | dataset = get_Lewis_test_data() 125 | modelPath = 'checkpoints/Lewis/finalModel.pkl' 126 | savePath = 'output/Lewis' 127 | 128 | # Subway 129 | # dataset = get_Subway_test_data() 130 | # modelPath = 'checkpoints/Subway/finalModel.pkl' 131 | # savePath = 'output/Subway' 132 | 133 | # SunTemple 134 | # dataset = get_SunTemple_test_data() 135 | # modelPath = 'checkpoints/SunTemple/finalModel.pkl' 136 | # savePath = 'output/SunTemple' 137 | 138 | # Arena 139 | # dataset = get_Arena_test_data() 140 | # modelPath = 'checkpoints/Arena/finalModel.pkl' 141 | # savePath = 'output/Arena' 142 | 143 | testLoader = data.DataLoader(dataset,1,shuffle=False,num_workers=2, pin_memory=True) 144 | 145 | from model import STSSNet 146 | model = STSSNet(6,3,9,4) 147 | 148 | mode = 'all' # 'all', 'edge', 'hole' 149 | 150 | save_res(testLoader, model, modelPath, savePath, save_img=False, mode=mode) 151 | plot_res(savePath) 152 | 153 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import lpips 5 | import torchvision as tv 6 | import torch.nn.functional as F 7 | import torch.utils.data as data 8 | 9 | from torch import optim 10 | from torch.cuda import amp 11 | from visdom import Visdom 12 | from model import STSSNet 13 | from tqdm.auto import tqdm 14 | 15 | from dataloaders import * 16 | from utils import metrics 17 | 18 | mdevice=torch.device("cuda:0") 19 | learningrate=1e-4 20 | epoch=100 21 | printevery=50 22 | batch_size=2 23 | 24 | class VisdomWriter: 25 | def __init__(self, visdom_port): 26 | self.viz = Visdom(port=visdom_port) 27 | self.names = [] 28 | def add_scalar(self, name, val, step): 29 | try: 30 | val = val.item() 31 | except: 32 | val = float(val) 33 | if name not in self.names: 34 | self.names.append(name) 35 | self.viz.line([val], [step], win=name, opts=dict(title=name)) 36 | else: 37 | self.viz.line([val], [step], win=name, update='append') 38 | def add_image(self, name, image, step): 39 | self.viz.image(image, win=name, opts=dict(title=name)) 40 | def close(self): 41 | return 42 | 43 | def colornorm(img): 44 | img = img.clamp(0,1) 45 | return img 46 | 47 | def train(dataLoaderIns, modelSavePath, save_dir, reload=None, port=2336): 48 | if not os.path.exists(save_dir): 49 | os.makedirs(save_dir) 50 | vgg_model = lpips.LPIPS(net='vgg').cuda() 51 | 52 | model = STSSNet(6,3,9,4) 53 | 54 | model = model.to(mdevice) 55 | scaler = amp.GradScaler() 56 | optimizer = optim.Adam(model.parameters(), lr=learningrate) 57 | 58 | scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=50,gamma=0.9) 59 | 60 | writer = VisdomWriter(port) 61 | global_step = 0 62 | start_e = 0 63 | 64 | if reload is not None: 65 | pth = torch.load(os.path.join(save_dir, f'totalModel.{reload}.pth.tar')) 66 | start_e = pth['epoch'] 67 | model.load_state_dict(pth['state_dict']) 68 | optimizer.load_state_dict(pth['optimizer']) 69 | for e in range(start_e): 70 | if e > 20: 71 | scheduler.step() 72 | 73 | print('start epoch:', start_e) 74 | for e in range(start_e, epoch): 75 | model.train() 76 | 77 | iter=0 78 | loss_all=0 79 | startTime = time.time() 80 | 81 | for input,features,mask,hisBuffer,label in tqdm(dataLoaderIns): 82 | 83 | input=input.cuda() 84 | hisBuffer=hisBuffer.cuda() 85 | mask=mask.cuda() 86 | features=features.cuda() 87 | label=label.cuda() 88 | 89 | input_lst, hisBuffer_lst, mask_lst, features_lst, label_lst = [], [], [], [], [] 90 | for i in range(4): 91 | i, j, h, w = tv.transforms.RandomCrop.get_params(input, output_size=(256, 256)) 92 | input_lst.append(tv.transforms.functional.crop(input, i, j, h, w)) 93 | hisBuffer_lst.append(tv.transforms.functional.crop(hisBuffer, i, j, h, w)) 94 | mask_lst.append(tv.transforms.functional.crop(mask, i, j, h, w)) 95 | features_lst.append(tv.transforms.functional.crop(features, i, j, h, w)) 96 | label_lst.append(tv.transforms.functional.crop(label, i*2, j*2, h*2, w*2)) 97 | input, hisBuffer, mask, features, label = torch.cat(input_lst),torch.cat(hisBuffer_lst), torch.cat(mask_lst), torch.cat(features_lst), torch.cat(label_lst) 98 | 99 | optimizer.zero_grad() 100 | 101 | with amp.autocast(): 102 | res=model(input, features, mask, hisBuffer).float() 103 | loss_full = torch.abs(res-label).mean() 104 | mask_up = (F.interpolate(mask,scale_factor=2,mode='bilinear') > 0).float() 105 | loss_mask = (torch.abs(res-label)*(1-mask_up)).sum()/(1-mask_up).sum().clamp_min(1e-6) 106 | loss_lpips = vgg_model(res*2 -1,label*2-1) 107 | loss = loss_full + loss_mask + 0.01 * loss_lpips 108 | 109 | scaler.scale(loss.mean()).backward() 110 | scaler.step(optimizer) 111 | scaler.update() 112 | 113 | if iter % printevery == 0: 114 | with torch.no_grad(): 115 | writer.add_scalar('loss/loss_total', loss.mean(), global_step) 116 | writer.add_scalar('loss/loss_full', loss_full.mean(), global_step) 117 | writer.add_scalar('loss/loss_mask', loss_mask.mean(), global_step) 118 | writer.add_scalar('loss/loss_lpips', loss_lpips.mean(), global_step) 119 | 120 | writer.add_image('img/input', colornorm(input[-1,:3]).cpu().detach(), global_step) 121 | writer.add_image('img/gt', colornorm(label[-1]).cpu().detach(), global_step) 122 | writer.add_image('img/pred', colornorm(res[-1]).cpu().detach(), global_step) 123 | writer.add_image('img/mask', mask[-1].cpu().detach(), global_step) 124 | writer.add_scalar('metric/psnr', metrics.psnr(res,label), global_step) 125 | writer.add_scalar('metric/ssim', metrics.ssim(res,label), global_step) 126 | 127 | iter+=1 128 | global_step += 1 129 | loss_all+=loss.mean().item() 130 | 131 | endTime = time.time() 132 | print("epoch time is ",endTime - startTime) 133 | print("epoch %d mean loss for train is %f"%(e,loss_all/iter)) 134 | 135 | if e > 20: 136 | scheduler.step() 137 | 138 | if e % 5 == 0: 139 | torch.save({'epoch': e + 1, 'state_dict': model.state_dict(), 140 | 'optimizer': optimizer.state_dict()}, 141 | os.path.join(save_dir, 'totalModel.{}.pth.tar'.format(e))) 142 | 143 | torch.save(model.state_dict(), os.path.join(save_dir,modelSavePath)) 144 | 145 | if __name__ =="__main__": 146 | # Lewis 147 | dataset = get_Lewis_train_data() 148 | 149 | # Subway 150 | # dataset = get_Subway_train_data() 151 | 152 | # SunTemple 153 | # dataset = get_SunTemple_train_data() 154 | 155 | # Arena 156 | # dataset = get_Lewis_train_data() + get_Subway_train_data() + get_SunTemple_train_data() 157 | 158 | trainDiffuseLoader = data.DataLoader(dataset, batch_size, shuffle=True, num_workers=4, pin_memory=False) 159 | train(trainDiffuseLoader, "finalModel.pkl", 'checkpoints/Lewis', port=8097, reload=None) 160 | 161 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import torch.nn as nn 4 | 5 | class DoubleConv(nn.Module): 6 | """(convolution => [BN] => ReLU) * 2""" 7 | 8 | def __init__(self, in_channels, out_channels, mid_channels=None ,kernel_size=3,padding=1): 9 | super().__init__() 10 | if not mid_channels: 11 | mid_channels = out_channels 12 | self.double_conv = nn.Sequential( 13 | nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=padding), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | def forward(self, x): 20 | return self.double_conv(x) 21 | 22 | class ConvUp(nn.Module): 23 | def __init__(self, in_channels, out_channels,kernel_size=3,padding=1, bilinear=True): 24 | super().__init__() 25 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 26 | self.conv = DoubleConv(in_channels, out_channels, mid_channels=in_channels // 2, kernel_size=kernel_size, 27 | padding=padding) 28 | def forward(self, x1): 29 | x1 = self.up(x1) 30 | return self.conv(x1) 31 | 32 | class LWGatedConv2D(nn.Module): 33 | def __init__(self, input_channel1, output_channel, pad, kernel_size, stride): 34 | super(LWGatedConv2D, self).__init__() 35 | 36 | self.conv_feature = nn.Conv2d(in_channels=input_channel1, out_channels=output_channel, kernel_size=kernel_size, 37 | stride=stride, padding=pad) 38 | 39 | self.conv_mask = nn.Sequential( 40 | nn.Conv2d(in_channels=input_channel1, out_channels=1, kernel_size=kernel_size, stride=stride, 41 | padding=pad), 42 | nn.Sigmoid() 43 | ) 44 | 45 | def forward(self, inputs): 46 | newinputs = self.conv_feature(inputs) 47 | mask = self.conv_mask(inputs) 48 | 49 | return newinputs*mask 50 | 51 | class DownLWGated(nn.Module): 52 | """Downscaling with maxpool then double conv""" 53 | 54 | def __init__(self, in_channels, out_channels): 55 | super().__init__() 56 | self.downsample = LWGatedConv2D(in_channels, in_channels, kernel_size=3, pad=1, stride=2) 57 | self.conv1 = LWGatedConv2D(in_channels, out_channels, kernel_size=3, stride=1, pad=1) 58 | self.relu1 = nn.ReLU(inplace=True) 59 | self.conv2 = LWGatedConv2D(out_channels, out_channels, kernel_size=3, stride=1, pad=1) 60 | self.relu2 = nn.ReLU(inplace=True) 61 | 62 | def forward(self, x): 63 | x= self.downsample(x) 64 | x= self.conv1(x) 65 | x = self.relu1(x) 66 | x= self.conv2(x) 67 | x = self.relu2(x) 68 | return x 69 | 70 | class Up(nn.Module): 71 | """Upscaling then double conv""" 72 | 73 | def __init__(self, in_channels, out_channels, bilinear=True): 74 | super().__init__() 75 | 76 | if bilinear: 77 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 78 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 79 | else: 80 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 81 | self.conv = DoubleConv(in_channels, out_channels) 82 | 83 | 84 | def forward(self, x1, x2): 85 | x1 = self.up(x1) 86 | 87 | x = torch.cat([x2, x1], dim=1) 88 | return self.conv(x) 89 | 90 | class STSSNet(nn.Module): 91 | def __init__(self, in_ch, out_ch, feat_ch, his_ch, skip=True): 92 | super(STSSNet, self).__init__() 93 | self.skip = skip 94 | 95 | self.convHis1 = nn.Sequential( 96 | nn.Conv2d(his_ch, 24, kernel_size=3, stride=2, padding=1), 97 | nn.ReLU(inplace=True), 98 | nn.Conv2d(24, 24, kernel_size=3, stride=1, padding=1), 99 | nn.ReLU(inplace=True), 100 | nn.Conv2d(24, 24, kernel_size=3, stride=1, padding=1), 101 | nn.ReLU(inplace=True) 102 | ) 103 | self.convHis2 = nn.Sequential( 104 | nn.Conv2d(24, 32, kernel_size=3, stride=2, padding=1), 105 | nn.ReLU(inplace=True), 106 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 109 | nn.ReLU(inplace=True) 110 | ) 111 | self.convHis3 = nn.Sequential( 112 | nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), 113 | nn.ReLU(inplace=True), 114 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 115 | nn.ReLU(inplace=True), 116 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 117 | nn.ReLU(inplace=True) 118 | ) 119 | 120 | self.latentEncoder = nn.Sequential( 121 | nn.Conv2d(32+feat_ch, 64, kernel_size=3, stride=1, padding=1), 122 | nn.ReLU(inplace=True), 123 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(64, 64, kernel_size = 3, stride = 1, dilation = 1, padding = 1, bias=True) 126 | ) 127 | self.KEncoder = nn.Sequential( 128 | nn.Conv2d(feat_ch, 32, kernel_size=3, stride=1, padding=1), 129 | nn.ReLU(inplace=True), 130 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 131 | nn.ReLU(inplace=True), 132 | nn.Conv2d(32, 32, kernel_size = 3, stride = 1, dilation = 1, padding = 1, bias=True) 133 | ) 134 | 135 | self.lowlevelGated = LWGatedConv2D(32*3, 32, kernel_size=3, stride=1, pad=1) 136 | 137 | self.conv1 = LWGatedConv2D(in_ch+in_ch+feat_ch, 24, kernel_size=3, stride=1, pad=1) 138 | self.relu1 = nn.ReLU(inplace=True) 139 | self.conv2 = LWGatedConv2D(24, 24, kernel_size=3, stride=1, pad=1) 140 | self.relu2 = nn.ReLU(inplace=True) 141 | self.down1 = DownLWGated(24, 24) 142 | self.down2 = DownLWGated(24, 32) 143 | self.down3 = DownLWGated(32, 32) 144 | 145 | self.up1 = Up(96+32, 32) 146 | self.up2 = Up(56, 24) 147 | self.up3 = Up(48, 24) 148 | self.outc = nn.Conv2d(24, out_ch*4, kernel_size=1) 149 | self.outfinal = nn.PixelShuffle(2) 150 | 151 | def hole_inpaint(self, x, mask, feature): 152 | x_down = x 153 | mask_down = F.interpolate(mask,scale_factor=0.125,mode='bilinear') 154 | feature_down = F.interpolate(feature,scale_factor=0.125,mode='bilinear') 155 | 156 | latent_code = self.latentEncoder(torch.cat([x_down,feature_down], dim=1)) * mask_down 157 | K_map = F.normalize(self.KEncoder(feature_down), p=2, dim=1) 158 | 159 | b,c,h,w = list(K_map.size()) 160 | md = 2 161 | f1 = F.unfold(K_map*mask_down, kernel_size=(2*md+1, 2*md+1), padding=(md, md), stride=(1, 1)) 162 | f1 = f1.view([b, c, -1, h, w]) 163 | f2 = K_map.view([b, c, 1, h, w]) 164 | weight_k = torch.relu((f1*f2).sum(dim=1, keepdim=True)) 165 | 166 | b,c,h,w = list(latent_code.size()) 167 | v = F.unfold(latent_code, kernel_size=(2*md+1, 2*md+1), padding=(md, md), stride=(1, 1)) 168 | v = v.view([b, c, -1, h, w]) 169 | 170 | agg_latent = (v * weight_k).sum(dim=2)/(weight_k.sum(dim=2).clamp_min(1e-6)) 171 | return agg_latent 172 | 173 | def forward(self, x, feature, mask, hisBuffer): 174 | 175 | hisBuffer = hisBuffer.reshape(-1, 4, hisBuffer.shape[-2], hisBuffer.shape[-1]) 176 | 177 | hisDown1 = self.convHis1(hisBuffer) 178 | hisDown2 = self.convHis2(hisDown1) 179 | hisDown3 = self.convHis3(hisDown2) 180 | cathisDown3 = hisDown3.reshape(-1, 3*32, hisDown3.shape[-2], hisDown3.shape[-1]) # 64 181 | 182 | motionFeature = self.lowlevelGated(cathisDown3) 183 | 184 | x1 = torch.cat([x, x*mask, feature],dim=1) 185 | x1 = self.conv1(x1) 186 | x1 = self.relu1(x1) 187 | x1 = self.conv2(x1) 188 | x1 = self.relu2(x1) 189 | 190 | x2 = self.down1(x1) 191 | x3 = self.down2(x2) 192 | x4 = self.down3(x3) 193 | 194 | inpaint_feat = self.hole_inpaint(x4, mask, feature) 195 | x4 = torch.cat([inpaint_feat, motionFeature], dim=1) 196 | 197 | res = self.up1(x4, x3) 198 | res= self.up2(res, x2) 199 | res= self.up3(res, x1) 200 | logits = self.outc(res) 201 | logits = self.outfinal(logits) 202 | 203 | if self.skip: 204 | x1, x2 = x.chunk(2,dim=1) 205 | x_up = F.interpolate(x1,scale_factor=2,mode='bilinear') 206 | logits = logits + x_up 207 | 208 | return logits -------------------------------------------------------------------------------- /dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import torch.nn.functional as F 4 | import os 5 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "true" 6 | import cv2 7 | import numpy as np 8 | import time 9 | import random 10 | import pickle 11 | import torchvision as tv 12 | 13 | 14 | def ReadData(path, rd=None, depth_op = 'clip'): 15 | 16 | try: 17 | total = np.load(path) 18 | except: 19 | print(path) 20 | quit() 21 | 22 | if rd is not None: 23 | if rd<0.2: 24 | total = np.flip(total,0) 25 | elif rd<0.3: 26 | total = np.flip(total, 1) 27 | elif rd<0.35: 28 | total = np.flip(total,(0,1)) 29 | 30 | img = total[:,:,0:3] 31 | img3 = total[:,:,3:6] 32 | img5 = total[:,:,6:9] 33 | imgOcc = total[:,:,9:12] 34 | img_no_hole = total[:,:,12:15] 35 | img_no_hole3 = total[:,:,15:18] 36 | img_no_hole5 = total[:,:,18:21] 37 | basecolor = total[:,:,21:24] 38 | metalic = total[:,:,24:25] 39 | roughness = total[:,:,25:26] 40 | depth = total[:,:,26:27] 41 | normal = total[:,:,27:30] 42 | 43 | if total.shape[2] == 35: 44 | motion = total[:,:,31:34] 45 | else: 46 | motion = total[:,:,30:33] 47 | img = cv2.cvtColor(img.astype(np.float32), cv2.COLOR_RGB2BGR) 48 | img3 = cv2.cvtColor(img3.astype(np.float32), cv2.COLOR_RGB2BGR) 49 | img5 = cv2.cvtColor(img5.astype(np.float32), cv2.COLOR_RGB2BGR) 50 | imgOcc = cv2.cvtColor(imgOcc.astype(np.float32), cv2.COLOR_RGB2BGR) 51 | img_no_hole = cv2.cvtColor(img_no_hole.astype(np.float32), cv2.COLOR_RGB2BGR) 52 | img_no_hole3 = cv2.cvtColor(img_no_hole3.astype(np.float32), cv2.COLOR_RGB2BGR) 53 | img_no_hole5 = cv2.cvtColor(img_no_hole5.astype(np.float32), cv2.COLOR_RGB2BGR) 54 | basecolor = cv2.cvtColor(basecolor.astype(np.float32), cv2.COLOR_RGB2BGR) 55 | normal = cv2.cvtColor(normal.astype(np.float32), cv2.COLOR_RGB2BGR) 56 | 57 | if depth_op == 'clip': 58 | depth = np.clip(depth,0,1) 59 | else: 60 | depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6) 61 | 62 | motion = motion.astype(np.float32) 63 | motion = cv2.cvtColor(motion, cv2.COLOR_BGR2RGB)[:,:,:2] 64 | motion[:,:,0] = -motion[:,:,0] 65 | 66 | return img, img3, img5, imgOcc, img_no_hole, img_no_hole3, img_no_hole5, basecolor, metalic, roughness, depth, normal, motion 67 | 68 | 69 | def crop(lst,gt,size=256): 70 | i, j, h, w = tv.transforms.RandomCrop.get_params(lst[0], output_size=(size, size)) 71 | for i in range(len(lst)): 72 | lst[i] = tv.transforms.functional.crop(lst[i], i, j, h, w) 73 | gt = tv.transforms.functional.crop(gt, i*2, j*2, h*2, w*2) 74 | return lst, gt 75 | 76 | def form_input(data_path, is_train, index, get_his=True, get_motion=False, add_mask=True, random_flip=False, depth_op='clip'): 77 | 78 | npz_path, hr_path = data_path 79 | 80 | if is_train and random_flip: 81 | rd = random.random() 82 | else: 83 | rd = None 84 | 85 | gt = np.load(hr_path) 86 | gt = cv2.cvtColor(gt.astype(np.float32), cv2.COLOR_RGB2BGR) 87 | 88 | img, img_2, img_3, occ_warp_img, woCheckimg, woCheckimg_2, woCheckimg_3, basecolor, metalic, Roughnessimg, Depthimg, Normalimg, motion = ReadData(npz_path, rd, depth_op) 89 | 90 | input = img 91 | mask = input.copy() 92 | mask[mask==0.]=1.0 93 | mask[mask==-1]=0.0 94 | mask[mask!=0.0]=1.0 95 | 96 | occ_warp_img[occ_warp_img < 0.0] = 0.0 97 | woCheckimg[woCheckimg < 0.0] = 0.0 98 | 99 | features = np.concatenate([Normalimg, Depthimg, Roughnessimg, metalic, basecolor], axis=2).copy() 100 | 101 | if get_his: 102 | woCheckimg_2[woCheckimg_2 < 0.0] = 0.0 103 | woCheckimg_3[woCheckimg_3 < 0.0] = 0.0 104 | mask2 = img_2.copy() 105 | mask2[mask2==0.]=1.0 106 | mask2[mask2==-1]=0.0 107 | mask2[mask2!=0.0]=1.0 108 | 109 | mask3 = img_3.copy() 110 | mask3[mask3==0.]=1.0 111 | mask3[mask3==-1]=0.0 112 | mask3[mask3!=0.0]=1.0 113 | 114 | his_1 = np.concatenate([woCheckimg, mask[:,:,0].reshape((Normalimg.shape[0],Normalimg.shape[1], 1))], axis=2).transpose([2,0,1]).reshape(1, 4, Normalimg.shape[0],Normalimg.shape[1]) 115 | his_2 = np.concatenate([woCheckimg_2, mask2[:,:,0].reshape((Normalimg.shape[0],Normalimg.shape[1], 1))], axis=2).transpose([2,0,1]).reshape(1, 4, Normalimg.shape[0],Normalimg.shape[1]) 116 | his_3 = np.concatenate([woCheckimg_3, mask3[:,:,0].reshape((Normalimg.shape[0],Normalimg.shape[1], 1))], axis=2).transpose([2,0,1]).reshape(1, 4, Normalimg.shape[0],Normalimg.shape[1]) 117 | hisBuffer = np.concatenate([his_3, his_2, his_1], axis=0) 118 | 119 | input = np.concatenate([occ_warp_img,woCheckimg],axis=2).copy() 120 | 121 | if is_train and add_mask: 122 | h,w,c = input.shape 123 | 124 | h0 = random.randint(1,h//4) 125 | w0 = random.randint(1,min(h*w//16//h0, w//4)) 126 | 127 | hstart = random.randint(0, h-h0) 128 | wstart = random.randint(0, w-w0) 129 | 130 | full_mask = np.zeros((h,w,3),dtype=np.uint8) 131 | cv2.rectangle(full_mask, (wstart,hstart), (wstart+w0-1,hstart+h0-1), (255, 255, 255), -1) 132 | 133 | mask = np.logical_and((full_mask[:,:,:1] == 0), mask[:,:,:1]).astype(np.float32) 134 | else: 135 | mask = mask[:,:,:1].astype(np.float32) 136 | 137 | if rd is not None: 138 | if rd<0.2: 139 | gt = np.flip(gt,0) 140 | elif rd<0.3: 141 | gt = np.flip(gt, 1) 142 | elif rd<0.35: 143 | gt = np.flip(gt,(0,1)) 144 | 145 | out_puts = [] 146 | out_puts.extend([torch.tensor(input.transpose([2,0,1])), torch.tensor(features.transpose([2,0,1])), torch.tensor(mask.transpose([2,0,1]))]) 147 | if get_his: 148 | out_puts.append(torch.tensor(hisBuffer)) 149 | out_puts.append(torch.tensor(gt.copy().transpose([2,0,1]))) 150 | if get_motion: 151 | out_puts.append(torch.tensor(motion.transpose([2,0,1]))) 152 | 153 | return out_puts 154 | 155 | class RenderDataset(data.Dataset): 156 | def __init__(self, npz_formats, HR_format, idx_list, is_train=True, depth_op='clip'): 157 | self.depth_op = depth_op 158 | self.is_train = is_train 159 | self.data_list = [] 160 | 161 | for npz_format in npz_formats: 162 | for idx in idx_list: 163 | if idx == 1427: continue 164 | npz_path = npz_format%idx 165 | hr_path = HR_format%idx 166 | if os.path.exists(npz_path) and os.path.exists(hr_path): 167 | self.data_list.append((npz_path, hr_path)) 168 | 169 | def __getitem__(self, index): 170 | outputs = form_input(self.data_list[index], self.is_train, index, get_his=True, random_flip=False, depth_op=self.depth_op) 171 | return outputs 172 | 173 | def __len__(self): 174 | return len(self.data_list) 175 | 176 | class RenderDataset_eval(data.Dataset): 177 | def __init__(self, SF_format, IF_format, HR_format, idx_list, is_train=False, depth_op='clip'): 178 | self.depth_op = depth_op 179 | self.data_list = [] 180 | self.is_train = is_train 181 | 182 | for idx in idx_list: 183 | if idx % 2 == 0: 184 | npz_path = SF_format%idx 185 | else: 186 | npz_path = IF_format%idx 187 | hr_path = HR_format%idx 188 | 189 | if os.path.exists(npz_path) and os.path.exists(hr_path): 190 | self.data_list.append((idx, npz_path, hr_path)) 191 | else: 192 | print(idx) 193 | print(npz_path) 194 | print(hr_path) 195 | quit() 196 | 197 | def __getitem__(self, index): 198 | outputs = form_input(self.data_list[index][1:], False, index, get_motion=False, add_mask=False, depth_op = self.depth_op) 199 | 200 | if self.is_train: 201 | return outputs 202 | else: 203 | return torch.tensor(self.data_list[index][0]), outputs 204 | 205 | def __len__(self): 206 | return len(self.data_list) 207 | 208 | def get_SunTemple_train_data(): 209 | warp_format = '/home/user2/dataset/rendering/SunTemple/train1/NPY/compressed.%04d.Warp.npy' 210 | nowarp_format = '/home/user2/dataset/rendering/SunTemple/train1/NPY/compressed.%04d.NoWarp.npy' 211 | HR_format = '/home/user2/dataset/rendering/SunTemple/train1/HR/compressedHR.%04d.npy' 212 | idx_list = list(range(5,3002)) 213 | dataset1 = RenderDataset([warp_format, nowarp_format], HR_format, idx_list) 214 | print(len(dataset1)) 215 | warp_format = '/home/user2/dataset/rendering/SunTemple/train2/NPY/compressed.%04d.Warp.npy' 216 | nowarp_format = '/home/user2/dataset/rendering/SunTemple/train2/NPY/compressed.%04d.NoWarp.npy' 217 | HR_format = '/home/user2/dataset/rendering/SunTemple/train2/HR/compressedHR.%04d.npy' 218 | idx_list = list(range(5,2747)) 219 | dataset2 = RenderDataset([warp_format, nowarp_format], HR_format, idx_list) 220 | print(len(dataset2)) 221 | 222 | return dataset1 + dataset2 223 | 224 | def get_Subway_train_data(): 225 | warp_format = '/home/user2/dataset/rendering/Subway/training1/NPY/compressed.%04d.Warp.npy' 226 | nowarp_format = '/home/user2/dataset/rendering/Subway/training1/NPY/compressed.%04d.NoWarp.npy' 227 | HR_format = '/home/user2/dataset/rendering/Subway/training1/HR/compressedHR.%04d.npy' 228 | idx_list = list(range(6,3005)) 229 | dataset1 = RenderDataset([warp_format, nowarp_format], HR_format, idx_list) 230 | print(len(dataset1)) 231 | 232 | warp_format = '/home/user2/dataset/rendering/Subway/training2/NPY/compressed.%04d.Warp.npy' 233 | nowarp_format = '/home/user2/dataset/rendering/Subway/training2/NPY/compressed.%04d.NoWarp.npy' 234 | HR_format = '/home/user2/dataset/rendering/Subway/training2/HR/compressedHR.%04d.npy' 235 | idx_list = list(range(6,3600)) 236 | dataset2 = RenderDataset([warp_format, nowarp_format], HR_format, idx_list) 237 | print(len(dataset2)) 238 | 239 | return dataset1 + dataset2 240 | 241 | def get_Lewis_train_data(): 242 | warp_format = '/home/user2/dataset/rendering/Lewis/training1/NPY/compressed.%04d.Warp.npy' 243 | nowarp_format = '/home/user2/dataset/rendering/Lewis/training1/NPY/compressed.%04d.NoWarp.npy' 244 | HR_format = '/home/user2/dataset/rendering/Lewis/training1/HR/compressedHR.%04d.npy' 245 | idx_list = list(range(9, 2803)) 246 | dataset1 = RenderDataset([warp_format, nowarp_format], HR_format, idx_list, depth_op='scale') 247 | print(len(dataset1)) 248 | 249 | warp_format = '/home/user2/dataset/rendering/Lewis/training2/NPY/compressed.%04d.Warp.npy' 250 | nowarp_format = '/home/user2/dataset/rendering/Lewis/training2/NPY/compressed.%04d.NoWarp.npy' 251 | HR_format = '/home/user2/dataset/rendering/Lewis/training2/HR/compressedHR.%04d.npy' 252 | idx_list = list(range(5,3585)) 253 | dataset2 = RenderDataset([warp_format, nowarp_format], HR_format, idx_list, depth_op='scale') 254 | print(len(dataset2)) 255 | 256 | return dataset1 + dataset2 257 | 258 | def get_Lewis_test_data(is_train=False, depth_op='scale'): 259 | warp_format = '/home/user2/dataset/rendering/Lewis/test/NPY/compressed.%04d.Warp.npy' 260 | nowarp_format = '/home/user2/dataset/rendering/Lewis/test/NPY/compressed.%04d.NoWarp.npy' 261 | HR_format = '/home/user2/dataset/rendering/Lewis/test/HR/compressedHR.%04d.npy' 262 | idx_list = list(range(5,1001)) 263 | return RenderDataset_eval(nowarp_format, warp_format, HR_format, idx_list, is_train=is_train, depth_op=depth_op) 264 | 265 | def get_Subway_test_data(is_train=False): 266 | warp_format = '/home/user2/dataset/rendering/Subway/test/NPY/compressed.%04d.Warp.npy' 267 | nowarp_format = '/home/user2/dataset/rendering/Subway/test/NPY/compressed.%04d.NoWarp.npy' 268 | HR_format = '/home/user2/dataset/rendering/Subway/test/HR/compressedHR.%04d.npy' 269 | idx_list = list(range(6,1001)) 270 | return RenderDataset_eval(nowarp_format, warp_format, HR_format, idx_list, is_train=is_train) 271 | 272 | def get_Arena_test_data(is_train=False): 273 | warp_format = '/home/user2/dataset/rendering/Arena/test/NPY/compressed.%04d.Warp.npy' 274 | nowarp_format = '/home/user2/dataset/rendering/Arena/test/NPY/compressed.%04d.NoWarp.npy' 275 | HR_format = '/home/user2/dataset/rendering/Arena/test/HR/compressedHR.%04d.npy' 276 | idx_list = list(range(500,1501)) 277 | return RenderDataset_eval(nowarp_format, warp_format, HR_format, idx_list, is_train=is_train) 278 | 279 | def get_SunTemple_test_data(is_train=False): 280 | warp_format = '/home/user2/dataset/rendering/SunTemple/test/NPY/compressed.%04d.Warp.npy' 281 | nowarp_format = '/home/user2/dataset/rendering/SunTemple/test/NPY/compressed.%04d.NoWarp.npy' 282 | HR_format = '/home/user2/dataset/rendering/SunTemple/test/HR/compressedHR.%04d.npy' 283 | idx_list = list(range(5,1001)) 284 | return RenderDataset_eval(nowarp_format, warp_format, HR_format, idx_list, is_train=is_train) 285 | 286 | if __name__ == '__main__': 287 | dataset = get_Lewis_train_data() 288 | print(len(dataset)) 289 | input,features,mask,hisBuffer,label = dataset[0] 290 | print(input.shape,features.shape,mask.shape,hisBuffer.shape,label.shape) --------------------------------------------------------------------------------