├── .github └── workflows │ └── python-publish.yaml ├── .gitignore ├── LICENSE ├── README.md ├── config └── config.yaml ├── datasets.py ├── figs ├── fashionmnist-recons.png ├── fashionmnist-samples.png ├── gon.png ├── mnist-recons.png └── mnist-samples.png ├── gon_pytorch ├── __init__.py ├── data.py ├── modules.py └── utils.py ├── setup.py └── train_gon.py /.github/workflows/python-publish.yaml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist/ 2 | build/ 3 | runs/ 4 | *.egg-info -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kristian Klemon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Gradient Origin Networks in PyTorch 2 | =================================== 3 | 4 | Unofficial PyTorch implementation of [Gradient Origin Networks](https://arxiv.org/abs/2007.02798). 5 | 6 | ![](./figs/gon.png) 7 | 8 | | Reconstructions | Samples | 9 | | ----------------| ------- | 10 | | ![](./figs/fashionmnist-recons.png) | ![](./figs/fashionmnist-samples.png) | 11 | | ![](./figs/mnist-recons.png) | ![](./figs/mnist-samples.png) | 12 | 13 | 14 | Usage 15 | ----- 16 | 17 | ### Training 18 | 19 | Requirements: 20 | - [pytorch](https://github.com/pytorch/pytorch) 21 | - [numpy](https://github.com/numpy/numpy) 22 | - [hydra](https://github.com/facebookresearch/hydra) 23 | 24 | After cloning the repository, a GON can be trained using the `train_gon.py` script: 25 | 26 | ```bash 27 | python train_gon.py dataset.name= dataset.root= 28 | ``` 29 | 30 | All configuration options are listed in `config/config.yaml`. See the [hydra](https://github.com/facebookresearch/hydra) documentation for more information on configuration. 31 | 32 | 33 | ### From Code 34 | 35 | Install the package: 36 | 37 | ```bash 38 | pip install gon-pytorch 39 | ``` 40 | 41 | Instantiate a GON with [NeRF](https://arxiv.org/abs/2003.08934) positional encodings: 42 | 43 | ```python 44 | import torch 45 | from gon_pytorch import NeRFPositionalEncoding, ImplicitDecoder, GON, SirenBlockFactory 46 | 47 | pos_encoder = NeRFPositionalEncoding(in_dim=2) 48 | decoder = ImplicitDecoder( 49 | latent_dim=128, 50 | out_dim=3, 51 | hidden_dim=128, 52 | num_layers=4, 53 | block_factory=SirenBlockFactory(), 54 | pos_encoder=pos_encoder 55 | ) 56 | gon = GON(decoder) 57 | 58 | coords = torch.randn(1, 32, 32, 2) 59 | image = torch.rand(1, 32, 32, 3) 60 | 61 | # Obtain latent 62 | latent, latent_loss = gon.infer_latents(coords, image) 63 | 64 | # Reconstruct from latent 65 | recon = gon(coords, latent) 66 | 67 | # Compute gradients 68 | loss = ((recon - image) ** 2).mean() 69 | loss.backward() 70 | ``` 71 | 72 | 73 | Differences to the original implementation 74 | ------------------------------------------ 75 | 76 | - Cross-entropy is used as loss instead of MSE as this seems to be improve results 77 | - The original implementation obtains gradients with respect to the origin by calculating the mean over the latent loss. This seems to cause a bias on the batch-size as the mean loss is evenly distributed on the single latents in the backward pass. This is fixed by summing over the batch dimension for the latent loss instead of using the mean. 78 | - Latent modulation from [Modulated Periodic Activations for Generalizable Local Functional Representations](https://arxiv.org/abs/2104.03960) is implemented and can optionally be used. 79 | 80 | 81 | Citations 82 | --------- 83 | 84 | ```bibtex 85 | @misc{bondtaylor2021gradient, 86 | title={Gradient Origin Networks}, 87 | author={Sam Bond-Taylor and Chris G. Willcocks}, 88 | year={2021}, 89 | eprint={2007.02798}, 90 | archivePrefix={arXiv}, 91 | primaryClass={cs.CV} 92 | } 93 | ``` 94 | 95 | ```bibtex 96 | @misc{sitzmann2020implicit, 97 | title={Implicit Neural Representations with Periodic Activation Functions}, 98 | author={Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein}, 99 | year={2020}, 100 | eprint={2006.09661}, 101 | archivePrefix={arXiv}, 102 | primaryClass={cs.CV} 103 | } 104 | ``` 105 | 106 | ```bibtex 107 | @misc{mildenhall2020nerf, 108 | title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis}, 109 | author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng}, 110 | year={2020}, 111 | eprint={2003.08934}, 112 | archivePrefix={arXiv}, 113 | primaryClass={cs.CV} 114 | } 115 | ``` 116 | 117 | ```bibtex 118 | @misc{mehta2021modulated, 119 | title = {Modulated Periodic Activations for Generalizable Local Functional Representations}, 120 | author = {Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker}, 121 | year = {2021}, 122 | eprint = {2104.03960}, 123 | archivePrefix = {arXiv}, 124 | primaryClass = {cs.CV} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ${logging.log_dir}/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S} 4 | job: 5 | name: ${logging.run_name}-${dataset.name} 6 | 7 | logging: 8 | log_dir: runs 9 | run_name: gon 10 | log_every: 100 11 | n_samples_per_epoch: 64 12 | n_recons_per_epoch: 32 13 | 14 | model: 15 | hidden_dim: 128 16 | num_layers: 5 17 | activation: siren 18 | bias: true 19 | dropout: 0.1 20 | latent_dim: 128 21 | latent_reg: 0 22 | latent_modulation: false 23 | latent_updates: 1 24 | learn_origin: false 25 | pos_encoder: 26 | name: gaussian 27 | args: {} 28 | 29 | dataset: 30 | name: ??? 31 | root: ./data 32 | image_size: 32 33 | 34 | training: 35 | device: null 36 | batch_size: 128 37 | lr: 1e-4 38 | epochs: 10 39 | num_workers: 4 40 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets 2 | 3 | 4 | class MNIST: 5 | def __init__(self, root, transform): 6 | self.num_channels = 1 7 | 8 | self.train = datasets.MNIST(root, train=True, download=True, transform=transform) 9 | self.test = datasets.MNIST(root, train=False, download=True, transform=transform) 10 | 11 | 12 | class FashionMNIST: 13 | def __init__(self, root, transform): 14 | self.num_channels = 1 15 | 16 | self.train = datasets.FashionMNIST(root, train=True, download=True, transform=transform) 17 | self.test = datasets.FashionMNIST(root, train=False, download=True, transform=transform) 18 | 19 | 20 | class CIFAR10: 21 | def __init__(self, root, transform): 22 | self.num_channels = 3 23 | 24 | self.train = datasets.CIFAR10(root, train=True, download=True, transform=transform) 25 | self.test = datasets.CIFAR10(root, train=False, download=True, transform=transform) 26 | -------------------------------------------------------------------------------- /figs/fashionmnist-recons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kklemon/gon-pytorch/2e374124cdf4ec57f135fe103e5f7923e07c96c8/figs/fashionmnist-recons.png -------------------------------------------------------------------------------- /figs/fashionmnist-samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kklemon/gon-pytorch/2e374124cdf4ec57f135fe103e5f7923e07c96c8/figs/fashionmnist-samples.png -------------------------------------------------------------------------------- /figs/gon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kklemon/gon-pytorch/2e374124cdf4ec57f135fe103e5f7923e07c96c8/figs/gon.png -------------------------------------------------------------------------------- /figs/mnist-recons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kklemon/gon-pytorch/2e374124cdf4ec57f135fe103e5f7923e07c96c8/figs/mnist-recons.png -------------------------------------------------------------------------------- /figs/mnist-samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kklemon/gon-pytorch/2e374124cdf4ec57f135fe103e5f7923e07c96c8/figs/mnist-samples.png -------------------------------------------------------------------------------- /gon_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import ( 2 | CoordinateEncoding, 3 | IdentityPositionalEncoding, 4 | GaussianFourierFeatureTransform, 5 | NeRFPositionalEncoding, 6 | LinearBlock, 7 | SirenLinear, 8 | LinearBlockFactory, 9 | SirenBlockFactory, 10 | Swish, 11 | Sine, 12 | ImplicitDecoder, 13 | GON 14 | ) 15 | -------------------------------------------------------------------------------- /gon_pytorch/data.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils import data 3 | from pathlib import Path 4 | 5 | 6 | class ImageFolderDataset(data.Dataset): 7 | def __init__( 8 | self, 9 | root, 10 | transform=None, 11 | ext=('.png', '.jpg', '.jpeg', 'bmp'), 12 | recursive=False, 13 | return_filename=False 14 | ): 15 | if recursive: 16 | glob_pattern = '**/*' 17 | else: 18 | glob_pattern = '*' 19 | 20 | self.files = [f for f in Path(root).glob(glob_pattern) if f.is_file() and f.suffix.lower() in ext] 21 | self.transform = transform 22 | self.return_filename = return_filename 23 | self.mode = 'RGB' 24 | 25 | def __len__(self): 26 | return len(self.files) 27 | 28 | def __getitem__(self, idx): 29 | filename = self.files[idx] 30 | with open(filename, 'rb') as f: 31 | img = Image.open(f) 32 | img = img.convert(self.mode) 33 | if self.transform: 34 | img = self.transform(img) 35 | if self.return_filename: 36 | return img, str(filename) 37 | return img 38 | 39 | 40 | class NoLabelWrapper(data.Dataset): 41 | def __init__(self, dataset): 42 | self.dataset = dataset 43 | 44 | def __getitem__(self, idx): 45 | return self.dataset[idx][0] 46 | 47 | def __len__(self): 48 | return len(self.dataset) 49 | -------------------------------------------------------------------------------- /gon_pytorch/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from copy import copy 7 | from functools import partial 8 | from typing import Optional, List, Callable 9 | from math import pi, sqrt 10 | 11 | 12 | class CoordinateEncoding(nn.Module): 13 | def __init__(self, proj_matrix, is_trainable=False): 14 | super().__init__() 15 | if is_trainable: 16 | self.register_parameter('proj_matrix', nn.Parameter(proj_matrix)) 17 | else: 18 | self.register_buffer('proj_matrix', proj_matrix) 19 | self.in_dim = self.proj_matrix.size(0) 20 | self.out_dim = self.proj_matrix.size(1) * 2 21 | 22 | def forward(self, x): 23 | shape = x.shape 24 | channels = shape[-1] 25 | 26 | assert channels == self.in_dim, f'Expected input to have {self.in_dim} channels (got {channels} channels)' 27 | 28 | x = x.reshape(-1, channels) 29 | x = x @ self.proj_matrix 30 | 31 | x = x.view(*shape[:-1], -1) 32 | x = 2 * pi * x 33 | 34 | return torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 35 | 36 | 37 | class IdentityPositionalEncoding(CoordinateEncoding): 38 | def __init__(self, in_dim): 39 | super().__init__(torch.eye(in_dim)) 40 | self.out_dim = in_dim 41 | 42 | def forward(self, x): 43 | return x 44 | 45 | 46 | class GaussianFourierFeatureTransform(CoordinateEncoding): 47 | def __init__(self, in_dim: int, mapping_size: int = 32, sigma: float = 1.0, is_trainable: bool = False, seed=None): 48 | super().__init__(self.get_transform_matrix(in_dim, mapping_size, sigma, seed=seed), is_trainable=is_trainable) 49 | self.mapping_size = mapping_size 50 | self.sigma = sigma 51 | self.seed = seed 52 | 53 | @classmethod 54 | def get_transform_matrix(cls, in_dim, mapping_size, sigma, seed=None): 55 | generator = None 56 | if seed is not None: 57 | generator = torch.Generator().manual_seed(seed) 58 | return torch.normal(mean=0, std=sigma, size=(in_dim, mapping_size), generator=generator) 59 | 60 | @classmethod 61 | def from_matrix(cls, projection_matrix): 62 | in_dim, mapping_size = projection_matrix.shape 63 | feature_transform = cls(in_dim, mapping_size) 64 | feature_transform.projection_matrix.data = projection_matrix 65 | return feature_transform 66 | 67 | def __repr__(self): 68 | return f'{self.__class__.__name__}(in_dim={self.in_dim}, mapping_size={self.mapping_size}, sigma={self.sigma})' 69 | 70 | 71 | class NeRFPositionalEncoding(CoordinateEncoding): 72 | def __init__(self, in_dim, n=10): 73 | super().__init__((2.0 ** torch.arange(n))[None, :]) 74 | self.out_dim = n * 2 * in_dim 75 | 76 | def forward(self, x): 77 | shape = x.shape 78 | x = x.unsqueeze(-1) * self.proj_matrix 79 | x = pi * x 80 | x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 81 | x = x.view(*shape[:-1], -1) 82 | return x 83 | 84 | 85 | class LinearBlock(nn.Module): 86 | def __init__(self, 87 | in_features, 88 | out_features, 89 | linear_cls, 90 | activation=nn.ReLU, 91 | bias=True, 92 | is_first=False, 93 | is_last=False): 94 | super().__init__() 95 | self.in_f = in_features 96 | self.out_f = out_features 97 | self.linear = linear_cls(in_features, out_features, bias=bias) 98 | self.bias = bias 99 | self.is_first = is_first 100 | self.is_last = is_last 101 | self.activation = None if is_last else activation() 102 | 103 | def forward(self, x): 104 | x = self.linear(x) 105 | if self.activation is not None: 106 | return self.activation(x) 107 | else: 108 | return x 109 | 110 | def __repr__(self): 111 | return f'LinearBlock(in_features={self.in_f}, out_features={self.out_f}, linear_cls={self.linear}, ' \ 112 | f'activation={self.activation}, bias={self.bias}, is_first={self.is_first}, is_last={self.is_last})' 113 | 114 | 115 | class Swish(nn.Module): 116 | def __init__(self): 117 | super().__init__() 118 | 119 | def forward(self, x): 120 | return x * torch.sigmoid(x) 121 | 122 | 123 | class Sine(nn.Module): 124 | def __init__(self, w0=1.0): 125 | super().__init__() 126 | self.w0 = w0 127 | 128 | def forward(self, x): 129 | return torch.sin(self.w0 * x) 130 | 131 | def __repr__(self): 132 | return f'Sine(w0={self.w0})' 133 | 134 | 135 | class SirenLinear(LinearBlock): 136 | def __init__(self, in_features, out_features, linear_cls=nn.Linear, w0=30, bias=True, is_first=False, is_last=False): 137 | super().__init__(in_features, out_features, linear_cls, partial(Sine, w0), bias, is_first, is_last) 138 | self.w0 = w0 139 | self.init_weights() 140 | 141 | def init_weights(self): 142 | if self.is_first: 143 | b = 1 / self.in_f 144 | else: 145 | b = sqrt(6 / self.in_f) / self.w0 146 | 147 | with torch.no_grad(): 148 | self.linear.weight.uniform_(-b, b) 149 | if self.linear.bias is not None: 150 | self.linear.bias.uniform_(-b, b) 151 | 152 | 153 | class BatchedLinear(nn.Module): 154 | def __init__(self, in_feat, out_feat, num_models, bias=True): 155 | super().__init__() 156 | 157 | self.in_feat = in_feat 158 | self.out_feat = out_feat 159 | self.num_models = num_models 160 | 161 | self.weight = nn.Parameter(torch.Tensor(num_models, out_feat, in_feat)) 162 | if bias: 163 | self.bias = nn.Parameter(torch.Tensor(num_models, out_feat)) 164 | else: 165 | self.bias = None 166 | 167 | self.init_weights() 168 | 169 | def init_weights(self): 170 | for i in range(self.num_models): 171 | w = self.weight[i] 172 | nn.init.kaiming_uniform_(w, a=math.sqrt(5)) 173 | if self.bias is not None: 174 | b = self.bias[i] 175 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w) 176 | bound = 1 / fan_in 177 | nn.init.uniform_(b, -bound, bound) 178 | 179 | def forward(self, x): 180 | x = x.transpose(1, -1) 181 | orig_shape = x.shape 182 | x = x.reshape(x.size(0), x.size(1), -1) 183 | 184 | out = torch.bmm(self.weight, x) 185 | if self.bias is not None: 186 | out += self.bias.unsqueeze(-1) 187 | 188 | out = out.view((out.size(0), self.weight.shape[1]) + orig_shape[2:]) 189 | out = out.transpose(1, -1) 190 | 191 | return out 192 | 193 | def get_layer_by_index(self, idx): 194 | linear = nn.Linear(self.in_feat, self.out_feat, bias=self.bias is not None) 195 | linear.weight.data = self.weight[idx].data 196 | if self.bias is not None: 197 | linear.bias.data = self.bias[idx].data 198 | return linear 199 | 200 | def get_layers(self): 201 | return list(map(self.get_layer_by_index, range(self.num_models))) 202 | 203 | 204 | class BaseBlockFactory: 205 | def __call__(self, in_f, out_f, is_first=False, is_last=False): 206 | raise NotImplementedError 207 | 208 | 209 | class LinearBlockFactory(BaseBlockFactory): 210 | def __init__(self, linear_cls=nn.Linear, activation_cls=nn.ReLU, bias=True): 211 | self.linear_cls = linear_cls 212 | self.activation_cls = activation_cls 213 | self.bias = bias 214 | 215 | def __call__(self, in_f, out_f, is_first=False, is_last=False): 216 | return LinearBlock(in_f, out_f, self.linear_cls, self.activation_cls, self.bias, is_first, is_last) 217 | 218 | 219 | class SirenBlockFactory(BaseBlockFactory): 220 | def __init__(self, linear_cls=nn.Linear, w0=30, bias=True): 221 | self.linear_cls = linear_cls 222 | self.w0 = w0 223 | self.bias = bias 224 | 225 | def __call__(self, in_f, out_f, is_first=False, is_last=False): 226 | return SirenLinear(in_f, out_f, self.linear_cls, self.w0, self.bias, is_first, is_last) 227 | 228 | 229 | class MLP(nn.Module): 230 | def __init__(self, 231 | in_dim: int, 232 | out_dim: int, 233 | hidden_dim: int, 234 | num_layers: int, 235 | block_factory: BaseBlockFactory, 236 | dropout: float = 0.0, 237 | final_activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None): 238 | super().__init__() 239 | 240 | self.in_dim = in_dim 241 | self.out_dim = out_dim 242 | self.hidden_dim = hidden_dim 243 | self.num_layers = num_layers 244 | self.dropout = dropout 245 | 246 | self.blocks = nn.ModuleList() 247 | 248 | if self.num_layers < 1: 249 | raise ValueError(f'num_layers must be >= 1 (input to output); got {self.num_layers}') 250 | 251 | for i in range(self.num_layers): 252 | in_feat = self.in_dim if i == 0 else self.hidden_dim 253 | out_feat = self.out_dim if i + 1 == self.num_layers else self.hidden_dim 254 | 255 | is_first = i == 0 256 | is_last = i + 1 == self.num_layers 257 | 258 | curr_block = [block_factory( 259 | in_feat, 260 | out_feat, 261 | is_first=is_first, 262 | is_last=is_last 263 | )] 264 | if not is_last and dropout: 265 | curr_block.append(nn.Dropout(dropout)) 266 | 267 | self.blocks.append(nn.Sequential(*curr_block)) 268 | 269 | self.final_activation = final_activation 270 | if final_activation is None: 271 | self.final_activation = nn.Identity() 272 | 273 | def forward(self, x, modulations=None): 274 | for i, block in enumerate(self.blocks): 275 | x = block(x) 276 | if modulations is not None and len(self.blocks) > i + 1: 277 | x *= modulations[i][:, None, None, :] 278 | return self.final_activation(x) 279 | 280 | 281 | class BatchedImageMLP(MLP): 282 | def __init__(self, num_models: int, block_factory: BaseBlockFactory, *args, **kwargs): 283 | 284 | multi_model_block_factory = copy(block_factory) 285 | multi_model_block_factory.linear_cls = partial(BatchedLinear, num_models=num_models) 286 | 287 | super().__init__(*args, block_factory=multi_model_block_factory, **kwargs) 288 | 289 | self.block_factory = block_factory 290 | self.num_models = num_models 291 | self.expected_batch_size = num_models 292 | 293 | def get_model_by_index(self, idx): 294 | model = MLP( 295 | self.in_dim, 296 | self.out_dim, 297 | self.hidden_dim, 298 | self.num_layers, 299 | self.block_factory, 300 | self.dropout, 301 | self.final_activation 302 | ) 303 | for src_block, trg_block in zip(self.blocks, model.blocks): 304 | if hasattr(src_block, 'linear'): 305 | trg_block.linear = src_block.linear.get_layer_by_index(idx) 306 | return model 307 | 308 | def get_model_splits(self): 309 | return list(map(self.get_model_by_index, range(self.num_models))) 310 | 311 | 312 | class ModulationNetwork(nn.Module): 313 | def __init__(self, in_dim: int, mod_dims: List[int], activation=nn.ReLU): 314 | super().__init__() 315 | 316 | self.blocks = nn.ModuleList() 317 | for i in range(len(mod_dims)): 318 | self.blocks.append(nn.Sequential( 319 | nn.Linear(in_dim + (mod_dims[i - 1] if i else 0), mod_dims[i]), 320 | activation() 321 | )) 322 | 323 | def forward(self, input): 324 | out = input 325 | mods = [] 326 | for block in self.blocks: 327 | out = block(out) 328 | mods.append(out) 329 | out = torch.cat([out, input], dim=-1) 330 | return mods 331 | 332 | 333 | class ImplicitDecoder(nn.Module): 334 | def __init__(self, 335 | latent_dim: int, 336 | out_dim: int, 337 | hidden_dim: int, 338 | num_layers: int, 339 | block_factory: BaseBlockFactory, 340 | pos_encoder: CoordinateEncoding = None, 341 | modulation: bool = False, 342 | dropout: float = 0.0, 343 | final_activation=torch.sigmoid): 344 | super().__init__() 345 | 346 | self.pos_encoder = pos_encoder 347 | self.latent_dim = latent_dim 348 | 349 | self.mod_network = None 350 | if modulation: 351 | self.mod_network = ModulationNetwork( 352 | in_dim=latent_dim, 353 | mod_dims=[hidden_dim for _ in range(num_layers - 1)], 354 | activation=nn.ReLU 355 | ) 356 | 357 | self.net = MLP( 358 | in_dim=pos_encoder.out_dim + latent_dim * (not modulation), 359 | out_dim=out_dim, 360 | hidden_dim=hidden_dim, 361 | num_layers=num_layers, 362 | block_factory=block_factory, 363 | dropout=dropout, 364 | final_activation=final_activation 365 | ) 366 | 367 | def forward(self, input, latent): 368 | if self.pos_encoder is not None: 369 | input = self.pos_encoder(input) 370 | 371 | if self.mod_network is None: 372 | b, *spatial_dims, c = input.shape 373 | latent = latent.view(b, *((1,) * len(spatial_dims)), -1).repeat(1, *spatial_dims, 1) 374 | out = self.net(torch.cat([latent, input], dim=-1)) 375 | else: 376 | mods = self.mod_network(latent) 377 | out = self.net(input, mods) 378 | 379 | return out 380 | 381 | 382 | class GON(nn.Module): 383 | def __init__(self, decoder: ImplicitDecoder, latent_updates: int = 1, learn_origin: bool = False): 384 | super().__init__() 385 | 386 | self.decoder = decoder 387 | self.latent_updates = latent_updates 388 | self.latent_updates = latent_updates 389 | 390 | if learn_origin: 391 | self.init_latent = nn.Parameter(torch.zeros(1, self.decoder.latent_dim)) 392 | else: 393 | self.register_buffer('init_latent', torch.zeros(1, self.decoder.latent_dim)) 394 | 395 | def get_init_latent(self, n): 396 | return self.init_latent.repeat(n, 1) 397 | 398 | def loss_inner(self, output, target): 399 | return F.binary_cross_entropy( 400 | output.view(-1), target.view(-1), reduction='none' 401 | ).view(target.shape).sum(0).mean() 402 | 403 | def loss_outer(self, output, target): 404 | return F.binary_cross_entropy( 405 | output.view(-1), target.view(-1), reduction='none' 406 | ).view(target.shape).mean() 407 | 408 | def infer_latents(self, input, target): 409 | latent = self.get_init_latent(len(target)).requires_grad_(True) 410 | 411 | for i in range(self.latent_updates): 412 | out = self.decoder(input, latent) 413 | inner_loss = self.loss_inner(out, target) 414 | latent = latent - torch.autograd.grad(inner_loss, [latent], create_graph=True, retain_graph=True)[0] 415 | 416 | return latent, inner_loss 417 | 418 | def forward(self, input, latent): 419 | return self.decoder(input, latent) 420 | -------------------------------------------------------------------------------- /gon_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from gon_pytorch import modules 5 | 6 | 7 | def get_block_factory(activation='siren', bias=True): 8 | if activation == 'siren': 9 | return modules.SirenBlockFactory(nn.Linear, bias=bias) 10 | if activation == 'relu': 11 | return modules.LinearBlockFactory(nn.Linear, activation_cls=nn.ReLU, bias=bias) 12 | if activation == 'leaky_relu': 13 | return modules.LinearBlockFactory(nn.Linear, activation_cls=lambda: nn.LeakyReLU(0.2), bias=bias) 14 | if activation == 'swish': 15 | return modules.LinearBlockFactory(nn.Linear, activation_cls=modules.Swish, bias=bias) 16 | raise ValueError(f'Unknown activation {activation}') 17 | 18 | 19 | def get_xy_grid(width, height): 20 | x_coords = torch.linspace(-1, 1, width) 21 | y_coords = torch.linspace(-1, 1, height) 22 | 23 | xy_grid = torch.tensor( 24 | torch.stack(torch.meshgrid(x_coords, y_coords), -1) 25 | ).unsqueeze(0) 26 | 27 | return xy_grid 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup, find_packages 3 | 4 | ROOT = Path(__file__).parent 5 | README = (ROOT / 'README.md').read_text() 6 | 7 | setup( 8 | name='gon-pytorch', 9 | packages=find_packages(), 10 | version='0.1.1', 11 | license='MIT', 12 | description='Gradient Origin Networks for PyTorch', 13 | long_description=README, 14 | long_description_content_type='text/markdown', 15 | author='Kristian Klemon', 16 | author_email='kristian.klemon@gmail.com', 17 | url='https://github.com/kklemon/gon-pytorch', 18 | keywords=['artificial intelligence', 'deep learning'], 19 | install_requires=['torch'] 20 | ) 21 | -------------------------------------------------------------------------------- /train_gon.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import hydra 3 | import numpy as np 4 | import torch 5 | import datasets 6 | 7 | from pathlib import Path 8 | from itertools import chain 9 | from omegaconf import DictConfig 10 | from torchvision.utils import save_image, make_grid 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader 13 | from gon_pytorch import modules, utils 14 | 15 | 16 | def train(model, batches, input, opt, device, epoch, latent_reg, log_every=100): 17 | model.train() 18 | 19 | seen_samples = 0 20 | inner_loss_sum = 0 21 | outer_loss_sum = 0 22 | latent_l2_loss_sum = 0 23 | 24 | latent_buffer = torch.zeros(len(batches.dataset), model.decoder.latent_dim) 25 | label_buffer = torch.zeros(len(batches.dataset), dtype=torch.long) 26 | 27 | for step, (images, labels) in enumerate(batches): 28 | images = images.to(device) 29 | 30 | batch_input = input.repeat(len(images), 1, 1, 1) 31 | 32 | # Obtain latent with respect to origin 33 | latents, inner_loss = model.infer_latents(batch_input, images) 34 | 35 | # Optimize model with obtained latent 36 | out = model(batch_input, latents) 37 | outer_loss = model.loss_outer(out, images) 38 | 39 | latent_reg_loss = latent_reg * (latents ** 2).sum() 40 | loss = outer_loss + latent_reg_loss 41 | 42 | opt.zero_grad() 43 | loss.backward() 44 | opt.step() 45 | 46 | latent_buffer[seen_samples:seen_samples + len(images)] = latents.detach().cpu() 47 | label_buffer[seen_samples:seen_samples + len(images)] = labels 48 | 49 | seen_samples += len(images) 50 | inner_loss_sum += inner_loss.item() 51 | outer_loss_sum += outer_loss.item() 52 | latent_l2_loss_sum += latent_reg_loss 53 | 54 | step += 1 55 | 56 | if step % log_every == 0: 57 | stats = { 58 | 'avg inner loss': inner_loss_sum, 59 | 'avg outer loss': outer_loss_sum 60 | } 61 | if latent_reg: 62 | stats['latent l2 loss'] = latent_l2_loss_sum 63 | 64 | print(f'[EPOCH {epoch:03d}][{seen_samples:05d}/{len(batches.dataset):05d}] ' + 65 | ', '.join(f'{k}: {v / step:.4f}' for k, v in stats.items())) 66 | 67 | return latent_buffer, label_buffer 68 | 69 | 70 | def eval(model, batches, input, device): 71 | model.eval() 72 | 73 | seen_samples = 0 74 | inner_loss_sum = 0 75 | outer_loss_sum = 0 76 | 77 | latent_buffer = torch.zeros(len(batches.dataset), model.decoder.latent_dim) 78 | label_buffer = torch.zeros(len(batches.dataset), dtype=torch.long) 79 | 80 | for step, (images, labels) in enumerate(batches): 81 | images = images.to(device) 82 | batch_input = input.repeat(len(images), 1, 1, 1) 83 | 84 | # Obtain latent with respect to origin 85 | latents, inner_loss = model.infer_latents(batch_input, images) 86 | 87 | # Calculate loss for obtained latent 88 | out = model(batch_input, latents) 89 | outer_loss = model.loss_outer(out, images) 90 | 91 | inner_loss_sum += inner_loss.item() 92 | outer_loss_sum += outer_loss.item() 93 | 94 | latent_buffer[seen_samples:seen_samples + len(images)] = latents.detach().cpu() 95 | label_buffer[seen_samples:seen_samples + len(images)] = labels 96 | 97 | seen_samples += len(images) 98 | 99 | print(f'inner loss: {inner_loss_sum / len(batches):.4f}, outer loss: {outer_loss_sum / len(batches):.4f}') 100 | 101 | return latent_buffer, label_buffer 102 | 103 | 104 | def sample(model, input, mean, cov, n_samples): 105 | model.eval() 106 | 107 | latents = torch.tensor( 108 | np.random.multivariate_normal(mean, cov, size=n_samples), dtype=torch.float32 109 | ).to(input.device) 110 | 111 | model_input = input.repeat(n_samples, 1, 1, 1) 112 | samples = model(model_input, latents) 113 | return samples 114 | 115 | 116 | @hydra.main(config_path='config', config_name='config') 117 | def main(cfg: DictConfig): 118 | device = cfg.training.device or ('cuda' if torch.cuda.is_available() else 'cpu') 119 | 120 | log_dir = Path.cwd() 121 | log_dir.mkdir(parents=True, exist_ok=True) 122 | 123 | recon_dir = log_dir / 'reconstructions' 124 | recon_dir.mkdir(exist_ok=True) 125 | 126 | sample_dir = log_dir / 'samples' 127 | sample_dir.mkdir(exist_ok=True) 128 | 129 | print(f'Logging to {str(log_dir)}') 130 | 131 | dataset_cls = getattr(datasets, cfg.dataset.name, None) 132 | if dataset_cls is None: 133 | raise ValueError(f'Unknown dataset {cfg.dataset.name}') 134 | 135 | dataset = dataset_cls(cfg.dataset.root, transforms.Compose([ 136 | transforms.Resize(cfg.dataset.image_size), 137 | transforms.ToTensor(), 138 | transforms.Lambda(lambda t: t.permute(1, 2, 0)) 139 | ])) 140 | 141 | train_batches = DataLoader( 142 | dataset.train, cfg.training.batch_size, shuffle=True, num_workers=cfg.training.num_workers 143 | ) 144 | test_batches = DataLoader( 145 | dataset.test, cfg.training.batch_size, shuffle=False, num_workers=cfg.training.num_workers 146 | ) 147 | 148 | fixed_batch = next(iter(train_batches))[0][:cfg.logging.n_recons_per_epoch].to(device) 149 | 150 | pos_encoder_kwargs = {'in_dim': 2, **cfg.model.pos_encoder.get('args', {})} 151 | pos_encoder_cls = { 152 | 'none': modules.IdentityPositionalEncoding, 153 | 'gaussian': modules.GaussianFourierFeatureTransform, 154 | 'nerf': modules.NeRFPositionalEncoding 155 | }.get(cfg.model.pos_encoder.name) 156 | if pos_encoder_cls is None: 157 | raise ValueError(f'Unknown positional encoder \'{cfg.model.pos_encoder.name}\'') 158 | 159 | pos_encoder = pos_encoder_cls(**pos_encoder_kwargs) 160 | grid = utils.get_xy_grid(cfg.dataset.image_size, cfg.dataset.image_size) 161 | model_input = grid.to(device) 162 | 163 | decoder = modules.ImplicitDecoder( 164 | latent_dim=cfg.model.latent_dim, 165 | out_dim=dataset.num_channels, 166 | hidden_dim=cfg.model.hidden_dim, 167 | num_layers=cfg.model.num_layers, 168 | block_factory=utils.get_block_factory(cfg.model.activation, cfg.model.bias), 169 | pos_encoder=pos_encoder, 170 | modulation=cfg.model.latent_modulation, 171 | dropout=cfg.model.dropout, 172 | final_activation=torch.sigmoid 173 | ) 174 | model = modules.GON(decoder, cfg.model.latent_updates, cfg.model.learn_origin).to(device) 175 | 176 | print(model) 177 | print(f'# of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}') 178 | 179 | opt = torch.optim.Adam(model.parameters(), lr=cfg.training.lr, weight_decay=1e-3) 180 | 181 | try: 182 | print('TRAINING') 183 | for epoch in range(cfg.training.epochs): 184 | train_latents, train_labels = train( 185 | model, train_batches, model_input, opt, device, epoch, cfg.model.latent_reg, cfg.logging.log_every 186 | ) 187 | 188 | model.eval() 189 | 190 | train_latents = train_latents.numpy() 191 | train_labels = train_labels.numpy() 192 | 193 | recon_input = model_input.repeat(len(fixed_batch), 1, 1, 1) 194 | latent = model.infer_latents(recon_input, fixed_batch)[0] 195 | recon = model.forward(recon_input, latent) 196 | 197 | gt_recon_pairs = torch.stack(list(chain.from_iterable(zip(fixed_batch, recon)))) 198 | save_image(make_grid(gt_recon_pairs.permute(0, 3, 1, 2), normalize=True), recon_dir / f'{epoch:03d}.png') 199 | 200 | if cfg.logging.n_samples_per_epoch: 201 | cov = np.cov(train_latents.T) 202 | mean = np.mean(train_latents, 0) 203 | 204 | samples = sample(model, model_input, mean, cov, cfg.logging.n_samples_per_epoch) 205 | save_image(samples.permute(0, 3, 1, 2), sample_dir / f'{epoch:03d}.png', normalize=True) 206 | 207 | stats = {'cov': cov, 'mean': mean} 208 | (log_dir / 'stats.p').write_bytes(pickle.dumps(stats)) 209 | 210 | (log_dir / 'train_data.p').write_bytes(pickle.dumps({'latents': train_latents, 'labels': train_labels})) 211 | torch.save(model, log_dir / 'model.p') 212 | except KeyboardInterrupt: 213 | print('Interrupting training') 214 | 215 | print('EVALUATION') 216 | test_latents, test_labels = eval(model, test_batches, model_input, device) 217 | 218 | test_latents = test_latents.numpy() 219 | test_labels = test_labels.numpy() 220 | 221 | (log_dir / 'test_data.p').write_bytes(pickle.dumps({'latents': test_latents, 'labels': test_labels})) 222 | 223 | 224 | if __name__ == '__main__': 225 | main() 226 | --------------------------------------------------------------------------------