├── .gitignore ├── LICENSE ├── README.md ├── convert.py ├── data ├── LICENSE.txt ├── albert-compare.gif ├── albert.jpg └── tokyo-compare.gif ├── encoding.py ├── requirements.txt ├── train.py ├── utils.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ending Hsiao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hash Grid Encoding 2 | 3 | This repo contains an implementation of NVidia's hash grid encoding from [Instant Neural Graphics Primitives](https://nvlabs.github.io/instant-ngp/) and a runnable example of gigapixel tasks. The hash grid is implemented in ***pure PyTorch*** hense it's more human friendly than [NVidia's original implementation (C++/CUDA)](https://github.com/NVlabs/instant-ngp). 4 | 5 | 6 | Some features: 7 | * Implemented in ***pure PyTorch***. 8 | * Supports ***arbitrary dimensions***. 9 | 10 | ## How To Use 11 | 12 | ### MultiResHashGrid 13 | To use the [`MultiResHashGrid`](encoding.py#L129) in your own project, you can simply copy-paste the code in `encoding.py` into your project. For example: 14 | ```python 15 | import torch 16 | import encoding 17 | 18 | enc = encoding.MultiResHashGrid(2) # 2D image data 19 | enc = encoding.MultiResHashGrid(3) # 3D data 20 | 21 | dim = 3 22 | batch_size = 100 23 | 24 | # The input value must be within the range [0, 1] 25 | input = torch.rand((batch_size, dim), dtpye=torch.float32) 26 | enc_input = enc(input) 27 | 28 | # Then you can forward into your network 29 | model = MyMLP(dim=enc_input.output_dim, out_dim=1) 30 | output = model(enc_input) 31 | 32 | # Move to other devices 33 | enc = enc.to(dtype='cuda') 34 | ``` 35 | 36 | ### Gigapixel task 37 | This repo also contains a runnable gigapixel image task, which is implemented based on [PyTorch Lightning](https://www.pytorchlightning.ai/). For more instructions of running this code, see [Examples](#Examples). 38 | 39 | ## Examples 40 | 41 | ### Albert 42 | 43 | ![](https://github.com/Ending2015a/hash-grid-encoding/blob/master/data/albert-compare.gif) 44 | 45 | Run this example: 46 | ``` 47 | python train.py -i data/albert.jpg --enc_method hashgrid --visualize 48 | ``` 49 | 50 | 51 | ### Tokyo 52 | 53 | ![](https://github.com/Ending2015a/hash-grid-encoding/blob/master/data/tokyo-compare.gif) 54 | 55 | https://user-images.githubusercontent.com/18180004/174231919-16705ae3-357e-4c50-832c-bae6f1d92556.mp4 56 | 57 | Download [the tokyo image](https://www.flickr.com/photos/trevor_dobson_inefekt69/29314390837) and place it at `data/tokyo.jpg`. 58 | 59 | To run the tokyo example in its original size (56718 x 21450 pixels), your GPU must have memory at least 20GB. If your GPU have no such amount of memory, you can use the `convert.py` script to scale down the image size into half. By converting to `.npy` format can also increase the loading speed: 60 | 61 | ```shell 62 | python convert.py -i data/tokyo.jpg -o data/tokyo.npy --scale 0.5 63 | ``` 64 | 65 | Then run the experiment 66 | 67 | ```shell 68 | python train.py -i data/tokyo.npy --enc_method hashgrid --finest_resolution 32768 --visualize 69 | ``` 70 | 71 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # --- built in --- 2 | import os 3 | import time 4 | import argparse 5 | # --- 3rd party --- 6 | import matplotlib.pyplot as plt 7 | from skimage.transform import resize 8 | import numpy as np 9 | import imageio 10 | import PIL.Image 11 | PIL.Image.MAX_IMAGE_PIXELS = 10000000000 12 | # --- 3rd party --- 13 | import utils 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('-i', '--input', type=str) 18 | parser.add_argument('-o', '--output', type=str, default=None) 19 | parser.add_argument('-s', '--scale', type=float, default=0.5) 20 | return parser.parse_args() 21 | 22 | def main(): 23 | a = get_args() 24 | filename = a.input 25 | scale = a.scale 26 | 27 | start_time = time.time() 28 | image = utils.read_image(filename) 29 | print('Took {} seconds to load image'.format(time.time() - start_time)) 30 | 31 | h, w, c = image.shape 32 | print(f"{w}x{h} pixels, {c} channels") 33 | if scale != 1.0: 34 | h = int(h*scale) 35 | w = int(w*scale) 36 | print(f"Scaling image to {w}x{h} pixels") 37 | image = resize(image, (h, w)) 38 | 39 | output = a.output 40 | if a.output is None: 41 | output = os.path.splitext(filename)[0] + '.npy' 42 | 43 | utils.write_image(output, image.astype(np.float16)) 44 | 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /data/LICENSE.txt: -------------------------------------------------------------------------------- 1 | albert.jpg - Public domain photograph (see https://commons.wikimedia.org/wiki/File:Albert_Einstein_Head.jpg) 2 | -------------------------------------------------------------------------------- /data/albert-compare.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ending2015a/hash-grid-encoding/caf745cc33095b198474c9a228ef19e9e5552dde/data/albert-compare.gif -------------------------------------------------------------------------------- /data/albert.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ending2015a/hash-grid-encoding/caf745cc33095b198474c9a228ef19e9e5552dde/data/albert.jpg -------------------------------------------------------------------------------- /data/tokyo-compare.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ending2015a/hash-grid-encoding/caf745cc33095b198474c9a228ef19e9e5552dde/data/tokyo-compare.gif -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | # --- bulit in --- 2 | import math 3 | # --- 3rd party --- 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | # --- my module --- 8 | 9 | """ 10 | The MIT License (MIT) 11 | Copyright (c) 2022 Joe Hsiao (Ending2015a) 12 | 13 | Permission is hereby granted, free of charge, to any person obtaining a copy 14 | of this software and associated documentation files (the "Software"), to deal 15 | in the Software without restriction, including without limitation the rights 16 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | copies of the Software, and to permit persons to whom the Software is 18 | furnished to do so, subject to the following conditions: 19 | 20 | The above copyright notice and this permission notice shall be included in all 21 | copies or substantial portions of the Software. 22 | 23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 24 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 25 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 26 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 27 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 28 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 29 | OR OTHER DEALINGS IN THE SOFTWARE. 30 | """ 31 | 32 | # --- constants --- 33 | PRIMES = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737] 34 | 35 | 36 | class Frequency(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | n_levels: int = 10 41 | ): 42 | """Positional encoding from NeRF: https://www.matthewtancik.com/nerf 43 | [sin(x), cos(x), sin(4x), cos(4x), sin(8x), cos(8x), 44 | ..., sin(2^n*x), cos(2^n*x)] 45 | 46 | Args: 47 | dim (int): input dimensions 48 | n_levels (int, optional): number of frequencies. Defaults to 10. 49 | """ 50 | super().__init__() 51 | self.n_levels = n_levels 52 | assert self.n_levels > 0 53 | freqs = 2. ** torch.linspace(0., n_levels-1, n_levels) 54 | self.register_buffer('freqs', freqs, persistent=False) 55 | # --- 56 | self.input_dim = dim 57 | self.output_dim = dim * n_levels * 2 58 | 59 | def forward(self, x: torch.Tensor): 60 | x = x.unsqueeze(dim=-1) # (..., dim, 1) 61 | x = x * self.freqs # (..., dim, L) 62 | x = torch.cat((torch.sin(x), torch.cos(x)), dim=-1) # (..., dim, L*2) 63 | return x.flatten(-2, -1) # (..., dim * L * 2) 64 | 65 | 66 | @torch.no_grad() 67 | def fast_hash(ind: torch.Tensor, primes: torch.Tensor, hashmap_size: int): 68 | """Hashing function from: 69 | https://github.com/NVlabs/tiny-cuda-nn/blob/master/include/tiny-cuda-nn/encodings/grid.h#L76-L92 70 | """ 71 | d = ind.shape[-1] 72 | ind = (ind * primes[:d]) & 0xffffffff # uint32 73 | for i in range(1, d): 74 | ind[..., 0] ^= ind[..., i] 75 | return ind[..., 0] % hashmap_size 76 | 77 | class _HashGrid(nn.Module): 78 | def __init__( 79 | self, 80 | dim: int, 81 | n_features: int, 82 | hashmap_size: int, 83 | resolution: float 84 | ): 85 | super().__init__() 86 | self.dim = dim 87 | self.n_features = n_features 88 | self.hashmap_size = hashmap_size 89 | self.resolution = resolution 90 | 91 | # you can add more primes for supporting more dimensions 92 | assert self.dim <= len(PRIMES), \ 93 | f"HashGrid only supports < {len(PRIMES)}-D inputs" 94 | 95 | # create look-up table 96 | self.embedding = nn.Embedding(hashmap_size, n_features) 97 | nn.init.uniform_(self.embedding.weight, a=-0.0001, b=0.0001) 98 | 99 | primes = torch.tensor(PRIMES, dtype=torch.int64) 100 | self.register_buffer('primes', primes, persistent=False) 101 | 102 | # create interpolation binary mask 103 | n_neigs = 1 << self.dim 104 | neigs = np.arange(n_neigs, dtype=np.int64).reshape((-1, 1)) 105 | dims = np.arange(self.dim, dtype=np.int64).reshape((1, -1)) 106 | bin_mask = torch.tensor(neigs & (1 << dims) == 0, dtype=bool) # (neig, dim) 107 | self.register_buffer('bin_mask', bin_mask, persistent=False) 108 | 109 | def forward(self, x: torch.Tensor): 110 | # x: (b..., dim), torch.float32, range: [0, 1] 111 | bdims = len(x.shape[:-1]) 112 | x = x * self.resolution 113 | xi = x.long() 114 | xf = x - xi.float().detach() 115 | xi = xi.unsqueeze(dim=-2) # (b..., 1, dim) 116 | xf = xf.unsqueeze(dim=-2) # (b..., 1, dim) 117 | # to match the input batch shape 118 | bin_mask = self.bin_mask.reshape((1,)*bdims + self.bin_mask.shape) # (1..., neig, dim) 119 | # get neighbors' indices and weights on each dim 120 | inds = torch.where(bin_mask, xi, xi+1) # (b..., neig, dim) 121 | ws = torch.where(bin_mask, 1-xf, xf) # (b...., neig, dim) 122 | # aggregate nehgibors' interp weights 123 | w = ws.prod(dim=-1, keepdim=True) # (b..., neig, 1) 124 | # hash neighbors' id and look up table 125 | hash_ids = fast_hash(inds, self.primes, self.hashmap_size) # (b..., neig) 126 | neig_data = self.embedding(hash_ids) # (b..., neig, feat) 127 | return torch.sum(neig_data * w, dim=-2) # (b..., feat) 128 | 129 | class MultiResHashGrid(nn.Module): 130 | def __init__( 131 | self, 132 | dim: int, 133 | n_levels: int = 16, 134 | n_features_per_level: int = 2, 135 | log2_hashmap_size: int = 15, 136 | base_resolution: int = 16, 137 | finest_resolution: int = 512, 138 | ): 139 | """NVidia's hash grid encoding 140 | https://nvlabs.github.io/instant-ngp/ 141 | 142 | The output dimensions is `n_levels` * `n_features_per_level`, 143 | or your can simply access `model.output_dim` to get the output dimensions 144 | 145 | Args: 146 | dim (int): input dimensions, supports at most 7D data. 147 | n_levels (int, optional): number of grid levels. Defaults to 16. 148 | n_features_per_level (int, optional): number of features per grid level. 149 | Defaults to 2. 150 | log2_hashmap_size (int, optional): maximum size of the hashmap of each 151 | level in log2 scale. According to the paper, this value can be set to 152 | 14 ~ 24 depending on your problem size. Defaults to 15. 153 | base_resolution (int, optional): coarsest grid resolution. Defaults to 16. 154 | finest_resolution (int, optional): finest grid resolution. According to 155 | the paper, this value can be set to 512 ~ 524288. Defaults to 512. 156 | """ 157 | super().__init__() 158 | self.dim = dim 159 | self.n_levels = n_levels 160 | self.n_features_per_level = n_features_per_level 161 | self.log2_hashmap_size = log2_hashmap_size 162 | self.base_resolution = base_resolution 163 | self.finest_resolution = finest_resolution 164 | 165 | # from paper eq (3) 166 | b = math.exp((math.log(finest_resolution) - math.log(base_resolution))/(n_levels-1)) 167 | 168 | levels = [] 169 | for level_idx in range(n_levels): 170 | resolution = math.floor(base_resolution * (b ** level_idx)) 171 | hashmap_size = min(resolution ** dim, 2 ** log2_hashmap_size) 172 | levels.append(_HashGrid( 173 | dim = dim, 174 | n_features = n_features_per_level, 175 | hashmap_size = hashmap_size, 176 | resolution = resolution 177 | )) 178 | self.levels = nn.ModuleList(levels) 179 | 180 | self.input_dim = dim 181 | self.output_dim = n_levels * n_features_per_level 182 | 183 | def forward(self, x: torch.Tensor): 184 | return torch.cat([level(x) for level in self.levels], dim=-1) 185 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image 2 | pytorch_lightning 3 | imageio 4 | numpy 5 | matplotlib 6 | tqdm 7 | 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # --- built in --- 2 | from typing import Any, List, Dict, Union, Tuple, Optional, Callable 3 | import os 4 | import math 5 | import argparse 6 | # --- 3rd party --- 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | from torch.utils.data import Dataset, DataLoader 11 | import pytorch_lightning as pl 12 | import matplotlib.pyplot as plt 13 | import imageio 14 | import tqdm 15 | # --- my module --- 16 | import utils 17 | import encoding 18 | 19 | # --- datasets --- 20 | 21 | class Sampler2D(nn.Module): 22 | def __init__(self, filename: str): 23 | super().__init__() 24 | data = torch.from_numpy(utils.read_image(filename)).to(dtype=torch.float16) 25 | self.register_buffer('data', data, persistent=False) 26 | mesh = self.get_mesh().float() 27 | self.register_buffer('mesh', mesh, persistent=False) 28 | self.shape = self.data.shape 29 | h, w, c = self.shape 30 | self.h = h 31 | self.w = w 32 | self.c = c 33 | self.num_pixels = h * w 34 | resolution = torch.tensor((self.shape[1], self.shape[0]), dtype=torch.float32) 35 | self.register_buffer('resolution', resolution, persistent=True) 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | shape = self.data.shape 39 | x = x * self.resolution 40 | ind = x.long() 41 | w = x - ind.float() 42 | x0 = ind[:, 0].clamp(min=0, max=shape[1]-1) 43 | y0 = ind[:, 1].clamp(min=0, max=shape[0]-1) 44 | x1 = (x0 + 1).clamp(max=shape[1]-1) 45 | y1 = (y0 + 1).clamp(max=shape[0]-1) 46 | return ( 47 | self.data[y0, x0].to(dtype=torch.float32) * (1.0 - w[:,0:1]) * (1.0 - w[:,1:2]) + 48 | self.data[y0, x1].to(dtype=torch.float32) * w[:,0:1] * (1.0 - w[:,1:2]) + 49 | self.data[y1, x0].to(dtype=torch.float32) * (1.0 - w[:,0:1]) * w[:,1:2] + 50 | self.data[y1, x1].to(dtype=torch.float32) * w[:,0:1] * w[:,1:2] 51 | ) 52 | 53 | def get_mesh(self) -> torch.Tensor: 54 | h, w, c = self.data.shape 55 | n_pixels = h * w 56 | u_res = 0.5 / h 57 | v_res = 0.5 / w 58 | u = np.linspace(u_res, 1-u_res, h) 59 | v = np.linspace(v_res, 1-v_res, w) 60 | u, v = np.meshgrid(u, v, indexing='ij') 61 | xy = np.stack((v.flatten(), u.flatten()), axis=0).T # (n, 2) 62 | xy = xy.astype(np.float32) 63 | xy = torch.from_numpy(xy) 64 | return xy 65 | 66 | class TaskDataset(Dataset): 67 | def __init__( 68 | self, 69 | sampler: Sampler2D, 70 | batch_size: int, 71 | n_samples: int 72 | ): 73 | super().__init__() 74 | self.sampler = sampler 75 | self.batch_size = batch_size 76 | self.n_samples = n_samples 77 | self.jit_sampler = None 78 | 79 | def setup(self): 80 | self.jit_sampler = torch.jit.trace(self.sampler, self.get_rand()) 81 | 82 | def __len__(self): 83 | return self.n_samples 84 | 85 | def get_rand(self): 86 | return torch.rand([self.batch_size, 2], 87 | dtype=torch.float32, device=self.sampler.data.device) 88 | 89 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: 90 | if self.jit_sampler is None: 91 | self.setup() 92 | x = self.get_rand() 93 | return x, self.jit_sampler(x) 94 | 95 | # --- networks --- 96 | 97 | class MLP(nn.Module): 98 | def __init__( 99 | self, 100 | dim: int, 101 | out_dim: int = 1, 102 | mlp_units: List[int] = [64, 64] 103 | ): 104 | super().__init__() 105 | layers = [] 106 | self.input_dim = dim 107 | self.output_dim = out_dim 108 | in_dim = dim 109 | for out_dim in mlp_units: 110 | layers.append(nn.Linear(in_dim, out_dim)) 111 | layers.append(nn.ReLU(inplace=True)) 112 | in_dim = out_dim 113 | layers.append(nn.Linear(in_dim, self.output_dim)) 114 | self.model = nn.Sequential(*layers) 115 | 116 | def forward(self, x: torch.tensor): 117 | return self.model(x) 118 | 119 | class ToyNet(nn.Module): 120 | def __init__( 121 | self, 122 | dim: int, 123 | out_dim: int = 1, 124 | mlp_units: List[int] = [64, 64], 125 | enc_method: str = 'freq', 126 | enc_kwargs: dict = {} 127 | ): 128 | super().__init__() 129 | if enc_method == 'freq': 130 | self.enc = encoding.Frequency(dim, **enc_kwargs) 131 | dim = self.enc.output_dim 132 | elif enc_method == 'hashgrid': 133 | self.enc = encoding.MultiResHashGrid(dim, **enc_kwargs) 134 | dim = self.enc.output_dim 135 | else: 136 | print(f'Disable encoding: {enc_method}') 137 | self.enc = None 138 | self.mlp = MLP(dim, out_dim=out_dim, mlp_units=mlp_units) 139 | 140 | def forward(self, x: torch.Tensor): 141 | if self.enc is not None: 142 | x = self.enc(x) 143 | return self.mlp(x) 144 | 145 | 146 | class Task(pl.LightningModule): 147 | def __init__( 148 | self, 149 | filename: str, 150 | batch_size: int = 65536, 151 | n_samples: int = 10, 152 | lr: float = 1e-3, 153 | mlp_units: List[int] = [64, 64], 154 | relative_l2: bool = False, 155 | enc_method: Optional[str] = None, 156 | enc_kwargs: Dict[str, Any] = {}, 157 | channels: int = None, 158 | vis_freq: Callable = None, 159 | inference_only: bool = False 160 | ): 161 | super().__init__() 162 | 163 | self.vis_freq = vis_freq 164 | self.inference_only = inference_only 165 | 166 | if not inference_only: 167 | self.sampler = Sampler2D(filename) 168 | channels = self.sampler.c 169 | self.save_hyperparameters(ignore=['inference_only', 'vis_freq']) 170 | 171 | if not inference_only: 172 | self.setup_dataset() 173 | self.setup_model() 174 | 175 | def setup_dataset(self): 176 | self.trainset = TaskDataset( 177 | self.sampler, 178 | batch_size = self.hparams.batch_size, 179 | n_samples = self.hparams.n_samples 180 | ) 181 | 182 | def setup_model(self): 183 | self.model = ToyNet( 184 | dim = 2, 185 | out_dim = self.hparams.channels, 186 | mlp_units = self.hparams.mlp_units, 187 | enc_method = self.hparams.enc_method, 188 | enc_kwargs = self.hparams.enc_kwargs 189 | ) 190 | 191 | def configure_optimizers(self): 192 | optim = torch.optim.Adam( 193 | self.model.parameters(), 194 | lr = self.hparams.lr, 195 | weight_decay = 1e-8, 196 | eps = 1e-8, 197 | betas = (0.9, 0.99), 198 | ) 199 | return optim 200 | 201 | def train_dataloader(self): 202 | return DataLoader( 203 | self.trainset, 204 | batch_size = None, # manual batching 205 | num_workers = 0, # main thread 206 | ) 207 | 208 | def forward( 209 | self, 210 | x: torch.Tensor, 211 | ): 212 | x = torch.as_tensor(x, dtype=torch.float32, device=self.device) 213 | return self.model(x) 214 | 215 | def l2_loss(self, y, y_, relative=False): 216 | if relative: 217 | return ((y-y_)**2.0) / (y_.detach()**2.0 + 0.01) 218 | else: 219 | return ((y-y_)**2.0) 220 | 221 | def training_step(self, batch, batch_idx: int): 222 | x, y = batch 223 | y_ = self(x) 224 | loss = self.l2_loss(y, y_, relative=self.hparams.relative_l2).mean() 225 | 226 | self.log( 227 | "train/loss", 228 | loss.item(), 229 | on_step = True, 230 | on_epoch = True, 231 | sync_dist = True, 232 | prog_bar = True 233 | ) 234 | return loss 235 | 236 | @torch.no_grad() 237 | def _preview(self): 238 | batch_size = self.hparams.batch_size * 8 239 | num_batches = self.sampler.num_pixels // batch_size + 1 240 | start_idx = 0 241 | pixels = [] 242 | mesh = self.sampler.mesh 243 | for _ in range(num_batches): 244 | if start_idx >= self.sampler.num_pixels: 245 | break 246 | stop_idx = min(start_idx + batch_size, self.sampler.num_pixels) 247 | mesh_slice = mesh[start_idx:stop_idx] 248 | outs = self(mesh_slice) 249 | pixels.append(outs.cpu()) 250 | start_idx = stop_idx 251 | pixels = torch.cat(pixels, dim=0) 252 | canvas = pixels.reshape(self.sampler.shape).detach().cpu().numpy() 253 | 254 | path = os.path.join( 255 | self.logger.log_dir, 256 | f"predictions/steps_{self.global_step:06d}.jpg" 257 | ) 258 | os.makedirs(os.path.dirname(path), exist_ok=True) 259 | utils.write_image(path, canvas, quality=95) 260 | 261 | def on_save_checkpoint(self, checkpoint): 262 | if self.trainer.is_global_zero: 263 | res = (self.vis_freq is not None 264 | and self.vis_freq(self.current_epoch, self.global_step)) 265 | if res: 266 | print('Visualizing results...') 267 | self._preview() 268 | 269 | 270 | def get_args(): 271 | parser = argparse.ArgumentParser() 272 | parser.add_argument('-i', '--input', type=str, help='Path to input image (.jpg/.npy)') 273 | parser.add_argument('--root', type=str, default='./logs') 274 | parser.add_argument('--trace', type=str, default='experiments') 275 | parser.add_argument('--batch_size', type=int, default=65536) 276 | parser.add_argument('--epochs', type=int, default=400, help='100 steps per epoch') 277 | parser.add_argument('--device', type=int, default=0) 278 | parser.add_argument('--enc_method', choices=['freq', 'hashgrid', 'none']) 279 | parser.add_argument('--n_levels', type=int, default=16) 280 | parser.add_argument('--n_features_per_level', type=int, default=2) 281 | parser.add_argument('--log2_hashmap_size', type=int, default=15) 282 | parser.add_argument('--base_resolution', type=int, default=16) 283 | parser.add_argument('--finest_resolution', type=int, default=8192) 284 | parser.add_argument('--visualize', action='store_true', default=False) 285 | return parser.parse_args() 286 | 287 | if __name__ == '__main__': 288 | 289 | a = get_args() 290 | 291 | def vis_func(epoch, step): 292 | # [1, 2, 4, 8, 10, 20, 30, 40, ...] 293 | epoch += 1 294 | if epoch < 10: 295 | return (epoch & (epoch-1)) == 0 296 | if epoch < 100: 297 | return epoch % 10 == 0 298 | if epoch < 1000: 299 | return epoch % 100 == 0 300 | 301 | root_dir = a.root 302 | image_file = os.path.basename(a.input) 303 | trace_name = a.trace 304 | image_name = image_file.split(".")[0] 305 | 306 | dir_path = os.path.join(root_dir, trace_name, image_name) 307 | 308 | if a.enc_method == 'freq': 309 | enc_kwargs = dict( 310 | n_levels = a.n_levels 311 | ) 312 | elif a.enc_method == 'hashgrid': 313 | enc_kwargs = dict( 314 | n_levels = a.n_levels, 315 | n_features_per_level = a.n_features_per_level, 316 | log2_hashmap_size = a.log2_hashmap_size, 317 | base_resolution = a.base_resolution, 318 | finest_resolution = a.finest_resolution 319 | ) 320 | elif a.enc_method == 'none': 321 | a.enc_method = None 322 | enc_kwargs = dict() 323 | 324 | model = Task( 325 | filename = a.input, 326 | batch_size = a.batch_size, 327 | n_samples = 100, 328 | lr = 1e-3, 329 | relative_l2 = True, 330 | enc_method = a.enc_method, 331 | enc_kwargs = enc_kwargs, 332 | vis_freq = vis_func if a.visualize else None 333 | ) 334 | 335 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 336 | every_n_epochs = 1 337 | ) 338 | 339 | trainer = pl.Trainer( 340 | callbacks = checkpoint_callback, 341 | max_epochs = a.epochs, 342 | accelerator = "gpu", 343 | devices = [a.device], 344 | default_root_dir = dir_path 345 | ) 346 | 347 | trainer.fit(model) 348 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # --- built in --- 2 | import os 3 | # --- 3rd party --- 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import imageio 7 | import PIL.Image 8 | PIL.Image.MAX_IMAGE_PIXELS = 10000000000 9 | 10 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 11 | # 12 | # Redistribution and use in source and binary forms, with or without modification, are permitted 13 | # provided that the following conditions are met: 14 | # * Redistributions of source code must retain the above copyright notice, this list of 15 | # conditions and the following disclaimer. 16 | # * Redistributions in binary form must reproduce the above copyright notice, this list of 17 | # conditions and the following disclaimer in the documentation and/or other materials 18 | # provided with the distribution. 19 | # * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used 20 | # to endorse or promote products derived from this software without specific prior written 21 | # permission. 22 | # 23 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 24 | # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 25 | # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 26 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 27 | # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 28 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 29 | # STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | def srgb_to_linear(img): 33 | limit = 0.04045 34 | return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92) 35 | 36 | def linear_to_srgb(img): 37 | limit = 0.0031308 38 | return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img) 39 | 40 | def read_image_imageio(filename): 41 | image = imageio.imread(filename) 42 | image = np.asarray(image).astype(np.float32) 43 | if len(image.shape) == 2: 44 | image = image[:, :, np.newaxis] 45 | return image / 255.0 46 | 47 | def write_image_imageio(filename, image, quality=95): 48 | image = (np.clip(image, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8) 49 | kwargs = {} 50 | if os.path.splitext(filename)[1].lower() in ['.jpg', '.jpeg']: 51 | if image.ndim >= 3 and image.shape[2] > 3: 52 | image = image[:, :, :3] 53 | kwargs['quality'] = quality 54 | kwargs['subsampling'] = 0 55 | imageio.imwrite(filename, image, **kwargs) 56 | 57 | def read_image(filename): 58 | if os.path.splitext(filename)[1] == '.npy': 59 | image = np.load(filename) 60 | else: 61 | image = read_image_imageio(filename) 62 | image = srgb_to_linear(image) 63 | return image 64 | 65 | def write_image(filename, image, quality=95): 66 | if os.path.splitext(filename)[1] == '.npy': 67 | # here we encode the image to npy format 68 | np.save(filename, image) 69 | else: 70 | image = linear_to_srgb(np.clip(image, 0.0, 1.0)) 71 | write_image_imageio(filename, image, quality=quality) 72 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | # --- built in --- 2 | import os 3 | import argparse 4 | # --- 3rd party --- 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import tqdm 9 | # --- my module --- 10 | from train import Task 11 | import utils 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--ckpt', type=str) 16 | parser.add_argument('-o', '--output', type=str, default='output/frame_{:06d}.png') 17 | parser.add_argument('--zoom', type=float, default=10.0) 18 | parser.add_argument('--center_x', type=float, default=0.5) 19 | parser.add_argument('--center_y', type=float, default=0.5) 20 | parser.add_argument('--width', type=int, default=480) 21 | parser.add_argument('--height', type=int, default=640) 22 | parser.add_argument('--n_frames', type=int, default=150) 23 | return parser.parse_args() 24 | 25 | 26 | def smoothstep(e0, e1, x): 27 | t = np.clip((x-e0)/(e1-e0), 0.0, 1.0) 28 | return t * t * (3.0 - 2.0 * t) 29 | 30 | def lerp(x, y, a): 31 | return x * (1-a) + y * a 32 | 33 | @torch.no_grad() 34 | def render(a, model, grid, frame_idx, zoom_factor, center): 35 | grid = (grid-center) / zoom_factor + center 36 | bound_min = np.min(grid, axis=0) 37 | bound_max = np.max(grid, axis=0) 38 | move = np.maximum(0.0 - bound_min, 0.0) 39 | move = move + np.minimum(1.0 - (bound_max + move), 0.0) 40 | grid = grid + move 41 | outs = model(torch.from_numpy(grid)).cpu().numpy() 42 | return outs.reshape((a.height, a.width, -1)) 43 | 44 | 45 | def main(): 46 | a = get_args() 47 | model = Task.load_from_checkpoint(a.ckpt).to(device='cuda') 48 | 49 | # generate grid 50 | u_res = 0.0 51 | v_res = 0.0 52 | u = np.linspace(u_res, 1-u_res, a.height) 53 | v = np.linspace(v_res, 1-v_res, a.width) 54 | u, v = np.meshgrid(u, v, indexing='ij') 55 | grid = np.stack((v.flatten(), u.flatten()), axis=0).T # (n, 2) 56 | grid = grid.astype(np.float32) 57 | 58 | stay = 25 59 | zoom_in = stay + int(0.8 * (a.n_frames-stay)) 60 | zoom_out = a.n_frames 61 | 62 | target_center = np.array((a.center_x, a.center_y), dtype=np.float32) 63 | 64 | for frame_idx in tqdm.tqdm(range(a.n_frames)): 65 | if frame_idx < stay: 66 | zoom_factor = 1.0 67 | elif frame_idx < zoom_in: 68 | frame_time = smoothstep(stay, zoom_in, frame_idx) 69 | zoom_factor = lerp(1.0, a.zoom, frame_time) 70 | else: 71 | frame_time = smoothstep(zoom_in, zoom_out, frame_idx) 72 | zoom_factor = lerp(a.zoom, 1.0, frame_time) 73 | canvas = render(a, model, grid.copy(), frame_idx, zoom_factor, target_center) 74 | path = a.output.format(frame_idx) 75 | os.makedirs(os.path.dirname(path), exist_ok=True) 76 | utils.write_image(path, canvas, quality=95) 77 | 78 | if __name__ == '__main__': 79 | main() 80 | 81 | 82 | --------------------------------------------------------------------------------