├── figures └── network.png ├── normflow ├── __pycache__ │ ├── HAIS.cpython-38.pyc │ ├── HAIS.cpython-39.pyc │ ├── core.cpython-38.pyc │ ├── core.cpython-39.pyc │ ├── nets.cpython-38.pyc │ ├── nets.cpython-39.pyc │ ├── utils.cpython-38.pyc │ ├── utils.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── transforms.cpython-38.pyc │ └── transforms.cpython-39.pyc ├── flows │ ├── __pycache__ │ │ ├── base.cpython-38.pyc │ │ ├── base.cpython-39.pyc │ │ ├── glow.cpython-38.pyc │ │ ├── glow.cpython-39.pyc │ │ ├── mixing.cpython-38.pyc │ │ ├── mixing.cpython-39.pyc │ │ ├── planar.cpython-38.pyc │ │ ├── planar.cpython-39.pyc │ │ ├── radial.cpython-38.pyc │ │ ├── radial.cpython-39.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── reshape.cpython-38.pyc │ │ ├── reshape.cpython-39.pyc │ │ ├── residual.cpython-38.pyc │ │ ├── residual.cpython-39.pyc │ │ ├── stochastic.cpython-38.pyc │ │ ├── stochastic.cpython-39.pyc │ │ ├── neural_spline.cpython-38.pyc │ │ ├── neural_spline.cpython-39.pyc │ │ ├── normalization.cpython-38.pyc │ │ ├── normalization.cpython-39.pyc │ │ ├── affine_coupling.cpython-38.pyc │ │ └── affine_coupling.cpython-39.pyc │ ├── base.py │ ├── __init__.py │ ├── radial.py │ ├── residual.py │ ├── normalization.py │ ├── planar.py │ ├── glow.py │ ├── stochastic.py │ ├── reshape.py │ ├── neural_spline.py │ ├── mixing.py │ └── affine_coupling.py ├── distributions │ ├── __pycache__ │ │ ├── base.cpython-38.pyc │ │ ├── base.cpython-39.pyc │ │ ├── prior.cpython-38.pyc │ │ ├── prior.cpython-39.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── decoder.cpython-38.pyc │ │ ├── decoder.cpython-39.pyc │ │ ├── encoder.cpython-38.pyc │ │ ├── encoder.cpython-39.pyc │ │ ├── target.cpython-38.pyc │ │ ├── target.cpython-39.pyc │ │ ├── mh_proposal.cpython-38.pyc │ │ ├── mh_proposal.cpython-39.pyc │ │ ├── linear_interpolation.cpython-38.pyc │ │ └── linear_interpolation.cpython-39.pyc │ ├── __init__.py │ ├── linear_interpolation.py │ ├── mh_proposal.py │ ├── decoder.py │ ├── target.py │ ├── encoder.py │ └── prior.py ├── __init__.py ├── transforms.py ├── HAIS.py ├── utils.py └── nets.py ├── ops ├── __init__.py ├── operator.py ├── radon_3d_lib.py ├── traveltime_lib.py └── odl_lib.py ├── config_IP_solver.py ├── config_generative.py ├── flow_model.py ├── LICENSE ├── config_funknn.py ├── datasets.py ├── README.md ├── laplacian_loss.py ├── autoencoder_model.py ├── train_funknn.py ├── environment.yml ├── train_generative.py ├── results.py ├── utils.py └── funknn_model.py /figures/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/figures/network.png -------------------------------------------------------------------------------- /normflow/__pycache__/HAIS.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/HAIS.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/HAIS.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/HAIS.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/core.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/core.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/core.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/core.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/nets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/nets.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/nets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/nets.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/base.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/glow.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/glow.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/glow.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/glow.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/mixing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/mixing.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/mixing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/mixing.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/planar.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/planar.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/planar.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/planar.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/radial.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/radial.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/radial.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/radial.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/reshape.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/reshape.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/reshape.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/reshape.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/residual.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/residual.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/residual.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/residual.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/stochastic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/stochastic.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/stochastic.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/stochastic.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/base.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/prior.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/prior.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/prior.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/prior.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/neural_spline.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/neural_spline.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/neural_spline.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/neural_spline.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/normalization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/normalization.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/normalization.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/normalization.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/decoder.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/encoder.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/encoder.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/target.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/target.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/target.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/target.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/affine_coupling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/affine_coupling.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/flows/__pycache__/affine_coupling.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/affine_coupling.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/mh_proposal.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/mh_proposal.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/mh_proposal.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/mh_proposal.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/linear_interpolation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/linear_interpolation.cpython-38.pyc -------------------------------------------------------------------------------- /normflow/distributions/__pycache__/linear_interpolation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/linear_interpolation.cpython-39.pyc -------------------------------------------------------------------------------- /normflow/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from .core import * 5 | from . import flows 6 | from . import distributions 7 | from . import transforms 8 | from . import nets 9 | from . import utils 10 | from . import HAIS 11 | 12 | __version__ = '1.1' 13 | -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .odl_lib import OperatorAsAutogradFunction 2 | from .odl_lib import ParallelBeamGeometryOp 3 | from .odl_lib import ParallelBeamGeometryOpBroken 4 | from .odl_lib import ParallelBeamGeometryOpNonUniform 5 | 6 | from .radon_3d_lib import ParallelBeamGeometry3DOp, ParallelBeamGeometry3DOpBroken 7 | 8 | from .traveltime_lib import TravelTimeOperator 9 | 10 | from .operator import get_operator_dict 11 | 12 | -------------------------------------------------------------------------------- /config_IP_solver.py: -------------------------------------------------------------------------------- 1 | gpu_num = 0 2 | image_size = 256 # Working resolution for solving inverse problems 3 | problem = 'PDE' # inverse problem:{CT, PDE} 4 | sparse_derivatives = False # Sparse derivative option, just for PDE problem 5 | funknn_path = 'experiments/' + 'funknn_celeba-hq_128_factor_default' # Trained Funknn folder 6 | autoencoder_path = 'experiments/' + 'generator_celeba-hq_5_256_128_default' # Trained generative autoencoder folder 7 | exp_desc = 'default' # A note to indicate which version of funknn and generative autoencoder are combined -------------------------------------------------------------------------------- /normflow/flows/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | 5 | # Generic flow module 6 | class Flow(nn.Module): 7 | """ 8 | Generic class for flow functions 9 | """ 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, z): 14 | """ 15 | :param z: input variable, first dimension is batch dim 16 | :return: transformed z and log of absolute determinant 17 | """ 18 | raise NotImplementedError('Forward pass has not been implemented.') 19 | 20 | def inverse(self, z): 21 | raise NotImplementedError('This flow has no algebraic inverse.') -------------------------------------------------------------------------------- /normflow/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseDistribution, DiagGaussian, ClassCondDiagGaussian, \ 2 | GlowBase, AffineGaussian, GaussianMixture, GaussianPCA 3 | from .target import Target, TwoMoons, CircularGaussianMixture, RingMixture 4 | 5 | from .encoder import BaseEncoder, Dirac, Uniform, NNDiagGaussian 6 | from .decoder import BaseDecoder, NNDiagGaussianDecoder, NNBernoulliDecoder 7 | from .prior import PriorDistribution, ImagePrior, TwoModes, Sinusoidal, \ 8 | Sinusoidal_split, Sinusoidal_gap, Smiley 9 | 10 | from .mh_proposal import MHProposal, DiagGaussianProposal 11 | 12 | from .linear_interpolation import LinearInterpolation -------------------------------------------------------------------------------- /normflow/flows/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Flow 2 | 3 | from .reshape import Merge, Split, Squeeze 4 | from .mixing import Permute, InvertibleAffine, Invertible1x1Conv, LULinearPermute 5 | from .normalization import BatchNorm, ActNorm 6 | 7 | from .planar import Planar 8 | from .radial import Radial 9 | 10 | from .affine_coupling import AffineConstFlow, CCAffineConst, AffineCoupling, MaskedAffineFlow, AffineCouplingBlock 11 | from .glow import GlowBlock 12 | 13 | from .residual import Residual 14 | from .neural_spline import CoupledRationalQuadraticSpline, AutoregressiveRationalQuadraticSpline 15 | 16 | from .stochastic import MetropolisHastings, HamiltonianMonteCarlo 17 | -------------------------------------------------------------------------------- /config_generative.py: -------------------------------------------------------------------------------- 1 | epochs_aeder = 200 # number of epochs to train autoencoder 2 | epochs_flow = 200 # number of epochs to train flow 3 | flow_depth = 5 4 | latent_dim = 256 # Latent dimension of the flow model 5 | batch_size = 64 6 | dataset = 'celeb-hq' 7 | gpu_num = 0 # GPU number 8 | exp_desc = 'default' # Add a small descriptor to the experiment 9 | image_size = 128 # Resolution of the dataset 10 | c = 3 # Number of channels of the dataset 11 | train_aeder = True # Train autoencoder or just reload 12 | train_flow = True # Train flow or just reload 13 | restore_aeder = False # Restore the trained autoencoder if exists 14 | restore_flow = False # Restore the trained flow if exists 15 | -------------------------------------------------------------------------------- /normflow/distributions/linear_interpolation.py: -------------------------------------------------------------------------------- 1 | class LinearInterpolation: 2 | """ 3 | Linear interpolation of two distributions in the log space 4 | """ 5 | def __init__(self, dist1, dist2, alpha): 6 | """ 7 | Constructor 8 | :param dist1: First distribution 9 | :param dist2: Second distribution 10 | :param alpha: Interpolation parameter, 11 | log_p = alpha * log_p_1 + (1 - alpha) * log_p_2 12 | """ 13 | self.alpha = alpha 14 | self.dist1 = dist1 15 | self.dist2 = dist2 16 | 17 | def log_prob(self, z): 18 | return self.alpha * self.dist1.log_prob(z)\ 19 | + (1 - self.alpha) * self.dist2.log_prob(z) -------------------------------------------------------------------------------- /flow_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import normflow as nf 3 | import numpy as np 4 | 5 | def real_nvp(latent_dim, K=64): 6 | 7 | 8 | b = torch.Tensor([1 if i % 2 == 0 else 0 for i in range(latent_dim)]) 9 | flows = [] 10 | for i in range(K): 11 | s = nf.nets.MLP([latent_dim, 2 * latent_dim, latent_dim], init_zeros=True) 12 | t = nf.nets.MLP([latent_dim, 2 * latent_dim, latent_dim], init_zeros=True) 13 | if i % 2 == 0: 14 | flows += [nf.flows.MaskedAffineFlow(b, t, s)] 15 | else: 16 | flows += [nf.flows.MaskedAffineFlow(1 - b, t, s)] 17 | flows += [nf.flows.ActNorm(latent_dim)] 18 | 19 | q0 = nf.distributions.DiagGaussian(latent_dim) 20 | 21 | # Construct flow model 22 | nfm = nf.NormalizingFlow(q0=q0, flows=flows) 23 | 24 | return nfm 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 swing-research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /normflow/flows/radial.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | from .base import Flow 6 | 7 | 8 | 9 | class Radial(Flow): 10 | """ 11 | Radial flow as introduced in arXiv: 1505.05770 12 | f(z) = z + beta * h(alpha, r) * (z - z_0) 13 | """ 14 | def __init__(self, shape, z_0=None): 15 | """ 16 | Constructor of the radial flow 17 | :param shape: shape of the latent variable z 18 | :param z_0: parameter of the radial flow 19 | """ 20 | super().__init__() 21 | self.d_cpu = torch.prod(torch.tensor(shape)) 22 | self.register_buffer('d', self.d_cpu) 23 | self.beta = nn.Parameter(torch.empty(1)) 24 | lim = 1.0 / np.prod(shape) 25 | nn.init.uniform_(self.beta, -lim - 1.0, lim - 1.0) 26 | self.alpha = nn.Parameter(torch.empty(1)) 27 | nn.init.uniform_(self.alpha, -lim, lim) 28 | 29 | if z_0 is not None: 30 | self.z_0 = nn.Parameter(z_0) 31 | else: 32 | self.z_0 = nn.Parameter(torch.randn(shape)[None]) 33 | 34 | def forward(self, z): 35 | beta = torch.log(1 + torch.exp(self.beta)) - torch.abs(self.alpha) 36 | dz = z - self.z_0 37 | r = torch.norm(dz, dim=list(range(1, self.z_0.dim()))) 38 | h_arr = beta / (torch.abs(self.alpha) + r) 39 | h_arr_ = - beta * r / (torch.abs(self.alpha) + r) ** 2 40 | z_ = z + h_arr.unsqueeze(1) * dz 41 | log_det = (self.d - 1) * torch.log(1 + h_arr) + torch.log(1 + h_arr + h_arr_) 42 | return z_, log_det -------------------------------------------------------------------------------- /config_funknn.py: -------------------------------------------------------------------------------- 1 | epochs_funknn = 200 # number of epochs to train funknn network 2 | batch_size = 64 3 | dataset = 'celeb-hq' 4 | gpu_num = 0 # GPU number 5 | exp_desc = 'default' # Add a small descriptor to the experiment 6 | image_size = 128 # Maximum resolution of the training dataset 7 | c = 3 # Number of channels of the dataset 8 | train_funknn = True # Train or just reload to test 9 | restore_funknn = True 10 | training_mode = 'factor' # Training modes for funknn: {conitinuous, factor, single} 11 | ood_analysis = True # Evaluating the performance of model over out of distribution data (Lsun-bedroom) 12 | interpolation_kernel = 'bicubic' # interpolation kernels : {'cubic_conv', 'bilinear', 'bicubic'} 13 | # interpolation_kernel = 'bicubic' cannot be used for solving PDEs as its derivatives are not computed 14 | # but can be safely used for super-resolution task. 15 | # interpolation_kernel = 'bilinear' is fast but can only be used for PDEs with the fiesr-order derivatives. 16 | # interpolation_kernel = 'cubic_conv' is slow but can be used for solving PDEs with first- and second-order derivatives 17 | network = 'MLP' # The network can be a 'CNN' or 'MLP' ('MLP' is supposed to be significantly faster) 18 | activation = 'relu' # Activation function 'relu' or 'sin' ('sin' for more accurate spatial derivatives) 19 | 20 | # Evaluation arguements 21 | max_scale = 2 # Maximum scale to generate in test time (2 or 4 or 8) (=<8 for celeba-hq and 2 for other datasets) 22 | recursive = True # Recursive image reconstructions (Use just for factor training mode) 23 | sample_number = 25 # Number of samples in evaluation 24 | derivatives_evaluation = False # To evaluate the performance of the model for computing the derivatives 25 | # (Keep it False for 'bicubic' kernel) 26 | 27 | # Datasets paths: 28 | ood_path = 'datasets/lsun_bedroom_val' 29 | train_path = 'datasets/celeba_hq/celeba_hq_1024_train/' 30 | test_path = 'datasets/celeba_hq/celeba_hq_1024_test/' 31 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import torchvision 4 | from torchvision.datasets import ImageFolder 5 | from utils import * 6 | import config_funknn as config 7 | 8 | 9 | class Dataset_loader(torch.utils.data.Dataset): 10 | def __init__(self, dataset , size=(1024,1024), c = 3, quantize = False): 11 | 12 | if c==1: 13 | self.transform = transforms.Compose([ 14 | transforms.Resize(size), 15 | transforms.Grayscale(), 16 | transforms.ToTensor(), 17 | ]) 18 | else: 19 | self.transform = transforms.Compose([ 20 | transforms.Resize(size), 21 | transforms.ToTensor(), 22 | ]) 23 | 24 | self.c = c 25 | self.dataset = dataset 26 | self.meshgrid = get_mgrid(size[0]) 27 | self.im_size = size 28 | self.quantize = quantize 29 | 30 | if self.dataset == 'train': 31 | self.img_dataset = ImageFolder(config.train_path, self.transform) 32 | 33 | elif self.dataset == 'test': 34 | self.img_dataset = ImageFolder(config.test_path, self.transform) 35 | 36 | elif self.dataset == 'ood': 37 | lsun_class = ['bedroom_val'] 38 | self.img_dataset = torchvision.datasets.LSUN(config.ood_path, 39 | classes=lsun_class, transform=self.transform) 40 | 41 | def __len__(self): 42 | return len(self.img_dataset) 43 | 44 | def __getitem__(self, item): 45 | img = self.img_dataset[item][0] 46 | img = transforms.ToPILImage()(img) 47 | img = self.transform(img).permute([1,2,0]) 48 | img = img.reshape(-1, self.c) 49 | 50 | if self.quantize: 51 | img = img * 255.0 52 | img = torch.multiply(8.0, torch.div(img , 8 , rounding_mode = 'floor')) 53 | img = img/255.0 54 | 55 | return img 56 | 57 | -------------------------------------------------------------------------------- /normflow/flows/residual.py: -------------------------------------------------------------------------------- 1 | from .base import Flow 2 | 3 | # Try importing Residual Flow dependencies 4 | try: 5 | from residual_flows.layers import iResBlock 6 | except: 7 | print('Warning: Dependencies for Residual Flows could ' 8 | 'not be loaded. Other models can still be used.') 9 | 10 | 11 | 12 | class Residual(Flow): 13 | """ 14 | Invertible residual net block, wrapper to the implementation of Chen et al., 15 | see https://github.com/rtqichen/residual-flows 16 | """ 17 | def __init__(self, net, n_exact_terms=2, n_samples=1, reduce_memory=True, 18 | reverse=True): 19 | """ 20 | Constructor 21 | :param net: Neural network, must be Lipschitz continuous with L < 1 22 | :param n_exact_terms: Number of terms always included in the power series 23 | :param n_samples: Number of samples used to estimate power series 24 | :param reduce_memory: Flag, if true Neumann series and precomputations 25 | for backward pass in forward pass are done 26 | :param reverse: Flag, if true the map f(x) = x + net(x) is applied in 27 | the inverse pass, otherwise it is done in forward 28 | """ 29 | super().__init__() 30 | self.reverse = reverse 31 | self.iresblock = iResBlock(net, n_samples=n_samples, 32 | n_exact_terms=n_exact_terms, 33 | neumann_grad=reduce_memory, 34 | grad_in_forward=reduce_memory) 35 | 36 | def forward(self, z): 37 | if self.reverse: 38 | z, log_det = self.iresblock.inverse(z, 0) 39 | else: 40 | z, log_det = self.iresblock.forward(z, 0) 41 | return z, -log_det.view(-1) 42 | 43 | def inverse(self, z): 44 | if self.reverse: 45 | z, log_det = self.iresblock.forward(z, 0) 46 | else: 47 | z, log_det = self.iresblock.inverse(z, 0) 48 | return z, -log_det.view(-1) -------------------------------------------------------------------------------- /normflow/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from . import flows 4 | 5 | # Transforms to be applied to data as preprocessing 6 | 7 | class Logit(flows.Flow): 8 | """ 9 | Logit mapping of image tensor, see RealNVP paper 10 | logit(alpha + (1 - alpha) * x) where logit(x) = log(x / (1 - x)) 11 | """ 12 | def __init__(self, alpha=0.05): 13 | """ 14 | Constructor 15 | :param alpha: Alpha parameter, see above 16 | """ 17 | super().__init__() 18 | self.alpha = alpha 19 | 20 | def forward(self, z): 21 | beta = 1 - 2 * self.alpha 22 | sum_dims = list(range(1, z.dim())) 23 | ls = torch.sum(torch.nn.functional.logsigmoid(z), dim=sum_dims) 24 | mls = torch.sum(torch.nn.functional.logsigmoid(-z), dim=sum_dims) 25 | log_det = -np.log(beta) * np.prod([*z.shape[1:]]) + ls + mls 26 | z = (torch.sigmoid(z) - self.alpha) / beta 27 | return z, log_det 28 | 29 | def inverse(self, z): 30 | beta = 1 - 2 * self.alpha 31 | z = self.alpha + beta * z 32 | logz = torch.log(z) 33 | log1mz = torch.log(1 - z) 34 | z = logz - log1mz 35 | sum_dims = list(range(1, z.dim())) 36 | log_det = np.log(beta) * np.prod([*z.shape[1:]]) \ 37 | - torch.sum(logz, dim=sum_dims) \ 38 | - torch.sum(log1mz, dim=sum_dims) 39 | return z, log_det 40 | 41 | 42 | class Shift(flows.Flow): 43 | """ 44 | Shift data by a fixed constant, default is -0.5 to shift data from 45 | interval [0, 1] to [-0.5, 0.5] 46 | """ 47 | def __init__(self, shift=-0.5): 48 | """ 49 | Constructor 50 | :param shift: Shift to apply to the data 51 | """ 52 | super().__init__() 53 | self.shift = shift 54 | 55 | def forward(self, z): 56 | z -= self.shift 57 | log_det = 0. 58 | return z, log_det 59 | 60 | def inverse(self, z): 61 | z += self.shift 62 | log_det = 0. 63 | return z, log_det -------------------------------------------------------------------------------- /normflow/HAIS.py: -------------------------------------------------------------------------------- 1 | ### Implementation of Hamiltonian Annealed Importance Sampling ### 2 | 3 | import torch 4 | from . import distributions 5 | from . import flows 6 | 7 | class HAIS(): 8 | """ 9 | Class which performs HAIS 10 | """ 11 | def __init__(self, betas, prior, target, num_leapfrog, step_size, log_mass): 12 | """ 13 | :param betas: Annealing schedule, the jth target is f_j(x) = 14 | f_0(x)^{\beta_j} f_n(x)^{1-\beta_j} where the target is proportional 15 | to f_0 and the prior is proportional to f_n. The number of 16 | intermediate steps is infered from the shape of betas. 17 | Should be of the form 1 = \beta_0 > \beta_1 > ... > \beta_n = 0 18 | :param prior: The prior distribution to start the HAIS chain. 19 | :param target: The target distribution from which we would like to draw 20 | weighted samples. 21 | :param num_leapfrog: Number of leapfrog steps in the HMC transitions. 22 | :param step_size: step_size to use for HMC transitions. 23 | :param log_mass: log_mass to use for HMC transitions. 24 | """ 25 | self.prior = prior 26 | self.target = target 27 | self.layers = [] 28 | n = betas.shape[0] - 1 29 | for i in range(n-1, 0, -1): 30 | intermediate_target = distributions.LinearInterpolation(self.target, 31 | self.prior, betas[i]) 32 | self.layers += [flows.HamiltonianMonteCarlo(intermediate_target, 33 | num_leapfrog, torch.log(step_size), log_mass)] 34 | 35 | def sample(self, num_samples): 36 | """ 37 | Run HAIS to draw samples from the target with appropriate weights. 38 | :param num_samples: The number of samples to draw. 39 | """ 40 | samples, log_weights= self.prior.forward(num_samples) 41 | log_weights = -log_weights 42 | for i in range(len(self.layers)): 43 | samples, log_weights_addition = self.layers[i].forward(samples) 44 | log_weights += log_weights_addition 45 | log_weights += self.target.log_prob(samples) 46 | return samples, log_weights 47 | -------------------------------------------------------------------------------- /normflow/distributions/mh_proposal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | 7 | class MHProposal(nn.Module): 8 | """ 9 | Proposal distribution for the Metropolis Hastings algorithm 10 | """ 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def sample(self, z): 15 | """ 16 | Sample new value based on previous z 17 | """ 18 | raise NotImplementedError 19 | 20 | def log_prob(self, z_, z): 21 | """ 22 | :param z_: Potential new sample 23 | :param z: Previous sample 24 | :return: Log probability of proposal distribution 25 | """ 26 | raise NotImplementedError 27 | 28 | def forward(self, z): 29 | """ 30 | Draw samples given z and compute log probability difference 31 | log(p(z | z_new)) - log(p(z_new | z)) 32 | :param z: Previous samples 33 | :return: Proposal, difference of log probability ratio 34 | """ 35 | raise NotImplementedError 36 | 37 | 38 | class DiagGaussianProposal(MHProposal): 39 | """ 40 | Diagonal Gaussian distribution with previous value as mean 41 | as a proposal for Metropolis Hastings algorithm 42 | """ 43 | def __init__(self, shape, scale): 44 | """ 45 | Constructor 46 | :param shape: Shape of variables to sample 47 | :param scale: Standard deviation of distribution 48 | """ 49 | super().__init__() 50 | self.shape = shape 51 | self.scale_cpu = torch.tensor(scale) 52 | self.register_buffer("scale", self.scale_cpu.unsqueeze(0)) 53 | 54 | def sample(self, z): 55 | num_samples = len(z) 56 | eps = torch.randn((num_samples,) + self.shape, dtype=z.dtype, device=z.device) 57 | z_ = eps * self.scale + z 58 | return z_ 59 | 60 | def log_prob(self, z_, z): 61 | log_p = - 0.5 * np.prod(self.shape) * np.log(2 * np.pi) \ 62 | - torch.sum(torch.log(self.scale) + 0.5 * torch.pow((z_ - z) / self.scale, 2), 63 | list(range(1, z.dim()))) 64 | return log_p 65 | 66 | def forward(self, z): 67 | num_samples = len(z) 68 | eps = torch.randn((num_samples,) + self.shape, dtype=z.dtype, device=z.device) 69 | z_ = eps * self.scale + z 70 | log_p_diff = torch.zeros(num_samples, dtype=z.dtype, device=z.device) 71 | return z_, log_p_diff -------------------------------------------------------------------------------- /normflow/flows/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base import Flow 4 | from .affine_coupling import AffineConstFlow 5 | 6 | 7 | 8 | class ActNorm(AffineConstFlow): 9 | """ 10 | An AffineConstFlow but with a data-dependent initialization, 11 | where on the very first batch we clever initialize the s,t so that the output 12 | is unit gaussian. As described in Glow paper. 13 | """ 14 | 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.data_dep_init_done_cpu = torch.tensor(0.) 18 | self.register_buffer('data_dep_init_done', self.data_dep_init_done_cpu) 19 | 20 | def forward(self, z): 21 | # first batch is used for initialization, c.f. batchnorm 22 | if not self.data_dep_init_done > 0.: 23 | assert self.s is not None and self.t is not None 24 | s_init = -torch.log(z.std(dim=self.batch_dims, keepdim=True) + 1e-6) 25 | self.s.data = s_init.data 26 | self.t.data = (-z.mean(dim=self.batch_dims, keepdim=True) * torch.exp(self.s)).data 27 | self.data_dep_init_done = torch.tensor(1.) 28 | return super().forward(z) 29 | 30 | def inverse(self, z): 31 | # first batch is used for initialization, c.f. batchnorm 32 | if not self.data_dep_init_done: 33 | assert self.s is not None and self.t is not None 34 | s_init = torch.log(z.std(dim=self.batch_dims, keepdim=True) + 1e-6) 35 | self.s.data = s_init.data 36 | self.t.data = z.mean(dim=self.batch_dims, keepdim=True).data 37 | self.data_dep_init_done = torch.tensor(1.) 38 | return super().inverse(z) 39 | 40 | 41 | class BatchNorm(Flow): 42 | """ 43 | Batch Normalization with out considering the derivatives of the batch statistics, see arXiv: 1605.08803 44 | """ 45 | def __init__(self, eps=1.e-10): 46 | super().__init__() 47 | self.eps_cpu = torch.tensor(eps) 48 | self.register_buffer('eps', self.eps_cpu) 49 | 50 | def forward(self, z): 51 | """ 52 | Do batch norm over batch and sample dimension 53 | """ 54 | mean = torch.mean(z, dim=0, keepdims=True) 55 | std = torch.std(z, dim=0, keepdims=True) 56 | z_ = (z - mean) / torch.sqrt(std ** 2 + self.eps) 57 | log_det = torch.log(1 / torch.prod(torch.sqrt(std ** 2 + self.eps))).repeat(z.size()[0]) 58 | return z_, log_det -------------------------------------------------------------------------------- /normflow/distributions/decoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | 7 | class BaseDecoder(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, z): 12 | """ 13 | Decodes z to x 14 | :param z: latent variable 15 | :return: x, std of x 16 | """ 17 | raise NotImplementedError 18 | 19 | def log_prob(self, x, z): 20 | """ 21 | :param x: observable 22 | :param z: latent variable 23 | :return: log(p) of x given z 24 | """ 25 | raise NotImplementedError 26 | 27 | 28 | class NNDiagGaussianDecoder(BaseDecoder): 29 | """ 30 | BaseDecoder representing a diagonal Gaussian distribution with mean and std parametrized by a NN 31 | """ 32 | def __init__(self, net): 33 | """ 34 | Constructor 35 | :param net: neural network parametrizing mean and standard deviation of diagonal Gaussian 36 | """ 37 | super().__init__() 38 | self.net = net 39 | 40 | def forward(self, z): 41 | z_size = z.size() 42 | mean_std = self.net(z.view(-1, *z_size[2:])).view(z_size) 43 | n_hidden = mean_std.size()[2] // 2 44 | mean = mean_std[:, :, :n_hidden, ...] 45 | std = torch.exp(0.5 * mean_std[:, :, n_hidden:(2 * n_hidden), ...]) 46 | return mean, std 47 | 48 | def log_prob(self, x, z): 49 | mean_std = self.net(z.view(-1, *z.size()[2:])).view(*z.size()[:2], x.size(1) * 2, *x.size()[3:]) 50 | n_hidden = mean_std.size()[2] // 2 51 | mean = mean_std[:, :, :n_hidden, ...] 52 | var = torch.exp(mean_std[:, :, n_hidden:(2 * n_hidden), ...]) 53 | log_p = - 0.5 * torch.prod(torch.tensor(z.size()[2:])) * np.log(2 * np.pi) \ 54 | - 0.5 * torch.sum(torch.log(var) + (x.unsqueeze(1) - mean) ** 2 / var, list(range(2, z.dim()))) 55 | return log_p 56 | 57 | 58 | class NNBernoulliDecoder(BaseDecoder): 59 | """ 60 | BaseDecoder representing a Bernoulli distribution with mean parametrized by a NN 61 | """ 62 | 63 | def __init__(self, net): 64 | """ 65 | Constructor 66 | :param net: neural network parametrizing mean Bernoulli (mean = sigmoid(nn_out) 67 | """ 68 | super().__init__() 69 | self.net = net 70 | 71 | def forward(self, z): 72 | mean = torch.sigmoid(self.net(z)) 73 | return mean 74 | 75 | def log_prob(self, x, z): 76 | score = self.net(z) 77 | x = x.unsqueeze(1) 78 | x = x.repeat(1, z.size()[0] // x.size()[0], *((x.dim() - 2) * [1])).view(-1, *x.size()[2:]) 79 | log_sig = lambda a: -torch.relu(-a) - torch.log(1 + torch.exp(-torch.abs(a))) 80 | log_p = torch.sum(x * log_sig(score) + (1 - x) * log_sig(-score), list(range(1, x.dim()))) 81 | return log_p -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FunkNN: Neural Interpolation for Functional Generation 2 | 3 | 4 | [![Paper](https://img.shields.io/badge/arxiv-report-red)](https://arxiv.org/abs/2212.14042) 5 | [![PWC](https://img.shields.io/badge/PWC-report-blue)](https://paperswithcode.com/paper/funknn-neural-interpolation-for-functional) 6 | 7 | This repository is the official Pytorch implementation of "[FunkNN: Neural Interpolation for Functional Generation](https://openreview.net/forum?id=BT4N_v7CLrk)" in ICLR 2023. 8 | 9 | | [**Project Page**](https://sada.dmi.unibas.ch/en/research/implicit-neural-representation) | 10 | 11 | 12 |

13 | 14 |

15 | 16 | 17 | 18 | ## Requirements 19 | (This code is tested with PyTorch 1.12.1, Python 3.8.3, CUDA 11.6 and cuDNN 7.) 20 | - numpy 21 | - scipy 22 | - matplotlib 23 | - odl 24 | - imageio 25 | - torch==1.12.1 26 | - torchvision=0.13.1 27 | - astra-toolbox 28 | 29 | ## Installation 30 | 31 | Run the following code to install conda environment "environment.yml": 32 | ```sh 33 | conda env create -f environment.yml 34 | ``` 35 | 36 | ## Datasets 37 | You can download the [CelebA-HQ](https://drive.switch.ch/index.php/s/pA6X3TY9x4jgcxb), [LoDoPaB-CT](https://drive.switch.ch/index.php/s/lQeYWmAIYcEEdlc) and [LSUN-bedroom](https://drive.switch.ch/index.php/s/d1MNcrUZkPpK0zx) validation datasets and split them into train and test sets and put them in the data folder. You should specify the data folder addresses in config_funknn.py and config_generative.py. 38 | 39 | ## Experiments 40 | ### Train FunkNN 41 | All arguments for training FunkNN model are explained in config_funknn.py. After specifying your arguments, you can run the following command to train the model: 42 | ```sh 43 | python3 train_funknn.py 44 | ``` 45 | 46 | ### Train generative autoencoder 47 | All arguments for training generative autoencoder are explained in config_generative.py. After specifying your arguments, you can run the following command to train the model: 48 | ```sh 49 | python3 train_generative.py 50 | ``` 51 | 52 | 53 | ### Solving inverse problems 54 | All arguments for solving inverse problem by combining FunkNN and generative autoencoder are explained in config_IP_solver.py. After specifying your arguments including the folder address of trained FunkNN and generator, you can run the following command to solve the inverse problem of your choice (CT or PDE): 55 | ```sh 56 | python3 IP_solver.py 57 | ``` 58 | 59 | ## Citation 60 | If you find the code useful in your research, please consider citing the paper. 61 | 62 | ``` 63 | @inproceedings{ 64 | khorashadizadeh2023funknn, 65 | title={Funk{NN}: Neural Interpolation for Functional Generation}, 66 | author={AmirEhsan Khorashadizadeh and Anadi Chaman and Valentin Debarnot and Ivan Dokmani{\'c}}, 67 | booktitle={The Eleventh International Conference on Learning Representations }, 68 | year={2023}, 69 | url={https://openreview.net/forum?id=BT4N_v7CLrk} 70 | } 71 | ``` 72 | 73 | -------------------------------------------------------------------------------- /normflow/flows/planar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | from .base import Flow 6 | 7 | 8 | 9 | class Planar(Flow): 10 | """ 11 | Planar flow as introduced in arXiv: 1505.05770 12 | f(z) = z + u * h(w * z + b) 13 | """ 14 | 15 | def __init__(self, shape, act="tanh", u=None, w=None, b=None): 16 | """ 17 | Constructor of the planar flow 18 | :param shape: shape of the latent variable z 19 | :param h: nonlinear function h of the planar flow (see definition of f above) 20 | :param u,w,b: optional initialization for parameters 21 | """ 22 | super().__init__() 23 | lim_w = np.sqrt(2. / np.prod(shape)) 24 | lim_u = np.sqrt(2) 25 | 26 | if u is not None: 27 | self.u = nn.Parameter(u) 28 | else: 29 | self.u = nn.Parameter(torch.empty(shape)[None]) 30 | nn.init.uniform_(self.u, -lim_u, lim_u) 31 | if w is not None: 32 | self.w = nn.Parameter(w) 33 | else: 34 | self.w = nn.Parameter(torch.empty(shape)[None]) 35 | nn.init.uniform_(self.w, -lim_w, lim_w) 36 | if b is not None: 37 | self.b = nn.Parameter(b) 38 | else: 39 | self.b = nn.Parameter(torch.zeros(1)) 40 | 41 | self.act = act 42 | if act == "tanh": 43 | self.h = torch.tanh 44 | elif act == "leaky_relu": 45 | self.h = torch.nn.LeakyReLU(negative_slope=0.2) 46 | else: 47 | raise NotImplementedError('Nonlinearity is not implemented.') 48 | 49 | def forward(self, z): 50 | lin = torch.sum(self.w * z, list(range(1, self.w.dim()))) + self.b 51 | if self.act == "tanh": 52 | inner = torch.sum(self.w * self.u) 53 | u = self.u + (torch.log(1 + torch.exp(inner)) - 1 - inner) * self.w / torch.sum(self.w ** 2) 54 | h_ = lambda x: 1 / torch.cosh(x) ** 2 55 | elif self.act == "leaky_relu": 56 | inner = torch.sum(self.w * self.u) 57 | u = self.u + (torch.log(1 + torch.exp(inner)) - 1 - inner) * self.w / torch.sum( 58 | self.w ** 2) # constraint w.T * u neq -1, use > 59 | h_ = lambda x: (x < 0) * (self.h.negative_slope - 1.0) + 1.0 60 | 61 | z_ = z + u * self.h(lin.unsqueeze(1)) 62 | log_det = torch.log(torch.abs(1 + torch.sum(self.w * u) * h_(lin))) 63 | return z_, log_det 64 | 65 | def inverse(self, z): 66 | if self.act != "leaky_relu": 67 | raise NotImplementedError('This flow has no algebraic inverse.') 68 | lin = torch.sum(self.w * z, list(range(2, self.w.dim())), keepdim=True) + self.b 69 | inner = torch.sum(self.w * self.u) 70 | a = ((lin + self.b) / (1 + inner) < 0) * (self.h.negative_slope - 1.0) + 1.0 # absorb leakyReLU slope into u 71 | u = a * (self.u + (torch.log(1 + torch.exp(inner)) - 1 - inner) * self.w / torch.sum(self.w ** 2)) 72 | z_ = z - 1 / (1 + inner) * (lin + u * self.b) 73 | log_det = -torch.log(torch.abs(1 + torch.sum(self.w * u))) 74 | if log_det.dim() == 0: 75 | log_det = log_det.unsqueeze(0) 76 | if log_det.dim() == 1: 77 | log_det = log_det.unsqueeze(1) 78 | return z_, log_det -------------------------------------------------------------------------------- /laplacian_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | def gaussian_kernel(size=5, device=torch.device('cpu'), channels=3, sigma=1, dtype=torch.float): 8 | # Create Gaussian Kernel. In Numpy 9 | interval = (2*sigma +1)/(size) 10 | ax = np.linspace(-(size - 1)/ 2., (size-1)/2., size) 11 | xx, yy = np.meshgrid(ax, ax) 12 | kernel = np.exp(-0.5 * (np.square(xx)+ np.square(yy)) / np.square(sigma)) 13 | kernel /= np.sum(kernel) 14 | # Change kernel to PyTorch. reshapes to (channels, 1, size, size) 15 | kernel_tensor = torch.as_tensor(kernel, dtype=dtype) 16 | kernel_tensor = kernel_tensor.repeat(channels, 1 , 1, 1) 17 | kernel_tensor = kernel_tensor.to(device) 18 | return kernel_tensor 19 | 20 | def gaussian_conv2d(x, g_kernel, dtype=torch.float): 21 | #Assumes input of x is of shape: (minibatch, depth, height, width) 22 | #Infer depth automatically based on the shape 23 | channels = g_kernel.shape[0] 24 | padding = g_kernel.shape[-1] // 2 # Kernel size needs to be odd number 25 | if len(x.shape) != 4: 26 | raise IndexError('Expected input tensor to be of shape: (batch, depth, height, width) but got: ' + str(x.shape)) 27 | y = F.conv2d(x, weight=g_kernel, stride=1, padding=padding, groups=channels) 28 | return y 29 | 30 | def downsample(x): 31 | # Downsamples along image (H,W). Takes every 2 pixels. output (H, W) = input (H/2, W/2) 32 | return x[:, :, ::2, ::2] 33 | 34 | def create_laplacian_pyramid(x, kernel, levels): 35 | upsample = torch.nn.Upsample(scale_factor=2) # Default mode is nearest: [[1 2],[3 4]] -> [[1 1 2 2],[3 3 4 4]] 36 | pyramids = [] 37 | current_x = x 38 | for level in range(0, levels): 39 | gauss_filtered_x = gaussian_conv2d(current_x, kernel) 40 | down = downsample(gauss_filtered_x) 41 | # Original Algorithm does indeed: L_i = G_i - expand(G_i+1), with L_i as current laplacian layer, and G_i as current gaussian filtered image, and G_i+1 the next. 42 | # Some implementations skip expand(G_i+1) and use gaussian_conv(G_i). We decided to use expand, as is the original algorithm 43 | laplacian = current_x - upsample(down) 44 | pyramids.append(laplacian) 45 | current_x = down 46 | pyramids.append(current_x) 47 | return pyramids 48 | 49 | class LaplacianPyramidLoss(torch.nn.Module): 50 | def __init__(self, max_levels=3, channels=3, kernel_size=5, sigma=1, device=torch.device('cpu'), dtype=torch.float): 51 | super(LaplacianPyramidLoss, self).__init__() 52 | self.max_levels = max_levels 53 | self.kernel = gaussian_kernel(size=kernel_size,device=device, channels=channels, sigma=sigma, dtype=dtype) 54 | 55 | def forward(self, x, target): 56 | input_pyramid = create_laplacian_pyramid(x, self.kernel, self.max_levels) 57 | target_pyramid = create_laplacian_pyramid(target, self.kernel, self.max_levels) 58 | loss = 0.0 59 | cnt = 0 60 | for x, y in zip(input_pyramid, target_pyramid): 61 | w = 2**cnt 62 | loss = loss + w*torch.nn.functional.l1_loss(x,y) 63 | cnt = cnt + 1 64 | # return sum(torch.nn.functional.l1_loss(x,y) for x, y in zip(input_pyramid, target_pyramid)) 65 | return loss 66 | -------------------------------------------------------------------------------- /normflow/flows/glow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .base import Flow 5 | from .affine_coupling import AffineCouplingBlock 6 | from .mixing import Invertible1x1Conv 7 | from .normalization import ActNorm 8 | from .. import nets 9 | 10 | 11 | 12 | class GlowBlock(Flow): 13 | """ 14 | Glow: Generative Flow with Invertible 1×1 Convolutions, arXiv: 1807.03039 15 | One Block of the Glow model, comprised of 16 | MaskedAffineFlow (affine coupling layer 17 | Invertible1x1Conv (dropped if there is only one channel) 18 | ActNorm (first batch used for initialization) 19 | """ 20 | def __init__(self, channels, hidden_channels, scale=True, scale_map='sigmoid', 21 | split_mode='channel', leaky=0.0, init_zeros=True, use_lu=True, 22 | net_actnorm=False): 23 | """ 24 | Constructor 25 | :param channels: Number of channels of the data 26 | :param hidden_channels: number of channels in the hidden layer of the ConvNet 27 | :param scale: Flag, whether to include scale in affine coupling layer 28 | :param scale_map: Map to be applied to the scale parameter, can be 'exp' as in 29 | RealNVP or 'sigmoid' as in Glow 30 | :param split_mode: Splitting mode, for possible values see Split class 31 | :param leaky: Leaky parameter of LeakyReLUs of ConvNet2d 32 | :param init_zeros: Flag whether to initialize last conv layer with zeros 33 | :param use_lu: Flag whether to parametrize weights through the LU decomposition 34 | in invertible 1x1 convolution layers 35 | :param logscale_factor: Factor which can be used to control the scale of 36 | the log scale factor, see https://github.com/openai/glow 37 | """ 38 | super().__init__() 39 | self.flows = nn.ModuleList([]) 40 | # Coupling layer 41 | kernel_size = (3, 1, 3) 42 | num_param = 2 if scale else 1 43 | if 'channel' == split_mode: 44 | channels_ = (channels // 2,) + 2 * (hidden_channels,) 45 | channels_ += (num_param * ((channels + 1) // 2),) 46 | elif 'channel_inv' == split_mode: 47 | channels_ = ((channels + 1) // 2,) + 2 * (hidden_channels,) 48 | channels_ += (num_param * (channels // 2),) 49 | elif 'checkerboard' in split_mode: 50 | channels_ = (channels,) + 2 * (hidden_channels,) 51 | channels_ += (num_param * channels,) 52 | else: 53 | raise NotImplementedError('Mode ' + split_mode + ' is not implemented.') 54 | param_map = nets.ConvNet2d(channels_, kernel_size, leaky, init_zeros, 55 | actnorm=net_actnorm) 56 | self.flows += [AffineCouplingBlock(param_map, scale, scale_map, split_mode)] 57 | # Invertible 1x1 convolution 58 | if channels > 1: 59 | self.flows += [Invertible1x1Conv(channels, use_lu)] 60 | # Activation normalization 61 | self.flows += [ActNorm((channels,) + (1, 1))] 62 | 63 | def forward(self, z): 64 | log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device) 65 | for flow in self.flows: 66 | z, log_det = flow(z) 67 | log_det_tot += log_det 68 | return z, log_det_tot 69 | 70 | def inverse(self, z): 71 | log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device) 72 | for i in range(len(self.flows) - 1, -1, -1): 73 | z, log_det = self.flows[i].inverse(z) 74 | log_det_tot += log_det 75 | return z, log_det_tot -------------------------------------------------------------------------------- /ops/operator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from absl import logging 5 | 6 | from .odl_lib import ParallelBeamGeometryOp, ParallelBeamGeometryOpBroken 7 | from .radon_3d_lib import ParallelBeamGeometry3DOp, ParallelBeamGeometry3DOpBroken 8 | from .traveltime_lib import TravelTimeOperator 9 | 10 | def get_operator_dict(config): 11 | if config.problem == 'radon': 12 | operator = ParallelBeamGeometryOp( 13 | config.img_size, 14 | config.op_param, 15 | op_snr=np.inf, 16 | angle_max=config.angle_max) 17 | 18 | operator_dict = {} 19 | 20 | if config.opt_strat == 'dip_noisy': 21 | broken_operator = ParallelBeamGeometryOpBroken(operator, config.op_snr) 22 | operator_dict.update({'original_operator': broken_operator}) 23 | operator_dict.update({'noisy_operator': operator}) 24 | elif config.opt_strat in ['dip', 'joint', 'joint_input']: 25 | dense_operator = ParallelBeamGeometryOp( 26 | config.img_size, config.dense_op_param, op_snr=np.inf,angle_max=config.angle_max) 27 | operator_dict.update({'dense_operator': dense_operator}) 28 | broken_operator = ParallelBeamGeometryOpBroken(operator, config.op_snr) 29 | operator_dict.update({'original_operator': broken_operator}) 30 | elif config.opt_strat == 'broken_machine': 31 | dense_operator = ParallelBeamGeometryOp( 32 | config.img_size, 33 | config.op_param, 34 | op_snr=np.inf, 35 | angle_max=config.angle_max) 36 | operator_dict = {'original_operator_clean': dense_operator} 37 | 38 | broken_operator = ParallelBeamGeometryOpBroken(dense_operator, config.op_snr) 39 | operator_dict.update({'original_operator': broken_operator}) 40 | 41 | dense_operator = ParallelBeamGeometryOp( 42 | config.img_size, 43 | config.dense_op_param, 44 | op_snr=np.inf, 45 | angle_max=config.angle_max) 46 | operator_dict['dense_operator'] = dense_operator 47 | 48 | logging.info(f'operator_dict: {operator_dict}') 49 | 50 | else: 51 | raise ValueError(f'Did not recognize opt.strat={config.opt_strat}.') 52 | 53 | elif config.problem == 'radon_3d': 54 | operator_dict = {} 55 | 56 | if config.opt_strat == 'broken_machine': 57 | dense_operator = ParallelBeamGeometry3DOp(config.img_size, config.op_param, op_snr=np.inf,angle_max=config.angle_max) 58 | operator_dict = {'original_operator_clean': dense_operator} 59 | 60 | broken_operator = ParallelBeamGeometry3DOpBroken(dense_operator, config.op_snr) 61 | operator_dict.update({'original_operator': broken_operator}) 62 | 63 | dense_operator = ParallelBeamGeometry3DOp(config.img_size, config.dense_op_param, op_snr=np.inf,angle_max=config.angle_max) 64 | operator_dict['dense_operator'] = dense_operator 65 | 66 | logging.info(f'operator_dict: {operator_dict}') 67 | 68 | else: 69 | raise ValueError(f'Did not recognize opt.strat={config.opt_strat}.') 70 | 71 | 72 | elif config.problem == 'traveltime': 73 | original_sensors = torch.rand( 74 | config.op_param, 2, dtype=torch.float32) 75 | 76 | operator = TravelTimeOperator( 77 | original_sensors, 78 | config.img_size) 79 | operator_dict = {'original_operator': operator} 80 | 81 | # Add new sensors for dense operator. 82 | new_sensors = torch.rand( 83 | config.dense_op_param - config.op_param, 2, dtype=torch.float32) 84 | sensors = torch.cat((original_sensors, new_sensors), dim=0) 85 | 86 | dense_operator = TravelTimeOperator( 87 | sensors, config.img_size) 88 | operator_dict.update({'dense_operator': dense_operator}) 89 | 90 | else: 91 | raise ValueError('Inverse problem unrecognized.') 92 | 93 | return operator_dict -------------------------------------------------------------------------------- /normflow/flows/stochastic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base import Flow 4 | 5 | 6 | 7 | class MetropolisHastings(Flow): 8 | """ 9 | Sampling through Metropolis Hastings in Stochastic Normalizing 10 | Flow, see arXiv: 2002.06707 11 | """ 12 | def __init__(self, dist, proposal, steps): 13 | """ 14 | Constructor 15 | :param dist: Distribution to sample from 16 | :param proposal: Proposal distribution 17 | :param steps: Number of MCMC steps to perform 18 | """ 19 | super().__init__() 20 | self.dist = dist 21 | self.proposal = proposal 22 | self.steps = steps 23 | 24 | def forward(self, z): 25 | # Initialize number of samples and log(det) 26 | num_samples = len(z) 27 | log_det = torch.zeros(num_samples, dtype=z.dtype, device=z.device) 28 | # Get log(p) for current samples 29 | log_p = self.dist.log_prob(z) 30 | for i in range(self.steps): 31 | # Make proposal and get log(p) 32 | z_, log_p_diff = self.proposal(z) 33 | log_p_ = self.dist.log_prob(z_) 34 | # Make acceptance decision 35 | w = torch.rand(num_samples, dtype=z.dtype, device=z.device) 36 | log_w_accept = log_p_ - log_p + log_p_diff 37 | w_accept = torch.clamp(torch.exp(log_w_accept), max=1) 38 | accept = w <= w_accept 39 | # Update samples, log(det), and log(p) 40 | z = torch.where(accept.unsqueeze(1), z_, z) 41 | log_det_ = log_p - log_p_ 42 | log_det = torch.where(accept, log_det + log_det_, log_det) 43 | log_p = torch.where(accept, log_p_, log_p) 44 | return z, log_det 45 | 46 | def inverse(self, z): 47 | # Equivalent to forward pass 48 | return self.forward(z) 49 | 50 | 51 | class HamiltonianMonteCarlo(Flow): 52 | """ 53 | Flow layer using the HMC proposal in Stochastic Normalising Flows, 54 | see arXiv: 2002.06707 55 | """ 56 | def __init__(self, target, steps, log_step_size, log_mass): 57 | """ 58 | Constructor 59 | :param target: The stationary distribution of this Markov transition. Should be logp 60 | :param steps: The number of leapfrog steps 61 | :param log_step_size: The log step size used in the leapfrog integrator. shape (dim) 62 | :param log_mass: The log_mass determining the variance of the momentum samples. shape (dim) 63 | """ 64 | super().__init__() 65 | self.target = target 66 | self.steps = steps 67 | self.register_parameter('log_step_size', torch.nn.Parameter(log_step_size)) 68 | self.register_parameter('log_mass', torch.nn.Parameter(log_mass)) 69 | 70 | def forward(self, z): 71 | # Draw momentum 72 | p = torch.randn_like(z) * torch.exp(0.5 * self.log_mass) 73 | 74 | # leapfrog 75 | z_new = z.clone() 76 | p_new = p.clone() 77 | step_size = torch.exp(self.log_step_size) 78 | for i in range(self.steps): 79 | p_half = p_new - (step_size/2.0) * -self.gradlogP(z_new) 80 | z_new = z_new + step_size * (p_half/torch.exp(self.log_mass)) 81 | p_new = p_half - (step_size/2.0) * -self.gradlogP(z_new) 82 | 83 | # Metropolis Hastings correction 84 | probabilities = torch.exp( 85 | self.target.log_prob(z_new) - self.target.log_prob(z) - \ 86 | 0.5 * torch.sum(p_new ** 2 / torch.exp(self.log_mass), 1) + \ 87 | 0.5 * torch.sum(p ** 2 / torch.exp(self.log_mass), 1)) 88 | uniforms = torch.rand_like(probabilities) 89 | mask = uniforms < probabilities 90 | z_out = torch.where(mask.unsqueeze(1), z_new, z) 91 | 92 | return z_out, self.target.log_prob(z) - self.target.log_prob(z_out) 93 | 94 | def inverse(self, z): 95 | return self.forward(z) 96 | 97 | def gradlogP(self, z): 98 | z_ = z.detach().requires_grad_() 99 | logp = self.target.log_prob(z_) 100 | return torch.autograd.grad(logp, z_, 101 | grad_outputs=torch.ones_like(logp))[0] -------------------------------------------------------------------------------- /normflow/flows/reshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base import Flow 4 | 5 | 6 | 7 | # Flow layers to reshape the latent features 8 | 9 | class Split(Flow): 10 | """ 11 | Split features into two sets 12 | """ 13 | def __init__(self, mode='channel'): 14 | """ 15 | Constructor 16 | :param mode: Splitting mode, can be 17 | channel: Splits first feature dimension, usually channels, into two halfs 18 | channel_inv: Same as channel, but with z1 and z2 flipped 19 | checkerboard: Splits features using a checkerboard pattern (last feature dimension must be even) 20 | checkerboard_inv: Same as checkerboard, but with inverted coloring 21 | """ 22 | super().__init__() 23 | self.mode = mode 24 | 25 | def forward(self, z): 26 | if self.mode == 'channel': 27 | z1, z2 = z.chunk(2, dim=1) 28 | elif self.mode == 'channel_inv': 29 | z2, z1 = z.chunk(2, dim=1) 30 | elif 'checkerboard' in self.mode: 31 | n_dims = z.dim() 32 | cb0 = 0 33 | cb1 = 1 34 | for i in range(1, n_dims): 35 | cb0_ = cb0 36 | cb1_ = cb1 37 | cb0 = [cb0_ if j % 2 == 0 else cb1_ for j in range(z.size(n_dims - i))] 38 | cb1 = [cb1_ if j % 2 == 0 else cb0_ for j in range(z.size(n_dims - i))] 39 | cb = cb1 if 'inv' in self.mode else cb0 40 | cb = torch.tensor(cb)[None].repeat(len(z), *((n_dims - 1) * [1])) 41 | cb = cb.to(z.device) 42 | z_size = z.size() 43 | z1 = z.reshape(-1)[torch.nonzero(cb.view(-1), as_tuple=False)].view(*z_size[:-1], -1) 44 | z2 = z.reshape(-1)[torch.nonzero((1 - cb).view(-1), as_tuple=False)].view(*z_size[:-1], -1) 45 | else: 46 | raise NotImplementedError('Mode ' + self.mode + ' is not implemented.') 47 | log_det = 0 48 | return [z1, z2], log_det 49 | 50 | def inverse(self, z): 51 | z1, z2 = z 52 | if self.mode == 'channel': 53 | z = torch.cat([z1, z2], 1) 54 | elif self.mode == 'channel_inv': 55 | z = torch.cat([z2, z1], 1) 56 | elif 'checkerboard' in self.mode: 57 | n_dims = z1.dim() 58 | z_size = list(z1.size()) 59 | z_size[-1] *= 2 60 | cb0 = 0 61 | cb1 = 1 62 | for i in range(1, n_dims): 63 | cb0_ = cb0 64 | cb1_ = cb1 65 | cb0 = [cb0_ if j % 2 == 0 else cb1_ for j in range(z_size[n_dims - i])] 66 | cb1 = [cb1_ if j % 2 == 0 else cb0_ for j in range(z_size[n_dims - i])] 67 | cb = cb1 if 'inv' in self.mode else cb0 68 | cb = torch.tensor(cb)[None].repeat(z_size[0], *((n_dims - 1) * [1])) 69 | cb = cb.to(z1.device) 70 | z1 = z1[..., None].repeat(*(n_dims * [1]), 2).view(*z_size[:-1], -1) 71 | z2 = z2[..., None].repeat(*(n_dims * [1]), 2).view(*z_size[:-1], -1) 72 | z = cb * z1 + (1 - cb) * z2 73 | else: 74 | raise NotImplementedError('Mode ' + self.mode + ' is not implemented.') 75 | log_det = 0 76 | return z, log_det 77 | 78 | 79 | class Merge(Split): 80 | """ 81 | Same as Split but with forward and backward pass interchanged 82 | """ 83 | def __init__(self, mode='channel'): 84 | super().__init__(mode) 85 | 86 | def forward(self, z): 87 | return super().inverse(z) 88 | 89 | def inverse(self, z): 90 | return super().forward(z) 91 | 92 | 93 | class Squeeze(Flow): 94 | """ 95 | Squeeze operation of multi-scale architecture, RealNVP or Glow paper 96 | """ 97 | def __init__(self): 98 | """ 99 | Constructor 100 | """ 101 | super().__init__() 102 | 103 | def forward(self, z): 104 | log_det = 0 105 | s = z.size() 106 | z = z.view(s[0], s[1] // 4, 2, 2, s[2], s[3]) 107 | z = z.permute(0, 1, 4, 2, 5, 3).contiguous() 108 | z = z.view(s[0], s[1] // 4, 2 * s[2], 2 * s[3]) 109 | return z, log_det 110 | 111 | def inverse(self, z): 112 | log_det = 0 113 | s = z.size() 114 | z = z.view(*s[:2], s[2] // 2, 2, s[3] // 2, 2) 115 | z = z.permute(0, 1, 3, 5, 2, 4).contiguous() 116 | z = z.view(s[0], 4 * s[1], s[2] // 2, s[3] // 2) 117 | return z, log_det -------------------------------------------------------------------------------- /ops/radon_3d_lib.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import odl 5 | 6 | # from odl.contrib.torch import OperatorFunction 7 | from ops.ODLHelper import OperatorFunction 8 | 9 | from .odl_lib import apply_angle_noise 10 | 11 | class ParallelBeamGeometry3DOp(object): 12 | def __init__(self, img_size, num_angles, op_snr, angle_max=np.pi/3): 13 | self.img_size = img_size 14 | self.num_angles = num_angles 15 | self.angle_max = angle_max 16 | self.reco_space = odl.uniform_discr( 17 | min_pt=[-20, -20, -20], 18 | max_pt=[20, 20, 20], 19 | shape=[img_size, img_size, img_size], 20 | dtype='float32' 21 | ) 22 | 23 | # Make a 3d single-axis parallel beam geometry with flat detector 24 | # Angles: uniformly spaced, n = 180, min = 0, max = pi 25 | # self.angle_partition = odl.uniform_partition(0, np.pi, 180) 26 | self.angle_partition = odl.uniform_partition(-angle_max, angle_max, num_angles) 27 | # Detector: uniformly sampled, n = (512, 512), min = (-30, -30), max = (30, 30) 28 | # self.detector_partition = odl.uniform_partition([-30, -30], [30, 30], [256, 256]) 29 | # self.detector_partition = odl.tomo.parallel_beam_geometry(self.reco_space).det_partition 30 | self.detector_partition = odl.tomo.parallel_beam_geometry(self.reco_space,det_shape=(2*img_size,2*img_size)).det_partition 31 | self.geometry = odl.tomo.Parallel3dAxisGeometry(self.angle_partition, self.detector_partition) 32 | 33 | self.num_detectors_x, self.num_detectors_y = self.geometry.detector.shape 34 | 35 | self.angles = apply_angle_noise(self.geometry.angles, op_snr) 36 | self.optimizable_params = torch.tensor(self.angles, dtype=torch.float32) # Convert to torch.Tensor. 37 | 38 | self.op = odl.tomo.RayTransform(self.reco_space, self.geometry, impl='astra_cuda') 39 | 40 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op) 41 | 42 | def __call__(self, x): 43 | return OperatorFunction.apply(self.op, x) 44 | 45 | def pinv(self, y): 46 | return OperatorFunction.apply(self.fbp, y) 47 | 48 | class ParallelBeamGeometry3DOpBroken(ParallelBeamGeometry3DOp): 49 | def __init__(self, clean_operator, op_snr): 50 | super().__init__(clean_operator.img_size, clean_operator.num_angles, op_snr, clean_operator.angle_max) 51 | 52 | self.optimizable_params = torch.tensor(clean_operator.geometry.angles, dtype=torch.float32) 53 | 54 | self.angles = apply_angle_noise(clean_operator.geometry.angles, op_snr) 55 | # angle partition is changed to not be uniform 56 | self.angle_partition = odl.discr.nonuniform_partition(np.sort(self.angles)) 57 | 58 | self.geometry = odl.tomo.Parallel3dAxisGeometry(self.angle_partition, self.detector_partition) 59 | 60 | self.num_detectors_x, self.num_detectors_y = self.geometry.detector.shape 61 | 62 | self.op = odl.tomo.RayTransform(self.reco_space, self.geometry, impl='astra_cuda') 63 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op) 64 | 65 | 66 | class ParallelBeamGeometry3DOpAngles(ParallelBeamGeometry3DOp): 67 | def __init__(self, img_size, angles, op_snr, angle_max=np.pi/3): 68 | super().__init__(img_size, angles.shape[0], op_snr,angle_max) 69 | 70 | self.optimizable_params = torch.tensor(self.geometry.angles, dtype=torch.float32) 71 | 72 | self.angles = angles 73 | # angle partition is changed to not be uniform 74 | self.angle_partition = odl.discr.nonuniform_partition(np.sort(self.angles)) 75 | 76 | # TODO: change following axes to get rotation around another axis 77 | self.geometry = odl.tomo.Parallel3dAxisGeometry(self.angle_partition, self.detector_partition, 78 | axis=(1,0,0),det_axes_init=[(1, 0, 0), (0, 1, 0)],det_pos_init=(0,0,1)) 79 | 80 | 81 | self.num_detectors_x, self.num_detectors_y = self.geometry.detector.shape 82 | 83 | self.op = odl.tomo.RayTransform(self.reco_space, self.geometry, impl='astra_cuda') 84 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op) 85 | 86 | 87 | def unit_test(): 88 | img_size = 64 89 | num_angles = 60 90 | A = ParallelBeamGeometry3DOp(img_size, num_angles, np.inf) 91 | 92 | x = torch.rand([img_size, img_size, img_size]) 93 | y = A(x) 94 | x_hat = A.pinv(y) 95 | print (x.shape) 96 | print (y.shape) 97 | print(x_hat.shape) 98 | 99 | if __name__ == "__main__": 100 | unit_test() -------------------------------------------------------------------------------- /normflow/distributions/target.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | 7 | class Target(nn.Module): 8 | """ 9 | Sample target distributions to test models 10 | """ 11 | def __init__(self, prop_scale=torch.tensor(6.), 12 | prop_shift=torch.tensor(-3.)): 13 | """ 14 | Constructor 15 | :param prop_scale: Scale for the uniform proposal 16 | :param prop_shift: Shift for the uniform proposal 17 | """ 18 | super().__init__() 19 | self.register_buffer("prop_scale", prop_scale) 20 | self.register_buffer("prop_shift", prop_shift) 21 | 22 | def log_prob(self, z): 23 | """ 24 | :param z: value or batch of latent variable 25 | :return: log probability of the distribution for z 26 | """ 27 | raise NotImplementedError('The log probability is not implemented yet.') 28 | 29 | def rejection_sampling(self, num_steps=1): 30 | """ 31 | Perform rejection sampling on image distribution 32 | :param num_steps: Number of rejection sampling steps to perform 33 | :return: Accepted samples 34 | """ 35 | eps = torch.rand((num_steps, self.n_dims), dtype=self.prop_scale.dtype, 36 | device=self.prop_scale.device) 37 | z_ = self.prop_scale * eps + self.prop_shift 38 | prob = torch.rand(num_steps, dtype=self.prop_scale.dtype, 39 | device=self.prop_scale.device) 40 | prob_ = torch.exp(self.log_prob(z_) - self.max_log_prob) 41 | accept = prob_ > prob 42 | z = z_[accept, :] 43 | return z 44 | 45 | def sample(self, num_samples=1): 46 | """ 47 | Sample from image distribution through rejection sampling 48 | :param num_samples: Number of samples to draw 49 | :return: Samples 50 | """ 51 | z = torch.zeros((0, self.n_dims), dtype=self.prop_scale.dtype, 52 | device=self.prop_scale.device) 53 | while len(z) < num_samples: 54 | z_ = self.rejection_sampling(num_samples) 55 | ind = np.min([len(z_), num_samples - len(z)]) 56 | z = torch.cat([z, z_[:ind, :]], 0) 57 | return z 58 | 59 | 60 | class TwoMoons(Target): 61 | """ 62 | Bimodal two-dimensional distribution 63 | """ 64 | def __init__(self): 65 | super().__init__() 66 | self.n_dims = 2 67 | self.max_log_prob = 0. 68 | 69 | def log_prob(self, z): 70 | """ 71 | log(p) = - 1/2 * ((norm(z) - 2) / 0.2) ** 2 72 | + log( exp(-1/2 * ((z[0] - 2) / 0.3) ** 2) 73 | + exp(-1/2 * ((z[0] + 2) / 0.3) ** 2)) 74 | :param z: value or batch of latent variable 75 | :return: log probability of the distribution for z 76 | """ 77 | a = torch.abs(z[:, 0]) 78 | log_prob = - 0.5 * ((torch.norm(z, dim=1) - 2) / 0.2) ** 2 \ 79 | - 0.5 * ((a - 2) / 0.3) ** 2 \ 80 | + torch.log(1 + torch.exp(-4 * a / 0.09)) 81 | return log_prob 82 | 83 | 84 | class CircularGaussianMixture(nn.Module): 85 | """ 86 | Two-dimensional Gaussian mixture arranged in a circle 87 | """ 88 | def __init__(self, n_modes=8): 89 | """ 90 | Constructor 91 | :param n_modes: Number of modes 92 | """ 93 | super(CircularGaussianMixture, self).__init__() 94 | self.n_modes = n_modes 95 | self.register_buffer("scale", torch.tensor(2 / 3 * np.sin(np.pi / self.n_modes)).float()) 96 | 97 | def log_prob(self, z): 98 | d = torch.zeros((len(z), 0), dtype=z.dtype, device=z.device) 99 | for i in range(self.n_modes): 100 | d_ = ((z[:, 0] - 2 * np.sin(2 * np.pi / self.n_modes * i)) ** 2 101 | + (z[:, 1] - 2 * np.cos(2 * np.pi / self.n_modes * i)) ** 2)\ 102 | / (2 * self.scale ** 2) 103 | d = torch.cat((d, d_[:, None]), 1) 104 | log_p = - torch.log(2 * np.pi * self.scale ** 2 * self.n_modes) \ 105 | + torch.logsumexp(-d, 1) 106 | return log_p 107 | 108 | def sample(self, num_samples=1): 109 | eps = torch.randn((num_samples, 2), dtype=self.scale.dtype, device=self.scale.device) 110 | phi = 2 * np.pi / self.n_modes * torch.randint(0, self.n_modes, (num_samples,), 111 | device=self.scale.device) 112 | loc = torch.stack((2 * torch.sin(phi), 2 * torch.cos(phi)), 1).type(eps.dtype) 113 | return eps * self.scale + loc 114 | 115 | 116 | class RingMixture(Target): 117 | """ 118 | Mixture of ring distributions in two dimensions 119 | """ 120 | def __init__(self, n_rings=2): 121 | super().__init__() 122 | self.n_dims = 2 123 | self.max_log_prob = 0. 124 | self.n_rings = n_rings 125 | self.scale = 1 / 4 / self.n_rings 126 | 127 | def log_prob(self, z): 128 | d = torch.zeros((len(z), 0), dtype=z.dtype, device=z.device) 129 | for i in range(self.n_rings): 130 | d_ = ((torch.norm(z, dim=1) - 2 / self.n_rings * (i + 1)) ** 2) \ 131 | / (2 * self.scale ** 2) 132 | d = torch.cat((d, d_[:, None]), 1) 133 | return torch.logsumexp(-d, 1) -------------------------------------------------------------------------------- /normflow/distributions/encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | 7 | class BaseEncoder(nn.Module): 8 | """ 9 | Base distribution of a flow-based variational autoencoder 10 | Parameters of the distribution depend of the target variable x 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, x, num_samples=1): 17 | """ 18 | :param x: Variable to condition on, first dimension is batch size 19 | :param num_samples: number of samples to draw per element of mini-batch 20 | :return: sample of z for x, log probability for sample 21 | """ 22 | raise NotImplementedError 23 | 24 | def log_prob(self, z, x): 25 | """ 26 | :param z: Primary random variable, first dimension is batch size 27 | :param x: Variable to condition on, first dimension is batch size 28 | :return: log probability of z given x 29 | """ 30 | raise NotImplementedError 31 | 32 | 33 | class Dirac(BaseEncoder): 34 | def __init__(self): 35 | super().__init__() 36 | 37 | def forward(self, x, num_samples=1): 38 | z = x.unsqueeze(1).repeat(1, num_samples, 1) 39 | log_p = torch.zeros(z.size()[0:2]) 40 | return z, log_p 41 | 42 | def log_prob(self, z, x): 43 | log_p = torch.zeros(z.size()[0:2]) 44 | return log_p 45 | 46 | 47 | class Uniform(BaseEncoder): 48 | def __init__(self, zmin=0.0, zmax=1.0): 49 | super().__init__() 50 | self.zmin = zmin 51 | self.zmax = zmax 52 | self.log_p = -torch.log(zmax - zmin) 53 | 54 | def forward(self, x, num_samples=1): 55 | z = x.unsqueeze(1).repeat(1, num_samples, 1).uniform_(min=self.zmin, max=self.zmax) 56 | log_p = torch.zeros(z.size()[0:2]).fill_(self.log_p) 57 | return z, log_p 58 | 59 | def log_prob(self, z, x): 60 | log_p = torch.zeros(z.size()[0:2]).fill_(self.log_p) 61 | return log_p 62 | 63 | 64 | class ConstDiagGaussian(BaseEncoder): 65 | def __init__(self, loc, scale): 66 | """ 67 | Multivariate Gaussian distribution with diagonal covariance and parameters being constant wrt x 68 | :param loc: mean vector of the distribution 69 | :param scale: vector of the standard deviations on the diagonal of the covariance matrix 70 | """ 71 | super().__init__() 72 | self.d = len(loc) 73 | if not torch.is_tensor(loc): 74 | loc = torch.tensor(loc).float() 75 | if not torch.is_tensor(scale): 76 | scale = torch.tensor(scale).float() 77 | self.loc = nn.Parameter(loc.reshape((1, 1, self.d))) 78 | self.scale = nn.Parameter(scale) 79 | 80 | def forward(self, x=None, num_samples=1): 81 | """ 82 | :param x: Variable to condition on, will only be used to determine the batch size 83 | :param num_samples: number of samples to draw per element of mini-batch 84 | :return: sample of z for x, log probability for sample 85 | """ 86 | if x is not None: 87 | batch_size = len(x) 88 | else: 89 | batch_size = 1 90 | eps = torch.randn((batch_size, num_samples, self.d), device=x.device) 91 | z = self.loc + self.scale * eps 92 | log_p = - 0.5 * self.d * np.log(2 * np.pi) - torch.sum(torch.log(self.scale) + 0.5 * torch.pow(eps, 2), 2) 93 | return z, log_p 94 | 95 | def log_prob(self, z, x): 96 | """ 97 | :param z: Primary random variable, first dimension is batch dimension 98 | :param x: Variable to condition on, first dimension is batch dimension 99 | :return: log probability of z given x 100 | """ 101 | if z.dim() == 1: 102 | z = z.unsqueeze(0) 103 | if z.dim() == 2: 104 | z = z.unsqueeze(0) 105 | log_p = - 0.5 * self.d * np.log(2 * np.pi) - torch.sum( 106 | torch.log(self.scale) + 0.5 * ((z - self.loc) / self.scale) ** 2, 2) 107 | return log_p 108 | 109 | 110 | class NNDiagGaussian(BaseEncoder): 111 | """ 112 | Diagonal Gaussian distribution with mean and variance determined by a neural network 113 | """ 114 | 115 | def __init__(self, net): 116 | """ 117 | Constructor 118 | :param net: net computing mean (first n / 2 outputs), standard deviation (second n / 2 outputs) 119 | """ 120 | super().__init__() 121 | self.net = net 122 | 123 | def forward(self, x, num_samples=1): 124 | """ 125 | :param x: Variable to condition on 126 | :param num_samples: number of samples to draw per element of mini-batch 127 | :return: sample of z for x, log probability for sample 128 | """ 129 | batch_size = len(x) 130 | mean_std = self.net(x) 131 | n_hidden = mean_std.size()[1] // 2 132 | mean = mean_std[:, :n_hidden, ...].unsqueeze(1) 133 | std = torch.exp(0.5 * mean_std[:, n_hidden:(2 * n_hidden), ...].unsqueeze(1)) 134 | eps = torch.randn((batch_size, num_samples) + tuple(mean.size()[2:]), device=x.device) 135 | z = mean + std * eps 136 | log_p = - 0.5 * torch.prod(torch.tensor(z.size()[2:])) * np.log(2 * np.pi) \ 137 | - torch.sum(torch.log(std) + 0.5 * torch.pow(eps, 2), list(range(2, z.dim()))) 138 | return z, log_p 139 | 140 | def log_prob(self, z, x): 141 | """ 142 | :param z: Primary random variable, first dimension is batch dimension 143 | :param x: Variable to condition on, first dimension is batch dimension 144 | :return: log probability of z given x 145 | """ 146 | if z.dim() == 1: 147 | z = z.unsqueeze(0) 148 | if z.dim() == 2: 149 | z = z.unsqueeze(0) 150 | mean_std = self.net(x) 151 | n_hidden = mean_std.size()[1] // 2 152 | mean = mean_std[:, :n_hidden, ...].unsqueeze(1) 153 | var = torch.exp(mean_std[:, n_hidden:(2 * n_hidden), ...].unsqueeze(1)) 154 | log_p = - 0.5 * torch.prod(torch.tensor(z.size()[2:])) * np.log(2 * np.pi) \ 155 | - 0.5 * torch.sum(torch.log(var) + (z - mean) ** 2 / var, 2) 156 | return log_p -------------------------------------------------------------------------------- /normflow/flows/neural_spline.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .base import Flow 4 | 5 | # Try importing Neural Spline Flow dependencies 6 | try: 7 | from neural_spline_flows.nde.transforms.coupling import PiecewiseRationalQuadraticCouplingTransform 8 | from neural_spline_flows.nde.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform 9 | from neural_spline_flows.nn import ResidualNet 10 | from neural_spline_flows.utils import create_alternating_binary_mask 11 | except: 12 | print('Warning: Dependencies for Neural Spline Flows could ' 13 | 'not be loaded. Other models can still be used.') 14 | 15 | 16 | 17 | class CoupledRationalQuadraticSpline(Flow): 18 | """ 19 | Neural spline flow coupling layer, wrapper for the implementation 20 | of Durkan et al., see https://github.com/bayesiains/nsf 21 | """ 22 | def __init__( 23 | self, 24 | num_input_channels, 25 | num_blocks, 26 | num_hidden_channels, 27 | num_bins=8, 28 | tail_bound=3, 29 | activation=nn.ReLU, 30 | dropout_probability=0., 31 | reverse_mask=False, 32 | reverse=True 33 | ): 34 | """ 35 | Constructor 36 | :param num_input_channels: Flow dimension 37 | :type num_input_channels: Int 38 | :param num_blocks: Number of residual blocks of the parameter NN 39 | :type num_blocks: Int 40 | :param num_hidden_channels: Number of hidden units of the NN 41 | :type num_hidden_channels: Int 42 | :param num_bins: Number of bins 43 | :type num_bins: Int 44 | :param tail_bound: Bound of the spline tails 45 | :type tail_bound: Int 46 | :param activation: Activation function 47 | :type activation: torch module 48 | :param dropout_probability: Dropout probability of the NN 49 | :type dropout_probability: Float 50 | :param reverse_mask: Flag whether the reverse mask should be used 51 | :type reverse_mask: Boolean 52 | :param reverse: Flag whether forward and backward pass shall be swapped 53 | :type reverse: Boolean 54 | """ 55 | super().__init__() 56 | self.reverse = reverse 57 | 58 | def transform_net_create_fn(in_features, out_features): 59 | return ResidualNet( 60 | in_features=in_features, 61 | out_features=out_features, 62 | context_features=None, 63 | hidden_features=num_hidden_channels, 64 | num_blocks=num_blocks, 65 | activation=activation(), 66 | dropout_probability=dropout_probability, 67 | use_batch_norm=False 68 | ) 69 | 70 | self.prqct=PiecewiseRationalQuadraticCouplingTransform( 71 | mask=create_alternating_binary_mask( 72 | num_input_channels, 73 | even=reverse_mask 74 | ), 75 | transform_net_create_fn=transform_net_create_fn, 76 | num_bins=num_bins, 77 | tails='linear', 78 | tail_bound=tail_bound, 79 | 80 | # Setting True corresponds to equations (4), (5), (6) in the NSF paper: 81 | apply_unconditional_transform=True 82 | ) 83 | 84 | def forward(self, z): 85 | if self.reverse: 86 | z, log_det = self.prqct.inverse(z) 87 | else: 88 | z, log_det = self.prqct(z) 89 | return z, log_det.view(-1) 90 | 91 | def inverse(self, z): 92 | if self.reverse: 93 | z, log_det = self.prqct(z) 94 | else: 95 | z, log_det = self.prqct.inverse(z) 96 | return z, log_det.view(-1) 97 | 98 | 99 | class AutoregressiveRationalQuadraticSpline(Flow): 100 | """ 101 | Neural spline flow coupling layer, wrapper for the implementation 102 | of Durkan et al., see https://github.com/bayesiains/nsf 103 | """ 104 | def __init__( 105 | self, 106 | num_input_channels, 107 | num_blocks, 108 | num_hidden_channels, 109 | num_bins=8, 110 | tail_bound=3, 111 | activation=nn.ReLU, 112 | dropout_probability=0., 113 | reverse=True 114 | ): 115 | """ 116 | Constructor 117 | :param num_input_channels: Flow dimension 118 | :type num_input_channels: Int 119 | :param num_blocks: Number of residual blocks of the parameter NN 120 | :type num_blocks: Int 121 | :param num_hidden_channels: Number of hidden units of the NN 122 | :type num_hidden_channels: Int 123 | :param num_bins: Number of bins 124 | :type num_bins: Int 125 | :param tail_bound: Bound of the spline tails 126 | :type tail_bound: Int 127 | :param activation: Activation function 128 | :type activation: torch module 129 | :param dropout_probability: Dropout probability of the NN 130 | :type dropout_probability: Float 131 | :param reverse: Flag whether forward and backward pass shall be swapped 132 | :type reverse: Boolean 133 | """ 134 | super().__init__() 135 | self.reverse = reverse 136 | 137 | self.mprqat=MaskedPiecewiseRationalQuadraticAutoregressiveTransform( 138 | features=num_input_channels, 139 | hidden_features=num_hidden_channels, 140 | context_features=None, 141 | num_bins=num_bins, 142 | tails='linear', 143 | tail_bound=tail_bound, 144 | num_blocks=num_blocks, 145 | use_residual_blocks=True, 146 | random_mask=False, 147 | activation=activation(), 148 | dropout_probability=dropout_probability, 149 | use_batch_norm=False) 150 | 151 | def forward(self, z): 152 | if self.reverse: 153 | z, log_det = self.mprqat.inverse(z) 154 | else: 155 | z, log_det = self.mprqat(z) 156 | return z, log_det.view(-1) 157 | 158 | def inverse(self, z): 159 | if self.reverse: 160 | z, log_det = self.mprqat(z) 161 | else: 162 | z, log_det = self.mprqat.inverse(z) 163 | return z, log_det.view(-1) -------------------------------------------------------------------------------- /autoencoder_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Autoencoder (to be used in generative network) 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch 7 | from funknn_model import squeeze 8 | 9 | 10 | class autoencoder(nn.Module): 11 | 12 | def __init__(self, encoder=None, decoder=None): 13 | super(autoencoder, self).__init__() 14 | 15 | self.encoder = encoder 16 | self.decoder = decoder 17 | 18 | 19 | class encoder(nn.Module): 20 | 21 | def __init__(self, latent_dim=256, in_res=64, c=3): 22 | super(encoder, self).__init__() 23 | 24 | self.in_res = in_res 25 | self.c = c 26 | prev_ch = c 27 | c_last = 256 28 | CNNs = [] 29 | CNNs_add = [] 30 | num_layers = [64,128,128,256] 31 | for i in range(len(num_layers)): 32 | CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,3, 33 | padding = 'same')) 34 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3, 35 | padding = 'same')) 36 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3, 37 | padding = 'same')) 38 | prev_ch = num_layers[i] 39 | 40 | if in_res == 64: 41 | num_layers = [c_last] 42 | for i in range(len(num_layers)): 43 | CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,3, 44 | padding = 'same')) 45 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3, 46 | padding = 'same')) 47 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3, 48 | padding = 'same')) 49 | prev_ch = num_layers[i] 50 | 51 | if in_res == 128: 52 | num_layers = [256,c_last] 53 | for i in range(len(num_layers)): 54 | CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,3, 55 | padding = 'same')) 56 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3, 57 | padding = 'same')) 58 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3, 59 | padding = 'same')) 60 | prev_ch = num_layers[i] 61 | 62 | 63 | 64 | 65 | self.CNNs = nn.ModuleList(CNNs) 66 | self.CNNs_add = nn.ModuleList(CNNs_add) 67 | self.maxpool = nn.MaxPool2d(2, 2) 68 | 69 | feature_dim = 2 * 2 * c_last 70 | mlps = [] 71 | mlps.append(nn.Linear(feature_dim, latent_dim)) 72 | 73 | self.mlps = nn.ModuleList(mlps) 74 | 75 | 76 | def forward(self, x): 77 | 78 | x_skip = torch.mean(x , dim = 1, keepdim = True) 79 | for i in range(len(self.CNNs)): 80 | x = self.CNNs[i](x) 81 | x = F.relu(x) 82 | xm = self.maxpool(x) 83 | if i < 4: 84 | 85 | f = 2**(i+1) 86 | xm = xm + squeeze(x_skip , f).repeat_interleave(xm.shape[1]//(f**2) , dim = 1) 87 | 88 | x = self.CNNs_add[i*2](xm) 89 | x = F.relu(x) 90 | x = self.CNNs_add[i*2 + 1](x) 91 | x = F.relu(x) 92 | x = x + xm 93 | 94 | x = torch.flatten(x, 1) 95 | for i in range(len(self.mlps)-1): 96 | x = self.mlps[i](x) 97 | x = F.relu(x) 98 | 99 | x = self.mlps[-1](x) 100 | 101 | return x 102 | 103 | 104 | class decoder(nn.Module): 105 | 106 | def __init__(self, latent_dim=256, in_res=64, c=3): 107 | super(decoder, self).__init__() 108 | 109 | self.in_res = in_res 110 | self.c = c 111 | prev_ch = 256 112 | t_CNNs = [] 113 | CNNs = [] 114 | 115 | if in_res == 128: 116 | num_layers = [256,256,128,128,64,self.c] 117 | for i in range(len(num_layers)): 118 | # t_CNNs.append(nn.ConvTranspose2d(prev_ch, num_layers[i] ,3, 119 | # stride=2,padding = 1, output_padding=1)) 120 | c_inter = 64 if num_layers[i] == self.c else num_layers[i] 121 | t_CNNs.append(nn.Conv2d(prev_ch, c_inter ,3, 122 | padding = 'same')) 123 | CNNs.append(nn.Conv2d(c_inter, c_inter ,3, 124 | padding = 'same')) 125 | CNNs.append(nn.Conv2d(c_inter, num_layers[i] ,3, 126 | padding = 'same')) 127 | prev_ch = num_layers[i] 128 | 129 | elif in_res == 64: 130 | 131 | num_layers = [256,128,128,64,self.c] 132 | for i in range(len(num_layers)): 133 | 134 | c_inter = 64 if num_layers[i] == self.c else num_layers[i] 135 | t_CNNs.append(nn.Conv2d(prev_ch, c_inter ,3, 136 | padding = 'same')) 137 | CNNs.append(nn.Conv2d(c_inter, c_inter ,3, 138 | padding = 'same')) 139 | CNNs.append(nn.Conv2d(c_inter, num_layers[i] ,3, 140 | padding = 'same')) 141 | prev_ch = num_layers[i] 142 | 143 | 144 | 145 | self.t_CNNs = nn.ModuleList(t_CNNs) 146 | self.CNNs = nn.ModuleList(CNNs) 147 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 148 | 149 | self.feature_dim = 2 * 2 * 256 150 | mlps = [] 151 | mlps.append(nn.Linear(latent_dim , self.feature_dim)) 152 | 153 | self.mlps = nn.ModuleList(mlps) 154 | 155 | def forward(self, x): 156 | 157 | for i in range(len(self.mlps)): 158 | x = self.mlps[i](x) 159 | x = F.relu(x) 160 | # x = squeeze(x , 16) 161 | b = x.shape[0] 162 | x = x.reshape([b, 256, 2, 2]) 163 | 164 | for i in range(len(self.t_CNNs)-1): 165 | x = self.upsample(x) 166 | x = self.t_CNNs[i](x) 167 | xr = F.relu(x) 168 | x = self.CNNs[i*2](xr) 169 | x = F.relu(x) 170 | x = self.CNNs[i*2+1](x) 171 | x = F.relu(x) 172 | x = x + xr 173 | 174 | x = self.upsample(x) 175 | x = self.t_CNNs[-1](x) 176 | x = F.relu(x) 177 | x = self.CNNs[-2](x) 178 | x = F.relu(x) 179 | x = self.CNNs[-1](x) 180 | 181 | return x 182 | 183 | -------------------------------------------------------------------------------- /normflow/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | from . import flows 6 | 7 | # Try importing ResNet dependencies 8 | try: 9 | from residual_flows.layers.base import InducedNormLinear, InducedNormConv2d 10 | except: 11 | print('Warning: Dependencies for Residual Networks could ' 12 | 'not be loaded. Other models can still be used.') 13 | 14 | 15 | def set_requires_grad(module, flag): 16 | """ 17 | Sets requires_grad flag of all parameters of a torch.nn.module 18 | :param module: torch.nn.module 19 | :param flag: Flag to set requires_grad to 20 | """ 21 | 22 | for param in module.parameters(): 23 | param.requires_grad = flag 24 | 25 | 26 | class ConstScaleLayer(nn.Module): 27 | """ 28 | Scaling features by a fixed factor 29 | """ 30 | def __init__(self, scale=1.): 31 | """ 32 | Constructor 33 | :param scale: Scale to apply to features 34 | """ 35 | super().__init__() 36 | self.scale_cpu = torch.tensor(scale) 37 | self.register_buffer("scale", self.scale_cpu) 38 | 39 | def forward(self, input): 40 | return input * self.scale 41 | 42 | 43 | class ActNorm(nn.Module): 44 | """ 45 | ActNorm layer with just one forward pass 46 | """ 47 | def __init__(self, shape, logscale_factor=None): 48 | """ 49 | Constructor 50 | :param shape: Same as shape in flows.ActNorm 51 | :param logscale_factor: Same as shape in flows.ActNorm 52 | """ 53 | super().__init__() 54 | self.actNorm = flows.ActNorm(shape, logscale_factor=logscale_factor) 55 | 56 | def forward(self, input): 57 | out, _ = self.actNorm(input) 58 | return out 59 | 60 | 61 | # Dataset transforms 62 | 63 | class Logit(): 64 | """ 65 | Transform for dataloader 66 | logit(alpha + (1 - alpha) * x) where logit(x) = log(x / (1 - x)) 67 | """ 68 | def __init__(self, alpha=0): 69 | """ 70 | Constructor 71 | :param alpha: see above 72 | """ 73 | self.alpha = alpha 74 | 75 | def __call__(self, x): 76 | x_ = self.alpha + (1 - self.alpha) * x 77 | return torch.log(x_ / (1 - x_)) 78 | 79 | def inverse(self, x): 80 | return (torch.sigmoid(x) - self.alpha) / (1 - self.alpha) 81 | 82 | 83 | class Jitter(): 84 | """ 85 | Transform for dataloader 86 | Adds uniform jitter noise to data 87 | """ 88 | def __init__(self, scale=1./256): 89 | """ 90 | Constructor 91 | :param scale: Scaling factor for noise 92 | """ 93 | self.scale = scale 94 | 95 | def __call__(self, x): 96 | eps = torch.rand_like(x) * self.scale 97 | x_ = x + eps 98 | return x_ 99 | 100 | 101 | class Scale(): 102 | """ 103 | Transform for dataloader 104 | Adds uniform jitter noise to data 105 | """ 106 | def __init__(self, scale=255./256.): 107 | """ 108 | Constructor 109 | :param scale: Scaling factor for noise 110 | """ 111 | self.scale = scale 112 | 113 | def __call__(self, x): 114 | return x * self.scale 115 | 116 | 117 | # Nonlinearities 118 | 119 | class ClampExp(torch.nn.Module): 120 | """ 121 | Nonlinearity min(exp(lam * x), 1) 122 | """ 123 | def __init__(self): 124 | """ 125 | Constructor 126 | :param lam: Lambda parameter 127 | """ 128 | super(ClampExp, self).__init__() 129 | 130 | def forward(self, x): 131 | one = torch.tensor(1., device=x.device, dtype=x.dtype) 132 | return torch.min(torch.exp(x), one) 133 | 134 | 135 | # Functions for model analysis 136 | 137 | def bitsPerDim(model, x, y=None, trans='logit', trans_param=[0.05]): 138 | """ 139 | Computes the bits per dim for a batch of data 140 | :param model: Model to compute bits per dim for 141 | :param x: Batch of data 142 | :param y: Class labels for batch of data if base distribution is class conditional 143 | :param trans: Transformation to be applied to images during training 144 | :param trans_param: List of parameters of the transformation 145 | :return: Bits per dim for data batch under model 146 | """ 147 | dims = torch.prod(torch.tensor(x.size()[1:])) 148 | if trans == 'logit': 149 | if y is None: 150 | log_q = model.log_prob(x) 151 | else: 152 | log_q = model.log_prob(x, y) 153 | sum_dims = list(range(1, x.dim())) 154 | ls = torch.nn.LogSigmoid() 155 | sig_ = torch.sum(ls(x) / np.log(2), sum_dims) 156 | sig_ += torch.sum(ls(-x) / np.log(2), sum_dims) 157 | b = - log_q / dims / np.log(2) - np.log2(1 - trans_param[0]) + 8 158 | b += sig_ / dims 159 | else: 160 | raise NotImplementedError('The transformation ' + trans + ' is not implemented.') 161 | return b 162 | 163 | 164 | def bitsPerDimDataset(model, data_loader, class_cond=True, trans='logit', 165 | trans_param=[0.05]): 166 | """ 167 | Computes average bits per dim for an entire dataset given by a data loader 168 | :param model: Model to compute bits per dim for 169 | :param data_loader: Data loader of dataset 170 | :param class_cond: Flag indicating whether model is class_conditional 171 | :param trans: Transformation to be applied to images during training 172 | :param trans_param: List of parameters of the transformation 173 | :return: Average bits per dim for dataset 174 | """ 175 | n = 0 176 | b_cum = 0 177 | with torch.no_grad(): 178 | for x, y in iter(data_loader): 179 | b_ = bitsPerDim(model, x, y.to(x.device) if class_cond else None, 180 | trans, trans_param) 181 | b_np = b_.to('cpu').numpy() 182 | b_cum += np.nansum(b_np) 183 | n += len(x) - np.sum(np.isnan(b_np)) 184 | b = b_cum / n 185 | return b 186 | 187 | 188 | def clear_grad(model): 189 | """ 190 | Set gradients of model parameter to None as this speeds up training, 191 | see https://www.youtube.com/watch?v=9mS1fIYj1So 192 | :param model: Model to clear gradients of 193 | """ 194 | for param in model.parameters(): 195 | param.grad = None 196 | 197 | 198 | def update_lipschitz(model, n_iterations): 199 | for m in model.modules(): 200 | if isinstance(m, InducedNormConv2d) or isinstance(m, InducedNormLinear): 201 | m.compute_weight(update=True, n_iterations=n_iterations) 202 | -------------------------------------------------------------------------------- /train_funknn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from timeit import default_timer 5 | from torch.optim import Adam 6 | import os 7 | import matplotlib.pyplot as plt 8 | from funknn_model import FunkNN 9 | from utils import * 10 | from datasets import * 11 | from results import evaluator 12 | import config_funknn as config 13 | 14 | torch.manual_seed(0) 15 | np.random.seed(0) 16 | 17 | epochs_funknn = config.epochs_funknn 18 | batch_size = config.batch_size 19 | dataset = config.dataset 20 | gpu_num = config.gpu_num 21 | exp_desc = config.exp_desc 22 | image_size = config.image_size 23 | c = config.c 24 | train_funknn = config.train_funknn 25 | training_mode = config.training_mode 26 | ood_analysis = config.ood_analysis 27 | 28 | enable_cuda = True 29 | device = torch.device('cuda:' + str(gpu_num) if torch.cuda.is_available() and enable_cuda else 'cpu') 30 | 31 | all_experiments = 'experiments/' 32 | if os.path.exists(all_experiments) == False: 33 | os.mkdir(all_experiments) 34 | 35 | # experiment path 36 | exp_path = all_experiments + 'funknn_' + dataset + '_' \ 37 | + str(image_size) + '_' + training_mode + '_' + exp_desc 38 | 39 | 40 | if os.path.exists(exp_path) == False: 41 | os.mkdir(exp_path) 42 | 43 | 44 | learning_rate = 1e-4 45 | step_size = 50 46 | gamma = 0.5 47 | # myloss = F.mse_loss 48 | myloss = F.l1_loss 49 | num_batch_pixels = 3 # The number of iterations over each batch 50 | batch_pixels = 512 # Number of pixels to optimize in each iteration 51 | k = 2 # super resolution factor for training 52 | 53 | # Print the experiment setup: 54 | print('Experiment setup:') 55 | print('---> epochs_funknn: {}'.format(epochs_funknn)) 56 | print('---> batch_size: {}'.format(batch_size)) 57 | print('---> dataset: {}'.format(dataset)) 58 | print('---> Learning rate: {}'.format(learning_rate)) 59 | print('---> experiment path: {}'.format(exp_path)) 60 | print('---> image size: {}'.format(image_size)) 61 | 62 | # Dataset: 63 | train_dataset = Dataset_loader(dataset = 'train' ,size = (image_size,image_size), c = c) 64 | test_dataset = Dataset_loader(dataset = 'test' ,size = (config.max_scale*image_size,config.max_scale*image_size), c = c) 65 | 66 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=32, shuffle = True) 67 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, num_workers=32) 68 | 69 | ntrain = len(train_loader.dataset) 70 | n_test = len(test_loader.dataset) 71 | 72 | n_ood = 0 73 | if ood_analysis: 74 | ood_dataset = Dataset_loader(dataset = 'ood',size = (2*image_size,2*image_size), c = c) 75 | ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=300, num_workers=32) 76 | n_ood= len(ood_loader.dataset) 77 | 78 | print('---> Number of training, test and ood samples: {}, {}, {}'.format(ntrain,n_test, n_ood)) 79 | 80 | # Loading model 81 | plot_per_num_epoch = 10 if ntrain > 10000 else 10000//ntrain 82 | model = FunkNN(c=c).to(device) 83 | # model = torch.nn.DataParallel(model) # Using multiple GPUs 84 | num_param_funknn = count_parameters(model) 85 | print('---> Number of trainable parameters of funknn: {}'.format(num_param_funknn)) 86 | 87 | optimizer_funknn = Adam(model.parameters(), lr=learning_rate) 88 | scheduler_funknn = torch.optim.lr_scheduler.StepLR(optimizer_funknn, step_size=step_size, gamma=gamma) 89 | 90 | checkpoint_exp_path = os.path.join(exp_path, 'funknn.pt') 91 | if os.path.exists(checkpoint_exp_path) and config.restore_funknn: 92 | checkpoint_funknn = torch.load(checkpoint_exp_path) 93 | model.load_state_dict(checkpoint_funknn['model_state_dict']) 94 | optimizer_funknn.load_state_dict(checkpoint_funknn['optimizer_state_dict']) 95 | print('funknn is restored...') 96 | 97 | if train_funknn: 98 | print('Training...') 99 | 100 | if plot_per_num_epoch == -1: 101 | plot_per_num_epoch = epochs_funknn + 1 # only plot in the last epoch 102 | 103 | loss_funknn_plot = np.zeros([epochs_funknn]) 104 | for ep in range(epochs_funknn): 105 | model.train() 106 | t1 = default_timer() 107 | loss_funknn_epoch = 0 108 | 109 | for image in train_loader: 110 | 111 | batch_size = image.shape[0] 112 | image = image.to(device) 113 | 114 | for i in range(num_batch_pixels): 115 | image_mat = image.reshape(-1, image_size, image_size, c).permute(0,3,1,2) 116 | 117 | image_high, image_low, image_size_high = training_strategy(image_mat, image_size, factor = k , mode = training_mode) 118 | coords = get_mgrid(image_size_high).reshape(-1, 2) 119 | coords = torch.unsqueeze(coords, dim = 0) 120 | coords = coords.expand(batch_size , -1, -1).to(device) 121 | 122 | image_high = image_high.permute(0,2,3,1).reshape(-1, image_size_high * image_size_high, c) 123 | optimizer_funknn.zero_grad() 124 | pixels = np.random.randint(low = 0, high = image_size_high**2, size = batch_pixels) 125 | batch_coords = coords[:,pixels] 126 | batch_image = image_high[:,pixels] 127 | 128 | out = model(batch_coords, image_low) 129 | mse_loss = myloss(out.reshape(batch_size, -1) , batch_image.reshape(batch_size, -1) ) 130 | total_loss = mse_loss 131 | total_loss.backward() 132 | optimizer_funknn.step() 133 | loss_funknn_epoch += total_loss.item() 134 | 135 | scheduler_funknn.step() 136 | t2 = default_timer() 137 | loss_funknn_epoch/= ntrain 138 | loss_funknn_plot[ep] = loss_funknn_epoch 139 | 140 | plt.plot(np.arange(epochs_funknn)[:ep] , loss_funknn_plot[:ep], 'o-', linewidth=2) 141 | plt.title('FunkNN_loss') 142 | plt.xlabel('epoch') 143 | plt.ylabel('MSE loss') 144 | plt.savefig(os.path.join(exp_path, 'funknn_loss.jpg')) 145 | np.save(os.path.join(exp_path, 'funknn_loss.npy'), loss_funknn_plot[:ep]) 146 | plt.close() 147 | 148 | torch.save({ 149 | 'model_state_dict': model.state_dict(), 150 | 'optimizer_state_dict': optimizer_funknn.state_dict() 151 | }, checkpoint_exp_path) 152 | 153 | print('ep: {}/{} | time: {:.0f} | FunkNN_loss {:.6f} '.format(ep, epochs_funknn, t2-t1,loss_funknn_epoch)) 154 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file: 155 | file.write('ep: {}/{} | time: {:.0f} | FunkNN_loss {:.6f} '.format(ep, epochs_funknn, t2-t1,loss_funknn_epoch)) 156 | file.write('\n') 157 | 158 | if ep % plot_per_num_epoch == 0 or (ep + 1) == epochs_funknn: 159 | 160 | evaluator(ep = ep, subset = 'test', data_loader = test_loader, model = model, exp_path = exp_path) 161 | if ood_analysis: 162 | evaluator(ep = ep, subset = 'ood', data_loader = ood_loader, model = model, exp_path = exp_path) 163 | 164 | print('Evaluating...') 165 | evaluator(ep = -1, subset = 'test', data_loader = test_loader, model = model, exp_path = exp_path) 166 | if ood_analysis: 167 | evaluator(ep = -1, subset = 'ood', data_loader = ood_loader, model = model, exp_path = exp_path) 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: torch12 2 | channels: 3 | - jmcmurray 4 | - pytorch 5 | - astra-toolbox 6 | - anaconda 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=2_kmp_llvm 12 | - absl-py=0.15.0=pyhd3eb1b0_0 13 | - alsa-lib=1.2.7.2=h166bdaf_0 14 | - astra-toolbox=2.1.2=py39h4319877_0 15 | - asttokens=2.0.5=pyhd3eb1b0_0 16 | - attr=2.5.1=h166bdaf_1 17 | - attrs=22.1.0=pyh71513ae_1 18 | - backcall=0.2.0=pyhd3eb1b0_0 19 | - blas=1.0=mkl 20 | - brotli=1.0.9=h166bdaf_7 21 | - brotli-bin=1.0.9=h166bdaf_7 22 | - brotlipy=0.7.0=py39hb9d737c_1004 23 | - bzip2=1.0.8=h7f98852_4 24 | - ca-certificates=2022.9.24=ha878542_0 25 | - certifi=2022.9.24=pyhd8ed1ab_0 26 | - cffi=1.15.1=py39he91dace_0 27 | - charset-normalizer=2.1.1=pyhd8ed1ab_0 28 | - cloudpickle=2.0.0=pyhd3eb1b0_0 29 | - colorama=0.4.5=pyhd8ed1ab_0 30 | - contextlib2=21.6.0=pyhd8ed1ab_0 31 | - contourpy=1.0.5=py39hf939315_0 32 | - cryptography=37.0.4=py39hd97740a_0 33 | - cudatoolkit=11.6.0=hecad31d_10 34 | - cycler=0.11.0=pyhd8ed1ab_0 35 | - cytoolz=0.11.0=py39h27cfd23_0 36 | - dask-core=2022.7.0=py39h06a4308_0 37 | - dbus=1.13.6=h5008d03_3 38 | - decorator=5.1.1=pyhd3eb1b0_0 39 | - execnet=1.9.0=pyhd8ed1ab_0 40 | - executing=0.8.3=pyhd3eb1b0_0 41 | - expat=2.4.9=h27087fc_0 42 | - ffmpeg=4.3=hf484d3e_0 43 | - fftw=3.3.10=nompi_hf0379b8_105 44 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 45 | - font-ttf-inconsolata=3.000=h77eed37_0 46 | - font-ttf-source-code-pro=2.038=h77eed37_0 47 | - font-ttf-ubuntu=0.83=hab24e00_0 48 | - fontconfig=2.14.0=hc2a2eb6_1 49 | - fonts-conda-ecosystem=1=0 50 | - fonts-conda-forge=1=0 51 | - fonttools=4.37.3=py39hb9d737c_0 52 | - freetype=2.12.1=hca18f0e_0 53 | - fsspec=2022.7.1=py39h06a4308_0 54 | - gettext=0.19.8.1=h73d1719_1008 55 | - glib=2.72.1=h6239696_0 56 | - glib-tools=2.72.1=h6239696_0 57 | - gmp=6.2.1=h58526e2_0 58 | - gnutls=3.6.13=h85f3911_1 59 | - gst-plugins-base=1.20.3=h57caac4_2 60 | - gstreamer=1.20.3=hd4edc92_2 61 | - icu=70.1=h27087fc_0 62 | - idna=3.4=pyhd8ed1ab_0 63 | - imageio=2.22.0=pyhfa7a67d_0 64 | - iniconfig=1.1.1=pyh9f0ad1d_0 65 | - ipdb=0.13.9=pyhd8ed1ab_0 66 | - ipython=8.4.0=py39h06a4308_0 67 | - jack=1.9.18=h8c3723f_1003 68 | - jedi=0.18.1=py39h06a4308_1 69 | - jpeg=9e=h166bdaf_2 70 | - keyutils=1.6.1=h166bdaf_0 71 | - kiwisolver=1.4.4=py39hf939315_0 72 | - krb5=1.19.3=h3790be6_0 73 | - lame=3.100=h7f98852_1001 74 | - lcms2=2.12=hddcbb42_0 75 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 76 | - lerc=4.0.0=h27087fc_0 77 | - libastra=2.1.2=cuda_11.6_0 78 | - libbrotlicommon=1.0.9=h166bdaf_7 79 | - libbrotlidec=1.0.9=h166bdaf_7 80 | - libbrotlienc=1.0.9=h166bdaf_7 81 | - libcap=2.65=ha37c62d_0 82 | - libclang=14.0.6=default_h2e3cab8_0 83 | - libclang13=14.0.6=default_h3a83d3e_0 84 | - libcups=2.3.3=h3e49a29_2 85 | - libdb=6.2.32=h9c3ff4c_0 86 | - libdeflate=1.14=h166bdaf_0 87 | - libedit=3.1.20191231=he28a2e2_2 88 | - libevent=2.1.10=h9b69904_4 89 | - libffi=3.4.2=h7f98852_5 90 | - libflac=1.3.4=h27087fc_0 91 | - libgcc-ng=12.1.0=h8d9b700_16 92 | - libgfortran-ng=12.1.0=h69a702a_16 93 | - libgfortran5=12.1.0=hdcd56e2_16 94 | - libglib=2.72.1=h2d90d5f_0 95 | - libiconv=1.16=h516909a_0 96 | - libllvm14=14.0.6=he0ac6c6_0 97 | - libnsl=2.0.0=h7f98852_0 98 | - libogg=1.3.4=h7f98852_1 99 | - libopus=1.3.1=h7f98852_1 100 | - libpng=1.6.38=h753d276_0 101 | - libpq=14.5=hd77ab85_0 102 | - libsndfile=1.0.31=h9c3ff4c_1 103 | - libsqlite=3.39.3=h753d276_0 104 | - libstdcxx-ng=12.1.0=ha89aaad_16 105 | - libtiff=4.4.0=h55922b4_4 106 | - libtool=2.4.6=h9c3ff4c_1008 107 | - libudev1=249=h166bdaf_4 108 | - libuuid=2.32.1=h7f98852_1000 109 | - libvorbis=1.3.7=h9c3ff4c_0 110 | - libwebp-base=1.2.4=h166bdaf_0 111 | - libxcb=1.13=h7f98852_1004 112 | - libxkbcommon=1.0.3=he3ba5ed_0 113 | - libxml2=2.9.14=h22db469_4 114 | - libzlib=1.2.12=h166bdaf_3 115 | - llvm-openmp=14.0.4=he0ac6c6_0 116 | - locket=1.0.0=py39h06a4308_0 117 | - matplotlib=3.6.0=py39hf3d152e_0 118 | - matplotlib-base=3.6.0=py39hf9fd14e_0 119 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 120 | - mkl=2021.4.0=h8d4b97c_729 121 | - mkl-service=2.4.0=py39h7e14d7c_0 122 | - mkl_fft=1.3.1=py39h0c7bc48_1 123 | - mkl_random=1.2.2=py39hde0f152_0 124 | - mock=4.0.3=py39hf3d152e_3 125 | - munkres=1.1.4=pyh9f0ad1d_0 126 | - mysql-common=8.0.30=haf5c9bc_1 127 | - mysql-libs=8.0.30=h28c427c_1 128 | - ncurses=6.3=h27087fc_1 129 | - nettle=3.6=he412f7d_0 130 | - networkx=2.8.4=py39h06a4308_0 131 | - nspr=4.32=h9c3ff4c_1 132 | - nss=3.78=h2350873_0 133 | - numpy=1.22.3=py39he7a7128_0 134 | - numpy-base=1.22.3=py39hf524024_0 135 | - openh264=2.1.1=h780b84a_0 136 | - openjpeg=2.5.0=h7d73246_1 137 | - openssl=1.1.1s=h166bdaf_0 138 | - os=0.1.4=0 139 | - packaging=21.3=pyhd8ed1ab_0 140 | - parso=0.8.3=pyhd3eb1b0_0 141 | - partd=1.2.0=pyhd3eb1b0_1 142 | - path=16.4.0=py39hf3d152e_1 143 | - path.py=12.5.0=0 144 | - pcre=8.45=h9c3ff4c_0 145 | - pexpect=4.8.0=pyhd3eb1b0_3 146 | - pickleshare=0.7.5=pyhd3eb1b0_1003 147 | - pillow=9.2.0=py39hd5dbb17_2 148 | - pip=22.2.2=pyhd8ed1ab_0 149 | - pluggy=1.0.0=py39hf3d152e_3 150 | - ply=3.11=py_1 151 | - portaudio=19.6.0=h8e90077_6 152 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 153 | - pthread-stubs=0.4=h36c2ea0_1001 154 | - ptyprocess=0.7.0=pyhd3eb1b0_2 155 | - pulseaudio=14.0=h0868958_9 156 | - pure_eval=0.2.2=pyhd3eb1b0_0 157 | - py=1.11.0=pyh6c4a22f_0 158 | - pycparser=2.21=pyhd8ed1ab_0 159 | - pygments=2.11.2=pyhd3eb1b0_0 160 | - pyopenssl=22.0.0=pyhd8ed1ab_0 161 | - pyparsing=3.0.9=pyhd8ed1ab_0 162 | - pyqt=5.15.7=py39h18e9c17_0 163 | - pyqt5-sip=12.11.0=py39h5a03fae_0 164 | - pysocks=1.7.1=pyha2e5f31_6 165 | - pytest=7.1.3=py39hf3d152e_0 166 | - pytest-shutil=1.7.0=pyhd8ed1ab_1 167 | - python=3.9.13=h9a8a25e_0_cpython 168 | - python-dateutil=2.8.2=pyhd8ed1ab_0 169 | - python-lmdb=1.3.0=py39h5a03fae_1 170 | - python_abi=3.9=2_cp39 171 | - pytorch=1.12.1=py3.9_cuda11.6_cudnn8.3.2_0 172 | - pytorch-mutex=1.0=cuda 173 | - pywavelets=1.3.0=py39h7f8727e_0 174 | - pyyaml=6.0=py39h7f8727e_1 175 | - qt-main=5.15.6=hc525480_0 176 | - qutil=3.2.1=6 177 | - readline=8.1.2=h0f457ee_0 178 | - requests=2.28.1=pyhd8ed1ab_1 179 | - scikit-image=0.16.2=py39ha9443f7_0 180 | - scipy=1.7.3=py39h6c91a56_2 181 | - setuptools=65.3.0=pyhd8ed1ab_1 182 | - sip=6.6.2=py39h5a03fae_0 183 | - six=1.16.0=pyh6c4a22f_0 184 | - sqlite=3.39.3=h4ff8645_0 185 | - stack_data=0.2.0=pyhd3eb1b0_0 186 | - tbb=2021.5.0=h924138e_2 187 | - termcolor=2.0.1=pyhd8ed1ab_1 188 | - tk=8.6.12=h27826a3_0 189 | - toml=0.10.2=pyhd8ed1ab_0 190 | - tomli=2.0.1=pyhd8ed1ab_0 191 | - toolz=0.11.2=pyhd3eb1b0_0 192 | - torchaudio=0.12.1=py39_cu116 193 | - torchvision=0.13.1=py39_cu116 194 | - tornado=6.2=py39hb9d737c_0 195 | - tqdm=4.64.1=pyhd8ed1ab_0 196 | - traitlets=5.1.1=pyhd3eb1b0_0 197 | - typing_extensions=4.3.0=pyha770c72_0 198 | - tzdata=2022c=h191b570_0 199 | - unicodedata2=14.0.0=py39hb9d737c_1 200 | - urllib3=1.26.11=pyhd8ed1ab_0 201 | - wcwidth=0.2.5=pyhd3eb1b0_0 202 | - wheel=0.37.1=pyhd8ed1ab_0 203 | - xcb-util=0.4.0=h166bdaf_0 204 | - xcb-util-image=0.4.0=h166bdaf_0 205 | - xcb-util-keysyms=0.4.0=h166bdaf_0 206 | - xcb-util-renderutil=0.3.9=h166bdaf_0 207 | - xcb-util-wm=0.4.1=h166bdaf_0 208 | - xorg-libxau=1.0.9=h7f98852_0 209 | - xorg-libxdmcp=1.1.3=h7f98852_0 210 | - xz=5.2.6=h166bdaf_0 211 | - yaml=0.2.5=h7b6447c_0 212 | - zlib=1.2.12=h166bdaf_3 213 | - zstd=1.5.2=h6239696_4 214 | - pip: 215 | - future==0.18.2 216 | - geomloss==0.2.5 217 | - keopscore==2.1 218 | - normflows==1.4 219 | - odl==1.0.0.dev0 220 | - pybind11==2.10.0 221 | - pykeops==2.1 222 | - torch-summary==1.4.5 223 | prefix: /users/staff/dmi-dmi/debarn0000/anaconda3/envs/torch12 224 | -------------------------------------------------------------------------------- /ops/traveltime_lib.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | EPS = 1e-12 10 | 11 | def sense_2(sensors, img): 12 | 13 | img_size = img.shape[0] 14 | 15 | # make a line 16 | x1 = sensors[0, 0] 17 | y1 = sensors[0, 1] 18 | 19 | x2 = sensors[1, 0] 20 | y2 = sensors[1, 1] 21 | 22 | slope = (y2 - y1) / (x2 - x1) 23 | 24 | per_pixel_width = 1.0/img_size 25 | n_pts_x = int(np.abs(x2-x1)/per_pixel_width) 26 | n_pts_y = int(np.abs(y2-y1)/per_pixel_width) 27 | 28 | 29 | intersect_vert = None 30 | intersect_horz = None 31 | 32 | if n_pts_x > 0: 33 | xs = x1 + np.arange(1, n_pts_x + 1) * per_pixel_width * np.sign(x2-x1) 34 | ys = y2 - slope * (x2 - xs) 35 | intersect_vert = np.stack((xs, ys), axis=-1) 36 | 37 | if n_pts_y > 0: 38 | ys = y1 + np.arange(1, n_pts_y + 1) * per_pixel_width * np.sign(y2-y1) 39 | xs = x2 - (y2 - ys) / slope 40 | intersect_horz = np.stack((xs, ys), axis=-1) 41 | 42 | all_pts = np.concatenate((sensors, intersect_horz, intersect_vert), axis=0) 43 | idx= np.argsort(all_pts, axis=0)[:, 0] 44 | all_pts = all_pts[idx] # Sorts acc. to x coordinate. 45 | 46 | # find line segment length x pixels 47 | midpoints = (all_pts[:-1] + all_pts[1:])/2 48 | lengths = np.linalg.norm(all_pts[:-1] - all_pts[1:], axis=-1) 49 | 50 | # pick correct pixels x line segment length 51 | midpoints = np.clip((midpoints + 1) / 2.0, 0, 1-EPS) 52 | pixel_idx = np.floor(midpoints/per_pixel_width).astype(np.int32) 53 | pixel_intensities = img[pixel_idx[:, 0], pixel_idx[:, 1]] 54 | 55 | 56 | return np.sum(lengths * pixel_intensities) 57 | 58 | 59 | class tt_2sensors(nn.Module): 60 | """Traveltime for 2 sensors""" 61 | def __init__(self, sensors, img_size): 62 | super(tt_2sensors, self).__init__() 63 | self.lengths, self.idx = self.__build(sensors, img_size) 64 | 65 | @staticmethod 66 | def __build(sensors, img_size): 67 | if isinstance(sensors, np.ndarray): 68 | sensors = torch.from_numpy(sensors.astype(np.float32)).cuda() 69 | 70 | # make a line 71 | x1 = sensors[0, 0] 72 | y1 = sensors[0, 1] 73 | 74 | x2 = sensors[1, 0] 75 | y2 = sensors[1, 1] 76 | 77 | slope = (y2 - y1) / (x2 - x1) 78 | 79 | per_pixel_width = 1.0/img_size 80 | n_pts_x = torch.abs(x2-x1)/per_pixel_width 81 | n_pts_x = n_pts_x.type(torch.int) 82 | n_pts_y = torch.abs(y2-y1)/per_pixel_width 83 | n_pts_y = n_pts_y.type(torch.int) 84 | 85 | 86 | intersect_vert = None 87 | intersect_horz = None 88 | 89 | if n_pts_x > 0: 90 | xs = x1 + torch.arange( 91 | 1, n_pts_x + 1, device='cuda') * per_pixel_width * torch.sign(x2-x1) 92 | ys = y2 - slope * (x2 - xs) 93 | intersect_vert = torch.stack((xs, ys), dim=-1) 94 | 95 | if n_pts_y > 0: 96 | ys = y1 + torch.arange( 97 | 1, n_pts_y + 1, device='cuda') * per_pixel_width * torch.sign(y2-y1) 98 | xs = x2 - (y2 - ys) / slope 99 | intersect_horz = torch.stack((xs, ys), dim=-1) 100 | 101 | all_pts = sensors.clone().cuda() 102 | if intersect_horz is not None: 103 | all_pts = torch.cat((sensors, intersect_horz), dim=0) 104 | if intersect_vert is not None: 105 | all_pts = torch.cat((all_pts, intersect_vert), dim=0) 106 | 107 | idx = torch.argsort(all_pts, dim=0)[:, 0] 108 | all_pts = all_pts[idx] # Sorts acc. to x coordinate. 109 | 110 | # find line segment length x pixels 111 | midpoints = (all_pts[:-1] + all_pts[1:])/2 112 | lengths = torch.norm(all_pts[:-1] - all_pts[1:], dim=-1) 113 | 114 | # pick correct pixels x line segment length 115 | midpoints = torch.clip((midpoints + 1) / 2.0, 0, 1-EPS) 116 | pixel_idx = torch.floor(midpoints/per_pixel_width).type( 117 | torch.cuda.LongTensor) 118 | 119 | return lengths, pixel_idx 120 | 121 | 122 | def forward(self, img): 123 | pixel_intensities = img[self.idx[:,0], self.idx[:, 1]] 124 | return torch.sum(pixel_intensities * self.lengths) 125 | 126 | 127 | class TravelTimeOperator(nn.Module): 128 | """Builds the linear traveltime tomography operator.""" 129 | 130 | def __init__(self, sensors, img_size): 131 | """Initializes a linear TT operator. 132 | Args: 133 | sensors (torch.Tensor): Locations of sensors [-1, 1] 134 | img_size (int): Size of the image 135 | """ 136 | super(TravelTimeOperator, self).__init__() 137 | self.sensors = sensors 138 | self.lengths, self.idx, self.nelems = self._build(sensors, img_size) 139 | self.optimizable_params = sensors 140 | 141 | 142 | @staticmethod 143 | def _build(sensors, img_size): 144 | N = len(sensors) 145 | lengths = [] 146 | nelems = [] 147 | idx = [] 148 | 149 | for i in range(N-1): 150 | for j in range(i+1, N): 151 | sense2 = torch.stack((sensors[i], sensors[j]), axis=0) 152 | row = tt_2sensors(sense2, img_size) 153 | lengths.append(row.lengths) 154 | idx.append(row.idx) 155 | nelems.append(len(row.lengths)) 156 | 157 | return torch.cat(lengths), torch.cat(idx, dim=0), nelems 158 | 159 | 160 | 161 | def forward(self, img): 162 | pixel_intensities = img[self.idx[:,0], self.idx[:, 1]] 163 | y_measured = pixel_intensities * self.lengths 164 | measured_chunks = torch.split(y_measured, self.nelems) 165 | y_measured = torch.stack([ 166 | torch.sum(c) for c in measured_chunks] 167 | ) 168 | 169 | return y_measured 170 | 171 | 172 | def time_op(): 173 | print('\n Timing the op') 174 | from time import time 175 | IMG_SIZE = 64 176 | NS = 50 177 | CVAL = 0.3 178 | img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.float32) * CVAL 179 | 180 | sensors = np.random.rand(NS, 2) 181 | sensors = torch.tensor(sensors, requires_grad=True).cuda() 182 | 183 | A = TravelTimeOperator(sensors, IMG_SIZE).cuda() 184 | 185 | print('Operator built.') 186 | 187 | img = torch.from_numpy(img).cuda() 188 | t = time() 189 | for _ in range(500): 190 | y = A(img) 191 | 192 | print(f'Time per forward: {(time()-t)/500}s') 193 | 194 | def unit_test(): 195 | print('\n Unit testing the op') 196 | from time import time 197 | IMG_SIZE = 1024 198 | NS = 50 199 | CVAL = 0.3 200 | img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.float32) * CVAL 201 | 202 | sensors = np.random.rand(NS, 2) 203 | sensors = torch.tensor(sensors).cuda() 204 | sensors.requires_grad_(True) 205 | 206 | A = TravelTimeOperator(sensors, IMG_SIZE) 207 | print(f'A length is on {A.lengths.device}') 208 | print(f'A idx is on {A.idx.device}') 209 | 210 | y = A(torch.from_numpy(img).cuda()) 211 | 212 | obj = torch.sum(y) 213 | obj.backward() 214 | 215 | y_t = y.cpu().detach().numpy() 216 | print(sensors.grad) 217 | 218 | 219 | gt = np.zeros(int(NS * (NS-1) / 2)) 220 | c = 0 221 | for i in range(NS-1): 222 | for j in range(i+1, NS): 223 | gt[c] = torch.norm(sensors[i] - sensors[j]) * CVAL 224 | c += 1 225 | 226 | assert np.allclose(y_t, gt, atol=1e-5), f'Max abs error = {np.abs(y_t - gt).max()}.' 227 | 228 | 229 | def unit_test_2sensor(): 230 | """Unit testing the 2 sensor row. 231 | 232 | KNOWN BUG: 233 | - In the numpy version of this code, there is a chance 234 | that two sensors are so close in either x or y that when computing 235 | the intersection points along x or y output can be None. This causes 236 | failure but since this happens rarely it is not fixed. 237 | """ 238 | 239 | print('\n Unit testing each row of the op') 240 | 241 | IMG_SIZE = 64 242 | img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.float32) * 0.3 243 | # img = np.random.rand(IMG_SIZE, IMG_SIZE) 244 | 245 | # sensors = np.array([[-1.0, -1.0], [1.0, 1.0]]) 246 | sensors = np.random.rand(2, 2).astype(np.float32) 247 | sensors_t = torch.from_numpy(sensors).cuda() 248 | sensors_t.requires_grad_(True) 249 | A = tt_2sensors(sensors_t, IMG_SIZE).cuda() 250 | 251 | print(f'tt_2sensors length is on {A.lengths.device}') 252 | print(f'tt_2sensors idx is on {A.idx.device}') 253 | 254 | y = A(torch.from_numpy(img).cuda()) 255 | 256 | y.backward() 257 | print(sensors_t.grad) 258 | 259 | y_t = y.cpu().detach().numpy() 260 | 261 | y_np = sense_2(sensors, img).astype(np.float32) 262 | # gt = np.diag(img).sum()/32.0*np.sqrt(2) 263 | gt = np.linalg.norm(sensors[0] - sensors[1]) * 0.3 264 | 265 | assert np.abs(y_t - y_np) < 1e-6, f'Got {y_t}, expected {y_np}.' 266 | 267 | assert np.isclose(y_t, gt), f'Got {y_t}, expected {gt}.' 268 | 269 | 270 | if __name__ == '__main__': 271 | # unit_test_2sensor() 272 | # unit_test() 273 | time_op() 274 | 275 | 276 | -------------------------------------------------------------------------------- /normflow/nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from . import utils 4 | 5 | # Try importing ResNet dependencies 6 | try: 7 | from residual_flows.layers.base import Swish, InducedNormLinear, InducedNormConv2d 8 | except: 9 | print('Warning: Dependencies for Residual Networks could ' 10 | 'not be loaded. Other models can still be used.') 11 | 12 | 13 | class MLP(nn.Module): 14 | """ 15 | A multilayer perceptron with Leaky ReLU nonlinearities 16 | """ 17 | 18 | def __init__(self, layers, leaky=0.0, score_scale=None, output_fn=None, 19 | output_scale=None, init_zeros=False, dropout=None): 20 | """ 21 | :param layers: list of layer sizes from start to end 22 | :param leaky: slope of the leaky part of the ReLU, 23 | if 0.0, standard ReLU is used 24 | :param score_scale: Factor to apply to the scores, i.e. output before 25 | output_fn. 26 | :param output_fn: String, function to be applied to the output, either 27 | None, "sigmoid", "relu", "tanh", or "clampexp" 28 | :param output_scale: Rescale outputs if output_fn is specified, i.e. 29 | scale * output_fn(out / scale) 30 | :param init_zeros: Flag, if true, weights and biases of last layer 31 | are initialized with zeros (helpful for deep models, see arXiv 1807.03039) 32 | :param dropout: Float, if specified, dropout is done before last layer; 33 | if None, no dropout is done 34 | """ 35 | super().__init__() 36 | net = nn.ModuleList([]) 37 | for k in range(len(layers)-2): 38 | net.append(nn.Linear(layers[k], layers[k+1])) 39 | net.append(nn.LeakyReLU(leaky)) 40 | if dropout is not None: 41 | net.append(nn.Dropout(p=dropout)) 42 | net.append(nn.Linear(layers[-2], layers[-1])) 43 | if init_zeros: 44 | nn.init.zeros_(net[-1].weight) 45 | nn.init.zeros_(net[-1].bias) 46 | if output_fn is not None: 47 | if score_scale is not None: 48 | net.append(utils.ConstScaleLayer(score_scale)) 49 | if output_fn == "sigmoid": 50 | net.append(nn.Sigmoid()) 51 | elif output_fn == "relu": 52 | net.append(nn.ReLU()) 53 | elif output_fn == "tanh": 54 | net.append(nn.Tanh()) 55 | elif output_fn == "clampexp": 56 | net.append(utils.ClampExp()) 57 | else: 58 | NotImplementedError("This output function is not implemented.") 59 | if output_scale is not None: 60 | net.append(utils.ConstScaleLayer(output_scale)) 61 | self.net = nn.Sequential(*net) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | class ConvNet2d(nn.Module): 68 | """ 69 | Convolutional Neural Network with leaky ReLU nonlinearities 70 | """ 71 | 72 | def __init__(self, channels, kernel_size, leaky=0.0, init_zeros=True, 73 | actnorm=False, weight_std=None): 74 | """ 75 | Constructor 76 | :param channels: List of channels of conv layers, first entry is in_channels 77 | :param kernel_size: List of kernel sizes, same for height and width 78 | :param leaky: Leaky part of ReLU 79 | :param init_zeros: Flag whether last layer shall be initialized with zeros 80 | :param scale_output: Flag whether to scale output with a log scale parameter 81 | :param logscale_factor: Constant factor to be multiplied to log scaling 82 | :param actnorm: Flag whether activation normalization shall be done after 83 | each conv layer except output 84 | :param weight_std: Fixed std used to initialize every layer 85 | """ 86 | super().__init__() 87 | # Build network 88 | net = nn.ModuleList([]) 89 | for i in range(len(kernel_size) - 1): 90 | conv = nn.Conv2d(channels[i], channels[i + 1], kernel_size[i], 91 | padding=kernel_size[i] // 2, bias=(not actnorm)) 92 | if weight_std is not None: 93 | conv.weight.data.normal_(mean=0.0, std=weight_std) 94 | net.append(conv) 95 | if actnorm: 96 | net.append(utils.ActNorm((channels[i + 1],) + (1, 1))) 97 | net.append(nn.LeakyReLU(leaky)) 98 | i = len(kernel_size) 99 | net.append(nn.Conv2d(channels[i - 1], channels[i], kernel_size[i - 1], 100 | padding=kernel_size[i - 1] // 2)) 101 | if init_zeros: 102 | nn.init.zeros_(net[-1].weight) 103 | nn.init.zeros_(net[-1].bias) 104 | self.net = nn.Sequential(*net) 105 | 106 | def forward(self, x): 107 | return self.net(x) 108 | 109 | 110 | # Lipschitz continuous neural nets for residual flow 111 | 112 | class LipschitzMLP(nn.Module): 113 | """ 114 | Fully connected neural net which is Lipschitz continuous 115 | with Lipschitz constant L < 1 116 | """ 117 | def __init__(self, channels, lipschitz_const=0.97, max_lipschitz_iter=5, 118 | lipschitz_tolerance=None, init_zeros=True): 119 | """ 120 | Constructor 121 | :param channels: Integer list with the number of channels of 122 | the layers 123 | :param lipschitz_const: Maximum Lipschitz constant of each layer 124 | :param max_lipschitz_iter: Maximum number of iterations used to 125 | ensure that layers are Lipschitz continuous with L smaller than 126 | set maximum; if None, tolerance is used 127 | :param lipschitz_tolerance: Float, tolerance used to ensure 128 | Lipschitz continuity if max_lipschitz_iter is None, typically 1e-3 129 | :param init_zeros: Flag, whether to initialize last layer 130 | approximately with zeros 131 | """ 132 | super().__init__() 133 | 134 | self.n_layers = len(channels) - 1 135 | self.channels = channels 136 | self.lipschitz_const = lipschitz_const 137 | self.max_lipschitz_iter = max_lipschitz_iter 138 | self.lipschitz_tolerance = lipschitz_tolerance 139 | self.init_zeros = init_zeros 140 | 141 | layers = [] 142 | for i in range(self.n_layers): 143 | layers += [Swish(), 144 | InducedNormLinear(in_features=channels[i], 145 | out_features=channels[i + 1], coeff=lipschitz_const, 146 | domain=2, codomain=2, n_iterations=max_lipschitz_iter, 147 | atol=lipschitz_tolerance, rtol=lipschitz_tolerance, 148 | zero_init=init_zeros if i == (self.n_layers - 1) else False)] 149 | 150 | self.net = nn.Sequential(*layers) 151 | 152 | def forward(self, x): 153 | return self.net(x) 154 | 155 | 156 | class LipschitzCNN(nn.Module): 157 | """ 158 | Convolutional neural network which is Lipschitz continuous 159 | with Lipschitz constant L < 1 160 | """ 161 | def __init__(self, channels, kernel_size, lipschitz_const=0.97, 162 | max_lipschitz_iter=5, lipschitz_tolerance=None, 163 | init_zeros=True): 164 | """ 165 | Constructor 166 | :param channels: Integer list with the number of channels of 167 | the layers 168 | :param kernel_size: Integer list of kernel sizes of the layers 169 | :param lipschitz_const: Maximum Lipschitz constant of each layer 170 | :param max_lipschitz_iter: Maximum number of iterations used to 171 | ensure that layers are Lipschitz continuous with L smaller than 172 | set maximum; if None, tolerance is used 173 | :param lipschitz_tolerance: Float, tolerance used to ensure 174 | Lipschitz continuity if max_lipschitz_iter is None, typically 1e-3 175 | :param init_zeros: Flag, whether to initialize last layer 176 | approximately with zeros 177 | """ 178 | super().__init__() 179 | 180 | self.n_layers = len(kernel_size) 181 | self.channels = channels 182 | self.kernel_size = kernel_size 183 | self.lipschitz_const = lipschitz_const 184 | self.max_lipschitz_iter = max_lipschitz_iter 185 | self.lipschitz_tolerance = lipschitz_tolerance 186 | self.init_zeros = init_zeros 187 | 188 | layers = [] 189 | for i in range(self.n_layers): 190 | layers += [Swish(), 191 | InducedNormConv2d(in_channels=channels[i], 192 | out_channels=channels[i + 1], kernel_size=kernel_size[i], 193 | stride=1, padding=kernel_size[i] // 2, bias=True, 194 | coeff=lipschitz_const, domain=2, codomain=2, 195 | n_iterations=max_lipschitz_iter, atol=lipschitz_tolerance, 196 | rtol=lipschitz_tolerance, 197 | zero_init=init_zeros if i == self.n_layers - 1 else False)] 198 | 199 | self.net = nn.Sequential(*layers) 200 | 201 | def forward(self, x): 202 | return self.net(x) -------------------------------------------------------------------------------- /normflow/distributions/prior.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | 7 | class PriorDistribution: 8 | def __init__(self): 9 | raise NotImplementedError 10 | 11 | def log_prob(self, z): 12 | """ 13 | :param z: value or batch of latent variable 14 | :return: log probability of the distribution for z 15 | """ 16 | raise NotImplementedError 17 | 18 | 19 | class ImagePrior(nn.Module): 20 | """ 21 | Intensities of an image determine probability density of prior 22 | """ 23 | 24 | def __init__(self, image, x_range=[-3, 3], y_range=[-3, 3], eps=1.e-10): 25 | """ 26 | Constructor 27 | :param image: image as np matrix 28 | :param x_range: x range to position image at 29 | :param y_range: y range to position image at 30 | :param eps: small value to add to image to avoid log(0) problems 31 | """ 32 | super().__init__() 33 | image_ = np.flip(image, 0).transpose() + eps 34 | self.image_cpu = torch.tensor(image_ / np.max(image_)) 35 | self.image_size_cpu = self.image_cpu.size() 36 | self.x_range = torch.tensor(x_range) 37 | self.y_range = torch.tensor(y_range) 38 | 39 | self.register_buffer('image', self.image_cpu) 40 | self.register_buffer('image_size', torch.tensor(self.image_size_cpu).unsqueeze(0)) 41 | self.register_buffer('density', torch.log(self.image_cpu / torch.sum(self.image_cpu))) 42 | self.register_buffer('scale', torch.tensor([[self.x_range[1] - self.x_range[0], 43 | self.y_range[1] - self.y_range[0]]])) 44 | self.register_buffer('shift', torch.tensor([[self.x_range[0], self.y_range[0]]])) 45 | 46 | def log_prob(self, z): 47 | """ 48 | :param z: value or batch of latent variable 49 | :return: log probability of the distribution for z 50 | """ 51 | z_ = torch.clamp((z - self.shift) / self.scale, max=1, min=0) 52 | ind = (z_ * (self.image_size - 1)).long() 53 | return self.density[ind[:, 0], ind[:, 1]] 54 | 55 | def rejection_sampling(self, num_steps=1): 56 | """ 57 | Perform rejection sampling on image distribution 58 | :param num_steps: Number of rejection sampling steps to perform 59 | :return: Accepted samples 60 | """ 61 | z_ = torch.rand((num_steps, 2), dtype=self.image.dtype, device=self.image.device) 62 | prob = torch.rand(num_steps, dtype=self.image.dtype, device=self.image.device) 63 | ind = (z_ * (self.image_size - 1)).long() 64 | intensity = self.image[ind[:, 0], ind[:, 1]] 65 | accept = intensity > prob 66 | z = z_[accept, :] * self.scale + self.shift 67 | return z 68 | 69 | def sample(self, num_samples=1): 70 | """ 71 | Sample from image distribution through rejection sampling 72 | :param num_samples: Number of samples to draw 73 | :return: Samples 74 | """ 75 | z = torch.ones((0, 2), dtype=self.image.dtype, device=self.image.device) 76 | while len(z) < num_samples: 77 | z_ = self.rejection_sampling(num_samples) 78 | ind = np.min([len(z_), num_samples - len(z)]) 79 | z = torch.cat([z, z_[:ind, :]], 0) 80 | return z 81 | 82 | 83 | class TwoModes(PriorDistribution): 84 | def __init__(self, loc, scale): 85 | """ 86 | Distribution 2d with two modes at z[0] = -loc and z[0] = loc 87 | :param loc: distance of modes from the origin 88 | :param scale: scale of modes 89 | """ 90 | self.loc = loc 91 | self.scale = scale 92 | 93 | def log_prob(self, z): 94 | """ 95 | log(p) = 1/2 * ((norm(z) - loc) / (2 * scale)) ** 2 96 | - log(exp(-1/2 * ((z[0] - loc) / (3 * scale)) ** 2) + exp(-1/2 * ((z[0] + loc) / (3 * scale)) ** 2)) 97 | :param z: value or batch of latent variable 98 | :return: log probability of the distribution for z 99 | """ 100 | a = torch.abs(z[:, 0]) 101 | eps = torch.abs(torch.tensor(self.loc)) 102 | 103 | log_prob = - 0.5 * ((torch.norm(z, dim=1) - self.loc) / (2 * self.scale)) ** 2 \ 104 | - 0.5 * ((a - eps) / (3 * self.scale)) ** 2 \ 105 | + torch.log(1 + torch.exp(-2 * (a * eps) / (3 * self.scale) ** 2)) 106 | 107 | return log_prob 108 | 109 | 110 | class Sinusoidal(PriorDistribution): 111 | def __init__(self, scale, period): 112 | """ 113 | Distribution 2d with sinusoidal density 114 | :param loc: distance of modes from the origin 115 | :param scale: scale of modes 116 | """ 117 | self.scale = scale 118 | self.period = period 119 | 120 | def log_prob(self, z): 121 | """ 122 | log(p) = - 1/2 * ((z[1] - w_1(z)) / (2 * scale)) ** 2 123 | w_1(z) = sin(2*pi / period * z[0]) 124 | :param z: value or batch of latent variable 125 | :return: log probability of the distribution for z 126 | """ 127 | if z.dim() > 1: 128 | z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1))) 129 | else: 130 | z_ = z 131 | 132 | w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0]) 133 | log_prob = - 0.5 * ((z_[1] - w_1(z_)) / (self.scale)) ** 2 \ 134 | - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4 # add Gaussian envelope for valid p(z) 135 | 136 | return log_prob 137 | 138 | 139 | class Sinusoidal_gap(PriorDistribution): 140 | def __init__(self, scale, period): 141 | """ 142 | Distribution 2d with sinusoidal density with gap 143 | :param loc: distance of modes from the origin 144 | :param scale: scale of modes 145 | """ 146 | self.scale = scale 147 | self.period = period 148 | self.w2_scale = 0.6 149 | self.w2_amp = 3.0 150 | self.w2_mu = 1.0 151 | 152 | def log_prob(self, z): 153 | """ 154 | :param z: value or batch of latent variable 155 | :return: log probability of the distribution for z 156 | """ 157 | if z.dim() > 1: 158 | z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1))) 159 | else: 160 | z_ = z 161 | 162 | w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0]) 163 | w_2 = lambda x: self.w2_amp * torch.exp(-0.5 * ((z_[0] - self.w2_mu) / self.w2_scale) ** 2) 164 | 165 | eps = torch.abs(w_2(z_) / 2) 166 | a = torch.abs(z_[1] - w_1(z_) + w_2(z_) / 2) 167 | 168 | log_prob = -0.5 * ((a - eps) / self.scale) ** 2 + \ 169 | torch.log(1 + torch.exp(-2 * (eps * a) / self.scale ** 2)) \ 170 | - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4 171 | 172 | return log_prob 173 | 174 | 175 | class Sinusoidal_split(PriorDistribution): 176 | def __init__(self, scale, period): 177 | """ 178 | Distribution 2d with sinusoidal density with split 179 | :param loc: distance of modes from the origin 180 | :param scale: scale of modes 181 | """ 182 | self.scale = scale 183 | self.period = period 184 | self.w3_scale = 0.3 185 | self.w3_amp = 3.0 186 | self.w3_mu = 1.0 187 | 188 | def log_prob(self, z): 189 | """ 190 | :param z: value or batch of latent variable 191 | :return: log probability of the distribution for z 192 | """ 193 | if z.dim() > 1: 194 | z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1))) 195 | else: 196 | z_ = z 197 | 198 | w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0]) 199 | w_3 = lambda x: self.w3_amp * torch.sigmoid((z_[0] - self.w3_mu) / self.w3_scale) 200 | 201 | eps = torch.abs(w_3(z_) / 2) 202 | a = torch.abs(z_[1] - w_1(z_) + w_3(z_) / 2) 203 | 204 | log_prob = -0.5 * ((a - eps) / (self.scale)) ** 2 + \ 205 | torch.log(1 + torch.exp(-2 * (eps * a) / self.scale ** 2)) \ 206 | - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4 207 | 208 | return log_prob 209 | 210 | 211 | class Smiley(PriorDistribution): 212 | def __init__(self, scale): 213 | """ 214 | Distribution 2d of a smiley :) 215 | :param loc: distance of modes from the origin 216 | :param scale: scale of modes 217 | """ 218 | self.scale = scale 219 | self.loc = 2.0 220 | 221 | def log_prob(self, z): 222 | """ 223 | :param z: value or batch of latent variable 224 | :return: log probability of the distribution for z 225 | """ 226 | if z.dim() > 1: 227 | z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1))) 228 | else: 229 | z_ = z 230 | 231 | log_prob = - 0.5 * ((torch.norm(z_, dim=0) - self.loc) / (2 * self.scale)) ** 2 \ 232 | - 0.5 * ((torch.abs(z_[1] + 0.8) - 1.2) / (2 * self.scale)) ** 2 233 | 234 | return log_prob -------------------------------------------------------------------------------- /ops/odl_lib.py: -------------------------------------------------------------------------------- 1 | """Holds all the odl related utilities.""" 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import torch 5 | 6 | import odl 7 | 8 | # from odl.contrib.torch import OperatorFunction 9 | # from odl.contrib.torch import OperatorFunction 10 | from ops.ODLHelper import OperatorFunction 11 | 12 | class OperatorAsAutogradFunction(object): 13 | """Dummy class around OperatorFunction. 14 | 15 | Check https://github.com/swing-research/dual_implicit_rep/issues/1 . 16 | """ 17 | def __init__(self, odl_op): 18 | self.op = odl_op 19 | 20 | def __call__(self, x): 21 | return OperatorFunction.apply(self.op, x) 22 | 23 | 24 | def apply_angle_noise(angles, noise): 25 | """Applies operator noise in the angles of the operator. 26 | 27 | SNR = 20 log (P_angles/P_noise) 28 | Args: 29 | angles (np.ndarray): 1D array of angles used in the Radon operator. 30 | noise (float): SNR of the noise. 31 | Returns: 32 | noisy_angles (np.ndarray): Noisy angles at `noise` dB. 33 | """ 34 | if noise > 200: 35 | return angles 36 | 37 | noise_std_dev = 10**(-noise/20) 38 | 39 | noisy_angles = angles * (1 + np.random.randn(*angles.shape)*noise_std_dev) 40 | 41 | return noisy_angles 42 | 43 | 44 | class ParallelBeamGeometryOp(object): 45 | """Creates an `img_size` mesh parallel geometry tomography operator.""" 46 | 47 | def __init__(self, img_size, num_angles, op_snr=500): 48 | self.img_size = img_size 49 | self.num_angles = num_angles 50 | self.reco_space = odl.uniform_discr( 51 | min_pt=[-20, -20], 52 | max_pt=[20, 20], 53 | shape=[img_size, img_size], 54 | dtype='float32' 55 | ) 56 | 57 | self.geometry = odl.tomo.parallel_beam_geometry( 58 | self.reco_space, num_angles) 59 | 60 | self.num_detectors = self.geometry.detector.size 61 | self.op_snr = op_snr 62 | self.angles = apply_angle_noise(self.geometry.angles, op_snr) 63 | 64 | self.optimizable_params = torch.tensor( 65 | self.angles, dtype=torch.float32) # Convert to torch.Tensor. 66 | 67 | self.op = odl.tomo.RayTransform( 68 | self.reco_space, 69 | self.geometry, 70 | impl='astra_cuda') 71 | 72 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op) 73 | 74 | def __call__(self, x): 75 | return OperatorFunction.apply(self.op, x) 76 | 77 | def pinv(self, y): 78 | return OperatorFunction.apply(self.fbp, y) 79 | 80 | class ParallelBeamGeometryOpBroken(ParallelBeamGeometryOp): 81 | """Creates a noisy angle instance of ParallelBeamGeometryOp 82 | 83 | # Steps taken from implementation of odl.tomo.parallel_beam_geometry 84 | # https://github.com/odlgroup/odl/blob/master/odl/tomo/geometry/parallel.py#L1471 85 | 86 | Notes 87 | ----- 88 | According to [NW2001]_, pages 72--74, a function 89 | :math:`f : \mathbb{R}^2 \to \mathbb{R}` that has compact support 90 | .. math:: 91 | | x | > rho implies f(x) = 0, 92 | and is essentially bandlimited 93 | .. math:: 94 | | xi | > Omega implies hat{f}(xi) approx 0, 95 | can be fully reconstructed from a parallel beam ray transform 96 | if (1) the projection angles are sampled with a spacing of 97 | :math:`Delta psi` such that 98 | .. math:: 99 | Delta psi leq frac{pi}{rho Omega}, 100 | and (2) the detector is sampled with an interval :math:`Delta s` 101 | that satisfies 102 | .. math:: 103 | Delta s leq frac{pi}{Omega}. 104 | The geometry returned by this function satisfies these conditions exactly. 105 | If the domain is 3-dimensional, the geometry is "separable", in that each 106 | slice along the z-dimension of the data is treated as independed 2d data. 107 | References 108 | ---------- 109 | .. [NW2001] Natterer, F and Wuebbeling, F. 110 | *Mathematical Methods in Image Reconstruction*. 111 | SIAM, 2001. 112 | https://dx.doi.org/10.1137/1.9780898718324 113 | """ 114 | def __init__(self, clean_operator, op_snr): 115 | super().__init__(clean_operator.img_size, clean_operator.num_angles, op_snr) 116 | 117 | space = self.reco_space 118 | 119 | # Find maximum distance from rotation axis 120 | corners = space.domain.corners()[:, :2] 121 | rho = np.max(np.linalg.norm(corners, axis=1)) 122 | 123 | # Find default values according to Nyquist criterion. 124 | 125 | # We assume that the function is bandlimited by a wave along the x or y 126 | # axis. The highest frequency we can measure is then a standing wave with 127 | # period of twice the inter-node distance. 128 | min_side = min(space.partition.cell_sides[:2]) 129 | omega = np.pi / min_side 130 | num_px_horiz = 2 * int(np.ceil(rho * omega / np.pi)) + 1 131 | det_min_pt = -rho 132 | det_max_pt = rho 133 | det_shape = num_px_horiz 134 | det_partition = odl.discr.uniform_partition(det_min_pt, det_max_pt, det_shape) 135 | 136 | self.angles = apply_angle_noise(clean_operator.geometry.angles, op_snr) 137 | 138 | self.optimizable_params = torch.tensor(clean_operator.geometry.angles, dtype=torch.float32) 139 | 140 | # angle partition is changed to not be uniform 141 | angle_partition = odl.discr.nonuniform_partition(np.sort(self.angles)) 142 | 143 | self.geometry = odl.tomo.Parallel2dGeometry(angle_partition, det_partition) 144 | 145 | self.num_detectors = self.geometry.detector.size 146 | 147 | self.op = odl.tomo.RayTransform( 148 | self.reco_space, 149 | self.geometry, 150 | impl='astra_cuda') 151 | 152 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op) 153 | 154 | 155 | class ParallelBeamGeometryOpNonUniform(ParallelBeamGeometryOp): 156 | """Creates a noisy angle instance of ParallelBeamGeometryOp 157 | 158 | # Steps taken from implementation of odl.tomo.parallel_beam_geometry 159 | # https://github.com/odlgroup/odl/blob/master/odl/tomo/geometry/parallel.py#L1471 160 | 161 | Notes 162 | ----- 163 | According to [NW2001]_, pages 72--74, a function 164 | :math:`f : \mathbb{R}^2 \to \mathbb{R}` that has compact support 165 | .. math:: 166 | | x | > rho implies f(x) = 0, 167 | and is essentially bandlimited 168 | .. math:: 169 | | xi | > Omega implies hat{f}(xi) approx 0, 170 | can be fully reconstructed from a parallel beam ray transform 171 | if (1) the projection angles are sampled with a spacing of 172 | :math:`Delta psi` such that 173 | .. math:: 174 | Delta psi leq frac{pi}{rho Omega}, 175 | and (2) the detector is sampled with an interval :math:`Delta s` 176 | that satisfies 177 | .. math:: 178 | Delta s leq frac{pi}{Omega}. 179 | The geometry returned by this function satisfies these conditions exactly. 180 | If the domain is 3-dimensional, the geometry is "separable", in that each 181 | slice along the z-dimension of the data is treated as independed 2d data. 182 | References 183 | ---------- 184 | .. [NW2001] Natterer, F and Wuebbeling, F. 185 | *Mathematical Methods in Image Reconstruction*. 186 | SIAM, 2001. 187 | https://dx.doi.org/10.1137/1.9780898718324 188 | """ 189 | def __init__(self, clean_operator, nonuniform_angles): 190 | super().__init__(clean_operator.img_size, clean_operator.num_angles, clean_operator.op_snr) 191 | 192 | space = self.reco_space 193 | 194 | # Find maximum distance from rotation axis 195 | corners = space.domain.corners()[:, :2] 196 | rho = np.max(np.linalg.norm(corners, axis=1)) 197 | 198 | # Find default values according to Nyquist criterion. 199 | 200 | # We assume that the function is bandlimited by a wave along the x or y 201 | # axis. The highest frequency we can measure is then a standing wave with 202 | # period of twice the inter-node distance. 203 | min_side = min(space.partition.cell_sides[:2]) 204 | omega = np.pi / min_side 205 | num_px_horiz = 2 * int(np.ceil(rho * omega / np.pi)) + 1 206 | det_min_pt = -rho 207 | det_max_pt = rho 208 | det_shape = num_px_horiz 209 | det_partition = odl.discr.uniform_partition(det_min_pt, det_max_pt, det_shape) 210 | 211 | del self.angles 212 | 213 | # angle partition is changed to not be uniform 214 | angle_partition = odl.discr.nonuniform_partition(np.sort(nonuniform_angles)) 215 | 216 | self.geometry = odl.tomo.Parallel2dGeometry(angle_partition, det_partition) 217 | 218 | self.num_detectors = self.geometry.detector.size 219 | 220 | self.op = odl.tomo.RayTransform( 221 | self.reco_space, 222 | self.geometry, 223 | impl='astra_cuda') 224 | 225 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op) 226 | 227 | 228 | def unit_test(): 229 | op_64_32 = ParallelBeamGeometryOp(1024, 32) 230 | phantom = odl.phantom.shepp_logan( 231 | op_64_32.reco_space, modified=True) 232 | 233 | x = torch.from_numpy(phantom.data) 234 | y = op_64_32(x) 235 | 236 | print(y.shape) 237 | print(x.shape) 238 | 239 | def unit_test_nonuniform(): 240 | 241 | operator1 = ParallelBeamGeometryOp(64, 45) 242 | x = torch.from_numpy(odl.phantom.shepp_logan(operator1.reco_space, modified=True).data) 243 | y1 = operator1(x) 244 | 245 | operator2 = ParallelBeamGeometryOpNonUniform(operator1, operator1.geometry.angles) 246 | y2 = operator2(x) 247 | 248 | print(y1 - y2) 249 | 250 | if __name__ == '__main__': 251 | unit_test() 252 | unit_test_nonuniform() 253 | -------------------------------------------------------------------------------- /normflow/flows/mixing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .base import Flow 5 | 6 | # Try importing Neural Spline Flow dependencies 7 | try: 8 | from neural_spline_flows.nde.transforms.permutations import RandomPermutation 9 | from neural_spline_flows.nde.transforms.lu import LULinear 10 | except: 11 | print('Warning: Dependencies for Neural Spline Flows could ' 12 | 'not be loaded. Other models can still be used.') 13 | 14 | 15 | class Permute(Flow): 16 | """ 17 | Permutation features along the channel dimension 18 | """ 19 | 20 | def __init__(self, num_channels, mode='shuffle'): 21 | """ 22 | Constructor 23 | :param num_channel: Number of channels 24 | :param mode: Mode of permuting features, can be shuffle for 25 | random permutation or swap for interchanging upper and lower part 26 | """ 27 | super().__init__() 28 | self.mode = mode 29 | self.num_channels = num_channels 30 | if self.mode == 'shuffle': 31 | perm = torch.randperm(self.num_channels) 32 | inv_perm = torch.empty_like(perm).scatter_(dim=0, index=perm, 33 | src=torch.arange(self.num_channels)) 34 | self.register_buffer("perm", perm) 35 | self.register_buffer("inv_perm", inv_perm) 36 | 37 | def forward(self, z): 38 | if self.mode == 'shuffle': 39 | z = z[:, self.perm, ...] 40 | elif self.mode == 'swap': 41 | z1 = z[:, :self.num_channels // 2, ...] 42 | z2 = z[:, self.num_channels // 2:, ...] 43 | z = torch.cat([z2, z1], dim=1) 44 | else: 45 | raise NotImplementedError('The mode ' + self.mode + ' is not implemented.') 46 | log_det = 0 47 | return z, log_det 48 | 49 | def inverse(self, z): 50 | if self.mode == 'shuffle': 51 | z = z[:, self.inv_perm, ...] 52 | elif self.mode == 'swap': 53 | z1 = z[:, :(self.num_channels + 1) // 2, ...] 54 | z2 = z[:, (self.num_channels + 1) // 2:, ...] 55 | z = torch.cat([z2, z1], dim=1) 56 | else: 57 | raise NotImplementedError('The mode ' + self.mode + ' is not implemented.') 58 | log_det = 0 59 | return z, log_det 60 | 61 | 62 | class Invertible1x1Conv(Flow): 63 | """ 64 | Invertible 1x1 convolution introduced in the Glow paper 65 | Assumes 4d input/output tensors of the form NCHW 66 | """ 67 | 68 | def __init__(self, num_channels, use_lu=False): 69 | """ 70 | Constructor 71 | :param num_channels: Number of channels of the data 72 | :param use_lu: Flag whether to parametrize weights through the LU decomposition 73 | """ 74 | super().__init__() 75 | self.num_channels = num_channels 76 | self.use_lu = use_lu 77 | Q = torch.qr(torch.randn(self.num_channels, self.num_channels))[0] 78 | if use_lu: 79 | P, L, U = torch.lu_unpack(*Q.lu()) 80 | self.register_buffer('P', P) # remains fixed during optimization 81 | self.L = nn.Parameter(L) # lower triangular portion 82 | S = U.diag() # "crop out" the diagonal to its own parameter 83 | self.register_buffer("sign_S", torch.sign(S)) 84 | self.log_S = nn.Parameter(torch.log(torch.abs(S))) 85 | self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S 86 | self.register_buffer("eye", torch.diag(torch.ones(self.num_channels))) 87 | else: 88 | self.W = nn.Parameter(Q) 89 | 90 | def _assemble_W(self, inverse=False): 91 | # assemble W from its components (P, L, U, S) 92 | L = torch.tril(self.L, diagonal=-1) + self.eye 93 | U = torch.triu(self.U, diagonal=1) + torch.diag(self.sign_S * torch.exp(self.log_S)) 94 | if inverse: 95 | if self.log_S.dtype == torch.float64: 96 | L_inv = torch.inverse(L) 97 | U_inv = torch.inverse(U) 98 | else: 99 | L_inv = torch.inverse(L.double()).type(self.log_S.dtype) 100 | U_inv = torch.inverse(U.double()).type(self.log_S.dtype) 101 | W = U_inv @ L_inv @ self.P.t() 102 | else: 103 | W = self.P @ L @ U 104 | return W 105 | 106 | def forward(self, z): 107 | if self.use_lu: 108 | W = self._assemble_W(inverse=True) 109 | log_det = -torch.sum(self.log_S) 110 | else: 111 | W_dtype = self.W.dtype 112 | if W_dtype == torch.float64: 113 | W = torch.inverse(self.W) 114 | else: 115 | W = torch.inverse(self.W.double()).type(W_dtype) 116 | W = W.view(*W.size(), 1, 1) 117 | log_det = -torch.slogdet(self.W)[1] 118 | W = W.view(self.num_channels, self.num_channels, 1, 1) 119 | z_ = torch.nn.functional.conv2d(z, W) 120 | log_det = log_det * z.size(2) * z.size(3) 121 | return z_, log_det 122 | 123 | def inverse(self, z): 124 | if self.use_lu: 125 | W = self._assemble_W() 126 | log_det = torch.sum(self.log_S) 127 | else: 128 | W = self.W 129 | log_det = torch.slogdet(self.W)[1] 130 | W = W.view(self.num_channels, self.num_channels, 1, 1) 131 | z_ = torch.nn.functional.conv2d(z, W) 132 | log_det = log_det * z.size(2) * z.size(3) 133 | return z_, log_det 134 | 135 | 136 | class InvertibleAffine(Flow): 137 | """ 138 | Invertible affine transformation without shift, i.e. one-dimensional 139 | version of the invertible 1x1 convolutions 140 | """ 141 | 142 | def __init__(self, num_channels, use_lu=True): 143 | """ 144 | Constructor 145 | :param num_channels: Number of channels of the data 146 | :param use_lu: Flag whether to parametrize weights through the 147 | LU decomposition 148 | """ 149 | super().__init__() 150 | self.num_channels = num_channels 151 | self.use_lu = use_lu 152 | Q = torch.qr(torch.randn(self.num_channels, self.num_channels))[0] 153 | if use_lu: 154 | P, L, U = torch.lu_unpack(*Q.lu()) 155 | self.register_buffer('P', P) # remains fixed during optimization 156 | self.L = nn.Parameter(L) # lower triangular portion 157 | S = U.diag() # "crop out" the diagonal to its own parameter 158 | self.register_buffer("sign_S", torch.sign(S)) 159 | self.log_S = nn.Parameter(torch.log(torch.abs(S))) 160 | self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S 161 | self.register_buffer("eye", torch.diag(torch.ones(self.num_channels))) 162 | else: 163 | self.W = nn.Parameter(Q) 164 | 165 | def _assemble_W(self, inverse=False): 166 | # assemble W from its components (P, L, U, S) 167 | L = torch.tril(self.L, diagonal=-1) + self.eye 168 | U = torch.triu(self.U, diagonal=1) + torch.diag(self.sign_S * torch.exp(self.log_S)) 169 | if inverse: 170 | if self.log_S.dtype == torch.float64: 171 | L_inv = torch.inverse(L) 172 | U_inv = torch.inverse(U) 173 | else: 174 | L_inv = torch.inverse(L.double()).type(self.log_S.dtype) 175 | U_inv = torch.inverse(U.double()).type(self.log_S.dtype) 176 | W = U_inv @ L_inv @ self.P.t() 177 | else: 178 | W = self.P @ L @ U 179 | return W 180 | 181 | def forward(self, z): 182 | if self.use_lu: 183 | W = self._assemble_W(inverse=True) 184 | log_det = -torch.sum(self.log_S) 185 | else: 186 | W_dtype = self.W.dtype 187 | if W_dtype == torch.float64: 188 | W = torch.inverse(self.W) 189 | else: 190 | W = torch.inverse(self.W.double()).type(W_dtype) 191 | log_det = -torch.slogdet(self.W)[1] 192 | z_ = z @ W 193 | return z_, log_det 194 | 195 | def inverse(self, z): 196 | if self.use_lu: 197 | W = self._assemble_W() 198 | log_det = torch.sum(self.log_S) 199 | else: 200 | W = self.W 201 | log_det = torch.slogdet(self.W)[1] 202 | z_ = z @ W 203 | return z_, log_det 204 | 205 | 206 | class LULinearPermute(Flow): 207 | """ 208 | Fixed permutation combined with a linear transformation parametrized 209 | using the LU decomposition, used in https://arxiv.org/abs/1906.04032 210 | """ 211 | def __init__(self, num_channels, identity_init=True, reverse=True): 212 | """ 213 | Constructor 214 | :param num_channels: Number of dimensions of the data 215 | :param identity_init: Flag, whether to initialize linear 216 | transform as identity matrix 217 | :param reverse: Flag, change forward and inverse transform 218 | """ 219 | # Initialize 220 | super().__init__() 221 | self.reverse = reverse 222 | 223 | # Define modules 224 | self.permutation = RandomPermutation(num_channels) 225 | self.linear = LULinear(num_channels, identity_init=identity_init) 226 | 227 | def forward(self, z): 228 | if self.reverse: 229 | z, log_det = self.linear.inverse(z) 230 | z, _ = self.permutation.inverse(z) 231 | else: 232 | z, _ = self.permutation(z) 233 | z, log_det = self.linear(z) 234 | return z, log_det.view(-1) 235 | 236 | def inverse(self, z): 237 | if self.reverse: 238 | z, _ = self.permutation(z) 239 | z, log_det = self.linear(z) 240 | else: 241 | z, log_det = self.linear.inverse(z) 242 | z, _ = self.permutation.inverse(z) 243 | return z, log_det.view(-1) 244 | -------------------------------------------------------------------------------- /normflow/flows/affine_coupling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | from .base import Flow 6 | from .reshape import Split, Merge 7 | 8 | 9 | 10 | class AffineConstFlow(Flow): 11 | """ 12 | scales and shifts with learned constants per dimension. In the NICE paper there is a 13 | scaling layer which is a special case of this where t is None 14 | """ 15 | 16 | def __init__(self, shape, scale=True, shift=True): 17 | """ 18 | Constructor 19 | :param shape: Shape of the coupling layer 20 | :param scale: Flag whether to apply scaling 21 | :param shift: Flag whether to apply shift 22 | :param logscale_factor: Optional factor which can be used to control 23 | the scale of the log scale factor 24 | """ 25 | super().__init__() 26 | if scale: 27 | self.s = nn.Parameter(torch.zeros(shape)[None]) 28 | else: 29 | self.register_buffer('s', torch.zeros(shape)[None]) 30 | if shift: 31 | self.t = nn.Parameter(torch.zeros(shape)[None]) 32 | else: 33 | self.register_buffer('t', torch.zeros(shape)[None]) 34 | self.n_dim = self.s.dim() 35 | self.batch_dims = torch.nonzero(torch.tensor(self.s.shape) == 1, as_tuple=False)[:, 0].tolist() 36 | 37 | def forward(self, z): 38 | z_ = z * torch.exp(self.s) + self.t 39 | if len(self.batch_dims) > 1: 40 | prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]]) 41 | else: 42 | prod_batch_dims = 1 43 | log_det = prod_batch_dims * torch.sum(self.s) 44 | return z_, log_det 45 | 46 | def inverse(self, z): 47 | z_ = (z - self.t) * torch.exp(-self.s) 48 | if len(self.batch_dims) > 1: 49 | prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]]) 50 | else: 51 | prod_batch_dims = 1 52 | log_det = -prod_batch_dims * torch.sum(self.s) 53 | return z_, log_det 54 | 55 | 56 | class CCAffineConst(Flow): 57 | """ 58 | Affine constant flow layer with class-conditional parameters 59 | """ 60 | 61 | def __init__(self, shape, num_classes): 62 | super().__init__() 63 | self.shape = shape 64 | self.s = nn.Parameter(torch.zeros(shape)[None]) 65 | self.t = nn.Parameter(torch.zeros(shape)[None]) 66 | self.s_cc = nn.Parameter(torch.zeros(num_classes, np.prod(shape))) 67 | self.t_cc = nn.Parameter(torch.zeros(num_classes, np.prod(shape))) 68 | self.n_dim = self.s.dim() 69 | self.batch_dims = torch.nonzero(torch.tensor(self.s.shape) == 1, as_tuple=False)[:, 0].tolist() 70 | 71 | def forward(self, z, y): 72 | s = self.s + (y @ self.s_cc).view(-1, *self.shape) 73 | t = self.t + (y @ self.t_cc).view(-1, *self.shape) 74 | z_ = z * torch.exp(s) + t 75 | if len(self.batch_dims) > 1: 76 | prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]]) 77 | else: 78 | prod_batch_dims = 1 79 | log_det = prod_batch_dims * torch.sum(s, dim=list(range(1, self.n_dim))) 80 | return z_, log_det 81 | 82 | def inverse(self, z, y): 83 | s = self.s + (y @ self.s_cc).view(-1, *self.shape) 84 | t = self.t + (y @ self.t_cc).view(-1, *self.shape) 85 | z_ = (z - t) * torch.exp(-s) 86 | if len(self.batch_dims) > 1: 87 | prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]]) 88 | else: 89 | prod_batch_dims = 1 90 | log_det = -prod_batch_dims * torch.sum(s, dim=list(range(1, self.n_dim))) 91 | return z_, log_det 92 | 93 | 94 | class AffineCoupling(Flow): 95 | """ 96 | Affine Coupling layer as introduced RealNVP paper, see arXiv: 1605.08803 97 | """ 98 | 99 | def __init__(self, param_map, scale=True, scale_map='exp'): 100 | """ 101 | Constructor 102 | :param param_map: Maps features to shift and scale parameter (if applicable) 103 | :param scale: Flag whether scale shall be applied 104 | :param scale_map: Map to be applied to the scale parameter, can be 'exp' as in 105 | RealNVP or 'sigmoid' as in Glow, 'sigmoid_inv' uses multiplicative sigmoid 106 | scale when sampling from the model 107 | """ 108 | super().__init__() 109 | self.add_module('param_map', param_map) 110 | self.scale = scale 111 | self.scale_map = scale_map 112 | 113 | def forward(self, z): 114 | """ 115 | z is a list of z1 and z2; z = [z1, z2] 116 | z1 is left constant and affine map is applied to z2 with parameters depending 117 | on z1 118 | """ 119 | z1, z2 = z 120 | param = self.param_map(z1) 121 | if self.scale: 122 | shift = param[:, 0::2, ...] 123 | scale_ = param[:, 1::2, ...] 124 | if self.scale_map == 'exp': 125 | z2 = z2 * torch.exp(scale_) + shift 126 | log_det = torch.sum(scale_, dim=list(range(1, shift.dim()))) 127 | elif self.scale_map == 'sigmoid': 128 | scale = torch.sigmoid(scale_ + 2) 129 | z2 = z2 / scale + shift 130 | log_det = -torch.sum(torch.log(scale), 131 | dim=list(range(1, shift.dim()))) 132 | elif self.scale_map == 'sigmoid_inv': 133 | scale = torch.sigmoid(scale_ + 2) 134 | z2 = z2 * scale + shift 135 | log_det = torch.sum(torch.log(scale), 136 | dim=list(range(1, shift.dim()))) 137 | else: 138 | raise NotImplementedError('This scale map is not implemented.') 139 | else: 140 | z2 += param 141 | log_det = 0 142 | return [z1, z2], log_det 143 | 144 | def inverse(self, z): 145 | z1, z2 = z 146 | param = self.param_map(z1) 147 | if self.scale: 148 | shift = param[:, 0::2, ...] 149 | scale_ = param[:, 1::2, ...] 150 | if self.scale_map == 'exp': 151 | z2 = (z2 - shift) * torch.exp(-scale_) 152 | log_det = -torch.sum(scale_, dim=list(range(1, shift.dim()))) 153 | elif self.scale_map == 'sigmoid': 154 | scale = torch.sigmoid(scale_ + 2) 155 | z2 = (z2 - shift) * scale 156 | log_det = torch.sum(torch.log(scale), 157 | dim=list(range(1, shift.dim()))) 158 | elif self.scale_map == 'sigmoid_inv': 159 | scale = torch.sigmoid(scale_ + 2) 160 | z2 = (z2 - shift) / scale 161 | log_det = -torch.sum(torch.log(scale), 162 | dim=list(range(1, shift.dim()))) 163 | else: 164 | raise NotImplementedError('This scale map is not implemented.') 165 | else: 166 | z2 -= param 167 | log_det = 0 168 | return [z1, z2], log_det 169 | 170 | 171 | class MaskedAffineFlow(Flow): 172 | """ 173 | RealNVP as introduced in arXiv: 1605.08803 174 | Masked affine flow f(z) = b * z + (1 - b) * (z * exp(s(b * z)) + t) 175 | class AffineHalfFlow(Flow): is MaskedAffineFlow with alternating bit mask 176 | NICE is AffineFlow with only shifts (volume preserving) 177 | """ 178 | 179 | def __init__(self, b, t=None, s=None): 180 | """ 181 | Constructor 182 | :param b: mask for features, i.e. tensor of same size as latent data point filled with 0s and 1s 183 | :param t: translation mapping, i.e. neural network, where first input dimension is batch dim, 184 | if None no translation is applied 185 | :param s: scale mapping, i.e. neural network, where first input dimension is batch dim, 186 | if None no scale is applied 187 | """ 188 | super().__init__() 189 | self.b_cpu = b.view(1, *b.size()) 190 | self.register_buffer('b', self.b_cpu) 191 | 192 | if s is None: 193 | self.s = lambda x: torch.zeros_like(x) 194 | else: 195 | self.add_module('s', s) 196 | 197 | if t is None: 198 | self.t = lambda x: torch.zeros_like(x) 199 | else: 200 | self.add_module('t', t) 201 | 202 | def forward(self, z): 203 | z_masked = self.b * z 204 | scale = self.s(z_masked) 205 | nan = torch.tensor(np.nan, dtype=z.dtype, device=z.device) 206 | scale = torch.where(torch.isfinite(scale), scale, nan) 207 | trans = self.t(z_masked) 208 | trans = torch.where(torch.isfinite(trans), trans, nan) 209 | z_ = z_masked + (1 - self.b) * (z * torch.exp(scale) + trans) 210 | log_det = torch.sum((1 - self.b) * scale, dim=list(range(1, self.b.dim()))) 211 | return z_, log_det 212 | 213 | def inverse(self, z): 214 | z_masked = self.b * z 215 | scale = self.s(z_masked) 216 | nan = torch.tensor(np.nan, dtype=z.dtype, device=z.device) 217 | scale = torch.where(torch.isfinite(scale), scale, nan) 218 | trans = self.t(z_masked) 219 | trans = torch.where(torch.isfinite(trans), trans, nan) 220 | z_ = z_masked + (1 - self.b) * (z - trans) * torch.exp(-scale) 221 | log_det = -torch.sum((1 - self.b) * scale, dim=list(range(1, self.b.dim()))) 222 | return z_, log_det 223 | 224 | 225 | class AffineCouplingBlock(Flow): 226 | """ 227 | Affine Coupling layer including split and merge operation 228 | """ 229 | def __init__(self, param_map, scale=True, scale_map='exp', split_mode='channel'): 230 | """ 231 | Constructor 232 | :param param_map: Maps features to shift and scale parameter (if applicable) 233 | :param scale: Flag whether scale shall be applied 234 | :param scale_map: Map to be applied to the scale parameter, can be 'exp' as in 235 | RealNVP or 'sigmoid' as in Glow 236 | :param split_mode: Splitting mode, for possible values see Split class 237 | """ 238 | super().__init__() 239 | self.flows = nn.ModuleList([]) 240 | # Split layer 241 | self.flows += [Split(split_mode)] 242 | # Affine coupling layer 243 | self.flows += [AffineCoupling(param_map, scale, scale_map)] 244 | # Merge layer 245 | self.flows += [Merge(split_mode)] 246 | 247 | def forward(self, z): 248 | log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device) 249 | for flow in self.flows: 250 | z, log_det = flow(z) 251 | log_det_tot += log_det 252 | return z, log_det_tot 253 | 254 | def inverse(self, z): 255 | log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device) 256 | for i in range(len(self.flows) - 1, -1, -1): 257 | z, log_det = self.flows[i].inverse(z) 258 | log_det_tot += log_det 259 | return z, log_det_tot -------------------------------------------------------------------------------- /train_generative.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from timeit import default_timer 5 | from torch.optim import Adam 6 | import os 7 | import imageio 8 | import matplotlib.pyplot as plt 9 | from autoencoder_model import autoencoder, encoder, decoder 10 | from flow_model import real_nvp 11 | from utils import * 12 | from datasets import * 13 | from laplacian_loss import LaplacianPyramidLoss 14 | import config_generative as config 15 | 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | 19 | epochs_flow = config.epochs_flow 20 | epochs_aeder = config.epochs_aeder 21 | flow_depth = config.flow_depth 22 | latent_dim = config.latent_dim 23 | batch_size = config.batch_size 24 | dataset = config.dataset 25 | gpu_num = config.gpu_num 26 | exp_desc = config.exp_desc 27 | image_size = config.image_size 28 | c = config.c 29 | train_aeder = config.train_aeder 30 | train_flow = config.train_flow 31 | restore_flow = config.restore_flow 32 | restore_aeder = config.restore_aeder 33 | 34 | enable_cuda = True 35 | device = torch.device('cuda:' + str(gpu_num) if torch.cuda.is_available() and enable_cuda else 'cpu') 36 | 37 | 38 | all_experiments = 'experiments/' 39 | if os.path.exists(all_experiments) == False: 40 | os.mkdir(all_experiments) 41 | 42 | # experiment path 43 | exp_path = all_experiments + 'generator_' + dataset + '_' \ 44 | + str(flow_depth) + '_' + str(latent_dim) + '_' + str(image_size) + '_' + exp_desc 45 | 46 | if os.path.exists(exp_path) == False: 47 | os.mkdir(exp_path) 48 | 49 | 50 | learning_rate = 1e-4 51 | step_size = 50 52 | gamma = 0.5 53 | lam = 0.01 54 | 55 | # Print the experiment setup: 56 | print('Experiment setup:') 57 | print('---> epochs_aeder: {}'.format(epochs_aeder)) 58 | print('---> epochs_flow: {}'.format(epochs_flow)) 59 | print('---> flow_depth: {}'.format(flow_depth)) 60 | print('---> batch_size: {}'.format(batch_size)) 61 | print('---> dataset: {}'.format(dataset)) 62 | print('---> Learning rate: {}'.format(learning_rate)) 63 | print('---> experiment path: {}'.format(exp_path)) 64 | print('---> latent dim: {}'.format(latent_dim)) 65 | print('---> image size: {}'.format(image_size)) 66 | 67 | 68 | # Dataset: 69 | train_dataset = Dataset_loader(dataset = 'train' ,size = (image_size,image_size), c = c, quantize = False) 70 | test_dataset = Dataset_loader(dataset = 'test' ,size = (image_size,image_size), c = c, quantize = False) 71 | 72 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=40, shuffle = True) 73 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=25, num_workers=8) 74 | 75 | ntrain = len(train_loader.dataset) 76 | n_test = len(test_loader.dataset) 77 | print('---> Number of training, test samples: {}, {}'.format(ntrain,n_test)) 78 | plot_per_num_epoch = 1 if ntrain > 20000 else 20000//ntrain 79 | 80 | # Loss 81 | dum_samples = next(iter(test_loader)).to(device) 82 | mse_l = F.mse_loss 83 | pyramid_l = LaplacianPyramidLoss(max_levels=3, channels=c, kernel_size=5, 84 | sigma=1, device=device, dtype=dum_samples.dtype) 85 | vgg =Vgg16().to(device) 86 | for param in vgg.parameters(): 87 | param.requires_grad = False 88 | 89 | # 1. Training Autoencoder: 90 | enc = encoder(latent_dim = latent_dim, in_res = image_size , c = c).to(device) 91 | dec = decoder(latent_dim = latent_dim, in_res = image_size , c = c).to(device) 92 | aeder = autoencoder(encoder = enc , decoder = dec).to(device) 93 | 94 | num_param_aeder= count_parameters(aeder) 95 | print('---> Number of trainable parameters of Autoencoder: {}'.format(num_param_aeder)) 96 | 97 | optimizer_aeder = Adam(aeder.parameters(), lr=learning_rate) 98 | scheduler_aeder = torch.optim.lr_scheduler.StepLR(optimizer_aeder, step_size=step_size, gamma=gamma) 99 | 100 | checkpoint_autoencoder_path = os.path.join(exp_path, 'autoencoder.pt') 101 | if os.path.exists(checkpoint_autoencoder_path) and restore_aeder == True: 102 | checkpoint_autoencoder = torch.load(checkpoint_autoencoder_path) 103 | aeder.load_state_dict(checkpoint_autoencoder['model_state_dict']) 104 | optimizer_aeder.load_state_dict(checkpoint_autoencoder['optimizer_state_dict']) 105 | print('Autoencoder is restored...') 106 | 107 | 108 | if train_aeder: 109 | 110 | if plot_per_num_epoch == -1: 111 | plot_per_num_epoch = epochs_aeder + 1 # only plot in the last epoch 112 | 113 | loss_ae_plot = np.zeros([epochs_aeder]) 114 | for ep in range(epochs_aeder): 115 | aeder.train() 116 | t1 = default_timer() 117 | loss_ae_epoch = 0 118 | 119 | # Training 100 rpochs over style and then over combined loss of style and mse 120 | loss_type = 'style' if ep < 100 else 'style_mse' 121 | for image in train_loader: 122 | 123 | batch_size = image.shape[0] 124 | image = image.to(device) 125 | 126 | optimizer_aeder.zero_grad() 127 | image_mat = image.reshape(-1, image_size, image_size, c).permute(0,3,1,2) 128 | 129 | embed = aeder.encoder(image_mat) 130 | image_recon = aeder.decoder(embed) 131 | 132 | recon_loss = aeder_loss(image_mat, image_recon, loss_type = loss_type, 133 | pyramid_l = pyramid_l, mse_l = mse_l, vgg = vgg) 134 | regularization = mse_l(embed, torch.zeros(embed.shape).to(device)) 135 | ae_loss = recon_loss + lam * regularization 136 | 137 | ae_loss.backward() 138 | optimizer_aeder.step() 139 | loss_ae_epoch += ae_loss.item() 140 | 141 | scheduler_aeder.step() 142 | t2 = default_timer() 143 | loss_ae_epoch/= ntrain 144 | loss_ae_plot[ep] = loss_ae_epoch 145 | 146 | plt.plot(np.arange(epochs_aeder)[:ep], loss_ae_plot[:ep], 'o-', linewidth=2) 147 | plt.title('AE_loss') 148 | plt.xlabel('epoch') 149 | plt.ylabel('MSE loss') 150 | 151 | plt.savefig(os.path.join(exp_path, 'Autoencoder_loss.jpg')) 152 | np.save(os.path.join(exp_path, 'Autoencoder_loss.npy'), loss_ae_plot[:ep]) 153 | plt.close() 154 | 155 | torch.save({ 156 | 'model_state_dict': aeder.state_dict(), 157 | 'optimizer_state_dict': optimizer_aeder.state_dict() 158 | }, checkpoint_autoencoder_path) 159 | 160 | 161 | samples_folder = os.path.join(exp_path, 'Results') 162 | if not os.path.exists(samples_folder): 163 | os.mkdir(samples_folder) 164 | image_path_reconstructions = os.path.join( 165 | samples_folder, 'Reconstructions_aeder') 166 | 167 | if not os.path.exists(image_path_reconstructions): 168 | os.mkdir(image_path_reconstructions) 169 | 170 | if (ep + 1) % plot_per_num_epoch == 0 or ep + 1 == epochs_aeder: 171 | sample_number = 25 172 | ngrid = int(np.sqrt(sample_number)) 173 | 174 | test_images = next(iter(test_loader)).to(device)[:sample_number] 175 | test_images = test_images.reshape(-1, image_size, image_size, c).permute(0,3,1,2) 176 | 177 | image_np = test_images.permute(0,2,3,1).detach().cpu().numpy() 178 | image_write = image_np[:sample_number].reshape( 179 | ngrid, ngrid, 180 | image_size, image_size,c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0 181 | image_write = image_write.clip(0, 255).astype(np.uint8) 182 | imageio.imwrite(os.path.join(image_path_reconstructions, '%d_gt.png' % (ep,)),image_write) 183 | 184 | embed = aeder.encoder(test_images) 185 | image_recon = aeder.decoder(embed) 186 | image_recon_np = image_recon.detach().cpu().numpy().transpose(0,2,3,1) 187 | image_recon_write = image_recon_np[:sample_number].reshape( 188 | ngrid, ngrid, 189 | image_size, image_size, c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0 190 | 191 | image_recon_write = image_recon_write.clip(0, 255).astype(np.uint8) 192 | imageio.imwrite(os.path.join(image_path_reconstructions, '%d_aeder_recon.png' % (ep,)), 193 | image_recon_write) 194 | 195 | snr_aeder = SNR(image_np , image_recon_np) 196 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file: 197 | file.write('ep: %03d/%03d | time: %.0f | aeder_loss %.4f | SNR_aeder %.4f' %(ep, epochs_aeder,t2-t1, 198 | loss_ae_epoch, snr_aeder)) 199 | file.write('\n') 200 | print('ep: %03d/%03d | time: %.0f | aeder_loss %.4f | SNR_aeder %.4f' %(ep, epochs_aeder,t2-t1, 201 | loss_ae_epoch, snr_aeder)) 202 | 203 | 204 | # Training the flow model 205 | nfm = real_nvp(latent_dim = latent_dim, K = flow_depth) 206 | nfm = nfm.to(device) 207 | num_param_nfm = count_parameters(nfm) 208 | print('Number of trainable parametrs of flow: {}'.format(num_param_nfm)) 209 | 210 | loss_hist = np.array([]) 211 | optimizer_flow = torch.optim.Adam(nfm.parameters(), lr=1e-4, weight_decay=1e-5) 212 | scheduler_flow = torch.optim.lr_scheduler.StepLR(optimizer_flow, step_size=step_size, gamma=gamma) 213 | 214 | # Initialize ActNorm 215 | batch_img = next(iter(train_loader)).to(device) 216 | batch_img = batch_img.reshape(-1, image_size, image_size, c).permute(0,3,1,2) 217 | dummy_samples = aeder.encoder(batch_img) 218 | # dummy_samples = model.reference_latents(torch.tensor(0).to(device)) 219 | dummy_samples = dummy_samples.view(-1, latent_dim) 220 | # dummy_samples = torch.tensor(dummy_samples).float().to(device) 221 | likelihood = nfm.log_prob(dummy_samples) 222 | 223 | checkpoint_flow_path = os.path.join(exp_path, 'flow.pt') 224 | if os.path.exists(checkpoint_flow_path) and restore_flow == True: 225 | checkpoint_flow = torch.load(checkpoint_flow_path) 226 | nfm.load_state_dict(checkpoint_flow['model_state_dict']) 227 | optimizer_flow.load_state_dict(checkpoint_flow['optimizer_state_dict']) 228 | print('Flow model is restored...') 229 | 230 | if train_flow: 231 | 232 | for ep in range(epochs_flow): 233 | 234 | nfm.train() 235 | t1 = default_timer() 236 | loss_flow_epoch = 0 237 | for image in train_loader: 238 | optimizer_flow.zero_grad() 239 | image = image.to(device) 240 | image = image.reshape(-1, image_size, image_size, c).permute(0,3,1,2) 241 | 242 | x = aeder.encoder(image) 243 | # Compute loss 244 | loss_flow = nfm.forward_kld(x) 245 | 246 | if ~(torch.isnan(loss_flow) | torch.isinf(loss_flow)): 247 | loss_flow.backward() 248 | optimizer_flow.step() 249 | 250 | # Make layers Lipschitz continuous 251 | # nf.utils.update_lipschitz(nfm, 5) 252 | loss_flow_epoch += loss_flow.item() 253 | # Log loss 254 | loss_hist = np.append(loss_hist, loss_flow.to('cpu').data.numpy()) 255 | 256 | scheduler_flow.step() 257 | t2 = default_timer() 258 | loss_flow_epoch /= ntrain 259 | 260 | torch.save({ 261 | 'model_state_dict': nfm.state_dict(), 262 | 'optimizer_state_dict': optimizer_flow.state_dict() 263 | }, checkpoint_flow_path) 264 | 265 | 266 | if (ep + 1) % plot_per_num_epoch == 0 or ep + 1 == epochs_flow: 267 | samples_folder = os.path.join(exp_path, 'Results') 268 | if not os.path.exists(samples_folder): 269 | os.mkdir(samples_folder) 270 | image_path_generated = os.path.join( 271 | samples_folder, 'generated') 272 | 273 | if not os.path.exists(image_path_generated): 274 | os.mkdir(image_path_generated) 275 | sample_number = 25 276 | ngrid = int(np.sqrt(sample_number)) 277 | 278 | generated_embed, _ = nfm.sample(torch.tensor(sample_number).to(device)) 279 | 280 | generated_samples = aeder.decoder(generated_embed) 281 | generated_samples = generated_samples.detach().cpu().numpy().transpose(0,2,3,1) 282 | 283 | generated_samples = generated_samples[:sample_number].reshape( 284 | ngrid, ngrid, 285 | image_size, image_size, c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0 286 | generated_samples = generated_samples.clip(0, 255).astype(np.uint8) 287 | 288 | 289 | imageio.imwrite(os.path.join(image_path_generated, 'epoch %d.png' % (ep,)), generated_samples) # training images 290 | 291 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file: 292 | file.write('ep: %03d/%03d | time: %.0f | ML_loss %.4f' %(ep, epochs_flow, t2-t1, loss_flow_epoch)) 293 | file.write('\n') 294 | 295 | print('ep: %03d/%03d | time: %.0f | ML_loss %.4f' %(ep, epochs_flow, t2-t1, loss_flow_epoch)) 296 | 297 | -------------------------------------------------------------------------------- /results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import os 5 | import imageio 6 | from utils import * 7 | import config_funknn as config 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def evaluator(ep, subset, data_loader, model, exp_path): 12 | 13 | samples_folder = os.path.join(exp_path, 'Results') 14 | if not os.path.exists(samples_folder): 15 | os.mkdir(samples_folder) 16 | image_path_reconstructions = os.path.join( 17 | samples_folder, 'Reconstructions') 18 | 19 | if not os.path.exists(image_path_reconstructions): 20 | os.mkdir(image_path_reconstructions) 21 | 22 | max_scale = config.max_scale 23 | recursive = config.recursive 24 | sample_number = config.sample_number 25 | image_size = config.image_size 26 | c = config.c 27 | 28 | if subset == 'ood': 29 | max_scale = 2 30 | recursive = False 31 | 32 | device = model.ws1.device 33 | num_samples_write = sample_number if sample_number < 26 else 25 34 | ngrid = int(np.sqrt(num_samples_write)) 35 | num_samples_write = int(ngrid **2) 36 | 37 | images_k = next(iter(data_loader)).to(device)[:sample_number] 38 | images_k = images_k.reshape(-1, max_scale*image_size, max_scale*image_size, c).permute(0,3,1,2) 39 | images = F.interpolate(images_k, size = image_size, antialias = True, mode = 'bilinear') 40 | 41 | scales = [i+1 for i in range(int(np.log2(max_scale)))] 42 | scales = np.power(2, scales) 43 | 44 | print('Evaluation over {} set:'.format(subset)) 45 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file: 46 | file.write('Evaluation over {} set:'.format(subset)) 47 | file.write('\n') 48 | 49 | if recursive == True: 50 | 51 | images_down = images 52 | 53 | for i in range(len(scales)): 54 | # Recuirsive image generation starting from factor 2 well-suited for factor training mode 55 | res = scales[i]*image_size 56 | # GT: 57 | images_temp = F.interpolate(images_k, size = res , antialias = True, mode = 'bilinear') 58 | images_np = images_temp.permute(0, 2, 3, 1).detach().cpu().numpy() 59 | image_write = images_np[:num_samples_write].reshape( 60 | ngrid, ngrid, 61 | res, res,c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0 62 | image_write = image_write.clip(0, 255).astype(np.uint8) 63 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_gt_%d.png' % (ep,scales[i])), 64 | image_write) 65 | 66 | # Recon: 67 | coords = get_mgrid(res).reshape(-1, 2) 68 | coords = torch.unsqueeze(coords, dim = 0) 69 | coords = coords.expand(images_k.shape[0] , -1, -1).to(device) 70 | recon_np = batch_sampling(images_down, coords,c, model) 71 | recon_np = np.reshape(recon_np, [-1, res, res, c]) 72 | recon_write = recon_np[:num_samples_write].reshape( 73 | ngrid, ngrid, res, res, c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0 74 | recon_write = recon_write.clip(0, 255).astype(np.uint8) 75 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_recursive_FunkNN_%d.png' % (ep,scales[i])), 76 | recon_write) 77 | 78 | # Interpolate: 79 | interpolate = F.interpolate(images, size = res, mode = 'bilinear') 80 | interpolate_np = interpolate.detach().cpu().numpy().transpose(0,2,3,1) 81 | interpolate_write = interpolate_np[:num_samples_write].reshape( 82 | ngrid, ngrid, 83 | res, res, c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0 84 | interpolate_write = interpolate_write.clip(0, 255).astype(np.uint8) 85 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_interpolate_%d.png' % (ep,scales[i])), 86 | interpolate_write) # mesh_based_recon 87 | 88 | snr_recon = SNR(images_np, recon_np) 89 | snr_interpolate = SNR(images_np, interpolate_np) 90 | recon_np = recon_np.transpose([0,3,1,2]) 91 | images_down = torch.tensor(recon_np, dtype = images_down.dtype).to(device) 92 | 93 | print('SNR_interpolate_recursive_f{}: {:.1f} | SNR_FunkNN_recursive_f{}: {:.1f}'.format(scales[i], 94 | snr_interpolate, scales[i], snr_recon)) 95 | 96 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file: 97 | file.write('SNR_interpolate_recursive_f{}: {:.1f} | SNR_FunkNN_recursive_f{}: {:.1f} | '.format(scales[i], 98 | snr_interpolate, scales[i], snr_recon)) 99 | file.write('\n') 100 | if subset == 'ood': 101 | file.write('\n') 102 | 103 | else: 104 | for i in range(len(scales)): 105 | # Direct image generation well-suited for single and continuous training modes 106 | res = scales[i]*image_size 107 | # GT: 108 | images_temp = F.interpolate(images_k, size = res , antialias = True, mode = 'bilinear') 109 | images_np = images_temp.permute(0, 2, 3, 1).detach().cpu().numpy() 110 | image_write = images_np[:num_samples_write].reshape( 111 | ngrid, ngrid, 112 | res, res,c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0 113 | image_write = image_write.clip(0, 255).astype(np.uint8) 114 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_gt_%d.png' % (ep,scales[i])), 115 | image_write) 116 | 117 | # Recon: 118 | coords = get_mgrid(res).reshape(-1, 2) 119 | coords = torch.unsqueeze(coords, dim = 0) 120 | coords = coords.expand(images_k.shape[0] , -1, -1).to(device) 121 | recon_np = batch_sampling(images, coords,c, model) 122 | recon_np = np.reshape(recon_np, [-1, res, res, c]) 123 | recon_write = recon_np[:num_samples_write].reshape( 124 | ngrid, ngrid, res, res, c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0 125 | recon_write = recon_write.clip(0, 255).astype(np.uint8) 126 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_FunkNN_%d.png' % (ep,scales[i])), 127 | recon_write) 128 | 129 | # Interpolate: 130 | interpolate = F.interpolate(images, size = res, mode = 'bilinear') 131 | interpolate_np = interpolate.detach().cpu().numpy().transpose(0,2,3,1) 132 | interpolate_write = interpolate_np[:num_samples_write].reshape( 133 | ngrid, ngrid, 134 | res, res, c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0 135 | interpolate_write = interpolate_write.clip(0, 255).astype(np.uint8) 136 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_interpolate_%d.png' % (ep,scales[i])), 137 | interpolate_write) # mesh_based_recon 138 | 139 | snr_recon = SNR(images_np, recon_np) 140 | snr_interpolate = SNR(images_np, interpolate_np) 141 | 142 | print('SNR_interpolate_f{}: {:.1f} | SNR_FunkNN_f{}: {:.1f}'.format(scales[i], 143 | snr_interpolate, scales[i], snr_recon)) 144 | 145 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file: 146 | file.write('SNR_interpolate_f{}: {:.1f} | SNR_FunkNN_f{}: {:.1f} | '.format(scales[i], 147 | snr_interpolate, scales[i], snr_recon)) 148 | file.write('\n') 149 | if subset == 'ood': 150 | file.write('\n') 151 | 152 | 153 | 154 | 155 | if config.derivatives_evaluation: 156 | # Gradients: 157 | coords_2k = get_mgrid(2*image_size).reshape(-1, 2) 158 | coords_2k = torch.unsqueeze(coords_2k, dim = 0) 159 | coords_2k = coords_2k.expand(images_k.shape[0] , -1, -1).to(device) 160 | coords_2k = coords_2k.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input 161 | funknn_grad = batch_grad(images, coords_2k,c, model) 162 | funknn_grad = torch.tensor(funknn_grad, dtype=coords.dtype) 163 | funknn_grad = torch.norm(funknn_grad, dim = 2).cpu().detach().numpy() 164 | funknn_grad = np.reshape(funknn_grad, [-1, 2*image_size, 2*image_size,1]) 165 | funknn_grad_write = funknn_grad[:sample_number, :, :].reshape( 166 | ngrid, ngrid, 167 | 2*image_size, 2*image_size, 1).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, 1)*255.0 168 | plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_grad_funknn.png' % (ep,)), 169 | funknn_grad_write[:,:,0], cmap='seismic') 170 | 171 | 172 | coords_2k = get_mgrid(2*image_size).reshape(-1, 2) 173 | coords_2k = torch.unsqueeze(coords_2k, dim = 0) 174 | coords_2k = coords_2k.expand(images_k.shape[0] , -1, -1).to(device) 175 | coords_2k = coords_2k.clone().detach().requires_grad_(True) 176 | funknn = torch.tensor(batch_sampling(images, coords_2k,c, model), dtype = coords_2k.dtype).to(device) 177 | funknn_mat = funknn.reshape([-1, 2*image_size, 2*image_size, c]).permute(0,3,1,2) 178 | true_grad = image_derivative(funknn_mat , c)[1] 179 | true_grad = true_grad.permute(0,2,3,1).detach().cpu().numpy() 180 | true_grad_write = true_grad[:sample_number, :, :].reshape( 181 | ngrid, ngrid, 182 | 2*image_size, 2*image_size, 1).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, 1)*255.0 183 | 184 | plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_grad_finite_diff.png' % (ep,)), 185 | true_grad_write[:,:,0], cmap='seismic') 186 | 187 | coords_2k = get_mgrid(2*image_size).reshape(-1, 2) 188 | coords_2k = torch.unsqueeze(coords_2k, dim = 0) 189 | coords_2k = coords_2k.expand(images_k.shape[0] , -1, -1).to(device) 190 | coords_2k = coords_2k.clone().detach().requires_grad_(True).to(device) 191 | coords_2k_grad = coords_2k.reshape((-1,2*image_size,2*image_size,2)) 192 | coords_2k_grad = 2 * torch.flip(coords_2k_grad , dims = [3]) 193 | img_tmp = model.grid_sample_customized(images, coords_2k_grad, mode = config.interpolation_kernel) 194 | img_tmp = img_tmp.permute(0,2,3,1).reshape(1 , -1 , c) 195 | st_grad = gradient(img_tmp, coords_2k, grad_outputs=None) 196 | st_grad = torch.norm(st_grad, dim = 2).cpu().detach().numpy() 197 | st_grad = np.reshape(st_grad, [-1, 2*image_size, 2*image_size, 1]) 198 | st_grad_write = st_grad[:sample_number, :, :].reshape( 199 | ngrid, ngrid, 200 | 2*image_size, 2*image_size, 1).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, 1)*255.0 201 | 202 | plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_grad_ST.png' % (ep,)), 203 | st_grad_write[:,:,0], cmap='seismic') 204 | 205 | 206 | 207 | ############################################################################################ 208 | # Laplacian: 209 | shift = 4 210 | coords_2k = get_mgrid(2*image_size).reshape(-1, 2) 211 | coords_2k = torch.unsqueeze(coords_2k, dim = 0) 212 | coords_2k = coords_2k.expand(images_k.shape[0] , -1, -1).to(device) 213 | coords_2k = coords_2k.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input 214 | funknn_laplace= batch_laplace(images, coords_2k,c, model) 215 | funknn_laplace = np.reshape(funknn_laplace, [-1, 2*image_size, 2*image_size,1]) 216 | 217 | funknn_laplace_write = funknn_laplace[:images_k.shape[0], shift:2*image_size-shift, shift:2*image_size-shift].reshape( 218 | ngrid, ngrid, 219 | 2*image_size -2*shift, 2*image_size-2*shift, 1).swapaxes(1, 2).reshape(ngrid*(2*image_size-2*shift), -1, 1) 220 | 221 | plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_laplace_funknn.png' % (ep,)), 222 | funknn_laplace_write[:,:,0], cmap='seismic') 223 | 224 | coords_2k = get_mgrid(2*image_size).reshape(-1, 2) 225 | coords_2k = torch.unsqueeze(coords_2k, dim = 0) 226 | coords_2k = coords_2k.expand(images_k.shape[0] , -1, -1).to(device) 227 | coords_2k = coords_2k.clone().detach().requires_grad_(True).to(device) 228 | coords_2k_grad = coords_2k.reshape((-1,2*image_size,2*image_size,2)) 229 | coords_2k_grad = 2 * torch.flip(coords_2k_grad , dims = [3]) 230 | img_tmp = model.grid_sample_customized(images, coords_2k_grad, mode = config.interpolation_kernel) 231 | img_tmp = img_tmp.permute(0,2,3,1).reshape(1 , -1 , c) 232 | st_laplace = laplace(img_tmp, coords_2k).detach() # Ground truth derivatives 233 | st_laplace = np.reshape(st_laplace.cpu().numpy(), [-1, 2*image_size, 2*image_size, 1]) 234 | st_laplace_write = st_laplace[:images_k.shape[0], shift:2*image_size-shift, shift:2*image_size-shift].reshape( 235 | ngrid, ngrid, 236 | 2*image_size -2*shift, 2*image_size-2*shift, 1).swapaxes(1, 2).reshape(ngrid*(2*image_size-2*shift), -1, 1) 237 | 238 | plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_laplace_ST.png' % (ep,)), 239 | st_laplace_write[:,:,0], cmap='seismic') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torchvision.models import vgg16 5 | from skimage.transform import radon, iradon 6 | from scipy import optimize 7 | from skimage.metrics import peak_signal_noise_ratio as psnr 8 | 9 | 10 | def image_derivative(x , c): 11 | '''x must be a (b,c*h*w) tensor''' 12 | 13 | horiz_derive = np.array([[1, 0, -1],[2, 0, -2],[1,0,-1]], dtype = np.float64) 14 | horiz_derive = horiz_derive[None,None,:] 15 | horiz_derive = np.repeat(horiz_derive,c,axis = 1) 16 | 17 | vert_derive = np.array([[1,2,1],[0,0,0], [-1,-2,-1]]) 18 | vert_derive = vert_derive[None,None,:] 19 | vert_derive = np.repeat(vert_derive,c,axis = 1) 20 | 21 | conv_horiz = torch.nn.Conv2d(1, c, kernel_size=3, stride=1, padding='same', padding_mode = 'replicate',bias=False) 22 | conv_horiz.weight.data= torch.from_numpy(horiz_derive).float().to(x.device) 23 | 24 | conv_vert = torch.nn.Conv2d(1, c, kernel_size=3, stride=1, padding='same', padding_mode = 'replicate', bias=False) 25 | conv_vert.weight.data= torch.from_numpy(vert_derive).float().to(x.device) 26 | 27 | G_x = conv_horiz(x) 28 | G_y = conv_vert(x) 29 | G = torch.cat((G_x , G_y) , axis = 1) 30 | G_mag = torch.sqrt(torch.pow(G_x,2)+ torch.pow(G_y,2)) 31 | 32 | return G, G_mag 33 | 34 | 35 | def PSNR(x_true , x_pred): 36 | 37 | s = 0 38 | for i in range(np.shape(x_pred)[0]): 39 | s += psnr(x_pred[i], 40 | x_true[i], 41 | data_range=x_true[i].max() - x_true[i].min()) 42 | 43 | return s/np.shape(x_pred)[0] 44 | 45 | def SNR(x_true , x_pred): 46 | '''Calculate SNR of a batch of true and their estimations''' 47 | 48 | # x_true = np.reshape(x_true , [np.shape(x_true)[0] , -1]) 49 | # x_pred = np.reshape(x_pred , [np.shape(x_pred)[0] , -1]) 50 | 51 | snr = 0 52 | for i in range(x_true.shape[0]): 53 | Noise = x_true[i] - x_pred[i] 54 | Noise_power = np.sum(np.square(np.abs(Noise))) 55 | Signal_power = np.sum(np.square(np.abs(x_true[i]))) 56 | snr += 10*np.log10(Signal_power/Noise_power) 57 | 58 | return snr/x_true.shape[0] 59 | 60 | 61 | 62 | def SNR_rescale(x_true , x_pred): 63 | '''Calculate SNR rescale of a batch of true and their estimations''' 64 | snr = 0 65 | for i in range(x_true.shape[0]): 66 | 67 | def func(weights): 68 | Noise = x_true[i] - (weights[0]*x_pred[i]+weights[1]) 69 | Noise_power = np.sum(np.square(np.abs(Noise))) 70 | Signal_power = np.sum(np.square(np.abs(x_true[i]))) 71 | SNR = 10*np.log10(np.mean(Signal_power/(Noise_power+1e-12))) 72 | return SNR 73 | opt = optimize.minimize(lambda x: -func(x),x0=np.array([1,0])) 74 | snr += -opt.fun 75 | weights = opt.x 76 | return snr/x_true.shape[0] 77 | 78 | 79 | def PSNR_rescale(x_true , x_pred): 80 | '''Calculate SNR rescale of a batch of true and their estimations''' 81 | snr = 0 82 | for i in range(x_true.shape[0]): 83 | 84 | def func(weights): 85 | x_pred_rescale= weights[0]*x_pred[i]+weights[1] 86 | s = psnr(x_pred_rescale, 87 | x_true[i], 88 | data_range=x_true[i].max() - x_true[i].min()) 89 | 90 | return s 91 | opt = optimize.minimize(lambda x: -func(x),x0=np.array([1,0])) 92 | snr += -opt.fun 93 | weights = opt.x 94 | return snr/x_true.shape[0], weights 95 | 96 | 97 | 98 | def count_parameters(model): 99 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 100 | 101 | 102 | def relative_mse_loss(x_true, x_pred): 103 | noise_power = np.sqrt(np.sum(np.square(x_true - x_pred) , axis = [1,2,3])) 104 | signal_power = np.sqrt(np.sum(np.square(x_true) , axis = [1,2,3])) 105 | 106 | return np.mean(noise_power/signal_power) 107 | 108 | 109 | def batch_sampling(image_recon, coords, c, model): 110 | s = 64 111 | 112 | outs = np.zeros([np.shape(coords)[0], np.shape(coords)[1], c]) 113 | for i in range(np.shape(coords)[1]//s): 114 | 115 | batch_coords = coords[:,i*s: (i+1)*s] 116 | out = model(batch_coords, image_recon).detach().cpu().numpy() 117 | outs[:,i*s: (i+1)*s] = out 118 | 119 | return outs 120 | 121 | 122 | def batch_grad(image_recon, coords, c, model): 123 | s = 64 124 | 125 | out_grads = np.zeros([np.shape(coords)[0], np.shape(coords)[1], 2]) 126 | for i in range(np.shape(coords)[1]//s): 127 | 128 | batch_coords = coords[:,i*s: (i+1)*s] 129 | out = model(batch_coords, image_recon) 130 | out_grad = gradient(out, batch_coords).detach().cpu().numpy() 131 | out_grads[:,i*s: (i+1)*s] = out_grad 132 | 133 | return out_grads 134 | 135 | 136 | def batch_laplace(image_recon, coords, c, model): 137 | s = 64 138 | 139 | out_laplaces = np.zeros([np.shape(coords)[0], np.shape(coords)[1],1]) 140 | for i in range(np.shape(coords)[1]//s): 141 | 142 | batch_coords = coords[:,i*s: (i+1)*s] 143 | out = model(batch_coords, image_recon) 144 | out_laplace = laplace(out, batch_coords).detach().cpu().numpy() 145 | out_laplaces[:,i*s: (i+1)*s] = out_laplace 146 | 147 | return out_laplaces 148 | 149 | 150 | def batch_grad_pde(image_recon, coords, c, model): 151 | s = 64 152 | 153 | out_grads = np.zeros([np.shape(coords)[0], np.shape(coords)[1], 2]) 154 | for i in range(np.shape(coords)[1]//s): 155 | 156 | batch_coords = coords[:,i*s: (i+1)*s] 157 | out = model(batch_coords, image_recon) 158 | out_grad = gradient(out, batch_coords).detach().cpu().numpy() 159 | out_grads[:,i*s: (i+1)*s] = out_grad 160 | 161 | return out_grads 162 | 163 | 164 | 165 | 166 | 167 | def simpleaxis(ax): 168 | ax.spines['top'].set_visible(False) 169 | ax.spines['right'].set_visible(False) 170 | ax.get_xaxis().tick_bottom() 171 | ax.get_yaxis().tick_left() 172 | 173 | 174 | def get_mgrid(sidelen): 175 | # Generate 2D pixel coordinates from an image of sidelen x sidelen 176 | pixel_coords = np.stack(np.mgrid[:sidelen,:sidelen], axis=-1)[None,...].astype(np.float32) 177 | pixel_coords /= sidelen 178 | pixel_coords -= 0.5 179 | pixel_coords = torch.Tensor(pixel_coords).view(-1, 2) 180 | return pixel_coords 181 | 182 | def get_mgrid_unbalanced(sidelen1,sidelen2): 183 | # Generate 2D pixel coordinates from an image of sidelen x sidelen 184 | pixel_coords = np.stack(np.mgrid[:sidelen1,:sidelen2], axis=-1)[None,...].astype(np.float32) 185 | pixel_coords = torch.Tensor(pixel_coords).view(-1, 2) 186 | pixel_coords = pixel_coords/(pixel_coords.max(dim = 0)[0]+1) 187 | pixel_coords -= 0.5 188 | return pixel_coords 189 | 190 | 191 | def lin2img(tensor): 192 | batch_size, num_samples, channels = tensor.shape 193 | sidelen = np.sqrt(num_samples).astype(int) 194 | return tensor.view(batch_size, channels, sidelen, sidelen) 195 | 196 | 197 | def plot_sample_image(img_batch, ax): 198 | # plot the first item in batch 199 | img = lin2img(img_batch)[0].detach().cpu().numpy() 200 | img += 1 201 | img /= 2. 202 | img = np.clip(img, 0., 1.) 203 | ax.set_axis_off() 204 | ax.imshow(img) 205 | 206 | 207 | def gradient(y, x, grad_outputs=None): 208 | if grad_outputs is None: 209 | grad_outputs = torch.ones_like(y) 210 | grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0] 211 | return grad 212 | 213 | 214 | def divergence(y, x): 215 | div = 0. 216 | for i in range(y.shape[-1]): 217 | div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1] 218 | return div 219 | 220 | 221 | def multiple_gradient(y , x): 222 | out = torch.zeros([y.shape[0] , y.shape[1], 6]) 223 | for i in range(y.shape[-1]): 224 | a = torch.autograd.grad(y[...,i], x, torch.ones_like(y[...,i]), create_graph=True)[0] 225 | out[...,i:i +2] = a 226 | return out 227 | 228 | def laplace(y, x): 229 | grad = gradient(y, x) 230 | return divergence(grad, x) 231 | 232 | 233 | def aeder_loss(x_true, x_hat, loss_type = 'mse', pyramid_l = None, mse_l = None, vgg = None): 234 | batch_size = x_true.shape[0] 235 | if loss_type == 'mse': 236 | loss = mse_l(x_hat.reshape(batch_size, -1) , x_true.reshape(batch_size, -1)) 237 | 238 | elif loss_type == 'pyramid': 239 | loss = pyramid_l(x_hat , x_true) 240 | 241 | elif loss_type == 'style': 242 | 243 | reg_weight = 1e-6 244 | style_weight = 1 245 | feature_weight = 10 246 | pure_feat = vgg(x_true) 247 | recon_feat = vgg(x_hat) 248 | 249 | loss_style = 0 250 | loss_feature = 0 251 | 252 | for k in range(len(pure_feat)): 253 | 254 | bs, ch, h, w = pure_feat[k].size() 255 | 256 | pure_re_feat= pure_feat[k].view(bs, ch, h*w) 257 | gram_pure = torch.matmul(pure_re_feat, torch.transpose(pure_re_feat,1,2))/(ch*h*w) 258 | 259 | recon_re_feat= recon_feat[k].view(bs, ch, h*w) 260 | gram_recon = torch.matmul(recon_re_feat, torch.transpose(recon_re_feat,1,2))/(ch*h*w) 261 | 262 | loss_style = loss_style+ mse_l(gram_pure.view(batch_size,-1),gram_recon.view(batch_size,-1)) 263 | 264 | loss_feature = loss_feature + mse_l(pure_feat[k].reshape(batch_size,-1),recon_feat[k].reshape(batch_size,-1))#/(pure_feat[k].size(1)*pure_feat[k].size(2)*pure_feat[k].size(3)) 265 | 266 | 267 | loss_style = style_weight * loss_style 268 | loss_feature = feature_weight * loss_feature 269 | 270 | loss_tv = reg_weight * ( 271 | torch.sum(torch.abs(x_hat[:, :, :, :-1] - x_hat[:, :, :, 1:])) + 272 | torch.sum(torch.abs(x_hat[:, :, :-1, :] - x_hat[:, :, 1:, :]))) 273 | 274 | loss = loss_feature + loss_style + loss_tv 275 | 276 | elif loss_type == 'style_mse': 277 | 278 | reg_weight = 1e-6 279 | style_weight = 1 280 | feature_weight = 10 281 | mse_weight = 5000 282 | pure_feat = vgg(x_true) 283 | recon_feat = vgg(x_hat) 284 | 285 | loss_style = 0 286 | loss_feature = 0 287 | 288 | for k in range(len(pure_feat)): 289 | 290 | bs, ch, h, w = pure_feat[k].size() 291 | 292 | pure_re_feat= pure_feat[k].view(bs, ch, h*w) 293 | gram_pure = torch.matmul(pure_re_feat, torch.transpose(pure_re_feat,1,2))/(ch*h*w) 294 | 295 | recon_re_feat= recon_feat[k].view(bs, ch, h*w) 296 | gram_recon = torch.matmul(recon_re_feat, torch.transpose(recon_re_feat,1,2))/(ch*h*w) 297 | 298 | loss_style = loss_style+ mse_l(gram_pure.view(batch_size,-1),gram_recon.view(batch_size,-1)) 299 | 300 | loss_feature = loss_feature + mse_l(pure_feat[k].reshape(batch_size,-1),recon_feat[k].reshape(batch_size,-1))#/(pure_feat[k].size(1)*pure_feat[k].size(2)*pure_feat[k].size(3)) 301 | 302 | 303 | loss_style = style_weight * loss_style 304 | loss_feature = feature_weight * loss_feature 305 | 306 | loss_tv = reg_weight * ( 307 | torch.sum(torch.abs(x_hat[:, :, :, :-1] - x_hat[:, :, :, 1:])) + 308 | torch.sum(torch.abs(x_hat[:, :, :-1, :] - x_hat[:, :, 1:, :]))) 309 | 310 | loss_mse = mse_weight * mse_l(x_hat.reshape(batch_size, -1) , x_true.reshape(batch_size, -1)) 311 | 312 | loss = loss_feature + loss_style + loss_tv + loss_mse 313 | 314 | 315 | return loss 316 | 317 | 318 | 319 | 320 | 321 | def training_strategy(x_true, image_size, factor = None , mode = 'continuous', image_recon = None): 322 | 323 | image_recon = x_true if image_recon == None else image_recon 324 | 325 | if mode == 'continuous': 326 | 327 | image_size_random = np.random.randint(low = image_size//4, high = image_size//2, size = 1)[0] 328 | x_low = F.interpolate(image_recon, size = image_size_random, antialias = True, mode = 'bilinear') 329 | x_high = x_true 330 | image_size_high = image_size 331 | 332 | 333 | elif mode == 'factor': 334 | 335 | image_size_low = np.random.randint(low = image_size//8, high = image_size//2, size = 1)[0] 336 | image_size_high = 2 * image_size_low 337 | x_high = F.interpolate(x_true, size = image_size_high, antialias = True, mode = 'bilinear') 338 | x_low = F.interpolate(image_recon, size = image_size_low, antialias = True, mode = 'bilinear') 339 | 340 | elif mode == 'single': 341 | x_high = x_true 342 | x_low = F.interpolate(image_recon, size = image_size//2, antialias = True, mode = 'bilinear') 343 | image_size_high = image_size 344 | 345 | return x_high, x_low, image_size_high 346 | 347 | 348 | class Vgg16(torch.nn.Module): 349 | def __init__(self): 350 | super(Vgg16, self).__init__() 351 | features = list(vgg16(pretrained = True).features)[:23] 352 | self.features = torch.nn.ModuleList(features).eval() 353 | 354 | def forward(self, x): 355 | results = [] 356 | for ii,model in enumerate(self.features): 357 | x = model(x) 358 | if ii in {3,8,15,22}: 359 | results.append(x) 360 | 361 | return results 362 | 363 | 364 | 365 | def fbp_batch(x): 366 | n_measure = x.shape[2] 367 | theta = np.linspace(0., 180., n_measure, endpoint=False) 368 | 369 | fbps = [] 370 | for i in range(x.shape[0]): 371 | fbps.append(iradon(x[i], theta=theta, circle = False)) 372 | 373 | fbps = np.array(fbps) 374 | return fbps -------------------------------------------------------------------------------- /funknn_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import config_funknn as config 5 | import numpy as np 6 | 7 | 8 | def squeeze(x , f): 9 | x = x.permute(0,2,3,1) 10 | b, N1, N2, nch = x.shape 11 | x = torch.reshape( 12 | torch.permute( 13 | torch.reshape(x, shape=[b, N1//f, f, N2//f, f, nch]), 14 | [0, 1, 3, 2, 4, 5]), 15 | [b, N1//f, N2//f, nch*f*f]) 16 | x = x.permute(0,3,1,2) 17 | return x 18 | 19 | 20 | def reflect_coords(ix, min_val, max_val): 21 | 22 | pos_delta = ix[ix>max_val] - max_val 23 | 24 | neg_delta = min_val - ix[ix < min_val] 25 | 26 | ix[ix>max_val] = ix[ix>max_val] - 2*pos_delta 27 | ix[ix