├── LICENSE.txt ├── README.md ├── augmentations.py ├── cfg.py ├── data └── README.md ├── datasets.py ├── docker └── Dockerfile ├── layer_network.py ├── license.html ├── sample_network.py ├── torch_utils ├── README.md ├── clean.bat ├── setup.py ├── tests │ ├── test_weighted.py │ ├── torch_utils_ref │ │ ├── __init__.py │ │ └── weighted_filter.py │ └── update_pkg.bat └── torch_utils │ ├── __init__.py │ ├── cuda_weighted_filter.cu │ ├── torch_utils.cpp │ └── weighted_filter.py ├── train.py ├── unet.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code License for Neural Denoising with Layer Embeddings 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | “Software” means the original work of authorship made available under this License. 7 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License. 8 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 9 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 10 | 11 | 2. License Grant 12 | 13 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 14 | 15 | 3. Limitations 16 | 17 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 18 | 19 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 20 | 21 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 22 | 23 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 24 | 25 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. 26 | 27 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately. 28 | 29 | 4. Disclaimer of Warranty. 30 | 31 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 32 | 33 | 5. Limitation of Liability. 34 | 35 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 36 | 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Layered Denoiser with Per-Sample Input in PyTorch 2 | ================================================== 3 | 4 | Description 5 | ----------- 6 | Training code for a layered denoiser that works on per-sample data, 7 | as described in the paper: 8 | *Neural Denoising with Layer Embeddings* 9 | https://research.nvidia.com/publication/2020-06_Neural-Denoising-with 10 | 11 | The code base also includes a few variants of per-sample denoisers, which are (simplified) 12 | networks adapted from the paper: 13 | *Sample-based Monte Carlo Denoising using a Kernel-Splatting Network* 14 | https://groups.csail.mit.edu/graphics/rendernet/ 15 | 16 | For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/) 17 | 18 | License 19 | ------- 20 | 21 | Copyright © 2020, NVIDIA Corporation. All rights reserved. 22 | 23 | This work is made available under the [NVIDIA Source Code License](https://nvlabs.github.io/layerdenoise/license.html). 24 | 25 | Citation 26 | -------- 27 | 28 | ``` 29 | @article{Munkberg2020, 30 | author = {Jacob Munkberg and Jon Hasselgren}, 31 | title = {Neural Denoising with Layer Embeddings}, 32 | journal = {Computer Graphics Forum}, 33 | volume = {39}, 34 | number = {4}, 35 | pages = {1--12}, 36 | year = {2020} 37 | } 38 | ``` 39 | 40 | Requirements 41 | ------------ 42 | - Python (tested on Anaconda Python 3.7) 43 | - Pytorch 1.5 or newer 44 | - h5py for loading hdf5 files 45 | - Visual Studio 2019 and CUDA 10.2 are required to build the included PyTorch extension [torch_utils](./torch_utils/) on Windows. 46 | 47 | Configure PyTorch environment 48 | ----------------------------- 49 | 50 | Create a PyTorch Anaconda environment 51 | ``` 52 | conda create -n pytorch python=3.7 53 | activate pytorch 54 | conda install pytorch torchvision cudatoolkit=10.2 -c pytorch 55 | conda install -c anaconda h5py 56 | ``` 57 | 58 | In the same environment install `torch_utils`, which is included in the `torch_utils` folder of this repository. 59 | and run `python setup.py install` to build and install the package on your local computer. 60 | Please check the [README.md](./torch_utils/README.md) for detailed installation instructions. 61 | 62 | Docker 63 | ------ 64 | 65 | Navigate to the folder where you've cloned this repository and build the docker image 66 | `docker build --tag ldenoiser:latest -f docker/Dockerfile .` 67 | 68 | Launch a container 69 | `docker run --gpus device=0 --shm-size 16G --rm -v /raid:/raid -it ldenoiser:latest bash` 70 | 71 | Usage 72 | ----- 73 | 74 | Open a command line with PyTorch support (typically, on Windows: `activate pytorch` to activate an environment in Anaconda). 75 | 76 | Download example training data, `valsetegsr16k.zip` and `indoorC.zip` from https://github.com/NVlabs/layerdenoise/releases/tag/release_1 77 | and place it in the `./data/` folder. Note that this is a small example set, not the full set used in the paper. 78 | 79 | Training: `python train.py --job myjob --config cfg.py --network Layer` 80 | 81 | The checkpoints are automatically stored in the `models` folder. 82 | 83 | Note that we append a random hash to each jobname. 84 | 85 | Example: For job `test_ccc3e302` 86 | - Debug images at `./jobs/test_ccc3e302/images` 87 | - Checkpoints at `./jobs/test_ccc3e302/model` 88 | 89 | Stop training: `ctrl+c` 90 | 91 | Inference: `python train.py --job [jobID] --config cfg.py --inference`. The latest checkpoint from the `models` directory is used. 92 | The output images are stored in the `./out` directory. Use the switch `--savedir` to specify another folder. 93 | 94 | Output format from inference run: 95 | - `img00000input.png` noisy input image 96 | - `img00000ref.png` target image 97 | - `img00000out.png` denoised image 98 | 99 | 100 | Run a pretrained model 101 | ---------------------- 102 | 103 | Download the pretrained weights `model_0700.tar` from https://github.com/NVlabs/layerdenoise/releases/tag/release_1 104 | and the testset `valsetegsr16k_*.h5` from the same location. 105 | 106 | Place the weights in 107 | `[installation dir]/jobs/server/model/model_0700.tar` 108 | and the dataset in the `./data/` folder. 109 | 110 | Run inference with 111 | `python train.py --job server --config cfg.py --inference --savedir results --network Layer` 112 | 113 | The folder `results` will be populated with the denoised images (input, denoised and reference images). 114 | The expected error metrics for this trained network on the set of images in `valsetegsr16k_*.h5` are: 115 | ``` 116 | relMSE, SMAPE, PSNR 117 | 0.0263, 0.0346, 33.885 118 | ``` 119 | 120 | Dataset generation 121 | ------------------ 122 | 123 | Datasets are generated in hdf5 format. A small example dataset can be downloaded from 124 | https://github.com/NVlabs/layerdenoise/releases/tag/release_1 125 | 126 | The datasets are 5D tensors on the form: `[frames, samples, channels, height, width]` 127 | 128 | The current datasets have the header: 129 | ``` 130 | ...\git\layerdl\data>h5dump -H indoorC_input.h5 131 | HDF5 "indoorC_input.h5" { 132 | GROUP "/" { 133 | DATASET "albedo" { 134 | DATATYPE 16-bit little-endian floating-point 135 | DATASPACE SIMPLE { ( 128, 8, 3, 256, 256 ) / ( 128, 8, 3, 256, 256 ) } 136 | } 137 | DATASET "color" { 138 | DATATYPE 16-bit little-endian floating-point 139 | DATASPACE SIMPLE { ( 128, 8, 3, 256, 256 ) / ( 128, 8, 3, 256, 256 ) } 140 | } 141 | DATASET "motionvecs" { 142 | DATATYPE 16-bit little-endian floating-point 143 | DATASPACE SIMPLE { ( 128, 8, 3, 256, 256 ) / ( 128, 8, 3, 256, 256 ) } 144 | } 145 | DATASET "normals_depth" { 146 | DATATYPE 16-bit little-endian floating-point 147 | DATASPACE SIMPLE { ( 128, 8, 4, 256, 256 ) / ( 128, 8, 4, 256, 256 ) } 148 | } 149 | DATASET "specular" { 150 | DATATYPE 16-bit little-endian floating-point 151 | DATASPACE SIMPLE { ( 128, 8, 4, 256, 256 ) / ( 128, 8, 4, 256, 256 ) } 152 | } 153 | DATASET "uvt" { 154 | DATATYPE 16-bit little-endian floating-point 155 | DATASPACE SIMPLE { ( 128, 8, 3, 256, 256 ) / ( 128, 8, 3, 256, 256 ) } 156 | } 157 | } 158 | } 159 | ``` 160 | -------------------------------------------------------------------------------- /augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | import numpy as np 11 | 12 | def flip_x(x): 13 | return x.flip(len(x.shape) - 1) 14 | 15 | def flip_y(x): 16 | return x.flip(len(x.shape) - 2) 17 | 18 | def rot90(x): 19 | return flip_y(torch.transpose(x, -2, -1)) 20 | 21 | ############################################################################### 22 | # Data augmentation utility functions 23 | ############################################################################### 24 | 25 | def augment_rot90(sequenceHeader, sequenceData): 26 | # Randomly rotate by 90 degrees 27 | if np.random.random() < 0.5: 28 | # Iterate all frames in sequence 29 | for f in sequenceData.frameData: 30 | # Iterate all keys (different data streams / features etc.) 31 | for key in f.__dict__.keys(): 32 | if key == "normals_depth": 33 | nd = rot90(getattr(f, key)) 34 | nd = torch.stack([-nd.select(-3, 1), 35 | nd.select(-3, 0)] + 36 | [nd.select(-3, i) for i in range(2, nd.shape[-3])], dim=-3) 37 | setattr(f, key, nd) 38 | else: 39 | setattr(f, key, rot90(getattr(f, key))) 40 | 41 | def augment_flip_xy(sequenceData): 42 | # Randomly flip y axis 43 | if np.random.random() < 0.5: 44 | # Iterate all frames in sequence 45 | for f in sequenceData.frameData: 46 | # Iterate all keys (different data streams / features etc.) 47 | for key in f.__dict__.keys(): 48 | if key == "normals_depth": 49 | nd = flip_y(getattr(f, key)) 50 | nd = torch.stack([ nd.select(-3, 0), 51 | -nd.select(-3, 1)] + 52 | [nd.select(-3, i) for i in range(2, nd.shape[-3])], dim=-3) 53 | setattr(f, key, nd) 54 | else: 55 | setattr(f, key, flip_y(getattr(f, key))) 56 | 57 | # Randomly flip x axis 58 | if np.random.random() < 0.5: 59 | # Iterate all frames in sequence 60 | for f in sequenceData.frameData: 61 | # Iterate all keys (different data streams / features etc.) 62 | for key in f.__dict__.keys(): 63 | if key == "normals_depth": 64 | nd = flip_x(getattr(f, key)) 65 | nd = torch.stack([-nd.select(-3, 0), 66 | nd.select(-3, 1)] + 67 | [nd.select(-3, i) for i in range(2, nd.shape[-3])], dim=-3) 68 | setattr(f, key, nd) 69 | else: 70 | setattr(f, key, flip_x(getattr(f, key))) 71 | 72 | def augment(sequenceHeader, sequenceData): 73 | augment_rot90(sequenceHeader, sequenceData) 74 | augment_flip_xy(sequenceData) 75 | 76 | 77 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | datadir = './data/' 2 | 3 | valscene = 'valsetegsr16k_in.h5' 4 | 5 | scenes = [ 6 | 'indoorC_input.h5' 7 | ] 8 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Place datasets in this folder. 2 | 3 | A small example dataset is available here: 4 | https://github.com/NVlabs/layerdenoise/releases/tag/release_1 5 | 6 | The datasets are 5D tensors on the form: `[frames, samples, channels, height, width]` -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | import h5py 11 | import numpy as np 12 | 13 | import utils 14 | 15 | ############################################################################### 16 | # Image sequence data class 17 | ############################################################################### 18 | 19 | class SequenceHeader(object): 20 | def __init__(self, nSequence, resolution, cropSize, sequenceData): 21 | self.nSequence = nSequence 22 | self.resolution = resolution 23 | self.cropSize = cropSize 24 | 25 | shape_dict = dict([(key, value.shape) for key, value in sequenceData.frameData[0].__dict__.items()]) 26 | self.frameShape = utils.object_from_dict(shape_dict) 27 | 28 | class SequenceData(object): 29 | def makeCudaTensors(self, x): 30 | if isinstance(x, dict): 31 | _dict = {key : self.makeCudaTensors(value) for key, value in x.items()} 32 | return utils.object_from_dict(_dict) 33 | elif isinstance(x, (list, tuple)): 34 | return [self.makeCudaTensors(y) for y in x] 35 | elif torch.is_tensor(x): 36 | return x.cuda().float() 37 | else: 38 | return x 39 | 40 | def __init__(self, dataset, sequenceData): 41 | self.frameData = self.makeCudaTensors(sequenceData) 42 | 43 | ############################################################################### 44 | # Sample dataset 45 | ############################################################################### 46 | class SampleDataset(torch.utils.data.Dataset): 47 | def __init__(self, filename, filename_ref, cropSize, flags, limit=None, randomCrop=True): 48 | super().__init__() 49 | 50 | self.filename = filename 51 | self.filename_ref = filename_ref 52 | self.limit = limit 53 | self.cropSize = cropSize 54 | self.randomCrop = randomCrop 55 | 56 | # Copy required FLAGS 57 | self._spp = flags.spp 58 | 59 | # Parse out header information 60 | h5py_file = h5py.File(self.filename, 'r') 61 | h5py_file_ref = h5py.File(self.filename_ref, 'r') 62 | 63 | self.resolution = h5py_file['color'].shape[-2:] 64 | self.nDim = len(h5py_file['color'].shape) 65 | self.nFramesPerClip = h5py_file['color'].shape[0] 66 | 67 | assert(self.nDim == 5) # Dataset with 5D tensors [frame, sample, channel, y, x] 68 | 69 | pcrop = self.cropSize 70 | if pcrop == None: 71 | pcrop = self.resolution[0] 72 | 73 | print("Dataset %s - Res: %dx%d, Crop: %dx%d" % (self.filename, self.resolution[0], self.resolution[1], pcrop, pcrop)) 74 | 75 | def getHeader(self): 76 | return SequenceHeader(1, self.resolution, self.cropSize, SequenceData(self, self.__getitem__(0))) 77 | 78 | def __len__(self): 79 | return self.nFramesPerClip if self.limit is None else self.limit 80 | 81 | def __getitem__(self, idx): 82 | # Create random crop. This data augmentation is added to the reader as it affects disk I/O 83 | cw, ch = self.resolution[1], self.resolution[0] 84 | ow, oh = 0, 0 85 | if self.cropSize != None: 86 | cw, ch = self.cropSize, self.cropSize 87 | sw, sh = max(0, self.resolution[1] - cw), max(0, self.resolution[0] - ch) 88 | if self.randomCrop: 89 | ow, oh = torch.randint(0, sw + 1, (1,)).item(), torch.randint(0, sh + 1, (1,)).item() 90 | 91 | h5py_file = h5py.File(self.filename, 'r') 92 | h5py_file_ref = h5py.File(self.filename_ref, 'r') 93 | 94 | assert(self.nDim == 5) 95 | # Load data 96 | # Dataset with 5D tensors [frame, sample, channel, y, x] stored in fp16 97 | color = h5py_file['color'][idx, 0:self._spp, ..., oh:oh+ch, ow:ow+cw] # Linear radiance in HDR 98 | normals_depth = h5py_file['normals_depth'][idx, 0:self._spp, ..., oh:oh+ch, ow:ow+cw] # View space normals in xyz, normalized world space depth in w 99 | albedo = h5py_file['albedo'][idx, 0:self._spp, ..., oh:oh+ch, ow:ow+cw] # Albedo map at first hit 100 | specular = h5py_file['specular'][idx, 0:self._spp, ..., oh:oh+ch, ow:ow+cw] # Specular map at first hit 101 | uvt = h5py_file['uvt'][idx, 0:self._spp, ..., oh:oh+ch, ow:ow+cw] # Lens position (xy) and time (z) 102 | motionvecs = h5py_file['motionvecs'][idx, 0:self._spp, ..., oh:oh+ch, ow:ow+cw] # NDC Motion vectors in xy, signed CoC radius in z 103 | target = h5py_file_ref['color'][idx, ..., oh:oh+ch, ow:ow+cw] # Reference radiance in linear HDR 104 | 105 | # Create object with frame data 106 | frame_dict = { 107 | "color" : np.clip(color, 0.0, 65535.0), 108 | "normals_depth" : normals_depth, 109 | "albedo" : albedo, 110 | "specular" : specular, 111 | "uvt" : uvt, 112 | "motionvecs" : motionvecs, 113 | "target" : np.clip(target, 0.0, 65535.0) 114 | } 115 | return [frame_dict] 116 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:20.10-py3 10 | FROM $BASE_IMAGE 11 | 12 | # Install torch_utils 13 | COPY torch_utils /tmp/torch_utils/ 14 | RUN cd /tmp/torch_utils && python setup.py install 15 | -------------------------------------------------------------------------------- /layer_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch_utils import WeightedFilter 12 | 13 | from unet import * 14 | from utils import * 15 | 16 | EPSILON = 0.000001 # Small epsilon to avoid division by zero 17 | 18 | ############################################################################### 19 | # Layer network definition 20 | # 21 | # The large network described in "Neural Denoising with Layer Embeddings" 22 | # 23 | # https://research.nvidia.com/publication/2020-06_Neural-Denoising-with 24 | # 25 | ############################################################################### 26 | 27 | class LayerNet(nn.Module): 28 | def __init__(self, sequenceHeader, tonemapper, splat, num_samples, kernel_size): 29 | 30 | super(LayerNet, self).__init__() 31 | self.tonemapper = tonemapper 32 | self.output_channels = 128 33 | self.embed_channels = 32 34 | self.kernel_size = kernel_size 35 | self.num_samples = int(num_samples) 36 | self.splat = splat 37 | self.resolution = sequenceHeader.resolution 38 | frameShape = sequenceHeader.frameShape 39 | self.input_channels = frameShape.color[1] + frameShape.normals_depth[1] + frameShape.albedo[1] + frameShape.specular[1] + frameShape.uvt[1] + frameShape.motionvecs[1] 40 | self.layers = 2 41 | 42 | # Sample reducer: Maps from input channels to sample embeddings, uses 1x1 convolutions 43 | self._red1 = nn.Sequential( 44 | nn.Conv2d(self.input_channels, self.embed_channels, 1, padding=0), 45 | Activation, 46 | nn.Conv2d(self.embed_channels, self.embed_channels, 1, padding=0), 47 | Activation, 48 | nn.Conv2d(self.embed_channels, self.embed_channels, 1, padding=0), 49 | Activation, 50 | ) 51 | 52 | # Sample partitioner: Computes weights for splatting samples to layers, uses 1x1 convolutions 53 | self._sample_partitioner = nn.Sequential( 54 | nn.Conv2d(self.output_channels+self.embed_channels, 32, 1, padding=0), 55 | Activation, 56 | nn.Conv2d(32, 16, 1, padding=0), 57 | Activation, 58 | nn.Conv2d(16, self.layers, 1, padding=0), # One splat weight per layer 59 | ) 60 | 61 | # Kernel generator: Computes filter kernels per-layer, uses 1x1 convolutions 62 | self._kernel_generator = nn.Sequential( 63 | nn.Conv2d(self.output_channels+self.embed_channels, 128, 1, padding=0), 64 | Activation, 65 | nn.Conv2d(128, 128, 1, padding=0), 66 | Activation, 67 | nn.Conv2d(128, self.kernel_size*self.kernel_size, 1, padding=0), # output kernel weights 68 | ) 69 | 70 | # U-Net: Generates context features 71 | self._unet = UNet(self.embed_channels, self.output_channels, encoder_features=[[64, 64], [128], [256], [512], [512]], bottleneck_features=[512], decoder_features=[[512, 512], [256, 256], [128, 128], [128, 128], [128, 128]]) 72 | 73 | # Filter for applying predicted kernels 74 | self._kpn = WeightedFilter(channels=3, kernel_size=self.kernel_size, bias=False, splat=self.splat) 75 | 76 | # Initialize network weights 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | nn.init.xavier_uniform_(m.weight.data) 80 | if m.bias is not None: 81 | m.bias.data.zero_() 82 | 83 | def forward(self, sequenceData, epoch): 84 | frame = sequenceData.frameData[0] 85 | 86 | radiance = frame.color # Linear color values 87 | 88 | rgb = self.tonemapper(frame.color) 89 | normals_depth = frame.normals_depth 90 | motionvecs = frame.motionvecs 91 | albedo = frame.albedo 92 | specular = frame.specular 93 | uvt = frame.uvt 94 | 95 | xc = torch.cat((rgb, normals_depth, motionvecs, albedo, specular, uvt), dim=2) 96 | 97 | # loop over samples to create embeddings 98 | sh = xc.shape 99 | embedding = torch.cuda.FloatTensor(sh[0], sh[1], self.embed_channels, sh[3], sh[4]).fill_(0) 100 | for i in range(sh[1]): 101 | embedding[:, i, ...] = self._red1(xc[:,i,...]) 102 | avg_embeddings = embedding.mean(dim=1) # average over embeddings dimension 103 | 104 | # Run U-net 105 | context = self._unet(avg_embeddings) 106 | 107 | # Allocate buffers 108 | l_radiance = [torch.cuda.FloatTensor(sh[0], 3, sh[3], sh[4]).fill_(0) for i in range(self.layers)] 109 | l_weights = [torch.cuda.FloatTensor(sh[0], 1, sh[3], sh[4]).fill_(0) for i in range(self.layers)] 110 | l_n = [torch.cuda.FloatTensor(sh[0], 1, sh[3], sh[4]).fill_(0) for i in range(self.layers)] 111 | l_e = [torch.cuda.FloatTensor(sh[0], self.embed_channels, sh[3], sh[4]).fill_(0) for i in range(self.layers)] 112 | 113 | # Splat samples to layers 114 | for i in range(0, self.num_samples): # loop over samples 115 | w = self._sample_partitioner(torch.cat((embedding[:, i, ...], context), dim=1)) 116 | w = torch.softmax(w, dim=1) / self.num_samples 117 | 118 | for j in range(self.layers): 119 | l_radiance[j] += radiance[:, i, ...] * w[:, j:j+1, ...] 120 | l_weights[j] += w[:, j:j+1, ...] 121 | l_e[j] += embedding[:, i, ...] * w[:, j:j+1, ...] 122 | l_n[j] += torch.sum(w[:, j:self.layers, ...], dim=1, keepdim=True) # increment only for samples in or in front 123 | 124 | # Generate layer weights and take exp to make them positive 125 | layer_weights = torch.cat(tuple(self._kernel_generator(torch.cat((l_e[i], context), dim=1)) for i in range(self.layers)), dim=1) 126 | weight_max = torch.max(layer_weights, dim=1, keepdim=True)[0] 127 | layer_weights = torch.exp(layer_weights - weight_max) # subtract largest weight for stability 128 | num_weights = self.kernel_size*self.kernel_size 129 | 130 | # Alpha-blending compositing 131 | col_sum = torch.cuda.FloatTensor(sh[0], 3, sh[3], sh[4]).fill_(0) 132 | k = torch.cuda.FloatTensor(sh[0], 1, sh[3], sh[4]).fill_(1.0) 133 | for j in range(self.layers): 134 | startw = num_weights*j 135 | endw = num_weights*(j+1) 136 | kernel = layer_weights[:, startw:endw, ...] 137 | 138 | filtered_rad = self._kpn(l_radiance[j].contiguous(), kernel.contiguous()) 139 | alpha = self._kpn(l_weights[j].contiguous(), kernel.contiguous()) 140 | filtered_n = self._kpn(l_n[j].contiguous(), kernel.contiguous()) 141 | filtered_rad = filtered_rad / (filtered_n + EPSILON) 142 | alpha = alpha / (filtered_n + EPSILON) 143 | col_sum += filtered_rad * k 144 | k = (1.0 - alpha) * k 145 | 146 | return utils.object_from_dict({'color' : col_sum}) 147 | 148 | def inference(self, sequenceData): 149 | return self.forward(sequenceData, 0) 150 | -------------------------------------------------------------------------------- /license.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Nvidia Source Code License-NC 7 | 8 | 56 | 57 | 58 | 59 |

NVIDIA Source Code License for Neural Denoising with Layer Embeddings

60 | 61 |
62 | 63 |

1. Definitions

64 | 65 |

“Licensor” means any person or entity that distributes its Work.

66 | 67 |

“Software” means the original work of authorship made available under 68 | this License.

69 | 70 |

“Work” means the Software and any additions to or derivative works of 71 | the Software that are made available under this License.

72 | 73 |

The terms “reproduce,” “reproduction,” “derivative works,” and 74 | “distribution” have the meaning as provided under U.S. copyright law; 75 | provided, however, that for the purposes of this License, derivative 76 | works shall not include works that remain separable from, or merely 77 | link (or bind by name) to the interfaces of, the Work.

78 | 79 |

Works, including the Software, are “made available” under this License 80 | by including in or with the Work either (a) a copyright notice 81 | referencing the applicability of this License to the Work, or (b) a 82 | copy of this License.

83 | 84 |

2. License Grants

85 | 86 |

2.1 Copyright Grant. Subject to the terms and conditions of this 87 | License, each Licensor grants to you a perpetual, worldwide, 88 | non-exclusive, royalty-free, copyright license to reproduce, 89 | prepare derivative works of, publicly display, publicly perform, 90 | sublicense and distribute its Work and any resulting derivative 91 | works in any form.

92 | 93 |

3. Limitations

94 | 95 |

3.1 Redistribution. You may reproduce or distribute the Work only 96 | if (a) you do so under this License, (b) you include a complete 97 | copy of this License with your distribution, and (c) you retain 98 | without modification any copyright, patent, trademark, or 99 | attribution notices that are present in the Work.

100 | 101 |

3.2 Derivative Works. You may specify that additional or different 102 | terms apply to the use, reproduction, and distribution of your 103 | derivative works of the Work (“Your Terms”) only if (a) Your Terms 104 | provide that the use limitation in Section 3.3 applies to your 105 | derivative works, and (b) you identify the specific derivative 106 | works that are subject to Your Terms. Notwithstanding Your Terms, 107 | this License (including the redistribution requirements in Section 108 | 3.1) will continue to apply to the Work itself.

109 | 110 |

3.3 Use Limitation. The Work and any derivative works thereof only 111 | may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA 112 | and its affiliates may use the Work and any derivative works commercially. As used herein, 113 | “non-commercially” means for research or evaluation purposes only.

114 | 115 |

3.4 Patent Claims. If you bring or threaten to bring a patent claim 116 | against any Licensor (including any claim, cross-claim or 117 | counterclaim in a lawsuit) to enforce any patents that you allege 118 | are infringed by any Work, then your rights under this License from 119 | such Licensor (including the grant in Section 2.1) will 120 | terminate immediately.

121 | 122 |

3.5 Trademarks. This License does not grant any rights to use any 123 | Licensor’s or its affiliates’ names, logos, or trademarks, except 124 | as necessary to reproduce the notices described in this License.

125 | 126 |

3.6 Termination. If you violate any term of this License, then your 127 | rights under this License (including the grant in Section 2.1) will terminate immediately.

128 | 129 |

4. Disclaimer of Warranty.

130 | 131 |

THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY 132 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 133 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 134 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 135 | THIS LICENSE.

136 | 137 |

5. Limitation of Liability.

138 | 139 |

EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 140 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 141 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 142 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 143 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 144 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 145 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 146 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 147 | THE POSSIBILITY OF SUCH DAMAGES.

148 | 149 |
150 |
151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /sample_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch_utils import WeightedFilter 13 | 14 | from unet import * 15 | from utils import * 16 | 17 | EPSILON = 0.000001 # Small epsilon to avoid division by zero 18 | 19 | ############################################################################### 20 | # Sample network definition 21 | # 22 | # A scaled-down version of Sample-based Monte Carlo Denoising using a Kernel-Splatting Network 23 | # 24 | # https://groups.csail.mit.edu/graphics/rendernet/ 25 | # 26 | ############################################################################### 27 | 28 | class SampleNet(nn.Module): 29 | def __init__(self, sequenceHeader, tonemapper, num_samples=8, splat=False, use_sample_info=False, kernel_size=17): 30 | 31 | super(SampleNet, self).__init__() 32 | self.use_sample_info = use_sample_info 33 | self.tonemapper = tonemapper 34 | self.output_channels = 128 35 | self.embed_channels = 32 36 | self.kernel_size = kernel_size 37 | self.num_samples = int(num_samples) 38 | self.splat = splat 39 | self.resolution = sequenceHeader.resolution 40 | frameShape = sequenceHeader.frameShape 41 | self.input_channels = frameShape.color[1] + frameShape.normals_depth[1] + frameShape.albedo[1] + frameShape.specular[1] + frameShape.uvt[1] + frameShape.motionvecs[1] 42 | 43 | # Sample Reducer: Maps from input channels to sample embeddings, uses 1x1 convolutions 44 | self._sample_reducer = nn.Sequential( 45 | nn.Conv2d(self.input_channels, self.embed_channels, 1, padding=0), 46 | Activation, 47 | nn.Conv2d(self.embed_channels, self.embed_channels, 1, padding=0), 48 | Activation, 49 | nn.Conv2d(self.embed_channels, self.embed_channels, 1, padding=0), 50 | Activation, 51 | ) 52 | 53 | # Pixel reducer: Used instead of sample reducer for the per-pixel network, uses 1x1 convolutions 54 | self._pixel_reducer = nn.Sequential( 55 | nn.Conv2d(self.input_channels*2, self.embed_channels, 1, padding=0), 56 | Activation, 57 | nn.Conv2d(self.embed_channels, self.embed_channels, 1, padding=0), 58 | Activation, 59 | nn.Conv2d(self.embed_channels, self.embed_channels, 1, padding=0), 60 | Activation, 61 | ) 62 | 63 | # Kernel generator: Combines UNet per-pixel output with per-sample or per-pixel embeddings, uses 1x1 convolutions 64 | self._kernel_generator = nn.Sequential( 65 | nn.Conv2d(self.output_channels+self.embed_channels, 128, 1, padding=0), 66 | Activation, 67 | nn.Conv2d(128, 128, 1, padding=0), 68 | Activation, 69 | nn.Conv2d(128, self.kernel_size*self.kernel_size, 1, padding=0), # output kernel weights 70 | ) 71 | 72 | # U-Net: Generates context features 73 | self._unet = UNet(self.embed_channels, self.output_channels, encoder_features=[[64, 64], [128], [256], [512], [512]], bottleneck_features=[512], decoder_features=[[512, 512], [256, 256], [128, 128], [128, 128], [128, 128]]) 74 | 75 | # Filter for applying predicted kernels 76 | self._kpn = WeightedFilter(channels=3, kernel_size=self.kernel_size, bias=False, splat=self.splat) 77 | 78 | # Initialize network weights 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | nn.init.xavier_uniform_(m.weight.data) 82 | if m.bias is not None: 83 | m.bias.data.zero_() 84 | 85 | def forward(self, sequenceData, epoch): 86 | num_weights = self.kernel_size*self.kernel_size 87 | 88 | frame = sequenceData.frameData[0] 89 | radiance = frame.color[:, :, 0:3, ...] 90 | rgb = self.tonemapper(frame.color) 91 | 92 | xc = torch.cat((rgb, frame.normals_depth, frame.albedo, frame.specular, frame.uvt, frame.motionvecs), dim=2) 93 | 94 | # We transform a 5D tensor [batch, sample, channel, weight, width] 95 | # into a 4D tensor [batch, embedding, weight, width] 96 | 97 | # loop over samples to create embeddings 98 | sh = xc.shape 99 | embedding = torch.cuda.FloatTensor(sh[0], sh[1], self.embed_channels, sh[3], sh[4]).fill_(0) 100 | 101 | if self.use_sample_info: 102 | # loop over samples to create embeddings 103 | for i in range(sh[1]): 104 | embedding[:,i,...] = self._sample_reducer(xc[:,i,...]) 105 | avg_embeddings = embedding.mean(dim=1) # average over embeddings dimension 106 | else: 107 | # average per-sample info 108 | xc_mean = torch.mean(xc, dim=1) 109 | xc_variance = torch.var(xc, dim=1, unbiased=False) 110 | embedding[:,0,...] = self._pixel_reducer(torch.cat((xc_mean,xc_variance), dim=1)) 111 | avg_embeddings = embedding[:,0,...] 112 | 113 | context = self._unet(avg_embeddings) 114 | ones = torch.cuda.FloatTensor(sh[0], 1, sh[3], sh[4]).fill_(1.0) 115 | 116 | if self.use_sample_info: # work on individual samples 117 | accum = torch.cuda.FloatTensor(sh[0], 3, sh[3], sh[4]).fill_(0) 118 | accum_w = torch.cuda.FloatTensor(sh[0], 1, sh[3], sh[4]).fill_(0) 119 | # create sample weights 120 | sample_weights = torch.cat(tuple(self._kernel_generator(torch.cat((embedding[:, i, ...], context), dim=1)) for i in range(0, self.num_samples)), dim=1) 121 | weight_max = torch.max(sample_weights, dim=1, keepdim=True)[0] 122 | sample_weights = torch.exp(sample_weights - weight_max) 123 | 124 | for i in range(self.num_samples): # loop over samples 125 | startw = num_weights*(i) 126 | endw = num_weights*(i+1) 127 | accum += self._kpn(radiance[:, i, ...].contiguous(), sample_weights[:, startw:endw, ...].contiguous()) 128 | accum_w += self._kpn(ones.contiguous(), sample_weights[:, startw:endw, ...].contiguous()) 129 | filtered = accum / (accum_w + EPSILON) 130 | 131 | else: # work on pixel aggregates 132 | radiance_mean = torch.mean(radiance, dim=1) 133 | pixel_weights = self._kernel_generator(torch.cat((embedding[:,0,...], context), dim=1)) 134 | weight_max = torch.max(pixel_weights, dim=1, keepdim=True)[0] 135 | pixel_weights = torch.exp(pixel_weights - weight_max) 136 | col = self._kpn(radiance_mean.contiguous(), pixel_weights) 137 | w = self._kpn(ones.contiguous(), pixel_weights) 138 | filtered = col/(w+EPSILON) 139 | 140 | return utils.object_from_dict({'color' : filtered}) 141 | 142 | def inference(self, sequenceData): 143 | return self.forward(sequenceData, 0) 144 | -------------------------------------------------------------------------------- /torch_utils/README.md: -------------------------------------------------------------------------------- 1 | # Cuda implementation of per-pixel kernel evaluation for PyTorch 2 | 3 | ## Introduction 4 | 5 | This folder represents a plugin layers for PyTorch. The layer is implemented with 6 | CUDA kernels and are thus more performant than the PyTorch equivalent. The layer is a regular `torch.nn.Module` primitive, behaves very similar to PyTorch stock layers (such as `Conv2d`), and supports both forward evaluation and back-propagation. 7 | 8 | ### WeightedFilter 9 | The weighted filter primitive is a two dimensional convolution filter that accepts a per-pixel weight matrix, rather than a set of trainable weights. This primitive is intended to be used for the final layer of a kernel predicting network. It has no trainable parameters, but supports back-propagation. Given an input tensor $`\mathbf{x}`$ 10 | and a weight tensor $`\mathbf{w}`$, the output activations are compute as: 11 | 12 | ```math 13 | out[c,y,x] = \sum_{i,i^+}^{N} \sum_{j,j^+}^{N} w[i^+\cdot N + j^+,y,x] x[c,y+i,x+j] 14 | ``` 15 | 16 | Note that the weights in $`\mathbf{w}`$ are applied equally to all feature channels of $`\mathbf{x}`$, producing an output with as many 17 | feature channels as $`\mathbf{x}`$. It is assumed that the input and weight tensors have the same *height* and *width* dimensions, and 18 | border conditions are handled by zero padding. 19 | 20 | **Splatting** The weighted filter also supports splat kernels. Instead of gathering the output activation as a nested sum, the contribution 21 | of each activation in the input tensor is scattered according to the pseudo code below. 22 | ``` 23 | for i in range(0, N): 24 | for j in range(0, N): 25 | out[c, y + i - N/2, x + j - N/2] += w[i * N + j, y, x] * x[c, y, x] 26 | ``` 27 | However, one can easily realize that this can be rewritten as a nested sum (gather) by modifying how the weight tensor is indexed. 28 | ```math 29 | out[c,y,x] = \sum_{i,i^+}^{N} \sum_{j,j^+}^{N} w[(N-i^+-1)\cdot N + (N-j^+-1),y+i,x+j] x[c,y+i,x+j] 30 | ``` 31 | 32 | ## Usage example 33 | 34 | ```python 35 | import torch 36 | from torch_utils import WeightedFilter 37 | 38 | in_channels = 3 39 | kernel_size = 5 40 | 41 | # Create kernel predicting network model 42 | network = UNet(out_channels=kernel_size*kernel_size) 43 | 44 | # Create a weighted (kpn) network layer without bias and using gather (no splatting) 45 | kpn_layer = WeightedFilter(in_channels, kernel_size, bias=False, splat=False) 46 | 47 | # Load image and guide 48 | color = loadImage('color.png') 49 | guide = loadImage('normal.png') 50 | target = loadImage('target.png') 51 | 52 | # Run forward pass 53 | kpn_weights = network(color, guide) 54 | out = kpn_layer(color, kpn_weights) 55 | 56 | # Compute loss 57 | loss = torch.nn.MSELoss()(out, target) 58 | 59 | # Back propagate 60 | loss.backward() 61 | ``` 62 | 63 | ## Windows/Anaconda installation 64 | 65 | ### Requirements 66 | - PyTorch in Anaconda environment (tested with Python 3.7 and PyTorch 1.6) 67 | - Visual Studio 2019. 68 | - Cuda 10.2 69 | 70 | ### Installing 71 | 72 | Open a **"x64 Native Tools Command Prompt for VS 2019"** and start your PyTorch Anaconda environment 73 | from that prompt (it need to be that prompt so the paths to the correct compiler is properly set). 74 | 75 | Then type: 76 | ``` 77 | cd [layerdl installation path]\torch_utils 78 | set DISTUTILS_USE_SDK=1 79 | python setup.py install 80 | ``` 81 | 82 | - List installed packages: `pip list` 83 | - Remove package: `pip uninstall torch-utils` 84 | 85 | ## Installation in a Docker container 86 | 87 | Navigate to the folder where you've cloned `layerdenoise` and build the docker image 88 | `docker build --tag ldenoiser:latest -f docker/Dockerfile .` 89 | 90 | Launch a container 91 | `docker run --gpus device=0 --shm-size 16G --rm -v /raid:/raid -it ldenoiser:latest bash` 92 | 93 | ### Tutorial for building custom modules 94 | 95 | https://pytorch.org/tutorials/advanced/cpp_extension.html 96 | 97 | -------------------------------------------------------------------------------- /torch_utils/clean.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | rmdir /q /s build > nul 2>&1 3 | rmdir /q /s dist > nul 2>&1 4 | rmdir /q /s torch_utils.egg-info > nul 2>&1 5 | rmdir /q /s tests\__pycache__ > nul 2>&1 6 | 7 | -------------------------------------------------------------------------------- /torch_utils/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from setuptools import setup, find_packages 10 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 11 | 12 | CPP_FILES = [ 13 | 'torch_utils/torch_utils.cpp', 14 | 'torch_utils/cuda_weighted_filter.cu', 15 | ] 16 | 17 | setup( 18 | name='torch_utils', 19 | version='0.1', 20 | author="Jon Hasselgren", 21 | author_email="jhasselgren@nvidia.com", 22 | description="torch_utils - fast kernel evaluations", 23 | url="https://github.com/NVlabs/layerdenoise", 24 | install_requires=['torch'], 25 | packages=find_packages(exclude=['test*']), 26 | ext_modules=[CUDAExtension('torch_utils_cpp', CPP_FILES, extra_compile_args={'cxx' : [], 'nvcc' : ['-arch', 'compute_70']})], 27 | py_modules=["torch_utils/weighted_filter"], 28 | cmdclass={ 29 | 'build_ext': BuildExtension 30 | }, 31 | classifiers=[ 32 | "Programming Language :: Python :: 3", 33 | "Operating System :: OS Independent", 34 | ], 35 | python_requires='>=3.6', 36 | ) 37 | -------------------------------------------------------------------------------- /torch_utils/tests/test_weighted.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import sys 10 | import torch 11 | import numpy as np 12 | 13 | from torch_utils import WeightedFilter 14 | from torch_utils_ref import WeightedFilterPy 15 | 16 | ################################################## 17 | # Utility function 18 | ################################################## 19 | 20 | GRAD_DICT = {} 21 | def save_grad(name): 22 | def hook(grad): 23 | global GRAD_DICT 24 | GRAD_DICT[name] = grad 25 | return hook 26 | 27 | def max_relative_err(x,y): 28 | return (torch.abs(x - y) / torch.abs(y).max()).max() 29 | 30 | ################################################## 31 | # Networks 32 | ################################################## 33 | 34 | class RefNet(torch.nn.Module): 35 | def __init__(self, input_w, weight_w, b, splat): 36 | super(RefNet, self).__init__() 37 | self.c1 = torch.nn.Conv2d(input_w.shape[1], input_w.shape[0], input_w.shape[2], padding=input_w.shape[2]//2, bias=False) 38 | self.c2 = torch.nn.Conv2d(weight_w.shape[1], weight_w.shape[0], weight_w.shape[2], padding=weight_w.shape[2]//2, bias=False) 39 | self.c3 = WeightedFilterPy(input_w.shape[0], weight_w.shape[2], splat=splat) 40 | 41 | self.c1.weight.data = input_w.clone() 42 | self.c2.weight.data = weight_w.clone() 43 | self.c3.bias.data = b.clone() 44 | 45 | def forward(self, x, w): 46 | self.input = self.c1(x) 47 | self.weight = self.c2(w) 48 | self.input.register_hook(save_grad("ref_input_grad")) 49 | self.weight.register_hook(save_grad("ref_weight_grad")) 50 | return self.c3(self.input, self.weight) 51 | 52 | class OurNet(torch.nn.Module): 53 | def __init__(self, input_w, weight_w, b, splat): 54 | super(OurNet, self).__init__() 55 | self.c1 = torch.nn.Conv2d(input_w.shape[1], input_w.shape[0], input_w.shape[2], padding=input_w.shape[2]//2, bias=False) 56 | self.c2 = torch.nn.Conv2d(weight_w.shape[1], weight_w.shape[0], weight_w.shape[2], padding=weight_w.shape[2]//2, bias=False) 57 | self.c3 = WeightedFilter(input_w.shape[0], weight_w.shape[2], splat=splat) 58 | 59 | self.c1.weight.data = input_w.clone() 60 | self.c2.weight.data = weight_w.clone() 61 | self.c3.bias.data = b.clone() 62 | 63 | def forward(self, x, w): 64 | self.input = self.c1(x) 65 | self.weight = self.c2(w) 66 | self.input.register_hook(save_grad("our_input_grad")) 67 | self.weight.register_hook(save_grad("our_weight_grad")) 68 | return self.c3(self.input, self.weight) 69 | 70 | ################################################## 71 | # Test 72 | ################################################## 73 | 74 | for splat in [False, True]: 75 | print("Splatting: %s" % str(splat)) 76 | 77 | num_tests = 10000 78 | kernel_size = 3 79 | img_size = 256 80 | batch_size = 1 81 | channels = 1 82 | 83 | e_forward = 0.0 84 | e_input_grad = 0.0 85 | e_weight_grad = 0.0 86 | for i in range(num_tests): 87 | print("%5d / %5d" % (i, num_tests), end="\r", flush=True) 88 | 89 | # Create random image & initialize random weights 90 | input = torch.randn((batch_size, channels, img_size, img_size)).cuda() 91 | target = torch.randn((batch_size, channels, img_size, img_size)).cuda() 92 | W = torch.randn((batch_size, kernel_size*kernel_size, img_size, img_size)).cuda() 93 | 94 | input_w = torch.randn((channels, channels, kernel_size, kernel_size)).cuda() 95 | weight_w = torch.randn((kernel_size*kernel_size, kernel_size*kernel_size, kernel_size, kernel_size)).cuda() 96 | b = torch.zeros((channels)).cuda() 97 | 98 | # Setup our and refernce networks 99 | ref_net = RefNet(input_w, weight_w, b, splat).cuda() 100 | our_net = OurNet(input_w, weight_w, b, splat).cuda() 101 | 102 | # Run forward pass 103 | ref_res = ref_net(input, W) 104 | our_res = our_net(input, W) 105 | 106 | # Compute loss and back propagate 107 | our_loss = torch.nn.MSELoss()(our_res, target) 108 | ref_loss = torch.nn.MSELoss()(ref_res, target) 109 | our_loss.backward() 110 | ref_loss.backward() 111 | 112 | fwd = max_relative_err(our_res, ref_res) 113 | igrad = max_relative_err(GRAD_DICT["our_input_grad"], GRAD_DICT["ref_input_grad"]) 114 | wgrad = max_relative_err(GRAD_DICT["our_weight_grad"], GRAD_DICT["ref_weight_grad"]) 115 | 116 | ################################################################## 117 | # Debug prints 118 | 119 | #if fwd > e_forward: 120 | # print("\nNew max forward error:\n", our_res - ref_res) 121 | #if igrad > e_input_grad: 122 | # print("\nNew max input gradient error:\n", GRAD_DICT["our_input_grad"] - GRAD_DICT["ref_input_grad"]) 123 | #if wgrad > e_weight_grad: 124 | # print("\nNew max input gradient error:\n", (GRAD_DICT["our_weight_grad"] - GRAD_DICT["ref_weight_grad"]) / GRAD_DICT["ref_weight_grad"].max()) 125 | 126 | # Find errors everywhere 127 | e_forward = max(e_forward, fwd) 128 | e_input_grad = max(e_input_grad, igrad) 129 | e_weight_grad = max(e_weight_grad, wgrad) 130 | 131 | print("Forward: %f" % e_forward) 132 | print("Input grad: %f" % e_input_grad) 133 | print("Weight grad: %f" % e_weight_grad) 134 | -------------------------------------------------------------------------------- /torch_utils/tests/torch_utils_ref/__init__.py: -------------------------------------------------------------------------------- 1 | from .weighted_filter import WeightedFilterPy -------------------------------------------------------------------------------- /torch_utils/tests/torch_utils_ref/weighted_filter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | 11 | ############################################################################### 12 | # Weighted filter for kernel predicting networks 13 | ############################################################################### 14 | 15 | class WeightedFilterPy(torch.nn.Module): 16 | def __init__(self, in_channels, kernel_size, bias=True, splat=False): 17 | super(WeightedFilterPy, self).__init__() 18 | self.in_channels = in_channels 19 | self.out_channels = in_channels 20 | self.kernel_size = kernel_size 21 | self.splat = splat 22 | 23 | if bias: 24 | self.bias = torch.nn.Parameter(torch.Tensor(self.out_channels)) 25 | 26 | def forward(self, input, weight): 27 | 28 | HEIGHT = input.shape[2] # assume input is a tensor with shape NCHW 29 | WIDTH = input.shape[3] 30 | 31 | v1, v0 = torch.meshgrid([torch.linspace(-1.0, 1.0, HEIGHT).cuda(), torch.linspace(-1.0, 1.0, WIDTH).cuda()]) 32 | 33 | offsetx = 2.0 / (WIDTH - 1) 34 | offsety = 2.0 / (HEIGHT - 1) 35 | 36 | radius = self.kernel_size // 2 37 | batch_size = input.shape[0] 38 | 39 | out = torch.zeros_like(input) 40 | for i in range(-radius, radius + 1): 41 | for j in range(-radius, radius + 1): 42 | # compute tap offset 43 | v0_tap = v0 + j*offsetx 44 | v1_tap = v1 + i*offsety 45 | 46 | mvs = torch.stack((v0_tap, v1_tap), dim=2) 47 | 48 | # shift image according to offset 49 | tap_col = torch.nn.functional.grid_sample(input, mvs.expand(batch_size,-1,-1,-1), padding_mode='zeros', align_corners=True) 50 | 51 | # If using "splat" kernels, shift weights along with colors 52 | if self.splat: 53 | tap_w = torch.nn.functional.grid_sample(weight, mvs.expand(batch_size,-1,-1,-1), padding_mode='zeros', align_corners=True) 54 | out = out + tap_col[:, :, ...] * tap_w[:, (radius - i)*self.kernel_size + (radius - j), ...].unsqueeze(1) 55 | else: 56 | out = out + tap_col[:, :, ...] * weight[:, (i + radius)*self.kernel_size + (j + radius), ...].unsqueeze(1) 57 | 58 | if hasattr(self, 'bias'): 59 | for oc in range(self.out_channels): 60 | out[:, oc, ...] = out[:, oc, ...] + self.bias[oc].expand_as(out[:, oc, ...]) 61 | 62 | return out 63 | -------------------------------------------------------------------------------- /torch_utils/tests/update_pkg.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | cd .. 3 | python setup.py install 4 | cd tests 5 | -------------------------------------------------------------------------------- /torch_utils/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .weighted_filter import WeightedFilter 2 | -------------------------------------------------------------------------------- /torch_utils/torch_utils/cuda_weighted_filter.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | #include 17 | 18 | #define WEIGHTS_TILE_SIZE 64 19 | 20 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 21 | // Utility 22 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 23 | 24 | static dim3 compute_blocks(dim3 size, dim3 threads) 25 | { 26 | return dim3((size.x + threads.x - 1) / threads.x, (size.y + threads.y - 1) / threads.y); 27 | } 28 | 29 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 30 | // Forward pass 31 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 32 | 33 | __global__ void cuda_weighted_filter_forward_kernel( 34 | const float* __restrict__ input, 35 | const float* __restrict__ weight, 36 | float* __restrict__ output, 37 | int32_t batch_size, 38 | int32_t in_channels, 39 | int32_t height, 40 | int32_t width, 41 | int32_t filter_h, 42 | int32_t filter_w, 43 | bool splat) 44 | { 45 | const int32_t ch = blockIdx.y * blockDim.y + threadIdx.y; 46 | const int32_t pixel_index = blockIdx.x * blockDim.x + threadIdx.x; 47 | 48 | if (pixel_index >= height * width || ch >= in_channels) 49 | return; 50 | 51 | const int32_t py = pixel_index / width; 52 | const int32_t px = pixel_index % width; 53 | 54 | float result = 0.0f; 55 | for (int32_t fy = 0; fy < filter_h; ++fy) 56 | { 57 | for (int32_t fx = 0; fx < filter_w; ++fx) 58 | { 59 | // Compute tap coordinates, used for input activations and bilateral guides 60 | int32_t y = py + fy - (filter_h - 1) / 2; 61 | int32_t x = px + fx - (filter_w - 1) / 2; 62 | 63 | if (y < 0 || x < 0 || y >= height || x >= width) 64 | continue; 65 | 66 | // Filter using custom weight, use either gathering or splatting (scatter) 67 | if (splat) 68 | result += input[ch*height*width + y * width + x] * weight[((filter_h - fy - 1)*filter_w + (filter_w - fx - 1))*height*width + y*width + x]; // Splatting 69 | else 70 | result += input[ch*height*width + y*width + x] * weight[(fy*filter_w + fx)*height*width + py*width + px]; // Gathering 71 | } 72 | } 73 | output[ch*height*width + pixel_index] = result; 74 | } 75 | 76 | at::Tensor cuda_weighted_filter_forward(at::Tensor input, at::Tensor weight, int64_t kernel_size, bool splat) 77 | { 78 | // Get tensor shapes 79 | at::IntList input_shape = input.sizes(); 80 | at::IntList weight_shape = weight.sizes(); 81 | at::IntList output_shape = input_shape; 82 | 83 | // Initialize output tensor to zero 84 | at::Tensor output = at::zeros(output_shape, input.options()); 85 | 86 | // Setup dimensions for cuda kernel 87 | dim3 threads = dim3(32, 4); 88 | dim3 size = dim3(input_shape[2] * input_shape[3], input_shape[1]); // #pixels, out_channels = in_channels 89 | dim3 blocks = compute_blocks(size, threads); 90 | 91 | // Invoke separate cuda kernel for each batch 92 | for (int64_t batch = 0; batch < input_shape[0]; ++batch) 93 | { 94 | cuda_weighted_filter_forward_kernel <<>> ( 95 | input.data() + batch * input_shape[1] * input_shape[2] * input_shape[3], 96 | weight.data() + batch * weight_shape[1] * weight_shape[2] * weight_shape[3], 97 | output.data() + batch * output_shape[1] * output_shape[2] * output_shape[3], 98 | (int32_t)input_shape[0], // batch_size 99 | (int32_t)input_shape[1], // in_channels 100 | (int32_t)input_shape[2], // height 101 | (int32_t)input_shape[3], // width 102 | (int32_t)kernel_size, // filter_h 103 | (int32_t)kernel_size, // filter_w 104 | splat // splatting vs gather 105 | ); 106 | } 107 | 108 | // Return result 109 | return output; 110 | } 111 | 112 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 113 | // Backward pass 114 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 115 | 116 | __global__ void cuda_weighted_filter_backward_kernel_activations( 117 | const float* __restrict__ grad_out, 118 | const float* __restrict__ weight, 119 | float* __restrict__ grad_input, 120 | int32_t batch_size, 121 | int32_t in_channels, 122 | int32_t height, 123 | int32_t width, 124 | int32_t filter_h, 125 | int32_t filter_w, 126 | bool splat) 127 | { 128 | const int32_t ch = blockIdx.y * blockDim.y + threadIdx.y; 129 | const int32_t pixel_index = blockIdx.x * blockDim.x + threadIdx.x; 130 | 131 | if (pixel_index >= height * width || ch >= in_channels) 132 | return; 133 | 134 | const int32_t py = pixel_index / width; 135 | const int32_t px = pixel_index % width; 136 | 137 | float result = 0.0f; 138 | for (int32_t fy = 0; fy < filter_h; ++fy) 139 | { 140 | for (int32_t fx = 0; fx < filter_w; ++fx) 141 | { 142 | // Gradient and guide coordinates, regular unflipped. This probably wont work with even sized filters 143 | int32_t y = py + fy - (filter_h - 1) / 2; 144 | int32_t x = px + fx - (filter_w - 1) / 2; 145 | 146 | // Check for out-of-bounds access 147 | if (y < 0 || x < 0 || y >= height || x >= width) 148 | continue; 149 | 150 | // Compute activation derivative 151 | if (splat) 152 | result += grad_out[ch*height*width + y * width + x] * weight[(fy*filter_w + fx)*height*width + py * width + px]; 153 | else 154 | result += grad_out[ch*height*width + y*width + x] * weight[((filter_h-fy-1)*filter_w + (filter_w-fx-1))*height*width + y*width + x]; 155 | } 156 | } 157 | grad_input[ch*height*width + pixel_index] = result; 158 | } 159 | 160 | __global__ void cuda_weighted_filter_backward_kernel_weights( 161 | const float* __restrict__ grad_out, 162 | const float* __restrict__ input, 163 | float* __restrict__ grad_weight, 164 | int32_t batch_size, 165 | int32_t in_channels, 166 | int32_t height, 167 | int32_t width, 168 | int32_t filter_h, 169 | int32_t filter_w, 170 | bool splat) 171 | { 172 | const int32_t weight_index = blockIdx.y * blockDim.y + threadIdx.y; 173 | const int32_t pixel_index = blockIdx.x * blockDim.x + threadIdx.x; 174 | 175 | if (pixel_index >= height * width || weight_index >= filter_h*filter_w) 176 | return; 177 | 178 | // Compute pixel coordinate 179 | const int32_t py = pixel_index / width; 180 | const int32_t px = pixel_index % width; 181 | 182 | // Compute tap/weight coordinate 183 | const int32_t fy = weight_index / filter_w; 184 | const int32_t fx = weight_index % filter_w; 185 | 186 | // Compute gradient, use zero if tap points to outside image region 187 | float result = 0.0f; 188 | if (splat) 189 | { 190 | // Compute tap offset in image space 191 | int32_t y = py + (fy - (filter_h - 1) / 2); 192 | int32_t x = px + (fx - (filter_w - 1) / 2); 193 | 194 | if (y >= 0 && x >= 0 && y < height && x < width) 195 | { 196 | for (int32_t ch = 0; ch < in_channels; ++ch) 197 | { 198 | // Result based on output gradient at pixel coordinate and input activation 199 | result += grad_out[ch*height*width + y*width + x] * input[ch*height*width + py*width + px]; 200 | } 201 | } 202 | } 203 | else 204 | { 205 | // Compute tap offset in image space 206 | int32_t y = py + fy - (filter_h - 1) / 2; 207 | int32_t x = px + fx - (filter_w - 1) / 2; 208 | 209 | if (y >= 0 && x >= 0 && y < height && x < width) 210 | { 211 | for (int32_t ch = 0; ch < in_channels; ++ch) 212 | { 213 | // Result based on output gradient at pixel coordinate and input activation 214 | result += grad_out[ch*height*width + py*width + px] * input[ch*height*width + y*width + x]; 215 | } 216 | } 217 | } 218 | 219 | grad_weight[weight_index*height*width + pixel_index] = result; 220 | } 221 | 222 | 223 | std::vector cuda_weighted_filter_backward(at::Tensor grad_out, at::Tensor input, at::Tensor weight, int64_t kernel_size, bool splat) 224 | { 225 | // Get tensor shapes 226 | at::IntList input_shape = input.sizes(); 227 | at::IntList weight_shape = weight.sizes(); 228 | at::IntList output_shape = grad_out.sizes(); 229 | 230 | // Initialize output gradient tensors to zero 231 | at::Tensor input_grad = at::zeros_like(input); 232 | at::Tensor weight_grad = at::zeros_like(weight); 233 | 234 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////// 235 | // Gradients for input activations 236 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////// 237 | { 238 | // Setup dimensions for cuda kernel 239 | dim3 threads = dim3(128, 4); 240 | dim3 size = dim3(input_shape[2] * input_shape[3], input_shape[1]); // #pixels, #out_channels = in_channels 241 | dim3 blocks = compute_blocks(size, threads); 242 | 243 | // Invoke separate cuda kernel for each batch 244 | for (int64_t batch = 0; batch < input_shape[0]; ++batch) 245 | { 246 | cuda_weighted_filter_backward_kernel_activations << > > ( 247 | grad_out.data() + batch * output_shape[1] * output_shape[2] * output_shape[3], 248 | weight.data() + batch * weight_shape[1] * weight_shape[2] * weight_shape[3], 249 | input_grad.data() + batch * input_shape[1] * input_shape[2] * input_shape[3], 250 | (int32_t)input_shape[0], // batch_size 251 | (int32_t)input_shape[1], // in_channels 252 | (int32_t)input_shape[2], // height 253 | (int32_t)input_shape[3], // width 254 | (int32_t)kernel_size, // filter_h 255 | (int32_t)kernel_size, // filter_w 256 | splat // splatting vs gather 257 | ); 258 | } 259 | } 260 | 261 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////// 262 | // Gradients for weights 263 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////// 264 | { 265 | // Setup dimensions for cuda kernel 266 | dim3 threads = dim3(128, 4); 267 | dim3 size = dim3(weight_shape[2] * weight_shape[3], weight_shape[1]); // #pixels, #weights 268 | dim3 blocks = compute_blocks(size, threads); 269 | 270 | // Invoke separate cuda kernel for each batch 271 | for (int64_t batch = 0; batch < input_shape[0]; ++batch) 272 | { 273 | cuda_weighted_filter_backward_kernel_weights << > > ( 274 | grad_out.data() + batch * output_shape[1] * output_shape[2] * output_shape[3], 275 | input.data() + batch * input_shape[1] * input_shape[2] * input_shape[3], 276 | weight_grad.data() + batch * weight_shape[1] * weight_shape[2] * weight_shape[3], 277 | (int32_t)input_shape[0], // batch_size 278 | (int32_t)input_shape[1], // in_channels 279 | (int32_t)input_shape[2], // height 280 | (int32_t)input_shape[3], // width 281 | (int32_t)kernel_size, // filter_h 282 | (int32_t)kernel_size, // filter_w 283 | splat // splatting vs gather 284 | ); 285 | } 286 | } 287 | 288 | return { input_grad, weight_grad }; 289 | } 290 | -------------------------------------------------------------------------------- /torch_utils/torch_utils/torch_utils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | 12 | // CUDA forward declarations 13 | at::Tensor cuda_weighted_filter_forward(at::Tensor input, at::Tensor weight, int64_t kernel_size, bool splat); 14 | std::vector cuda_weighted_filter_backward(at::Tensor grad_out, at::Tensor input, at::Tensor weight, int64_t kernel_size, bool splat); 15 | 16 | 17 | ////////////////////////////////////////////////////////////////////////////////// 18 | // C++ / Python interface 19 | ////////////////////////////////////////////////////////////////////////////////// 20 | 21 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 22 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 23 | #define CHECK_TYPE(x) AT_ASSERTM(x.type().scalarType() == at::ScalarType::Float, #x " must be contiguous") 24 | #define CHECK_DIM(x) AT_ASSERTM(x.dim() == 4LL, #x " must be contiguous") 25 | #define CHECK_CUDA_CONTIGUOUS(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | #define CHECK_TENSOR_4D_FLOAT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_TYPE(x); CHECK_DIM(x) 27 | 28 | at::Tensor weighted_filter_forward( 29 | at::Tensor input, 30 | at::Tensor weights, 31 | int64_t kernel_size, 32 | bool splat 33 | ) { 34 | CHECK_TENSOR_4D_FLOAT(input); 35 | CHECK_TENSOR_4D_FLOAT(weights); 36 | AT_ASSERTM(weights.size(0) == input.size(0) && weights.size(2) == input.size(2) && weights.size(3) == input.size(3), "Input and weight tensors missmatch"); 37 | AT_ASSERTM(weights.size(1) == kernel_size * kernel_size, "Weight tensors and kernel size missmatch"); 38 | AT_ASSERTM(kernel_size % 2 == 1, "Kernel size must be odd"); 39 | 40 | return cuda_weighted_filter_forward(input, weights, kernel_size, splat); 41 | } 42 | 43 | std::vector weighted_filter_backward( 44 | at::Tensor grad_out, 45 | at::Tensor input, 46 | at::Tensor weights, 47 | int64_t kernel_size, 48 | bool splat 49 | ) { 50 | CHECK_TENSOR_4D_FLOAT(grad_out); 51 | CHECK_TENSOR_4D_FLOAT(input); 52 | CHECK_TENSOR_4D_FLOAT(weights); 53 | AT_ASSERTM(grad_out.size(0) == input.size(0) && grad_out.size(2) == input.size(2) && grad_out.size(3) == input.size(3), "Input and gradient tensors missmatch"); 54 | AT_ASSERTM(weights.size(0) == input.size(0) && weights.size(2) == input.size(2) && weights.size(3) == input.size(3), "Input and weight tensors missmatch"); 55 | AT_ASSERTM(weights.size(1) == kernel_size * kernel_size, "Weight tensors and kernel size missmatch"); 56 | AT_ASSERTM(kernel_size % 2 == 1, "Kernel size must be odd"); 57 | 58 | return cuda_weighted_filter_backward(grad_out, input, weights, kernel_size, splat); 59 | } 60 | 61 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 62 | m.def("weighted_filter_forward", &weighted_filter_forward, "weighted_filter_forward"); 63 | m.def("weighted_filter_backward", &weighted_filter_backward, "weighted_filter_backward"); 64 | } 65 | -------------------------------------------------------------------------------- /torch_utils/torch_utils/weighted_filter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | import torch_utils_cpp 11 | 12 | class WeightedFilterFunction(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, input, weights, kernel_size, splat): 15 | # Store input activations and weights for back propagation pass 16 | ctx.save_for_backward(input, weights) 17 | 18 | # Store kernel size 19 | assert not hasattr(ctx, '_weighted_kernel_size') or ctx._weighted_kernel_size is None 20 | ctx._weighted_kernel_size = kernel_size 21 | ctx._weighted_splat = splat 22 | 23 | # Evaluate convolution 24 | return torch_utils_cpp.weighted_filter_forward(input, weights, kernel_size, splat) 25 | 26 | @staticmethod 27 | def backward(ctx, grad_out): 28 | grad_input, grad_weights = torch_utils_cpp.weighted_filter_backward(grad_out.contiguous(), *ctx.saved_variables, ctx._weighted_kernel_size, ctx._weighted_splat) 29 | return grad_input, grad_weights, None, None 30 | 31 | 32 | class WeightedFilter(torch.nn.Module): 33 | def __init__(self, channels, kernel_size, bias=True, splat=False): 34 | super(WeightedFilter, self).__init__() 35 | self.in_channels = channels 36 | self.out_channels = channels 37 | self.kernel_size = kernel_size 38 | self.splat = splat 39 | 40 | if bias: 41 | self.bias = torch.nn.Parameter(torch.Tensor(self.out_channels)) 42 | 43 | def forward(self, input, weight): 44 | bilat = WeightedFilterFunction.apply(input, weight, self.kernel_size, self.splat) 45 | if hasattr(self, 'bias'): 46 | return bilat + self.bias.view(1, self.out_channels, 1, 1) 47 | else: 48 | return bilat 49 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import re 11 | import sys 12 | import glob 13 | import multiprocessing 14 | import time 15 | import argparse 16 | import uuid 17 | import importlib 18 | import logging 19 | import inspect 20 | 21 | import numpy as np 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.utils.data 26 | from torch.optim import Adam, lr_scheduler 27 | 28 | from utils import * 29 | from datasets import * 30 | from augmentations import * 31 | from sample_network import * 32 | from layer_network import * 33 | 34 | FLAGS = None 35 | 36 | ############################################################################### 37 | # Configuration 38 | ############################################################################### 39 | 40 | # Number of training epochs. Each epoch is a complete pass over the training images 41 | NUM_EPOCHS = 1000 42 | VALIDATE_AFTER_EACH_X_EPOCHS = 10 43 | 44 | # Save training data to a checkpoint file after each x epochs 45 | SAVE_AFTER_NUM_EPOCHS = 100 46 | 47 | # Configuration of learning rate 48 | LEARNING_RATE = 0.0005 49 | 50 | # Gradient clamping 51 | GRADIENT_CLAMP_N = 0.001 52 | GRADIENT_CLAMP = 0.25 53 | 54 | ############################################################################### 55 | # Utility functions 56 | ############################################################################### 57 | 58 | def tonemap(f): 59 | return tonemap_srgb(tonemap_log(f)) 60 | 61 | def latest_checkpoint(modeldir): 62 | ckpts = glob.glob(os.path.join(modeldir, "model_*.tar")) 63 | nums = [int(re.findall('model_\d+', x)[0][6:]) for x in ckpts] 64 | return ckpts[nums.index(max(nums))] 65 | 66 | def get_learning_rate(optimizer): 67 | lr = 0.0 68 | for param_group in optimizer.param_groups: 69 | lr = param_group['lr'] 70 | return lr 71 | 72 | def dumpResult(savedir, idx, output, frameData): 73 | saveImg(os.path.join(savedir, "img%05d_in.png" % idx), tonemap(frameData.color[0, 0:int(FLAGS.spp),...].cpu().numpy().mean(axis=0))) 74 | saveImg(os.path.join(savedir, "img%05d_out.png" % idx), tonemap(output.color[0, ...].cpu().numpy())) 75 | saveImg(os.path.join(savedir, "img%05d_ref.png" % idx), tonemap(frameData.target[0, ...].cpu().numpy())) 76 | 77 | ############################################################################### 78 | # Dump error metrics 79 | ############################################################################### 80 | 81 | def computeErrorMetrics(savedir, output, frameData): 82 | out = output.color 83 | ref = frameData.target 84 | 85 | relmse_val = relMSE(out, ref).item() 86 | smape_val = SMAPE(out,ref).item() 87 | 88 | outt = torch.clamp(tonemap(out), 0.0, 1.0) 89 | reft = torch.clamp(tonemap(ref), 0.0, 1.0) 90 | psnr_val = PSNR(outt, reft).item() 91 | 92 | print("relMSE: %1.4f - SMAPE: %1.3f - PSNR: %2.2f" % (relmse_val, smape_val, psnr_val)) 93 | return relmse_val, smape_val, psnr_val 94 | 95 | ############################################################################### 96 | # Network setup 97 | ############################################################################### 98 | 99 | def createNetwork(FLAGS, dataset, sequenceHeader): 100 | if FLAGS.network == "SampleSplat": 101 | return SampleNet(sequenceHeader, tonemap, splat=True, use_sample_info=True, num_samples = FLAGS.spp, kernel_size=FLAGS.kernel_size).cuda() 102 | elif FLAGS.network == "PixelGather": 103 | return SampleNet(sequenceHeader, tonemap, splat=False, use_sample_info=False, num_samples = FLAGS.spp, kernel_size=FLAGS.kernel_size).cuda() 104 | elif FLAGS.network == "PixelSplat": 105 | return SampleNet(sequenceHeader, tonemap, splat=True, use_sample_info=False, num_samples = FLAGS.spp, kernel_size=FLAGS.kernel_size).cuda() 106 | elif FLAGS.network == "SampleGather": 107 | return SampleNet(sequenceHeader, tonemap, splat=False, use_sample_info=True, num_samples = FLAGS.spp, kernel_size=FLAGS.kernel_size).cuda() 108 | elif FLAGS.network == "Layer": 109 | return LayerNet(sequenceHeader, tonemap, splat=True, num_samples = FLAGS.spp, kernel_size=FLAGS.kernel_size).cuda() 110 | else: 111 | print("Unsupported network type", FLAGS.network) 112 | assert False 113 | 114 | 115 | ############################################################################### 116 | # Inference and training 117 | ############################################################################### 118 | 119 | def inference(data): 120 | mkdir(FLAGS.savedir) 121 | 122 | dataset = SampleDataset(data[0], data[1], cropSize=None, flags=FLAGS, randomCrop=False) 123 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=FLAGS.num_workers, drop_last=True) 124 | 125 | # Get animation sequence header information 126 | sequenceHeader = dataset.getHeader() 127 | 128 | # Setup network model 129 | model = createNetwork(FLAGS, dataset, sequenceHeader) 130 | 131 | ckpt_name = latest_checkpoint(FLAGS.modeldir) 132 | print("loading checkpoint %s" % ckpt_name) 133 | checkpoint = torch.load(ckpt_name) 134 | model.load_state_dict(checkpoint['model.state_dict']) 135 | 136 | with open(os.path.join(FLAGS.savedir, 'metrics.txt'), 'w') as fout: 137 | fout.write('ID, relMSE, SMAPE, PSNR \n') 138 | 139 | print("Number of images", len(dataset)) 140 | 141 | arelmse = np.empty(len(dataset)) 142 | asmape = np.empty(len(dataset)) 143 | apsnr = np.empty(len(dataset)) 144 | 145 | cnt = 0 146 | with torch.no_grad(): 147 | for sequenceData in loader: 148 | sequenceData = SequenceData(dataset, sequenceData) 149 | output = model.inference(sequenceData) 150 | 151 | # compute losses 152 | relmse_val, smape_val, psnr_val = computeErrorMetrics(FLAGS.savedir, output, sequenceData.frameData[-1]) 153 | 154 | arelmse[cnt] = relmse_val 155 | asmape[cnt] = smape_val 156 | apsnr[cnt] = psnr_val 157 | 158 | line = "%d, %1.8f, %1.8f, %2.8f \n" % (cnt, relmse_val, smape_val, psnr_val) 159 | fout.write(line) 160 | 161 | dumpResult(FLAGS.savedir, cnt, output, sequenceData.frameData[-1]) 162 | cnt += 1 163 | 164 | line = "AVERAGES: %1.4f, %1.4f, %2.3f \n" % (np.mean(arelmse), np.mean(asmape), np.mean(apsnr)) 165 | fout.write(line) 166 | 167 | # compute average values 168 | print("relMSE, SMAPE, PSNR \n") 169 | print("%1.4f, %1.4f, %2.3f \n" % (np.mean(arelmse), np.mean(asmape), np.mean(apsnr))) 170 | 171 | def loss_fn(output, target): 172 | return SMAPE(output, target) 173 | 174 | def train(data_train, data_validation): 175 | # Setup dataloader 176 | datasets = [] 177 | for d in data_train: 178 | datasets.append(SampleDataset(d[0], d[1], cropSize=FLAGS.cropsize, flags=FLAGS, limit=FLAGS.limit)) 179 | dataset = torch.utils.data.ConcatDataset(datasets) 180 | loader = torch.utils.data.DataLoader(dataset, batch_size=FLAGS.batch, shuffle=True, num_workers=FLAGS.num_workers, drop_last=True) 181 | 182 | if FLAGS.validate: 183 | val_dataset = SampleDataset(data_validation[0], data_validation[1], cropSize=256, flags=FLAGS, limit=None, randomCrop=False) 184 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=FLAGS.batch, shuffle=False, num_workers=FLAGS.num_workers) 185 | 186 | # Enable for debugging 187 | # torch.autograd.set_detect_anomaly(True) 188 | 189 | # Get animation sequence header information 190 | sequenceHeader = datasets[0].getHeader() 191 | 192 | # Setup network model 193 | model = createNetwork(FLAGS, dataset, sequenceHeader) 194 | 195 | # Setup optimizer and scheduler 196 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 197 | scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 198 | 199 | # Setup modeldir, create or resume from checkpoint if needed 200 | start_epoch = 1 201 | if FLAGS.resume and os.path.exists(FLAGS.modeldir): 202 | ckpt_name = latest_checkpoint(FLAGS.modeldir) 203 | print("-> Resuming from checkpoint: %s" % ckpt_name) 204 | checkpoint = torch.load(ckpt_name) 205 | start_epoch = checkpoint['epoch'] 206 | model.load_state_dict(checkpoint['model.state_dict']) 207 | optimizer.load_state_dict(checkpoint['optimizer.state_dict']) 208 | scheduler.load_state_dict(checkpoint['scheduler.state_dict']) 209 | elif os.path.exists(FLAGS.modeldir): 210 | print("ERROR: modeldir [%s] already exists, use --resume to continue training" % FLAGS.modeldir) 211 | sys.exit(1) 212 | 213 | mkdir(FLAGS.modeldir) 214 | 215 | with open(os.path.join(FLAGS.jobdir, 'output.log'), 'w') as fout: 216 | fout.write('LOG FILE: TRAINING LOSS \n') 217 | 218 | with open(os.path.join(FLAGS.jobdir, 'outputval.log'), 'w') as fout: 219 | fout.write('LOG FILE: VALIDATION LOSS \n') 220 | 221 | imagedir = os.path.join(FLAGS.jobdir, 'images') 222 | mkdir(imagedir) 223 | 224 | val_loss = 1.0 225 | for epoch in range(start_epoch, NUM_EPOCHS+1): 226 | start_time = time.time() 227 | sum = 0.0 228 | num = 0.0 229 | # train 230 | for sequenceData in loader: 231 | sequenceData = SequenceData(dataset, sequenceData) 232 | 233 | augment(sequenceHeader, sequenceData) 234 | 235 | optimizer.zero_grad() 236 | output = model.forward(sequenceData, epoch) 237 | 238 | loss = loss_fn(output.color, sequenceData.frameData[0].target) 239 | 240 | loss.backward() 241 | torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLAMP_N) 242 | torch.nn.utils.clip_grad_value_(model.parameters(), GRADIENT_CLAMP) 243 | optimizer.step() 244 | 245 | sum += loss.item() 246 | num += 1 247 | 248 | train_loss = sum / max(num, 1.0) 249 | 250 | # Compute validation loss 251 | if FLAGS.validate and epoch % VALIDATE_AFTER_EACH_X_EPOCHS == 0: 252 | val_sum = 0.0 253 | val_num = 0.0 254 | with torch.no_grad(): 255 | for sequenceData in val_loader: 256 | sequenceData = SequenceData(val_dataset, sequenceData) 257 | output = model.forward(sequenceData, epoch) 258 | dumpResult(imagedir, epoch, output, sequenceData.frameData[-1]) 259 | 260 | loss = loss_fn(output.color, sequenceData.frameData[0].target) 261 | val_sum = val_sum + loss.item() 262 | val_num = val_num + 1 263 | val_loss = val_sum / max(val_num, 1.0) 264 | 265 | with open(os.path.join(FLAGS.jobdir, 'outputval.log'), 'a') as fout: 266 | line = "%3d %1.6f \n" % (epoch, val_loss) 267 | fout.write(str(line)) 268 | 269 | duration = time.time() - start_time 270 | remaining = (NUM_EPOCHS-epoch)*duration/(60*60) 271 | timestring = getTimeString(remaining) 272 | print("Epoch %3d - Learn rate: %1.6f - train loss: %5.5f - validation loss: %5.5f - time %.1f ms (remaining %.1f %s) - time/step: %1.2f ms" 273 | % (epoch, get_learning_rate(optimizer), train_loss, val_loss, duration*1000.0, remaining, timestring, duration*1000.0 / len(dataset))) 274 | 275 | with open(os.path.join(FLAGS.jobdir, 'output.log'), 'a') as fout: 276 | line = "%3d %1.6f \n" % (epoch, train_loss) 277 | fout.write(str(line)) 278 | 279 | if epoch % SAVE_AFTER_NUM_EPOCHS == 0 or epoch == NUM_EPOCHS: 280 | torch.save({ 281 | 'epoch': epoch + 1, 282 | 'train_loss': train_loss, 283 | 'val_loss': val_loss, 284 | 'model.state_dict': model.state_dict(), 285 | 'optimizer.state_dict': optimizer.state_dict(), 286 | 'scheduler.state_dict': scheduler.state_dict() 287 | }, 288 | os.path.join(FLAGS.modeldir, "model_%04d.tar" % epoch)) 289 | scheduler.step() 290 | 291 | ############################################################################### 292 | # Main function 293 | ############################################################################### 294 | 295 | if __name__ == '__main__': 296 | multiprocessing.freeze_support() 297 | 298 | print("Pytorch version:", torch.__version__) 299 | 300 | # Parse command line flags 301 | parser = argparse.ArgumentParser() 302 | parser.add_argument('--job', type=str, default='', help='Directory to store the trained model', required=True) 303 | parser.add_argument('--resume', action='store_true', default=False, help='Resume training from latest checkpoint') 304 | parser.add_argument('--batch', type=int, default=4, help="Training batch size") 305 | parser.add_argument('--cropsize', type=int, default=128, help="Training crop size") 306 | parser.add_argument('--inference', action='store_true', default=False, help="Run inference instead of training, get checkpoint from job modeldir") 307 | parser.add_argument('--savedir', type=str, default='./out/', help='Directory to save inference data') 308 | parser.add_argument('--datadir', type=str, default='./', help='Training data directory') 309 | parser.add_argument('--network', default="PixelGather", choices=["SampleSplat","PixelGather","SampleGather", "PixelSplat", "Layer"], help="Set network type [SampleSplat,PixelGather,SampleGather,PixelSplat,Layer]") 310 | parser.add_argument('--limit', type=int, default=None, help="Limit the number of frames") 311 | parser.add_argument('--scenes', nargs='*', default=[], help="List of scenes") 312 | parser.add_argument('--valscene', type=str, default=None, help='Validation scene') 313 | parser.add_argument('--num_workers', type=int, default=8, help="Number of workers") 314 | parser.add_argument('--spp', type=float, default=8, help='Samples per pixel: 1-8') 315 | parser.add_argument('--kernel_size', type=int, default=17, help='Kernel size [17x17]') 316 | parser.add_argument('--config', type=str, default=None, help='Config file') 317 | 318 | FLAGS, unparsed = parser.parse_known_args() 319 | 320 | # Read config file 321 | if FLAGS.config is not None: 322 | cfg = importlib.import_module(FLAGS.config[:-len('.py')] if FLAGS.config.endswith('.py') else FLAGS.config) 323 | 324 | for key in cfg.__dict__: 325 | if not key.startswith("__") and not inspect.ismodule(cfg.__dict__[key]): 326 | FLAGS.__dict__[key] = cfg.__dict__[key] 327 | 328 | FLAGS.savedir = os.path.join(FLAGS.savedir, '') 329 | FLAGS.validate = True 330 | FLAGS.num_workers = min(multiprocessing.cpu_count(), FLAGS.num_workers) 331 | 332 | # Add hash to the job directory to avoid collisions 333 | if not FLAGS.inference: 334 | uid = uuid.uuid4() 335 | FLAGS.job = FLAGS.job + "_" + str(str(uid.hex)[:8]) 336 | 337 | print("Commandline arguments") 338 | print("----") 339 | for arg in sorted(vars(FLAGS)): 340 | print("%-12s %s" % (str(arg), str(getattr(FLAGS, arg)))) 341 | print("----") 342 | 343 | script_path = os.path.split(os.path.realpath(__file__))[0] 344 | all_jobs_path = os.path.join(script_path, 'jobs') 345 | FLAGS.jobdir = os.path.join(all_jobs_path, FLAGS.job) 346 | FLAGS.modeldir = os.path.join(FLAGS.jobdir, 'model') 347 | 348 | # Create input data 349 | data_train = [] # holds tuple of train and ref data file names 350 | for s in FLAGS.scenes: 351 | data_in = os.path.join(FLAGS.datadir, s) 352 | data_ref = os.path.join(FLAGS.datadir, s[0:s.rfind("_")] + "_ref.h5") 353 | data_train.append((data_in, data_ref)) 354 | 355 | # validation scene file name 356 | if FLAGS.valscene is None: 357 | print("--valscene required flag") 358 | sys.exit(1) 359 | data_in = os.path.join(FLAGS.datadir, FLAGS.valscene) 360 | data_ref = os.path.join(FLAGS.datadir, FLAGS.valscene[0:FLAGS.valscene.rfind("_")] + "_ref.h5") 361 | data_validation = (data_in, data_ref) 362 | 363 | mkdir(all_jobs_path) 364 | mkdir(FLAGS.jobdir) 365 | 366 | if FLAGS.inference: 367 | inference(data_validation) 368 | else: 369 | train(data_train, data_validation) 370 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | import utils 13 | 14 | ############################################################################### 15 | # Activation 16 | ############################################################################### 17 | 18 | Activation = nn.LeakyReLU(negative_slope=0.01, inplace=True) 19 | 20 | ############################################################################### 21 | # Regular U-net 22 | ############################################################################### 23 | 24 | class UNet(nn.Module): 25 | def __init__(self, input_channels, output_channels, 26 | encoder_features=[[64, 64], [128], [256], [512], [512]], 27 | bottleneck_features=[512], 28 | decoder_features=[[512, 512], [256, 256], [128, 128], [64, 64], [64, 64]]): 29 | super().__init__() 30 | 31 | self.output_channels = output_channels 32 | self.input_channels = input_channels 33 | self.encoder_features = encoder_features 34 | self.bottleneck_features = bottleneck_features 35 | self.decoder_features = decoder_features 36 | self.initNetwork() 37 | 38 | def initNetwork(self): 39 | # Utility function that creates a convolution "block" from a list of features, 40 | # with one convolutional layer per feature count in the list 41 | def make_conv_block(in_features, features): 42 | layers = [] 43 | prev_features = in_features 44 | for f in features: 45 | layers = layers + [nn.Conv2d(prev_features, f, 3, padding=1), Activation] 46 | prev_features = f 47 | return layers 48 | 49 | prev_features = self.input_channels 50 | 51 | # Create encoder 52 | enc = [] 53 | for enc_f in self.encoder_features: 54 | enc = enc + [nn.Sequential(*make_conv_block(prev_features, enc_f), nn.MaxPool2d(2))] 55 | prev_features = enc_f[-1] 56 | self.enc = nn.ModuleList(enc) 57 | 58 | # Create bottleneck 59 | self.bottleneck = nn.Sequential(*make_conv_block(prev_features, self.bottleneck_features)).cuda() 60 | prev_features = self.bottleneck_features[-1] 61 | 62 | # Create decoder 63 | dec = [] 64 | for idx, dec_f in enumerate(self.decoder_features[:-1]): 65 | skip_features = self.encoder_features[len(self.decoder_features) - idx - 2][-1] 66 | dec = dec + [nn.Sequential(*make_conv_block(prev_features + skip_features, dec_f)).cuda()] 67 | prev_features = dec_f[-1] 68 | dec = dec + [nn.Sequential(*make_conv_block(prev_features + self.input_channels, self.decoder_features[-1])).cuda()] 69 | self.dec = nn.ModuleList(dec) 70 | 71 | # Final layer 72 | self.final = nn.Conv2d(self.decoder_features[-1][-1], self.output_channels, 3, padding=1) 73 | 74 | # initialize weights 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | nn.init.xavier_uniform_(m.weight.data) 78 | if m.bias is not None: 79 | m.bias.data.zero_() 80 | 81 | def forward(self, prev): 82 | 83 | # Run encoder 84 | enc_vars = [prev] 85 | for block in self.enc: 86 | prev = block(prev) 87 | enc_vars = enc_vars + [prev] 88 | 89 | # Run bottleneck 90 | prev = self.bottleneck(prev) 91 | 92 | # Run decoder 93 | for idx, block in enumerate(self.dec): 94 | prev = nn.functional.interpolate(prev, scale_factor=2, mode='nearest', align_corners=None) # Upscale result from previous step 95 | concat = torch.cat((prev, enc_vars[len(self.dec) - idx - 1]), dim=1) # Concatenate skip connection 96 | prev = block(concat) 97 | 98 | # Run final composition 99 | output = self.final(prev) 100 | 101 | # Return output color & all decoder levels 102 | return output 103 | 104 | 105 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | 11 | from PIL import Image 12 | 13 | import numpy as np 14 | import torch 15 | 16 | ############################################################################### 17 | # Some utility functions to make pytorch and numpy behave the same 18 | ############################################################################### 19 | 20 | def _pow(x, y): 21 | if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor): 22 | return torch.pow(x, y) 23 | else: 24 | return np.power(x, y) 25 | 26 | def _log(x): 27 | if isinstance(x, torch.Tensor): 28 | return torch.log(x) 29 | else: 30 | return np.log(x) 31 | 32 | def _clamp(x, y, z): 33 | if isinstance(x, torch.Tensor) or isinstance(y, torch.Tensor) or isinstance(z, torch.Tensor): 34 | return torch.clamp(x, y, z) 35 | else: 36 | return np.clip(x, y, z) 37 | 38 | ############################################################################### 39 | # Create a object with members from a dictionary 40 | ############################################################################### 41 | 42 | class DictObject: 43 | def __init__(self, _dict): 44 | self.__dict__.update(**_dict) 45 | 46 | def object_from_dict(_dict): 47 | return DictObject(_dict) 48 | 49 | ############################################################################### 50 | # SMAPE Loss 51 | ############################################################################### 52 | 53 | def SMAPE(d, r): 54 | denom = torch.abs(d) + torch.abs(r) + 0.01 55 | return torch.mean(torch.abs(d-r) / denom) 56 | 57 | ############################################################################### 58 | # relMSE Loss 59 | ############################################################################### 60 | 61 | def relMSE(d, r): 62 | diff = d - r 63 | denom = torch.pow(r, 2.0) + 0.0001 64 | return torch.mean(torch.pow(diff, 2) / denom) 65 | 66 | ############################################################################### 67 | # PSNR 68 | ############################################################################### 69 | 70 | def PSNR(d, ref): 71 | MSE = torch.mean(torch.pow(d - ref, 2.0)) 72 | PSNR = 10. * torch.log10(1.0/MSE) 73 | return PSNR 74 | 75 | ############################################################################### 76 | # Tonemapping 77 | ############################################################################### 78 | 79 | def tonemap_log(f): 80 | fc = _clamp(f, 0.00001, 65536.0) 81 | return _log(fc + 1.0) 82 | 83 | # Transfer function taken from https://arxiv.org/pdf/1712.02327.pdf 84 | def tonemap_srgb(f): 85 | a = 0.055 86 | if isinstance(f, torch.Tensor): 87 | return torch.where(f > 0.0031308, _pow(f, 1.0/2.4)*(1 + a) - a, 12.92*f) 88 | else: 89 | return np.where(f > 0.0031308, _pow(f, 1.0/2.4)*(1 + a) - a, 12.92*f) 90 | 91 | 92 | ############################################################################### 93 | # Image load/store 94 | ############################################################################### 95 | 96 | def saveImg(img_file, img): 97 | # Convert image from chw to hwc 98 | hwc_img = np.swapaxes(np.swapaxes(img, -3, -2), -2, -1) if len(img.shape) == 3 else img 99 | if len(hwc_img.shape) == 3 and hwc_img.shape[2] == 1: 100 | hwc_img = np.squeeze(hwc_img, axis=2) 101 | if len(hwc_img.shape) == 3 and hwc_img.shape[2] == 2: 102 | hwc_img = np.concatenate((hwc_img, np.zeros_like(hwc_img[..., 0:1])), axis=2) 103 | 104 | # Save image 105 | img_array = (np.clip(hwc_img , 0.0, 1.0) * 255.0).astype(np.uint8) 106 | im = Image.fromarray(img_array) 107 | im.save(img_file) 108 | 109 | ############################################################################### 110 | # Create a folder if it doesn't exist 111 | ############################################################################### 112 | 113 | def mkdir(x): 114 | if not os.path.exists(x): 115 | os.mkdir(x) 116 | 117 | ############################################################################### 118 | # Get time string with remaining time formatted 119 | ############################################################################### 120 | 121 | def getTimeString(remaining): 122 | timestring = "hours" 123 | if (remaining < 1): 124 | remaining *= 60 125 | timestring = "minutes" 126 | if (remaining < 1): 127 | remaining *= 60 128 | timestring = "seconds" 129 | return timestring 130 | 131 | --------------------------------------------------------------------------------