├── figures
└── network.png
├── normflow
├── __pycache__
│ ├── HAIS.cpython-38.pyc
│ ├── HAIS.cpython-39.pyc
│ ├── core.cpython-38.pyc
│ ├── core.cpython-39.pyc
│ ├── nets.cpython-38.pyc
│ ├── nets.cpython-39.pyc
│ ├── utils.cpython-38.pyc
│ ├── utils.cpython-39.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── transforms.cpython-38.pyc
│ └── transforms.cpython-39.pyc
├── flows
│ ├── __pycache__
│ │ ├── base.cpython-38.pyc
│ │ ├── base.cpython-39.pyc
│ │ ├── glow.cpython-38.pyc
│ │ ├── glow.cpython-39.pyc
│ │ ├── mixing.cpython-38.pyc
│ │ ├── mixing.cpython-39.pyc
│ │ ├── planar.cpython-38.pyc
│ │ ├── planar.cpython-39.pyc
│ │ ├── radial.cpython-38.pyc
│ │ ├── radial.cpython-39.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── reshape.cpython-38.pyc
│ │ ├── reshape.cpython-39.pyc
│ │ ├── residual.cpython-38.pyc
│ │ ├── residual.cpython-39.pyc
│ │ ├── stochastic.cpython-38.pyc
│ │ ├── stochastic.cpython-39.pyc
│ │ ├── neural_spline.cpython-38.pyc
│ │ ├── neural_spline.cpython-39.pyc
│ │ ├── normalization.cpython-38.pyc
│ │ ├── normalization.cpython-39.pyc
│ │ ├── affine_coupling.cpython-38.pyc
│ │ └── affine_coupling.cpython-39.pyc
│ ├── base.py
│ ├── __init__.py
│ ├── radial.py
│ ├── residual.py
│ ├── normalization.py
│ ├── planar.py
│ ├── glow.py
│ ├── stochastic.py
│ ├── reshape.py
│ ├── neural_spline.py
│ ├── mixing.py
│ └── affine_coupling.py
├── distributions
│ ├── __pycache__
│ │ ├── base.cpython-38.pyc
│ │ ├── base.cpython-39.pyc
│ │ ├── prior.cpython-38.pyc
│ │ ├── prior.cpython-39.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── decoder.cpython-38.pyc
│ │ ├── decoder.cpython-39.pyc
│ │ ├── encoder.cpython-38.pyc
│ │ ├── encoder.cpython-39.pyc
│ │ ├── target.cpython-38.pyc
│ │ ├── target.cpython-39.pyc
│ │ ├── mh_proposal.cpython-38.pyc
│ │ ├── mh_proposal.cpython-39.pyc
│ │ ├── linear_interpolation.cpython-38.pyc
│ │ └── linear_interpolation.cpython-39.pyc
│ ├── __init__.py
│ ├── linear_interpolation.py
│ ├── mh_proposal.py
│ ├── decoder.py
│ ├── target.py
│ ├── encoder.py
│ └── prior.py
├── __init__.py
├── transforms.py
├── HAIS.py
├── utils.py
└── nets.py
├── ops
├── __init__.py
├── operator.py
├── radon_3d_lib.py
├── traveltime_lib.py
└── odl_lib.py
├── config_IP_solver.py
├── config_generative.py
├── flow_model.py
├── LICENSE
├── config_funknn.py
├── datasets.py
├── README.md
├── laplacian_loss.py
├── autoencoder_model.py
├── train_funknn.py
├── environment.yml
├── train_generative.py
├── results.py
├── utils.py
└── funknn_model.py
/figures/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/figures/network.png
--------------------------------------------------------------------------------
/normflow/__pycache__/HAIS.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/HAIS.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/HAIS.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/HAIS.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/core.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/core.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/core.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/core.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/nets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/nets.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/nets.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/nets.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/__pycache__/transforms.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/__pycache__/transforms.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/base.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/base.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/base.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/base.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/glow.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/glow.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/glow.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/glow.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/mixing.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/mixing.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/mixing.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/mixing.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/planar.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/planar.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/planar.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/planar.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/radial.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/radial.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/radial.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/radial.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/reshape.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/reshape.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/reshape.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/reshape.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/residual.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/residual.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/residual.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/residual.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/stochastic.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/stochastic.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/stochastic.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/stochastic.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/base.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/base.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/base.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/base.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/prior.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/prior.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/prior.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/prior.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/neural_spline.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/neural_spline.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/neural_spline.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/neural_spline.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/normalization.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/normalization.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/normalization.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/normalization.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/decoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/decoder.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/decoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/decoder.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/encoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/encoder.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/encoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/encoder.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/target.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/target.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/target.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/target.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/affine_coupling.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/affine_coupling.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/flows/__pycache__/affine_coupling.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/flows/__pycache__/affine_coupling.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/mh_proposal.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/mh_proposal.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/mh_proposal.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/mh_proposal.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/linear_interpolation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/linear_interpolation.cpython-38.pyc
--------------------------------------------------------------------------------
/normflow/distributions/__pycache__/linear_interpolation.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/swing-research/FunkNN/HEAD/normflow/distributions/__pycache__/linear_interpolation.cpython-39.pyc
--------------------------------------------------------------------------------
/normflow/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from .core import *
5 | from . import flows
6 | from . import distributions
7 | from . import transforms
8 | from . import nets
9 | from . import utils
10 | from . import HAIS
11 |
12 | __version__ = '1.1'
13 |
--------------------------------------------------------------------------------
/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from .odl_lib import OperatorAsAutogradFunction
2 | from .odl_lib import ParallelBeamGeometryOp
3 | from .odl_lib import ParallelBeamGeometryOpBroken
4 | from .odl_lib import ParallelBeamGeometryOpNonUniform
5 |
6 | from .radon_3d_lib import ParallelBeamGeometry3DOp, ParallelBeamGeometry3DOpBroken
7 |
8 | from .traveltime_lib import TravelTimeOperator
9 |
10 | from .operator import get_operator_dict
11 |
12 |
--------------------------------------------------------------------------------
/config_IP_solver.py:
--------------------------------------------------------------------------------
1 | gpu_num = 0
2 | image_size = 256 # Working resolution for solving inverse problems
3 | problem = 'PDE' # inverse problem:{CT, PDE}
4 | sparse_derivatives = False # Sparse derivative option, just for PDE problem
5 | funknn_path = 'experiments/' + 'funknn_celeba-hq_128_factor_default' # Trained Funknn folder
6 | autoencoder_path = 'experiments/' + 'generator_celeba-hq_5_256_128_default' # Trained generative autoencoder folder
7 | exp_desc = 'default' # A note to indicate which version of funknn and generative autoencoder are combined
--------------------------------------------------------------------------------
/normflow/flows/base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 |
5 | # Generic flow module
6 | class Flow(nn.Module):
7 | """
8 | Generic class for flow functions
9 | """
10 | def __init__(self):
11 | super().__init__()
12 |
13 | def forward(self, z):
14 | """
15 | :param z: input variable, first dimension is batch dim
16 | :return: transformed z and log of absolute determinant
17 | """
18 | raise NotImplementedError('Forward pass has not been implemented.')
19 |
20 | def inverse(self, z):
21 | raise NotImplementedError('This flow has no algebraic inverse.')
--------------------------------------------------------------------------------
/normflow/distributions/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseDistribution, DiagGaussian, ClassCondDiagGaussian, \
2 | GlowBase, AffineGaussian, GaussianMixture, GaussianPCA
3 | from .target import Target, TwoMoons, CircularGaussianMixture, RingMixture
4 |
5 | from .encoder import BaseEncoder, Dirac, Uniform, NNDiagGaussian
6 | from .decoder import BaseDecoder, NNDiagGaussianDecoder, NNBernoulliDecoder
7 | from .prior import PriorDistribution, ImagePrior, TwoModes, Sinusoidal, \
8 | Sinusoidal_split, Sinusoidal_gap, Smiley
9 |
10 | from .mh_proposal import MHProposal, DiagGaussianProposal
11 |
12 | from .linear_interpolation import LinearInterpolation
--------------------------------------------------------------------------------
/normflow/flows/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Flow
2 |
3 | from .reshape import Merge, Split, Squeeze
4 | from .mixing import Permute, InvertibleAffine, Invertible1x1Conv, LULinearPermute
5 | from .normalization import BatchNorm, ActNorm
6 |
7 | from .planar import Planar
8 | from .radial import Radial
9 |
10 | from .affine_coupling import AffineConstFlow, CCAffineConst, AffineCoupling, MaskedAffineFlow, AffineCouplingBlock
11 | from .glow import GlowBlock
12 |
13 | from .residual import Residual
14 | from .neural_spline import CoupledRationalQuadraticSpline, AutoregressiveRationalQuadraticSpline
15 |
16 | from .stochastic import MetropolisHastings, HamiltonianMonteCarlo
17 |
--------------------------------------------------------------------------------
/config_generative.py:
--------------------------------------------------------------------------------
1 | epochs_aeder = 200 # number of epochs to train autoencoder
2 | epochs_flow = 200 # number of epochs to train flow
3 | flow_depth = 5
4 | latent_dim = 256 # Latent dimension of the flow model
5 | batch_size = 64
6 | dataset = 'celeb-hq'
7 | gpu_num = 0 # GPU number
8 | exp_desc = 'default' # Add a small descriptor to the experiment
9 | image_size = 128 # Resolution of the dataset
10 | c = 3 # Number of channels of the dataset
11 | train_aeder = True # Train autoencoder or just reload
12 | train_flow = True # Train flow or just reload
13 | restore_aeder = False # Restore the trained autoencoder if exists
14 | restore_flow = False # Restore the trained flow if exists
15 |
--------------------------------------------------------------------------------
/normflow/distributions/linear_interpolation.py:
--------------------------------------------------------------------------------
1 | class LinearInterpolation:
2 | """
3 | Linear interpolation of two distributions in the log space
4 | """
5 | def __init__(self, dist1, dist2, alpha):
6 | """
7 | Constructor
8 | :param dist1: First distribution
9 | :param dist2: Second distribution
10 | :param alpha: Interpolation parameter,
11 | log_p = alpha * log_p_1 + (1 - alpha) * log_p_2
12 | """
13 | self.alpha = alpha
14 | self.dist1 = dist1
15 | self.dist2 = dist2
16 |
17 | def log_prob(self, z):
18 | return self.alpha * self.dist1.log_prob(z)\
19 | + (1 - self.alpha) * self.dist2.log_prob(z)
--------------------------------------------------------------------------------
/flow_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import normflow as nf
3 | import numpy as np
4 |
5 | def real_nvp(latent_dim, K=64):
6 |
7 |
8 | b = torch.Tensor([1 if i % 2 == 0 else 0 for i in range(latent_dim)])
9 | flows = []
10 | for i in range(K):
11 | s = nf.nets.MLP([latent_dim, 2 * latent_dim, latent_dim], init_zeros=True)
12 | t = nf.nets.MLP([latent_dim, 2 * latent_dim, latent_dim], init_zeros=True)
13 | if i % 2 == 0:
14 | flows += [nf.flows.MaskedAffineFlow(b, t, s)]
15 | else:
16 | flows += [nf.flows.MaskedAffineFlow(1 - b, t, s)]
17 | flows += [nf.flows.ActNorm(latent_dim)]
18 |
19 | q0 = nf.distributions.DiagGaussian(latent_dim)
20 |
21 | # Construct flow model
22 | nfm = nf.NormalizingFlow(q0=q0, flows=flows)
23 |
24 | return nfm
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 swing-research
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/normflow/flows/radial.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 | from .base import Flow
6 |
7 |
8 |
9 | class Radial(Flow):
10 | """
11 | Radial flow as introduced in arXiv: 1505.05770
12 | f(z) = z + beta * h(alpha, r) * (z - z_0)
13 | """
14 | def __init__(self, shape, z_0=None):
15 | """
16 | Constructor of the radial flow
17 | :param shape: shape of the latent variable z
18 | :param z_0: parameter of the radial flow
19 | """
20 | super().__init__()
21 | self.d_cpu = torch.prod(torch.tensor(shape))
22 | self.register_buffer('d', self.d_cpu)
23 | self.beta = nn.Parameter(torch.empty(1))
24 | lim = 1.0 / np.prod(shape)
25 | nn.init.uniform_(self.beta, -lim - 1.0, lim - 1.0)
26 | self.alpha = nn.Parameter(torch.empty(1))
27 | nn.init.uniform_(self.alpha, -lim, lim)
28 |
29 | if z_0 is not None:
30 | self.z_0 = nn.Parameter(z_0)
31 | else:
32 | self.z_0 = nn.Parameter(torch.randn(shape)[None])
33 |
34 | def forward(self, z):
35 | beta = torch.log(1 + torch.exp(self.beta)) - torch.abs(self.alpha)
36 | dz = z - self.z_0
37 | r = torch.norm(dz, dim=list(range(1, self.z_0.dim())))
38 | h_arr = beta / (torch.abs(self.alpha) + r)
39 | h_arr_ = - beta * r / (torch.abs(self.alpha) + r) ** 2
40 | z_ = z + h_arr.unsqueeze(1) * dz
41 | log_det = (self.d - 1) * torch.log(1 + h_arr) + torch.log(1 + h_arr + h_arr_)
42 | return z_, log_det
--------------------------------------------------------------------------------
/config_funknn.py:
--------------------------------------------------------------------------------
1 | epochs_funknn = 200 # number of epochs to train funknn network
2 | batch_size = 64
3 | dataset = 'celeb-hq'
4 | gpu_num = 0 # GPU number
5 | exp_desc = 'default' # Add a small descriptor to the experiment
6 | image_size = 128 # Maximum resolution of the training dataset
7 | c = 3 # Number of channels of the dataset
8 | train_funknn = True # Train or just reload to test
9 | restore_funknn = True
10 | training_mode = 'factor' # Training modes for funknn: {conitinuous, factor, single}
11 | ood_analysis = True # Evaluating the performance of model over out of distribution data (Lsun-bedroom)
12 | interpolation_kernel = 'bicubic' # interpolation kernels : {'cubic_conv', 'bilinear', 'bicubic'}
13 | # interpolation_kernel = 'bicubic' cannot be used for solving PDEs as its derivatives are not computed
14 | # but can be safely used for super-resolution task.
15 | # interpolation_kernel = 'bilinear' is fast but can only be used for PDEs with the fiesr-order derivatives.
16 | # interpolation_kernel = 'cubic_conv' is slow but can be used for solving PDEs with first- and second-order derivatives
17 | network = 'MLP' # The network can be a 'CNN' or 'MLP' ('MLP' is supposed to be significantly faster)
18 | activation = 'relu' # Activation function 'relu' or 'sin' ('sin' for more accurate spatial derivatives)
19 |
20 | # Evaluation arguements
21 | max_scale = 2 # Maximum scale to generate in test time (2 or 4 or 8) (=<8 for celeba-hq and 2 for other datasets)
22 | recursive = True # Recursive image reconstructions (Use just for factor training mode)
23 | sample_number = 25 # Number of samples in evaluation
24 | derivatives_evaluation = False # To evaluate the performance of the model for computing the derivatives
25 | # (Keep it False for 'bicubic' kernel)
26 |
27 | # Datasets paths:
28 | ood_path = 'datasets/lsun_bedroom_val'
29 | train_path = 'datasets/celeba_hq/celeba_hq_1024_train/'
30 | test_path = 'datasets/celeba_hq/celeba_hq_1024_test/'
31 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | import torchvision
4 | from torchvision.datasets import ImageFolder
5 | from utils import *
6 | import config_funknn as config
7 |
8 |
9 | class Dataset_loader(torch.utils.data.Dataset):
10 | def __init__(self, dataset , size=(1024,1024), c = 3, quantize = False):
11 |
12 | if c==1:
13 | self.transform = transforms.Compose([
14 | transforms.Resize(size),
15 | transforms.Grayscale(),
16 | transforms.ToTensor(),
17 | ])
18 | else:
19 | self.transform = transforms.Compose([
20 | transforms.Resize(size),
21 | transforms.ToTensor(),
22 | ])
23 |
24 | self.c = c
25 | self.dataset = dataset
26 | self.meshgrid = get_mgrid(size[0])
27 | self.im_size = size
28 | self.quantize = quantize
29 |
30 | if self.dataset == 'train':
31 | self.img_dataset = ImageFolder(config.train_path, self.transform)
32 |
33 | elif self.dataset == 'test':
34 | self.img_dataset = ImageFolder(config.test_path, self.transform)
35 |
36 | elif self.dataset == 'ood':
37 | lsun_class = ['bedroom_val']
38 | self.img_dataset = torchvision.datasets.LSUN(config.ood_path,
39 | classes=lsun_class, transform=self.transform)
40 |
41 | def __len__(self):
42 | return len(self.img_dataset)
43 |
44 | def __getitem__(self, item):
45 | img = self.img_dataset[item][0]
46 | img = transforms.ToPILImage()(img)
47 | img = self.transform(img).permute([1,2,0])
48 | img = img.reshape(-1, self.c)
49 |
50 | if self.quantize:
51 | img = img * 255.0
52 | img = torch.multiply(8.0, torch.div(img , 8 , rounding_mode = 'floor'))
53 | img = img/255.0
54 |
55 | return img
56 |
57 |
--------------------------------------------------------------------------------
/normflow/flows/residual.py:
--------------------------------------------------------------------------------
1 | from .base import Flow
2 |
3 | # Try importing Residual Flow dependencies
4 | try:
5 | from residual_flows.layers import iResBlock
6 | except:
7 | print('Warning: Dependencies for Residual Flows could '
8 | 'not be loaded. Other models can still be used.')
9 |
10 |
11 |
12 | class Residual(Flow):
13 | """
14 | Invertible residual net block, wrapper to the implementation of Chen et al.,
15 | see https://github.com/rtqichen/residual-flows
16 | """
17 | def __init__(self, net, n_exact_terms=2, n_samples=1, reduce_memory=True,
18 | reverse=True):
19 | """
20 | Constructor
21 | :param net: Neural network, must be Lipschitz continuous with L < 1
22 | :param n_exact_terms: Number of terms always included in the power series
23 | :param n_samples: Number of samples used to estimate power series
24 | :param reduce_memory: Flag, if true Neumann series and precomputations
25 | for backward pass in forward pass are done
26 | :param reverse: Flag, if true the map f(x) = x + net(x) is applied in
27 | the inverse pass, otherwise it is done in forward
28 | """
29 | super().__init__()
30 | self.reverse = reverse
31 | self.iresblock = iResBlock(net, n_samples=n_samples,
32 | n_exact_terms=n_exact_terms,
33 | neumann_grad=reduce_memory,
34 | grad_in_forward=reduce_memory)
35 |
36 | def forward(self, z):
37 | if self.reverse:
38 | z, log_det = self.iresblock.inverse(z, 0)
39 | else:
40 | z, log_det = self.iresblock.forward(z, 0)
41 | return z, -log_det.view(-1)
42 |
43 | def inverse(self, z):
44 | if self.reverse:
45 | z, log_det = self.iresblock.forward(z, 0)
46 | else:
47 | z, log_det = self.iresblock.inverse(z, 0)
48 | return z, -log_det.view(-1)
--------------------------------------------------------------------------------
/normflow/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from . import flows
4 |
5 | # Transforms to be applied to data as preprocessing
6 |
7 | class Logit(flows.Flow):
8 | """
9 | Logit mapping of image tensor, see RealNVP paper
10 | logit(alpha + (1 - alpha) * x) where logit(x) = log(x / (1 - x))
11 | """
12 | def __init__(self, alpha=0.05):
13 | """
14 | Constructor
15 | :param alpha: Alpha parameter, see above
16 | """
17 | super().__init__()
18 | self.alpha = alpha
19 |
20 | def forward(self, z):
21 | beta = 1 - 2 * self.alpha
22 | sum_dims = list(range(1, z.dim()))
23 | ls = torch.sum(torch.nn.functional.logsigmoid(z), dim=sum_dims)
24 | mls = torch.sum(torch.nn.functional.logsigmoid(-z), dim=sum_dims)
25 | log_det = -np.log(beta) * np.prod([*z.shape[1:]]) + ls + mls
26 | z = (torch.sigmoid(z) - self.alpha) / beta
27 | return z, log_det
28 |
29 | def inverse(self, z):
30 | beta = 1 - 2 * self.alpha
31 | z = self.alpha + beta * z
32 | logz = torch.log(z)
33 | log1mz = torch.log(1 - z)
34 | z = logz - log1mz
35 | sum_dims = list(range(1, z.dim()))
36 | log_det = np.log(beta) * np.prod([*z.shape[1:]]) \
37 | - torch.sum(logz, dim=sum_dims) \
38 | - torch.sum(log1mz, dim=sum_dims)
39 | return z, log_det
40 |
41 |
42 | class Shift(flows.Flow):
43 | """
44 | Shift data by a fixed constant, default is -0.5 to shift data from
45 | interval [0, 1] to [-0.5, 0.5]
46 | """
47 | def __init__(self, shift=-0.5):
48 | """
49 | Constructor
50 | :param shift: Shift to apply to the data
51 | """
52 | super().__init__()
53 | self.shift = shift
54 |
55 | def forward(self, z):
56 | z -= self.shift
57 | log_det = 0.
58 | return z, log_det
59 |
60 | def inverse(self, z):
61 | z += self.shift
62 | log_det = 0.
63 | return z, log_det
--------------------------------------------------------------------------------
/normflow/HAIS.py:
--------------------------------------------------------------------------------
1 | ### Implementation of Hamiltonian Annealed Importance Sampling ###
2 |
3 | import torch
4 | from . import distributions
5 | from . import flows
6 |
7 | class HAIS():
8 | """
9 | Class which performs HAIS
10 | """
11 | def __init__(self, betas, prior, target, num_leapfrog, step_size, log_mass):
12 | """
13 | :param betas: Annealing schedule, the jth target is f_j(x) =
14 | f_0(x)^{\beta_j} f_n(x)^{1-\beta_j} where the target is proportional
15 | to f_0 and the prior is proportional to f_n. The number of
16 | intermediate steps is infered from the shape of betas.
17 | Should be of the form 1 = \beta_0 > \beta_1 > ... > \beta_n = 0
18 | :param prior: The prior distribution to start the HAIS chain.
19 | :param target: The target distribution from which we would like to draw
20 | weighted samples.
21 | :param num_leapfrog: Number of leapfrog steps in the HMC transitions.
22 | :param step_size: step_size to use for HMC transitions.
23 | :param log_mass: log_mass to use for HMC transitions.
24 | """
25 | self.prior = prior
26 | self.target = target
27 | self.layers = []
28 | n = betas.shape[0] - 1
29 | for i in range(n-1, 0, -1):
30 | intermediate_target = distributions.LinearInterpolation(self.target,
31 | self.prior, betas[i])
32 | self.layers += [flows.HamiltonianMonteCarlo(intermediate_target,
33 | num_leapfrog, torch.log(step_size), log_mass)]
34 |
35 | def sample(self, num_samples):
36 | """
37 | Run HAIS to draw samples from the target with appropriate weights.
38 | :param num_samples: The number of samples to draw.
39 | """
40 | samples, log_weights= self.prior.forward(num_samples)
41 | log_weights = -log_weights
42 | for i in range(len(self.layers)):
43 | samples, log_weights_addition = self.layers[i].forward(samples)
44 | log_weights += log_weights_addition
45 | log_weights += self.target.log_prob(samples)
46 | return samples, log_weights
47 |
--------------------------------------------------------------------------------
/normflow/distributions/mh_proposal.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 |
6 |
7 | class MHProposal(nn.Module):
8 | """
9 | Proposal distribution for the Metropolis Hastings algorithm
10 | """
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def sample(self, z):
15 | """
16 | Sample new value based on previous z
17 | """
18 | raise NotImplementedError
19 |
20 | def log_prob(self, z_, z):
21 | """
22 | :param z_: Potential new sample
23 | :param z: Previous sample
24 | :return: Log probability of proposal distribution
25 | """
26 | raise NotImplementedError
27 |
28 | def forward(self, z):
29 | """
30 | Draw samples given z and compute log probability difference
31 | log(p(z | z_new)) - log(p(z_new | z))
32 | :param z: Previous samples
33 | :return: Proposal, difference of log probability ratio
34 | """
35 | raise NotImplementedError
36 |
37 |
38 | class DiagGaussianProposal(MHProposal):
39 | """
40 | Diagonal Gaussian distribution with previous value as mean
41 | as a proposal for Metropolis Hastings algorithm
42 | """
43 | def __init__(self, shape, scale):
44 | """
45 | Constructor
46 | :param shape: Shape of variables to sample
47 | :param scale: Standard deviation of distribution
48 | """
49 | super().__init__()
50 | self.shape = shape
51 | self.scale_cpu = torch.tensor(scale)
52 | self.register_buffer("scale", self.scale_cpu.unsqueeze(0))
53 |
54 | def sample(self, z):
55 | num_samples = len(z)
56 | eps = torch.randn((num_samples,) + self.shape, dtype=z.dtype, device=z.device)
57 | z_ = eps * self.scale + z
58 | return z_
59 |
60 | def log_prob(self, z_, z):
61 | log_p = - 0.5 * np.prod(self.shape) * np.log(2 * np.pi) \
62 | - torch.sum(torch.log(self.scale) + 0.5 * torch.pow((z_ - z) / self.scale, 2),
63 | list(range(1, z.dim())))
64 | return log_p
65 |
66 | def forward(self, z):
67 | num_samples = len(z)
68 | eps = torch.randn((num_samples,) + self.shape, dtype=z.dtype, device=z.device)
69 | z_ = eps * self.scale + z
70 | log_p_diff = torch.zeros(num_samples, dtype=z.dtype, device=z.device)
71 | return z_, log_p_diff
--------------------------------------------------------------------------------
/normflow/flows/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base import Flow
4 | from .affine_coupling import AffineConstFlow
5 |
6 |
7 |
8 | class ActNorm(AffineConstFlow):
9 | """
10 | An AffineConstFlow but with a data-dependent initialization,
11 | where on the very first batch we clever initialize the s,t so that the output
12 | is unit gaussian. As described in Glow paper.
13 | """
14 |
15 | def __init__(self, *args, **kwargs):
16 | super().__init__(*args, **kwargs)
17 | self.data_dep_init_done_cpu = torch.tensor(0.)
18 | self.register_buffer('data_dep_init_done', self.data_dep_init_done_cpu)
19 |
20 | def forward(self, z):
21 | # first batch is used for initialization, c.f. batchnorm
22 | if not self.data_dep_init_done > 0.:
23 | assert self.s is not None and self.t is not None
24 | s_init = -torch.log(z.std(dim=self.batch_dims, keepdim=True) + 1e-6)
25 | self.s.data = s_init.data
26 | self.t.data = (-z.mean(dim=self.batch_dims, keepdim=True) * torch.exp(self.s)).data
27 | self.data_dep_init_done = torch.tensor(1.)
28 | return super().forward(z)
29 |
30 | def inverse(self, z):
31 | # first batch is used for initialization, c.f. batchnorm
32 | if not self.data_dep_init_done:
33 | assert self.s is not None and self.t is not None
34 | s_init = torch.log(z.std(dim=self.batch_dims, keepdim=True) + 1e-6)
35 | self.s.data = s_init.data
36 | self.t.data = z.mean(dim=self.batch_dims, keepdim=True).data
37 | self.data_dep_init_done = torch.tensor(1.)
38 | return super().inverse(z)
39 |
40 |
41 | class BatchNorm(Flow):
42 | """
43 | Batch Normalization with out considering the derivatives of the batch statistics, see arXiv: 1605.08803
44 | """
45 | def __init__(self, eps=1.e-10):
46 | super().__init__()
47 | self.eps_cpu = torch.tensor(eps)
48 | self.register_buffer('eps', self.eps_cpu)
49 |
50 | def forward(self, z):
51 | """
52 | Do batch norm over batch and sample dimension
53 | """
54 | mean = torch.mean(z, dim=0, keepdims=True)
55 | std = torch.std(z, dim=0, keepdims=True)
56 | z_ = (z - mean) / torch.sqrt(std ** 2 + self.eps)
57 | log_det = torch.log(1 / torch.prod(torch.sqrt(std ** 2 + self.eps))).repeat(z.size()[0])
58 | return z_, log_det
--------------------------------------------------------------------------------
/normflow/distributions/decoder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 |
6 |
7 | class BaseDecoder(nn.Module):
8 | def __init__(self):
9 | super().__init__()
10 |
11 | def forward(self, z):
12 | """
13 | Decodes z to x
14 | :param z: latent variable
15 | :return: x, std of x
16 | """
17 | raise NotImplementedError
18 |
19 | def log_prob(self, x, z):
20 | """
21 | :param x: observable
22 | :param z: latent variable
23 | :return: log(p) of x given z
24 | """
25 | raise NotImplementedError
26 |
27 |
28 | class NNDiagGaussianDecoder(BaseDecoder):
29 | """
30 | BaseDecoder representing a diagonal Gaussian distribution with mean and std parametrized by a NN
31 | """
32 | def __init__(self, net):
33 | """
34 | Constructor
35 | :param net: neural network parametrizing mean and standard deviation of diagonal Gaussian
36 | """
37 | super().__init__()
38 | self.net = net
39 |
40 | def forward(self, z):
41 | z_size = z.size()
42 | mean_std = self.net(z.view(-1, *z_size[2:])).view(z_size)
43 | n_hidden = mean_std.size()[2] // 2
44 | mean = mean_std[:, :, :n_hidden, ...]
45 | std = torch.exp(0.5 * mean_std[:, :, n_hidden:(2 * n_hidden), ...])
46 | return mean, std
47 |
48 | def log_prob(self, x, z):
49 | mean_std = self.net(z.view(-1, *z.size()[2:])).view(*z.size()[:2], x.size(1) * 2, *x.size()[3:])
50 | n_hidden = mean_std.size()[2] // 2
51 | mean = mean_std[:, :, :n_hidden, ...]
52 | var = torch.exp(mean_std[:, :, n_hidden:(2 * n_hidden), ...])
53 | log_p = - 0.5 * torch.prod(torch.tensor(z.size()[2:])) * np.log(2 * np.pi) \
54 | - 0.5 * torch.sum(torch.log(var) + (x.unsqueeze(1) - mean) ** 2 / var, list(range(2, z.dim())))
55 | return log_p
56 |
57 |
58 | class NNBernoulliDecoder(BaseDecoder):
59 | """
60 | BaseDecoder representing a Bernoulli distribution with mean parametrized by a NN
61 | """
62 |
63 | def __init__(self, net):
64 | """
65 | Constructor
66 | :param net: neural network parametrizing mean Bernoulli (mean = sigmoid(nn_out)
67 | """
68 | super().__init__()
69 | self.net = net
70 |
71 | def forward(self, z):
72 | mean = torch.sigmoid(self.net(z))
73 | return mean
74 |
75 | def log_prob(self, x, z):
76 | score = self.net(z)
77 | x = x.unsqueeze(1)
78 | x = x.repeat(1, z.size()[0] // x.size()[0], *((x.dim() - 2) * [1])).view(-1, *x.size()[2:])
79 | log_sig = lambda a: -torch.relu(-a) - torch.log(1 + torch.exp(-torch.abs(a)))
80 | log_p = torch.sum(x * log_sig(score) + (1 - x) * log_sig(-score), list(range(1, x.dim())))
81 | return log_p
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FunkNN: Neural Interpolation for Functional Generation
2 |
3 |
4 | [](https://arxiv.org/abs/2212.14042)
5 | [](https://paperswithcode.com/paper/funknn-neural-interpolation-for-functional)
6 |
7 | This repository is the official Pytorch implementation of "[FunkNN: Neural Interpolation for Functional Generation](https://openreview.net/forum?id=BT4N_v7CLrk)" in ICLR 2023.
8 |
9 | | [**Project Page**](https://sada.dmi.unibas.ch/en/research/implicit-neural-representation) |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | ## Requirements
19 | (This code is tested with PyTorch 1.12.1, Python 3.8.3, CUDA 11.6 and cuDNN 7.)
20 | - numpy
21 | - scipy
22 | - matplotlib
23 | - odl
24 | - imageio
25 | - torch==1.12.1
26 | - torchvision=0.13.1
27 | - astra-toolbox
28 |
29 | ## Installation
30 |
31 | Run the following code to install conda environment "environment.yml":
32 | ```sh
33 | conda env create -f environment.yml
34 | ```
35 |
36 | ## Datasets
37 | You can download the [CelebA-HQ](https://drive.switch.ch/index.php/s/pA6X3TY9x4jgcxb), [LoDoPaB-CT](https://drive.switch.ch/index.php/s/lQeYWmAIYcEEdlc) and [LSUN-bedroom](https://drive.switch.ch/index.php/s/d1MNcrUZkPpK0zx) validation datasets and split them into train and test sets and put them in the data folder. You should specify the data folder addresses in config_funknn.py and config_generative.py.
38 |
39 | ## Experiments
40 | ### Train FunkNN
41 | All arguments for training FunkNN model are explained in config_funknn.py. After specifying your arguments, you can run the following command to train the model:
42 | ```sh
43 | python3 train_funknn.py
44 | ```
45 |
46 | ### Train generative autoencoder
47 | All arguments for training generative autoencoder are explained in config_generative.py. After specifying your arguments, you can run the following command to train the model:
48 | ```sh
49 | python3 train_generative.py
50 | ```
51 |
52 |
53 | ### Solving inverse problems
54 | All arguments for solving inverse problem by combining FunkNN and generative autoencoder are explained in config_IP_solver.py. After specifying your arguments including the folder address of trained FunkNN and generator, you can run the following command to solve the inverse problem of your choice (CT or PDE):
55 | ```sh
56 | python3 IP_solver.py
57 | ```
58 |
59 | ## Citation
60 | If you find the code useful in your research, please consider citing the paper.
61 |
62 | ```
63 | @inproceedings{
64 | khorashadizadeh2023funknn,
65 | title={Funk{NN}: Neural Interpolation for Functional Generation},
66 | author={AmirEhsan Khorashadizadeh and Anadi Chaman and Valentin Debarnot and Ivan Dokmani{\'c}},
67 | booktitle={The Eleventh International Conference on Learning Representations },
68 | year={2023},
69 | url={https://openreview.net/forum?id=BT4N_v7CLrk}
70 | }
71 | ```
72 |
73 |
--------------------------------------------------------------------------------
/normflow/flows/planar.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 | from .base import Flow
6 |
7 |
8 |
9 | class Planar(Flow):
10 | """
11 | Planar flow as introduced in arXiv: 1505.05770
12 | f(z) = z + u * h(w * z + b)
13 | """
14 |
15 | def __init__(self, shape, act="tanh", u=None, w=None, b=None):
16 | """
17 | Constructor of the planar flow
18 | :param shape: shape of the latent variable z
19 | :param h: nonlinear function h of the planar flow (see definition of f above)
20 | :param u,w,b: optional initialization for parameters
21 | """
22 | super().__init__()
23 | lim_w = np.sqrt(2. / np.prod(shape))
24 | lim_u = np.sqrt(2)
25 |
26 | if u is not None:
27 | self.u = nn.Parameter(u)
28 | else:
29 | self.u = nn.Parameter(torch.empty(shape)[None])
30 | nn.init.uniform_(self.u, -lim_u, lim_u)
31 | if w is not None:
32 | self.w = nn.Parameter(w)
33 | else:
34 | self.w = nn.Parameter(torch.empty(shape)[None])
35 | nn.init.uniform_(self.w, -lim_w, lim_w)
36 | if b is not None:
37 | self.b = nn.Parameter(b)
38 | else:
39 | self.b = nn.Parameter(torch.zeros(1))
40 |
41 | self.act = act
42 | if act == "tanh":
43 | self.h = torch.tanh
44 | elif act == "leaky_relu":
45 | self.h = torch.nn.LeakyReLU(negative_slope=0.2)
46 | else:
47 | raise NotImplementedError('Nonlinearity is not implemented.')
48 |
49 | def forward(self, z):
50 | lin = torch.sum(self.w * z, list(range(1, self.w.dim()))) + self.b
51 | if self.act == "tanh":
52 | inner = torch.sum(self.w * self.u)
53 | u = self.u + (torch.log(1 + torch.exp(inner)) - 1 - inner) * self.w / torch.sum(self.w ** 2)
54 | h_ = lambda x: 1 / torch.cosh(x) ** 2
55 | elif self.act == "leaky_relu":
56 | inner = torch.sum(self.w * self.u)
57 | u = self.u + (torch.log(1 + torch.exp(inner)) - 1 - inner) * self.w / torch.sum(
58 | self.w ** 2) # constraint w.T * u neq -1, use >
59 | h_ = lambda x: (x < 0) * (self.h.negative_slope - 1.0) + 1.0
60 |
61 | z_ = z + u * self.h(lin.unsqueeze(1))
62 | log_det = torch.log(torch.abs(1 + torch.sum(self.w * u) * h_(lin)))
63 | return z_, log_det
64 |
65 | def inverse(self, z):
66 | if self.act != "leaky_relu":
67 | raise NotImplementedError('This flow has no algebraic inverse.')
68 | lin = torch.sum(self.w * z, list(range(2, self.w.dim())), keepdim=True) + self.b
69 | inner = torch.sum(self.w * self.u)
70 | a = ((lin + self.b) / (1 + inner) < 0) * (self.h.negative_slope - 1.0) + 1.0 # absorb leakyReLU slope into u
71 | u = a * (self.u + (torch.log(1 + torch.exp(inner)) - 1 - inner) * self.w / torch.sum(self.w ** 2))
72 | z_ = z - 1 / (1 + inner) * (lin + u * self.b)
73 | log_det = -torch.log(torch.abs(1 + torch.sum(self.w * u)))
74 | if log_det.dim() == 0:
75 | log_det = log_det.unsqueeze(0)
76 | if log_det.dim() == 1:
77 | log_det = log_det.unsqueeze(1)
78 | return z_, log_det
--------------------------------------------------------------------------------
/laplacian_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 |
5 |
6 |
7 | def gaussian_kernel(size=5, device=torch.device('cpu'), channels=3, sigma=1, dtype=torch.float):
8 | # Create Gaussian Kernel. In Numpy
9 | interval = (2*sigma +1)/(size)
10 | ax = np.linspace(-(size - 1)/ 2., (size-1)/2., size)
11 | xx, yy = np.meshgrid(ax, ax)
12 | kernel = np.exp(-0.5 * (np.square(xx)+ np.square(yy)) / np.square(sigma))
13 | kernel /= np.sum(kernel)
14 | # Change kernel to PyTorch. reshapes to (channels, 1, size, size)
15 | kernel_tensor = torch.as_tensor(kernel, dtype=dtype)
16 | kernel_tensor = kernel_tensor.repeat(channels, 1 , 1, 1)
17 | kernel_tensor = kernel_tensor.to(device)
18 | return kernel_tensor
19 |
20 | def gaussian_conv2d(x, g_kernel, dtype=torch.float):
21 | #Assumes input of x is of shape: (minibatch, depth, height, width)
22 | #Infer depth automatically based on the shape
23 | channels = g_kernel.shape[0]
24 | padding = g_kernel.shape[-1] // 2 # Kernel size needs to be odd number
25 | if len(x.shape) != 4:
26 | raise IndexError('Expected input tensor to be of shape: (batch, depth, height, width) but got: ' + str(x.shape))
27 | y = F.conv2d(x, weight=g_kernel, stride=1, padding=padding, groups=channels)
28 | return y
29 |
30 | def downsample(x):
31 | # Downsamples along image (H,W). Takes every 2 pixels. output (H, W) = input (H/2, W/2)
32 | return x[:, :, ::2, ::2]
33 |
34 | def create_laplacian_pyramid(x, kernel, levels):
35 | upsample = torch.nn.Upsample(scale_factor=2) # Default mode is nearest: [[1 2],[3 4]] -> [[1 1 2 2],[3 3 4 4]]
36 | pyramids = []
37 | current_x = x
38 | for level in range(0, levels):
39 | gauss_filtered_x = gaussian_conv2d(current_x, kernel)
40 | down = downsample(gauss_filtered_x)
41 | # Original Algorithm does indeed: L_i = G_i - expand(G_i+1), with L_i as current laplacian layer, and G_i as current gaussian filtered image, and G_i+1 the next.
42 | # Some implementations skip expand(G_i+1) and use gaussian_conv(G_i). We decided to use expand, as is the original algorithm
43 | laplacian = current_x - upsample(down)
44 | pyramids.append(laplacian)
45 | current_x = down
46 | pyramids.append(current_x)
47 | return pyramids
48 |
49 | class LaplacianPyramidLoss(torch.nn.Module):
50 | def __init__(self, max_levels=3, channels=3, kernel_size=5, sigma=1, device=torch.device('cpu'), dtype=torch.float):
51 | super(LaplacianPyramidLoss, self).__init__()
52 | self.max_levels = max_levels
53 | self.kernel = gaussian_kernel(size=kernel_size,device=device, channels=channels, sigma=sigma, dtype=dtype)
54 |
55 | def forward(self, x, target):
56 | input_pyramid = create_laplacian_pyramid(x, self.kernel, self.max_levels)
57 | target_pyramid = create_laplacian_pyramid(target, self.kernel, self.max_levels)
58 | loss = 0.0
59 | cnt = 0
60 | for x, y in zip(input_pyramid, target_pyramid):
61 | w = 2**cnt
62 | loss = loss + w*torch.nn.functional.l1_loss(x,y)
63 | cnt = cnt + 1
64 | # return sum(torch.nn.functional.l1_loss(x,y) for x, y in zip(input_pyramid, target_pyramid))
65 | return loss
66 |
--------------------------------------------------------------------------------
/normflow/flows/glow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .base import Flow
5 | from .affine_coupling import AffineCouplingBlock
6 | from .mixing import Invertible1x1Conv
7 | from .normalization import ActNorm
8 | from .. import nets
9 |
10 |
11 |
12 | class GlowBlock(Flow):
13 | """
14 | Glow: Generative Flow with Invertible 1×1 Convolutions, arXiv: 1807.03039
15 | One Block of the Glow model, comprised of
16 | MaskedAffineFlow (affine coupling layer
17 | Invertible1x1Conv (dropped if there is only one channel)
18 | ActNorm (first batch used for initialization)
19 | """
20 | def __init__(self, channels, hidden_channels, scale=True, scale_map='sigmoid',
21 | split_mode='channel', leaky=0.0, init_zeros=True, use_lu=True,
22 | net_actnorm=False):
23 | """
24 | Constructor
25 | :param channels: Number of channels of the data
26 | :param hidden_channels: number of channels in the hidden layer of the ConvNet
27 | :param scale: Flag, whether to include scale in affine coupling layer
28 | :param scale_map: Map to be applied to the scale parameter, can be 'exp' as in
29 | RealNVP or 'sigmoid' as in Glow
30 | :param split_mode: Splitting mode, for possible values see Split class
31 | :param leaky: Leaky parameter of LeakyReLUs of ConvNet2d
32 | :param init_zeros: Flag whether to initialize last conv layer with zeros
33 | :param use_lu: Flag whether to parametrize weights through the LU decomposition
34 | in invertible 1x1 convolution layers
35 | :param logscale_factor: Factor which can be used to control the scale of
36 | the log scale factor, see https://github.com/openai/glow
37 | """
38 | super().__init__()
39 | self.flows = nn.ModuleList([])
40 | # Coupling layer
41 | kernel_size = (3, 1, 3)
42 | num_param = 2 if scale else 1
43 | if 'channel' == split_mode:
44 | channels_ = (channels // 2,) + 2 * (hidden_channels,)
45 | channels_ += (num_param * ((channels + 1) // 2),)
46 | elif 'channel_inv' == split_mode:
47 | channels_ = ((channels + 1) // 2,) + 2 * (hidden_channels,)
48 | channels_ += (num_param * (channels // 2),)
49 | elif 'checkerboard' in split_mode:
50 | channels_ = (channels,) + 2 * (hidden_channels,)
51 | channels_ += (num_param * channels,)
52 | else:
53 | raise NotImplementedError('Mode ' + split_mode + ' is not implemented.')
54 | param_map = nets.ConvNet2d(channels_, kernel_size, leaky, init_zeros,
55 | actnorm=net_actnorm)
56 | self.flows += [AffineCouplingBlock(param_map, scale, scale_map, split_mode)]
57 | # Invertible 1x1 convolution
58 | if channels > 1:
59 | self.flows += [Invertible1x1Conv(channels, use_lu)]
60 | # Activation normalization
61 | self.flows += [ActNorm((channels,) + (1, 1))]
62 |
63 | def forward(self, z):
64 | log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
65 | for flow in self.flows:
66 | z, log_det = flow(z)
67 | log_det_tot += log_det
68 | return z, log_det_tot
69 |
70 | def inverse(self, z):
71 | log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
72 | for i in range(len(self.flows) - 1, -1, -1):
73 | z, log_det = self.flows[i].inverse(z)
74 | log_det_tot += log_det
75 | return z, log_det_tot
--------------------------------------------------------------------------------
/ops/operator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from absl import logging
5 |
6 | from .odl_lib import ParallelBeamGeometryOp, ParallelBeamGeometryOpBroken
7 | from .radon_3d_lib import ParallelBeamGeometry3DOp, ParallelBeamGeometry3DOpBroken
8 | from .traveltime_lib import TravelTimeOperator
9 |
10 | def get_operator_dict(config):
11 | if config.problem == 'radon':
12 | operator = ParallelBeamGeometryOp(
13 | config.img_size,
14 | config.op_param,
15 | op_snr=np.inf,
16 | angle_max=config.angle_max)
17 |
18 | operator_dict = {}
19 |
20 | if config.opt_strat == 'dip_noisy':
21 | broken_operator = ParallelBeamGeometryOpBroken(operator, config.op_snr)
22 | operator_dict.update({'original_operator': broken_operator})
23 | operator_dict.update({'noisy_operator': operator})
24 | elif config.opt_strat in ['dip', 'joint', 'joint_input']:
25 | dense_operator = ParallelBeamGeometryOp(
26 | config.img_size, config.dense_op_param, op_snr=np.inf,angle_max=config.angle_max)
27 | operator_dict.update({'dense_operator': dense_operator})
28 | broken_operator = ParallelBeamGeometryOpBroken(operator, config.op_snr)
29 | operator_dict.update({'original_operator': broken_operator})
30 | elif config.opt_strat == 'broken_machine':
31 | dense_operator = ParallelBeamGeometryOp(
32 | config.img_size,
33 | config.op_param,
34 | op_snr=np.inf,
35 | angle_max=config.angle_max)
36 | operator_dict = {'original_operator_clean': dense_operator}
37 |
38 | broken_operator = ParallelBeamGeometryOpBroken(dense_operator, config.op_snr)
39 | operator_dict.update({'original_operator': broken_operator})
40 |
41 | dense_operator = ParallelBeamGeometryOp(
42 | config.img_size,
43 | config.dense_op_param,
44 | op_snr=np.inf,
45 | angle_max=config.angle_max)
46 | operator_dict['dense_operator'] = dense_operator
47 |
48 | logging.info(f'operator_dict: {operator_dict}')
49 |
50 | else:
51 | raise ValueError(f'Did not recognize opt.strat={config.opt_strat}.')
52 |
53 | elif config.problem == 'radon_3d':
54 | operator_dict = {}
55 |
56 | if config.opt_strat == 'broken_machine':
57 | dense_operator = ParallelBeamGeometry3DOp(config.img_size, config.op_param, op_snr=np.inf,angle_max=config.angle_max)
58 | operator_dict = {'original_operator_clean': dense_operator}
59 |
60 | broken_operator = ParallelBeamGeometry3DOpBroken(dense_operator, config.op_snr)
61 | operator_dict.update({'original_operator': broken_operator})
62 |
63 | dense_operator = ParallelBeamGeometry3DOp(config.img_size, config.dense_op_param, op_snr=np.inf,angle_max=config.angle_max)
64 | operator_dict['dense_operator'] = dense_operator
65 |
66 | logging.info(f'operator_dict: {operator_dict}')
67 |
68 | else:
69 | raise ValueError(f'Did not recognize opt.strat={config.opt_strat}.')
70 |
71 |
72 | elif config.problem == 'traveltime':
73 | original_sensors = torch.rand(
74 | config.op_param, 2, dtype=torch.float32)
75 |
76 | operator = TravelTimeOperator(
77 | original_sensors,
78 | config.img_size)
79 | operator_dict = {'original_operator': operator}
80 |
81 | # Add new sensors for dense operator.
82 | new_sensors = torch.rand(
83 | config.dense_op_param - config.op_param, 2, dtype=torch.float32)
84 | sensors = torch.cat((original_sensors, new_sensors), dim=0)
85 |
86 | dense_operator = TravelTimeOperator(
87 | sensors, config.img_size)
88 | operator_dict.update({'dense_operator': dense_operator})
89 |
90 | else:
91 | raise ValueError('Inverse problem unrecognized.')
92 |
93 | return operator_dict
--------------------------------------------------------------------------------
/normflow/flows/stochastic.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base import Flow
4 |
5 |
6 |
7 | class MetropolisHastings(Flow):
8 | """
9 | Sampling through Metropolis Hastings in Stochastic Normalizing
10 | Flow, see arXiv: 2002.06707
11 | """
12 | def __init__(self, dist, proposal, steps):
13 | """
14 | Constructor
15 | :param dist: Distribution to sample from
16 | :param proposal: Proposal distribution
17 | :param steps: Number of MCMC steps to perform
18 | """
19 | super().__init__()
20 | self.dist = dist
21 | self.proposal = proposal
22 | self.steps = steps
23 |
24 | def forward(self, z):
25 | # Initialize number of samples and log(det)
26 | num_samples = len(z)
27 | log_det = torch.zeros(num_samples, dtype=z.dtype, device=z.device)
28 | # Get log(p) for current samples
29 | log_p = self.dist.log_prob(z)
30 | for i in range(self.steps):
31 | # Make proposal and get log(p)
32 | z_, log_p_diff = self.proposal(z)
33 | log_p_ = self.dist.log_prob(z_)
34 | # Make acceptance decision
35 | w = torch.rand(num_samples, dtype=z.dtype, device=z.device)
36 | log_w_accept = log_p_ - log_p + log_p_diff
37 | w_accept = torch.clamp(torch.exp(log_w_accept), max=1)
38 | accept = w <= w_accept
39 | # Update samples, log(det), and log(p)
40 | z = torch.where(accept.unsqueeze(1), z_, z)
41 | log_det_ = log_p - log_p_
42 | log_det = torch.where(accept, log_det + log_det_, log_det)
43 | log_p = torch.where(accept, log_p_, log_p)
44 | return z, log_det
45 |
46 | def inverse(self, z):
47 | # Equivalent to forward pass
48 | return self.forward(z)
49 |
50 |
51 | class HamiltonianMonteCarlo(Flow):
52 | """
53 | Flow layer using the HMC proposal in Stochastic Normalising Flows,
54 | see arXiv: 2002.06707
55 | """
56 | def __init__(self, target, steps, log_step_size, log_mass):
57 | """
58 | Constructor
59 | :param target: The stationary distribution of this Markov transition. Should be logp
60 | :param steps: The number of leapfrog steps
61 | :param log_step_size: The log step size used in the leapfrog integrator. shape (dim)
62 | :param log_mass: The log_mass determining the variance of the momentum samples. shape (dim)
63 | """
64 | super().__init__()
65 | self.target = target
66 | self.steps = steps
67 | self.register_parameter('log_step_size', torch.nn.Parameter(log_step_size))
68 | self.register_parameter('log_mass', torch.nn.Parameter(log_mass))
69 |
70 | def forward(self, z):
71 | # Draw momentum
72 | p = torch.randn_like(z) * torch.exp(0.5 * self.log_mass)
73 |
74 | # leapfrog
75 | z_new = z.clone()
76 | p_new = p.clone()
77 | step_size = torch.exp(self.log_step_size)
78 | for i in range(self.steps):
79 | p_half = p_new - (step_size/2.0) * -self.gradlogP(z_new)
80 | z_new = z_new + step_size * (p_half/torch.exp(self.log_mass))
81 | p_new = p_half - (step_size/2.0) * -self.gradlogP(z_new)
82 |
83 | # Metropolis Hastings correction
84 | probabilities = torch.exp(
85 | self.target.log_prob(z_new) - self.target.log_prob(z) - \
86 | 0.5 * torch.sum(p_new ** 2 / torch.exp(self.log_mass), 1) + \
87 | 0.5 * torch.sum(p ** 2 / torch.exp(self.log_mass), 1))
88 | uniforms = torch.rand_like(probabilities)
89 | mask = uniforms < probabilities
90 | z_out = torch.where(mask.unsqueeze(1), z_new, z)
91 |
92 | return z_out, self.target.log_prob(z) - self.target.log_prob(z_out)
93 |
94 | def inverse(self, z):
95 | return self.forward(z)
96 |
97 | def gradlogP(self, z):
98 | z_ = z.detach().requires_grad_()
99 | logp = self.target.log_prob(z_)
100 | return torch.autograd.grad(logp, z_,
101 | grad_outputs=torch.ones_like(logp))[0]
--------------------------------------------------------------------------------
/normflow/flows/reshape.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base import Flow
4 |
5 |
6 |
7 | # Flow layers to reshape the latent features
8 |
9 | class Split(Flow):
10 | """
11 | Split features into two sets
12 | """
13 | def __init__(self, mode='channel'):
14 | """
15 | Constructor
16 | :param mode: Splitting mode, can be
17 | channel: Splits first feature dimension, usually channels, into two halfs
18 | channel_inv: Same as channel, but with z1 and z2 flipped
19 | checkerboard: Splits features using a checkerboard pattern (last feature dimension must be even)
20 | checkerboard_inv: Same as checkerboard, but with inverted coloring
21 | """
22 | super().__init__()
23 | self.mode = mode
24 |
25 | def forward(self, z):
26 | if self.mode == 'channel':
27 | z1, z2 = z.chunk(2, dim=1)
28 | elif self.mode == 'channel_inv':
29 | z2, z1 = z.chunk(2, dim=1)
30 | elif 'checkerboard' in self.mode:
31 | n_dims = z.dim()
32 | cb0 = 0
33 | cb1 = 1
34 | for i in range(1, n_dims):
35 | cb0_ = cb0
36 | cb1_ = cb1
37 | cb0 = [cb0_ if j % 2 == 0 else cb1_ for j in range(z.size(n_dims - i))]
38 | cb1 = [cb1_ if j % 2 == 0 else cb0_ for j in range(z.size(n_dims - i))]
39 | cb = cb1 if 'inv' in self.mode else cb0
40 | cb = torch.tensor(cb)[None].repeat(len(z), *((n_dims - 1) * [1]))
41 | cb = cb.to(z.device)
42 | z_size = z.size()
43 | z1 = z.reshape(-1)[torch.nonzero(cb.view(-1), as_tuple=False)].view(*z_size[:-1], -1)
44 | z2 = z.reshape(-1)[torch.nonzero((1 - cb).view(-1), as_tuple=False)].view(*z_size[:-1], -1)
45 | else:
46 | raise NotImplementedError('Mode ' + self.mode + ' is not implemented.')
47 | log_det = 0
48 | return [z1, z2], log_det
49 |
50 | def inverse(self, z):
51 | z1, z2 = z
52 | if self.mode == 'channel':
53 | z = torch.cat([z1, z2], 1)
54 | elif self.mode == 'channel_inv':
55 | z = torch.cat([z2, z1], 1)
56 | elif 'checkerboard' in self.mode:
57 | n_dims = z1.dim()
58 | z_size = list(z1.size())
59 | z_size[-1] *= 2
60 | cb0 = 0
61 | cb1 = 1
62 | for i in range(1, n_dims):
63 | cb0_ = cb0
64 | cb1_ = cb1
65 | cb0 = [cb0_ if j % 2 == 0 else cb1_ for j in range(z_size[n_dims - i])]
66 | cb1 = [cb1_ if j % 2 == 0 else cb0_ for j in range(z_size[n_dims - i])]
67 | cb = cb1 if 'inv' in self.mode else cb0
68 | cb = torch.tensor(cb)[None].repeat(z_size[0], *((n_dims - 1) * [1]))
69 | cb = cb.to(z1.device)
70 | z1 = z1[..., None].repeat(*(n_dims * [1]), 2).view(*z_size[:-1], -1)
71 | z2 = z2[..., None].repeat(*(n_dims * [1]), 2).view(*z_size[:-1], -1)
72 | z = cb * z1 + (1 - cb) * z2
73 | else:
74 | raise NotImplementedError('Mode ' + self.mode + ' is not implemented.')
75 | log_det = 0
76 | return z, log_det
77 |
78 |
79 | class Merge(Split):
80 | """
81 | Same as Split but with forward and backward pass interchanged
82 | """
83 | def __init__(self, mode='channel'):
84 | super().__init__(mode)
85 |
86 | def forward(self, z):
87 | return super().inverse(z)
88 |
89 | def inverse(self, z):
90 | return super().forward(z)
91 |
92 |
93 | class Squeeze(Flow):
94 | """
95 | Squeeze operation of multi-scale architecture, RealNVP or Glow paper
96 | """
97 | def __init__(self):
98 | """
99 | Constructor
100 | """
101 | super().__init__()
102 |
103 | def forward(self, z):
104 | log_det = 0
105 | s = z.size()
106 | z = z.view(s[0], s[1] // 4, 2, 2, s[2], s[3])
107 | z = z.permute(0, 1, 4, 2, 5, 3).contiguous()
108 | z = z.view(s[0], s[1] // 4, 2 * s[2], 2 * s[3])
109 | return z, log_det
110 |
111 | def inverse(self, z):
112 | log_det = 0
113 | s = z.size()
114 | z = z.view(*s[:2], s[2] // 2, 2, s[3] // 2, 2)
115 | z = z.permute(0, 1, 3, 5, 2, 4).contiguous()
116 | z = z.view(s[0], 4 * s[1], s[2] // 2, s[3] // 2)
117 | return z, log_det
--------------------------------------------------------------------------------
/ops/radon_3d_lib.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import odl
5 |
6 | # from odl.contrib.torch import OperatorFunction
7 | from ops.ODLHelper import OperatorFunction
8 |
9 | from .odl_lib import apply_angle_noise
10 |
11 | class ParallelBeamGeometry3DOp(object):
12 | def __init__(self, img_size, num_angles, op_snr, angle_max=np.pi/3):
13 | self.img_size = img_size
14 | self.num_angles = num_angles
15 | self.angle_max = angle_max
16 | self.reco_space = odl.uniform_discr(
17 | min_pt=[-20, -20, -20],
18 | max_pt=[20, 20, 20],
19 | shape=[img_size, img_size, img_size],
20 | dtype='float32'
21 | )
22 |
23 | # Make a 3d single-axis parallel beam geometry with flat detector
24 | # Angles: uniformly spaced, n = 180, min = 0, max = pi
25 | # self.angle_partition = odl.uniform_partition(0, np.pi, 180)
26 | self.angle_partition = odl.uniform_partition(-angle_max, angle_max, num_angles)
27 | # Detector: uniformly sampled, n = (512, 512), min = (-30, -30), max = (30, 30)
28 | # self.detector_partition = odl.uniform_partition([-30, -30], [30, 30], [256, 256])
29 | # self.detector_partition = odl.tomo.parallel_beam_geometry(self.reco_space).det_partition
30 | self.detector_partition = odl.tomo.parallel_beam_geometry(self.reco_space,det_shape=(2*img_size,2*img_size)).det_partition
31 | self.geometry = odl.tomo.Parallel3dAxisGeometry(self.angle_partition, self.detector_partition)
32 |
33 | self.num_detectors_x, self.num_detectors_y = self.geometry.detector.shape
34 |
35 | self.angles = apply_angle_noise(self.geometry.angles, op_snr)
36 | self.optimizable_params = torch.tensor(self.angles, dtype=torch.float32) # Convert to torch.Tensor.
37 |
38 | self.op = odl.tomo.RayTransform(self.reco_space, self.geometry, impl='astra_cuda')
39 |
40 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op)
41 |
42 | def __call__(self, x):
43 | return OperatorFunction.apply(self.op, x)
44 |
45 | def pinv(self, y):
46 | return OperatorFunction.apply(self.fbp, y)
47 |
48 | class ParallelBeamGeometry3DOpBroken(ParallelBeamGeometry3DOp):
49 | def __init__(self, clean_operator, op_snr):
50 | super().__init__(clean_operator.img_size, clean_operator.num_angles, op_snr, clean_operator.angle_max)
51 |
52 | self.optimizable_params = torch.tensor(clean_operator.geometry.angles, dtype=torch.float32)
53 |
54 | self.angles = apply_angle_noise(clean_operator.geometry.angles, op_snr)
55 | # angle partition is changed to not be uniform
56 | self.angle_partition = odl.discr.nonuniform_partition(np.sort(self.angles))
57 |
58 | self.geometry = odl.tomo.Parallel3dAxisGeometry(self.angle_partition, self.detector_partition)
59 |
60 | self.num_detectors_x, self.num_detectors_y = self.geometry.detector.shape
61 |
62 | self.op = odl.tomo.RayTransform(self.reco_space, self.geometry, impl='astra_cuda')
63 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op)
64 |
65 |
66 | class ParallelBeamGeometry3DOpAngles(ParallelBeamGeometry3DOp):
67 | def __init__(self, img_size, angles, op_snr, angle_max=np.pi/3):
68 | super().__init__(img_size, angles.shape[0], op_snr,angle_max)
69 |
70 | self.optimizable_params = torch.tensor(self.geometry.angles, dtype=torch.float32)
71 |
72 | self.angles = angles
73 | # angle partition is changed to not be uniform
74 | self.angle_partition = odl.discr.nonuniform_partition(np.sort(self.angles))
75 |
76 | # TODO: change following axes to get rotation around another axis
77 | self.geometry = odl.tomo.Parallel3dAxisGeometry(self.angle_partition, self.detector_partition,
78 | axis=(1,0,0),det_axes_init=[(1, 0, 0), (0, 1, 0)],det_pos_init=(0,0,1))
79 |
80 |
81 | self.num_detectors_x, self.num_detectors_y = self.geometry.detector.shape
82 |
83 | self.op = odl.tomo.RayTransform(self.reco_space, self.geometry, impl='astra_cuda')
84 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op)
85 |
86 |
87 | def unit_test():
88 | img_size = 64
89 | num_angles = 60
90 | A = ParallelBeamGeometry3DOp(img_size, num_angles, np.inf)
91 |
92 | x = torch.rand([img_size, img_size, img_size])
93 | y = A(x)
94 | x_hat = A.pinv(y)
95 | print (x.shape)
96 | print (y.shape)
97 | print(x_hat.shape)
98 |
99 | if __name__ == "__main__":
100 | unit_test()
--------------------------------------------------------------------------------
/normflow/distributions/target.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 |
6 |
7 | class Target(nn.Module):
8 | """
9 | Sample target distributions to test models
10 | """
11 | def __init__(self, prop_scale=torch.tensor(6.),
12 | prop_shift=torch.tensor(-3.)):
13 | """
14 | Constructor
15 | :param prop_scale: Scale for the uniform proposal
16 | :param prop_shift: Shift for the uniform proposal
17 | """
18 | super().__init__()
19 | self.register_buffer("prop_scale", prop_scale)
20 | self.register_buffer("prop_shift", prop_shift)
21 |
22 | def log_prob(self, z):
23 | """
24 | :param z: value or batch of latent variable
25 | :return: log probability of the distribution for z
26 | """
27 | raise NotImplementedError('The log probability is not implemented yet.')
28 |
29 | def rejection_sampling(self, num_steps=1):
30 | """
31 | Perform rejection sampling on image distribution
32 | :param num_steps: Number of rejection sampling steps to perform
33 | :return: Accepted samples
34 | """
35 | eps = torch.rand((num_steps, self.n_dims), dtype=self.prop_scale.dtype,
36 | device=self.prop_scale.device)
37 | z_ = self.prop_scale * eps + self.prop_shift
38 | prob = torch.rand(num_steps, dtype=self.prop_scale.dtype,
39 | device=self.prop_scale.device)
40 | prob_ = torch.exp(self.log_prob(z_) - self.max_log_prob)
41 | accept = prob_ > prob
42 | z = z_[accept, :]
43 | return z
44 |
45 | def sample(self, num_samples=1):
46 | """
47 | Sample from image distribution through rejection sampling
48 | :param num_samples: Number of samples to draw
49 | :return: Samples
50 | """
51 | z = torch.zeros((0, self.n_dims), dtype=self.prop_scale.dtype,
52 | device=self.prop_scale.device)
53 | while len(z) < num_samples:
54 | z_ = self.rejection_sampling(num_samples)
55 | ind = np.min([len(z_), num_samples - len(z)])
56 | z = torch.cat([z, z_[:ind, :]], 0)
57 | return z
58 |
59 |
60 | class TwoMoons(Target):
61 | """
62 | Bimodal two-dimensional distribution
63 | """
64 | def __init__(self):
65 | super().__init__()
66 | self.n_dims = 2
67 | self.max_log_prob = 0.
68 |
69 | def log_prob(self, z):
70 | """
71 | log(p) = - 1/2 * ((norm(z) - 2) / 0.2) ** 2
72 | + log( exp(-1/2 * ((z[0] - 2) / 0.3) ** 2)
73 | + exp(-1/2 * ((z[0] + 2) / 0.3) ** 2))
74 | :param z: value or batch of latent variable
75 | :return: log probability of the distribution for z
76 | """
77 | a = torch.abs(z[:, 0])
78 | log_prob = - 0.5 * ((torch.norm(z, dim=1) - 2) / 0.2) ** 2 \
79 | - 0.5 * ((a - 2) / 0.3) ** 2 \
80 | + torch.log(1 + torch.exp(-4 * a / 0.09))
81 | return log_prob
82 |
83 |
84 | class CircularGaussianMixture(nn.Module):
85 | """
86 | Two-dimensional Gaussian mixture arranged in a circle
87 | """
88 | def __init__(self, n_modes=8):
89 | """
90 | Constructor
91 | :param n_modes: Number of modes
92 | """
93 | super(CircularGaussianMixture, self).__init__()
94 | self.n_modes = n_modes
95 | self.register_buffer("scale", torch.tensor(2 / 3 * np.sin(np.pi / self.n_modes)).float())
96 |
97 | def log_prob(self, z):
98 | d = torch.zeros((len(z), 0), dtype=z.dtype, device=z.device)
99 | for i in range(self.n_modes):
100 | d_ = ((z[:, 0] - 2 * np.sin(2 * np.pi / self.n_modes * i)) ** 2
101 | + (z[:, 1] - 2 * np.cos(2 * np.pi / self.n_modes * i)) ** 2)\
102 | / (2 * self.scale ** 2)
103 | d = torch.cat((d, d_[:, None]), 1)
104 | log_p = - torch.log(2 * np.pi * self.scale ** 2 * self.n_modes) \
105 | + torch.logsumexp(-d, 1)
106 | return log_p
107 |
108 | def sample(self, num_samples=1):
109 | eps = torch.randn((num_samples, 2), dtype=self.scale.dtype, device=self.scale.device)
110 | phi = 2 * np.pi / self.n_modes * torch.randint(0, self.n_modes, (num_samples,),
111 | device=self.scale.device)
112 | loc = torch.stack((2 * torch.sin(phi), 2 * torch.cos(phi)), 1).type(eps.dtype)
113 | return eps * self.scale + loc
114 |
115 |
116 | class RingMixture(Target):
117 | """
118 | Mixture of ring distributions in two dimensions
119 | """
120 | def __init__(self, n_rings=2):
121 | super().__init__()
122 | self.n_dims = 2
123 | self.max_log_prob = 0.
124 | self.n_rings = n_rings
125 | self.scale = 1 / 4 / self.n_rings
126 |
127 | def log_prob(self, z):
128 | d = torch.zeros((len(z), 0), dtype=z.dtype, device=z.device)
129 | for i in range(self.n_rings):
130 | d_ = ((torch.norm(z, dim=1) - 2 / self.n_rings * (i + 1)) ** 2) \
131 | / (2 * self.scale ** 2)
132 | d = torch.cat((d, d_[:, None]), 1)
133 | return torch.logsumexp(-d, 1)
--------------------------------------------------------------------------------
/normflow/distributions/encoder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 |
6 |
7 | class BaseEncoder(nn.Module):
8 | """
9 | Base distribution of a flow-based variational autoencoder
10 | Parameters of the distribution depend of the target variable x
11 | """
12 |
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def forward(self, x, num_samples=1):
17 | """
18 | :param x: Variable to condition on, first dimension is batch size
19 | :param num_samples: number of samples to draw per element of mini-batch
20 | :return: sample of z for x, log probability for sample
21 | """
22 | raise NotImplementedError
23 |
24 | def log_prob(self, z, x):
25 | """
26 | :param z: Primary random variable, first dimension is batch size
27 | :param x: Variable to condition on, first dimension is batch size
28 | :return: log probability of z given x
29 | """
30 | raise NotImplementedError
31 |
32 |
33 | class Dirac(BaseEncoder):
34 | def __init__(self):
35 | super().__init__()
36 |
37 | def forward(self, x, num_samples=1):
38 | z = x.unsqueeze(1).repeat(1, num_samples, 1)
39 | log_p = torch.zeros(z.size()[0:2])
40 | return z, log_p
41 |
42 | def log_prob(self, z, x):
43 | log_p = torch.zeros(z.size()[0:2])
44 | return log_p
45 |
46 |
47 | class Uniform(BaseEncoder):
48 | def __init__(self, zmin=0.0, zmax=1.0):
49 | super().__init__()
50 | self.zmin = zmin
51 | self.zmax = zmax
52 | self.log_p = -torch.log(zmax - zmin)
53 |
54 | def forward(self, x, num_samples=1):
55 | z = x.unsqueeze(1).repeat(1, num_samples, 1).uniform_(min=self.zmin, max=self.zmax)
56 | log_p = torch.zeros(z.size()[0:2]).fill_(self.log_p)
57 | return z, log_p
58 |
59 | def log_prob(self, z, x):
60 | log_p = torch.zeros(z.size()[0:2]).fill_(self.log_p)
61 | return log_p
62 |
63 |
64 | class ConstDiagGaussian(BaseEncoder):
65 | def __init__(self, loc, scale):
66 | """
67 | Multivariate Gaussian distribution with diagonal covariance and parameters being constant wrt x
68 | :param loc: mean vector of the distribution
69 | :param scale: vector of the standard deviations on the diagonal of the covariance matrix
70 | """
71 | super().__init__()
72 | self.d = len(loc)
73 | if not torch.is_tensor(loc):
74 | loc = torch.tensor(loc).float()
75 | if not torch.is_tensor(scale):
76 | scale = torch.tensor(scale).float()
77 | self.loc = nn.Parameter(loc.reshape((1, 1, self.d)))
78 | self.scale = nn.Parameter(scale)
79 |
80 | def forward(self, x=None, num_samples=1):
81 | """
82 | :param x: Variable to condition on, will only be used to determine the batch size
83 | :param num_samples: number of samples to draw per element of mini-batch
84 | :return: sample of z for x, log probability for sample
85 | """
86 | if x is not None:
87 | batch_size = len(x)
88 | else:
89 | batch_size = 1
90 | eps = torch.randn((batch_size, num_samples, self.d), device=x.device)
91 | z = self.loc + self.scale * eps
92 | log_p = - 0.5 * self.d * np.log(2 * np.pi) - torch.sum(torch.log(self.scale) + 0.5 * torch.pow(eps, 2), 2)
93 | return z, log_p
94 |
95 | def log_prob(self, z, x):
96 | """
97 | :param z: Primary random variable, first dimension is batch dimension
98 | :param x: Variable to condition on, first dimension is batch dimension
99 | :return: log probability of z given x
100 | """
101 | if z.dim() == 1:
102 | z = z.unsqueeze(0)
103 | if z.dim() == 2:
104 | z = z.unsqueeze(0)
105 | log_p = - 0.5 * self.d * np.log(2 * np.pi) - torch.sum(
106 | torch.log(self.scale) + 0.5 * ((z - self.loc) / self.scale) ** 2, 2)
107 | return log_p
108 |
109 |
110 | class NNDiagGaussian(BaseEncoder):
111 | """
112 | Diagonal Gaussian distribution with mean and variance determined by a neural network
113 | """
114 |
115 | def __init__(self, net):
116 | """
117 | Constructor
118 | :param net: net computing mean (first n / 2 outputs), standard deviation (second n / 2 outputs)
119 | """
120 | super().__init__()
121 | self.net = net
122 |
123 | def forward(self, x, num_samples=1):
124 | """
125 | :param x: Variable to condition on
126 | :param num_samples: number of samples to draw per element of mini-batch
127 | :return: sample of z for x, log probability for sample
128 | """
129 | batch_size = len(x)
130 | mean_std = self.net(x)
131 | n_hidden = mean_std.size()[1] // 2
132 | mean = mean_std[:, :n_hidden, ...].unsqueeze(1)
133 | std = torch.exp(0.5 * mean_std[:, n_hidden:(2 * n_hidden), ...].unsqueeze(1))
134 | eps = torch.randn((batch_size, num_samples) + tuple(mean.size()[2:]), device=x.device)
135 | z = mean + std * eps
136 | log_p = - 0.5 * torch.prod(torch.tensor(z.size()[2:])) * np.log(2 * np.pi) \
137 | - torch.sum(torch.log(std) + 0.5 * torch.pow(eps, 2), list(range(2, z.dim())))
138 | return z, log_p
139 |
140 | def log_prob(self, z, x):
141 | """
142 | :param z: Primary random variable, first dimension is batch dimension
143 | :param x: Variable to condition on, first dimension is batch dimension
144 | :return: log probability of z given x
145 | """
146 | if z.dim() == 1:
147 | z = z.unsqueeze(0)
148 | if z.dim() == 2:
149 | z = z.unsqueeze(0)
150 | mean_std = self.net(x)
151 | n_hidden = mean_std.size()[1] // 2
152 | mean = mean_std[:, :n_hidden, ...].unsqueeze(1)
153 | var = torch.exp(mean_std[:, n_hidden:(2 * n_hidden), ...].unsqueeze(1))
154 | log_p = - 0.5 * torch.prod(torch.tensor(z.size()[2:])) * np.log(2 * np.pi) \
155 | - 0.5 * torch.sum(torch.log(var) + (z - mean) ** 2 / var, 2)
156 | return log_p
--------------------------------------------------------------------------------
/normflow/flows/neural_spline.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from .base import Flow
4 |
5 | # Try importing Neural Spline Flow dependencies
6 | try:
7 | from neural_spline_flows.nde.transforms.coupling import PiecewiseRationalQuadraticCouplingTransform
8 | from neural_spline_flows.nde.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform
9 | from neural_spline_flows.nn import ResidualNet
10 | from neural_spline_flows.utils import create_alternating_binary_mask
11 | except:
12 | print('Warning: Dependencies for Neural Spline Flows could '
13 | 'not be loaded. Other models can still be used.')
14 |
15 |
16 |
17 | class CoupledRationalQuadraticSpline(Flow):
18 | """
19 | Neural spline flow coupling layer, wrapper for the implementation
20 | of Durkan et al., see https://github.com/bayesiains/nsf
21 | """
22 | def __init__(
23 | self,
24 | num_input_channels,
25 | num_blocks,
26 | num_hidden_channels,
27 | num_bins=8,
28 | tail_bound=3,
29 | activation=nn.ReLU,
30 | dropout_probability=0.,
31 | reverse_mask=False,
32 | reverse=True
33 | ):
34 | """
35 | Constructor
36 | :param num_input_channels: Flow dimension
37 | :type num_input_channels: Int
38 | :param num_blocks: Number of residual blocks of the parameter NN
39 | :type num_blocks: Int
40 | :param num_hidden_channels: Number of hidden units of the NN
41 | :type num_hidden_channels: Int
42 | :param num_bins: Number of bins
43 | :type num_bins: Int
44 | :param tail_bound: Bound of the spline tails
45 | :type tail_bound: Int
46 | :param activation: Activation function
47 | :type activation: torch module
48 | :param dropout_probability: Dropout probability of the NN
49 | :type dropout_probability: Float
50 | :param reverse_mask: Flag whether the reverse mask should be used
51 | :type reverse_mask: Boolean
52 | :param reverse: Flag whether forward and backward pass shall be swapped
53 | :type reverse: Boolean
54 | """
55 | super().__init__()
56 | self.reverse = reverse
57 |
58 | def transform_net_create_fn(in_features, out_features):
59 | return ResidualNet(
60 | in_features=in_features,
61 | out_features=out_features,
62 | context_features=None,
63 | hidden_features=num_hidden_channels,
64 | num_blocks=num_blocks,
65 | activation=activation(),
66 | dropout_probability=dropout_probability,
67 | use_batch_norm=False
68 | )
69 |
70 | self.prqct=PiecewiseRationalQuadraticCouplingTransform(
71 | mask=create_alternating_binary_mask(
72 | num_input_channels,
73 | even=reverse_mask
74 | ),
75 | transform_net_create_fn=transform_net_create_fn,
76 | num_bins=num_bins,
77 | tails='linear',
78 | tail_bound=tail_bound,
79 |
80 | # Setting True corresponds to equations (4), (5), (6) in the NSF paper:
81 | apply_unconditional_transform=True
82 | )
83 |
84 | def forward(self, z):
85 | if self.reverse:
86 | z, log_det = self.prqct.inverse(z)
87 | else:
88 | z, log_det = self.prqct(z)
89 | return z, log_det.view(-1)
90 |
91 | def inverse(self, z):
92 | if self.reverse:
93 | z, log_det = self.prqct(z)
94 | else:
95 | z, log_det = self.prqct.inverse(z)
96 | return z, log_det.view(-1)
97 |
98 |
99 | class AutoregressiveRationalQuadraticSpline(Flow):
100 | """
101 | Neural spline flow coupling layer, wrapper for the implementation
102 | of Durkan et al., see https://github.com/bayesiains/nsf
103 | """
104 | def __init__(
105 | self,
106 | num_input_channels,
107 | num_blocks,
108 | num_hidden_channels,
109 | num_bins=8,
110 | tail_bound=3,
111 | activation=nn.ReLU,
112 | dropout_probability=0.,
113 | reverse=True
114 | ):
115 | """
116 | Constructor
117 | :param num_input_channels: Flow dimension
118 | :type num_input_channels: Int
119 | :param num_blocks: Number of residual blocks of the parameter NN
120 | :type num_blocks: Int
121 | :param num_hidden_channels: Number of hidden units of the NN
122 | :type num_hidden_channels: Int
123 | :param num_bins: Number of bins
124 | :type num_bins: Int
125 | :param tail_bound: Bound of the spline tails
126 | :type tail_bound: Int
127 | :param activation: Activation function
128 | :type activation: torch module
129 | :param dropout_probability: Dropout probability of the NN
130 | :type dropout_probability: Float
131 | :param reverse: Flag whether forward and backward pass shall be swapped
132 | :type reverse: Boolean
133 | """
134 | super().__init__()
135 | self.reverse = reverse
136 |
137 | self.mprqat=MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
138 | features=num_input_channels,
139 | hidden_features=num_hidden_channels,
140 | context_features=None,
141 | num_bins=num_bins,
142 | tails='linear',
143 | tail_bound=tail_bound,
144 | num_blocks=num_blocks,
145 | use_residual_blocks=True,
146 | random_mask=False,
147 | activation=activation(),
148 | dropout_probability=dropout_probability,
149 | use_batch_norm=False)
150 |
151 | def forward(self, z):
152 | if self.reverse:
153 | z, log_det = self.mprqat.inverse(z)
154 | else:
155 | z, log_det = self.mprqat(z)
156 | return z, log_det.view(-1)
157 |
158 | def inverse(self, z):
159 | if self.reverse:
160 | z, log_det = self.mprqat(z)
161 | else:
162 | z, log_det = self.mprqat.inverse(z)
163 | return z, log_det.view(-1)
--------------------------------------------------------------------------------
/autoencoder_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Autoencoder (to be used in generative network)
3 | """
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch
7 | from funknn_model import squeeze
8 |
9 |
10 | class autoencoder(nn.Module):
11 |
12 | def __init__(self, encoder=None, decoder=None):
13 | super(autoencoder, self).__init__()
14 |
15 | self.encoder = encoder
16 | self.decoder = decoder
17 |
18 |
19 | class encoder(nn.Module):
20 |
21 | def __init__(self, latent_dim=256, in_res=64, c=3):
22 | super(encoder, self).__init__()
23 |
24 | self.in_res = in_res
25 | self.c = c
26 | prev_ch = c
27 | c_last = 256
28 | CNNs = []
29 | CNNs_add = []
30 | num_layers = [64,128,128,256]
31 | for i in range(len(num_layers)):
32 | CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,3,
33 | padding = 'same'))
34 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
35 | padding = 'same'))
36 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
37 | padding = 'same'))
38 | prev_ch = num_layers[i]
39 |
40 | if in_res == 64:
41 | num_layers = [c_last]
42 | for i in range(len(num_layers)):
43 | CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,3,
44 | padding = 'same'))
45 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
46 | padding = 'same'))
47 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
48 | padding = 'same'))
49 | prev_ch = num_layers[i]
50 |
51 | if in_res == 128:
52 | num_layers = [256,c_last]
53 | for i in range(len(num_layers)):
54 | CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,3,
55 | padding = 'same'))
56 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
57 | padding = 'same'))
58 | CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
59 | padding = 'same'))
60 | prev_ch = num_layers[i]
61 |
62 |
63 |
64 |
65 | self.CNNs = nn.ModuleList(CNNs)
66 | self.CNNs_add = nn.ModuleList(CNNs_add)
67 | self.maxpool = nn.MaxPool2d(2, 2)
68 |
69 | feature_dim = 2 * 2 * c_last
70 | mlps = []
71 | mlps.append(nn.Linear(feature_dim, latent_dim))
72 |
73 | self.mlps = nn.ModuleList(mlps)
74 |
75 |
76 | def forward(self, x):
77 |
78 | x_skip = torch.mean(x , dim = 1, keepdim = True)
79 | for i in range(len(self.CNNs)):
80 | x = self.CNNs[i](x)
81 | x = F.relu(x)
82 | xm = self.maxpool(x)
83 | if i < 4:
84 |
85 | f = 2**(i+1)
86 | xm = xm + squeeze(x_skip , f).repeat_interleave(xm.shape[1]//(f**2) , dim = 1)
87 |
88 | x = self.CNNs_add[i*2](xm)
89 | x = F.relu(x)
90 | x = self.CNNs_add[i*2 + 1](x)
91 | x = F.relu(x)
92 | x = x + xm
93 |
94 | x = torch.flatten(x, 1)
95 | for i in range(len(self.mlps)-1):
96 | x = self.mlps[i](x)
97 | x = F.relu(x)
98 |
99 | x = self.mlps[-1](x)
100 |
101 | return x
102 |
103 |
104 | class decoder(nn.Module):
105 |
106 | def __init__(self, latent_dim=256, in_res=64, c=3):
107 | super(decoder, self).__init__()
108 |
109 | self.in_res = in_res
110 | self.c = c
111 | prev_ch = 256
112 | t_CNNs = []
113 | CNNs = []
114 |
115 | if in_res == 128:
116 | num_layers = [256,256,128,128,64,self.c]
117 | for i in range(len(num_layers)):
118 | # t_CNNs.append(nn.ConvTranspose2d(prev_ch, num_layers[i] ,3,
119 | # stride=2,padding = 1, output_padding=1))
120 | c_inter = 64 if num_layers[i] == self.c else num_layers[i]
121 | t_CNNs.append(nn.Conv2d(prev_ch, c_inter ,3,
122 | padding = 'same'))
123 | CNNs.append(nn.Conv2d(c_inter, c_inter ,3,
124 | padding = 'same'))
125 | CNNs.append(nn.Conv2d(c_inter, num_layers[i] ,3,
126 | padding = 'same'))
127 | prev_ch = num_layers[i]
128 |
129 | elif in_res == 64:
130 |
131 | num_layers = [256,128,128,64,self.c]
132 | for i in range(len(num_layers)):
133 |
134 | c_inter = 64 if num_layers[i] == self.c else num_layers[i]
135 | t_CNNs.append(nn.Conv2d(prev_ch, c_inter ,3,
136 | padding = 'same'))
137 | CNNs.append(nn.Conv2d(c_inter, c_inter ,3,
138 | padding = 'same'))
139 | CNNs.append(nn.Conv2d(c_inter, num_layers[i] ,3,
140 | padding = 'same'))
141 | prev_ch = num_layers[i]
142 |
143 |
144 |
145 | self.t_CNNs = nn.ModuleList(t_CNNs)
146 | self.CNNs = nn.ModuleList(CNNs)
147 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
148 |
149 | self.feature_dim = 2 * 2 * 256
150 | mlps = []
151 | mlps.append(nn.Linear(latent_dim , self.feature_dim))
152 |
153 | self.mlps = nn.ModuleList(mlps)
154 |
155 | def forward(self, x):
156 |
157 | for i in range(len(self.mlps)):
158 | x = self.mlps[i](x)
159 | x = F.relu(x)
160 | # x = squeeze(x , 16)
161 | b = x.shape[0]
162 | x = x.reshape([b, 256, 2, 2])
163 |
164 | for i in range(len(self.t_CNNs)-1):
165 | x = self.upsample(x)
166 | x = self.t_CNNs[i](x)
167 | xr = F.relu(x)
168 | x = self.CNNs[i*2](xr)
169 | x = F.relu(x)
170 | x = self.CNNs[i*2+1](x)
171 | x = F.relu(x)
172 | x = x + xr
173 |
174 | x = self.upsample(x)
175 | x = self.t_CNNs[-1](x)
176 | x = F.relu(x)
177 | x = self.CNNs[-2](x)
178 | x = F.relu(x)
179 | x = self.CNNs[-1](x)
180 |
181 | return x
182 |
183 |
--------------------------------------------------------------------------------
/normflow/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 | from . import flows
6 |
7 | # Try importing ResNet dependencies
8 | try:
9 | from residual_flows.layers.base import InducedNormLinear, InducedNormConv2d
10 | except:
11 | print('Warning: Dependencies for Residual Networks could '
12 | 'not be loaded. Other models can still be used.')
13 |
14 |
15 | def set_requires_grad(module, flag):
16 | """
17 | Sets requires_grad flag of all parameters of a torch.nn.module
18 | :param module: torch.nn.module
19 | :param flag: Flag to set requires_grad to
20 | """
21 |
22 | for param in module.parameters():
23 | param.requires_grad = flag
24 |
25 |
26 | class ConstScaleLayer(nn.Module):
27 | """
28 | Scaling features by a fixed factor
29 | """
30 | def __init__(self, scale=1.):
31 | """
32 | Constructor
33 | :param scale: Scale to apply to features
34 | """
35 | super().__init__()
36 | self.scale_cpu = torch.tensor(scale)
37 | self.register_buffer("scale", self.scale_cpu)
38 |
39 | def forward(self, input):
40 | return input * self.scale
41 |
42 |
43 | class ActNorm(nn.Module):
44 | """
45 | ActNorm layer with just one forward pass
46 | """
47 | def __init__(self, shape, logscale_factor=None):
48 | """
49 | Constructor
50 | :param shape: Same as shape in flows.ActNorm
51 | :param logscale_factor: Same as shape in flows.ActNorm
52 | """
53 | super().__init__()
54 | self.actNorm = flows.ActNorm(shape, logscale_factor=logscale_factor)
55 |
56 | def forward(self, input):
57 | out, _ = self.actNorm(input)
58 | return out
59 |
60 |
61 | # Dataset transforms
62 |
63 | class Logit():
64 | """
65 | Transform for dataloader
66 | logit(alpha + (1 - alpha) * x) where logit(x) = log(x / (1 - x))
67 | """
68 | def __init__(self, alpha=0):
69 | """
70 | Constructor
71 | :param alpha: see above
72 | """
73 | self.alpha = alpha
74 |
75 | def __call__(self, x):
76 | x_ = self.alpha + (1 - self.alpha) * x
77 | return torch.log(x_ / (1 - x_))
78 |
79 | def inverse(self, x):
80 | return (torch.sigmoid(x) - self.alpha) / (1 - self.alpha)
81 |
82 |
83 | class Jitter():
84 | """
85 | Transform for dataloader
86 | Adds uniform jitter noise to data
87 | """
88 | def __init__(self, scale=1./256):
89 | """
90 | Constructor
91 | :param scale: Scaling factor for noise
92 | """
93 | self.scale = scale
94 |
95 | def __call__(self, x):
96 | eps = torch.rand_like(x) * self.scale
97 | x_ = x + eps
98 | return x_
99 |
100 |
101 | class Scale():
102 | """
103 | Transform for dataloader
104 | Adds uniform jitter noise to data
105 | """
106 | def __init__(self, scale=255./256.):
107 | """
108 | Constructor
109 | :param scale: Scaling factor for noise
110 | """
111 | self.scale = scale
112 |
113 | def __call__(self, x):
114 | return x * self.scale
115 |
116 |
117 | # Nonlinearities
118 |
119 | class ClampExp(torch.nn.Module):
120 | """
121 | Nonlinearity min(exp(lam * x), 1)
122 | """
123 | def __init__(self):
124 | """
125 | Constructor
126 | :param lam: Lambda parameter
127 | """
128 | super(ClampExp, self).__init__()
129 |
130 | def forward(self, x):
131 | one = torch.tensor(1., device=x.device, dtype=x.dtype)
132 | return torch.min(torch.exp(x), one)
133 |
134 |
135 | # Functions for model analysis
136 |
137 | def bitsPerDim(model, x, y=None, trans='logit', trans_param=[0.05]):
138 | """
139 | Computes the bits per dim for a batch of data
140 | :param model: Model to compute bits per dim for
141 | :param x: Batch of data
142 | :param y: Class labels for batch of data if base distribution is class conditional
143 | :param trans: Transformation to be applied to images during training
144 | :param trans_param: List of parameters of the transformation
145 | :return: Bits per dim for data batch under model
146 | """
147 | dims = torch.prod(torch.tensor(x.size()[1:]))
148 | if trans == 'logit':
149 | if y is None:
150 | log_q = model.log_prob(x)
151 | else:
152 | log_q = model.log_prob(x, y)
153 | sum_dims = list(range(1, x.dim()))
154 | ls = torch.nn.LogSigmoid()
155 | sig_ = torch.sum(ls(x) / np.log(2), sum_dims)
156 | sig_ += torch.sum(ls(-x) / np.log(2), sum_dims)
157 | b = - log_q / dims / np.log(2) - np.log2(1 - trans_param[0]) + 8
158 | b += sig_ / dims
159 | else:
160 | raise NotImplementedError('The transformation ' + trans + ' is not implemented.')
161 | return b
162 |
163 |
164 | def bitsPerDimDataset(model, data_loader, class_cond=True, trans='logit',
165 | trans_param=[0.05]):
166 | """
167 | Computes average bits per dim for an entire dataset given by a data loader
168 | :param model: Model to compute bits per dim for
169 | :param data_loader: Data loader of dataset
170 | :param class_cond: Flag indicating whether model is class_conditional
171 | :param trans: Transformation to be applied to images during training
172 | :param trans_param: List of parameters of the transformation
173 | :return: Average bits per dim for dataset
174 | """
175 | n = 0
176 | b_cum = 0
177 | with torch.no_grad():
178 | for x, y in iter(data_loader):
179 | b_ = bitsPerDim(model, x, y.to(x.device) if class_cond else None,
180 | trans, trans_param)
181 | b_np = b_.to('cpu').numpy()
182 | b_cum += np.nansum(b_np)
183 | n += len(x) - np.sum(np.isnan(b_np))
184 | b = b_cum / n
185 | return b
186 |
187 |
188 | def clear_grad(model):
189 | """
190 | Set gradients of model parameter to None as this speeds up training,
191 | see https://www.youtube.com/watch?v=9mS1fIYj1So
192 | :param model: Model to clear gradients of
193 | """
194 | for param in model.parameters():
195 | param.grad = None
196 |
197 |
198 | def update_lipschitz(model, n_iterations):
199 | for m in model.modules():
200 | if isinstance(m, InducedNormConv2d) or isinstance(m, InducedNormLinear):
201 | m.compute_weight(update=True, n_iterations=n_iterations)
202 |
--------------------------------------------------------------------------------
/train_funknn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from timeit import default_timer
5 | from torch.optim import Adam
6 | import os
7 | import matplotlib.pyplot as plt
8 | from funknn_model import FunkNN
9 | from utils import *
10 | from datasets import *
11 | from results import evaluator
12 | import config_funknn as config
13 |
14 | torch.manual_seed(0)
15 | np.random.seed(0)
16 |
17 | epochs_funknn = config.epochs_funknn
18 | batch_size = config.batch_size
19 | dataset = config.dataset
20 | gpu_num = config.gpu_num
21 | exp_desc = config.exp_desc
22 | image_size = config.image_size
23 | c = config.c
24 | train_funknn = config.train_funknn
25 | training_mode = config.training_mode
26 | ood_analysis = config.ood_analysis
27 |
28 | enable_cuda = True
29 | device = torch.device('cuda:' + str(gpu_num) if torch.cuda.is_available() and enable_cuda else 'cpu')
30 |
31 | all_experiments = 'experiments/'
32 | if os.path.exists(all_experiments) == False:
33 | os.mkdir(all_experiments)
34 |
35 | # experiment path
36 | exp_path = all_experiments + 'funknn_' + dataset + '_' \
37 | + str(image_size) + '_' + training_mode + '_' + exp_desc
38 |
39 |
40 | if os.path.exists(exp_path) == False:
41 | os.mkdir(exp_path)
42 |
43 |
44 | learning_rate = 1e-4
45 | step_size = 50
46 | gamma = 0.5
47 | # myloss = F.mse_loss
48 | myloss = F.l1_loss
49 | num_batch_pixels = 3 # The number of iterations over each batch
50 | batch_pixels = 512 # Number of pixels to optimize in each iteration
51 | k = 2 # super resolution factor for training
52 |
53 | # Print the experiment setup:
54 | print('Experiment setup:')
55 | print('---> epochs_funknn: {}'.format(epochs_funknn))
56 | print('---> batch_size: {}'.format(batch_size))
57 | print('---> dataset: {}'.format(dataset))
58 | print('---> Learning rate: {}'.format(learning_rate))
59 | print('---> experiment path: {}'.format(exp_path))
60 | print('---> image size: {}'.format(image_size))
61 |
62 | # Dataset:
63 | train_dataset = Dataset_loader(dataset = 'train' ,size = (image_size,image_size), c = c)
64 | test_dataset = Dataset_loader(dataset = 'test' ,size = (config.max_scale*image_size,config.max_scale*image_size), c = c)
65 |
66 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=32, shuffle = True)
67 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, num_workers=32)
68 |
69 | ntrain = len(train_loader.dataset)
70 | n_test = len(test_loader.dataset)
71 |
72 | n_ood = 0
73 | if ood_analysis:
74 | ood_dataset = Dataset_loader(dataset = 'ood',size = (2*image_size,2*image_size), c = c)
75 | ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=300, num_workers=32)
76 | n_ood= len(ood_loader.dataset)
77 |
78 | print('---> Number of training, test and ood samples: {}, {}, {}'.format(ntrain,n_test, n_ood))
79 |
80 | # Loading model
81 | plot_per_num_epoch = 10 if ntrain > 10000 else 10000//ntrain
82 | model = FunkNN(c=c).to(device)
83 | # model = torch.nn.DataParallel(model) # Using multiple GPUs
84 | num_param_funknn = count_parameters(model)
85 | print('---> Number of trainable parameters of funknn: {}'.format(num_param_funknn))
86 |
87 | optimizer_funknn = Adam(model.parameters(), lr=learning_rate)
88 | scheduler_funknn = torch.optim.lr_scheduler.StepLR(optimizer_funknn, step_size=step_size, gamma=gamma)
89 |
90 | checkpoint_exp_path = os.path.join(exp_path, 'funknn.pt')
91 | if os.path.exists(checkpoint_exp_path) and config.restore_funknn:
92 | checkpoint_funknn = torch.load(checkpoint_exp_path)
93 | model.load_state_dict(checkpoint_funknn['model_state_dict'])
94 | optimizer_funknn.load_state_dict(checkpoint_funknn['optimizer_state_dict'])
95 | print('funknn is restored...')
96 |
97 | if train_funknn:
98 | print('Training...')
99 |
100 | if plot_per_num_epoch == -1:
101 | plot_per_num_epoch = epochs_funknn + 1 # only plot in the last epoch
102 |
103 | loss_funknn_plot = np.zeros([epochs_funknn])
104 | for ep in range(epochs_funknn):
105 | model.train()
106 | t1 = default_timer()
107 | loss_funknn_epoch = 0
108 |
109 | for image in train_loader:
110 |
111 | batch_size = image.shape[0]
112 | image = image.to(device)
113 |
114 | for i in range(num_batch_pixels):
115 | image_mat = image.reshape(-1, image_size, image_size, c).permute(0,3,1,2)
116 |
117 | image_high, image_low, image_size_high = training_strategy(image_mat, image_size, factor = k , mode = training_mode)
118 | coords = get_mgrid(image_size_high).reshape(-1, 2)
119 | coords = torch.unsqueeze(coords, dim = 0)
120 | coords = coords.expand(batch_size , -1, -1).to(device)
121 |
122 | image_high = image_high.permute(0,2,3,1).reshape(-1, image_size_high * image_size_high, c)
123 | optimizer_funknn.zero_grad()
124 | pixels = np.random.randint(low = 0, high = image_size_high**2, size = batch_pixels)
125 | batch_coords = coords[:,pixels]
126 | batch_image = image_high[:,pixels]
127 |
128 | out = model(batch_coords, image_low)
129 | mse_loss = myloss(out.reshape(batch_size, -1) , batch_image.reshape(batch_size, -1) )
130 | total_loss = mse_loss
131 | total_loss.backward()
132 | optimizer_funknn.step()
133 | loss_funknn_epoch += total_loss.item()
134 |
135 | scheduler_funknn.step()
136 | t2 = default_timer()
137 | loss_funknn_epoch/= ntrain
138 | loss_funknn_plot[ep] = loss_funknn_epoch
139 |
140 | plt.plot(np.arange(epochs_funknn)[:ep] , loss_funknn_plot[:ep], 'o-', linewidth=2)
141 | plt.title('FunkNN_loss')
142 | plt.xlabel('epoch')
143 | plt.ylabel('MSE loss')
144 | plt.savefig(os.path.join(exp_path, 'funknn_loss.jpg'))
145 | np.save(os.path.join(exp_path, 'funknn_loss.npy'), loss_funknn_plot[:ep])
146 | plt.close()
147 |
148 | torch.save({
149 | 'model_state_dict': model.state_dict(),
150 | 'optimizer_state_dict': optimizer_funknn.state_dict()
151 | }, checkpoint_exp_path)
152 |
153 | print('ep: {}/{} | time: {:.0f} | FunkNN_loss {:.6f} '.format(ep, epochs_funknn, t2-t1,loss_funknn_epoch))
154 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
155 | file.write('ep: {}/{} | time: {:.0f} | FunkNN_loss {:.6f} '.format(ep, epochs_funknn, t2-t1,loss_funknn_epoch))
156 | file.write('\n')
157 |
158 | if ep % plot_per_num_epoch == 0 or (ep + 1) == epochs_funknn:
159 |
160 | evaluator(ep = ep, subset = 'test', data_loader = test_loader, model = model, exp_path = exp_path)
161 | if ood_analysis:
162 | evaluator(ep = ep, subset = 'ood', data_loader = ood_loader, model = model, exp_path = exp_path)
163 |
164 | print('Evaluating...')
165 | evaluator(ep = -1, subset = 'test', data_loader = test_loader, model = model, exp_path = exp_path)
166 | if ood_analysis:
167 | evaluator(ep = -1, subset = 'ood', data_loader = ood_loader, model = model, exp_path = exp_path)
168 |
169 |
170 |
171 |
172 |
173 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: torch12
2 | channels:
3 | - jmcmurray
4 | - pytorch
5 | - astra-toolbox
6 | - anaconda
7 | - conda-forge
8 | - defaults
9 | dependencies:
10 | - _libgcc_mutex=0.1=conda_forge
11 | - _openmp_mutex=4.5=2_kmp_llvm
12 | - absl-py=0.15.0=pyhd3eb1b0_0
13 | - alsa-lib=1.2.7.2=h166bdaf_0
14 | - astra-toolbox=2.1.2=py39h4319877_0
15 | - asttokens=2.0.5=pyhd3eb1b0_0
16 | - attr=2.5.1=h166bdaf_1
17 | - attrs=22.1.0=pyh71513ae_1
18 | - backcall=0.2.0=pyhd3eb1b0_0
19 | - blas=1.0=mkl
20 | - brotli=1.0.9=h166bdaf_7
21 | - brotli-bin=1.0.9=h166bdaf_7
22 | - brotlipy=0.7.0=py39hb9d737c_1004
23 | - bzip2=1.0.8=h7f98852_4
24 | - ca-certificates=2022.9.24=ha878542_0
25 | - certifi=2022.9.24=pyhd8ed1ab_0
26 | - cffi=1.15.1=py39he91dace_0
27 | - charset-normalizer=2.1.1=pyhd8ed1ab_0
28 | - cloudpickle=2.0.0=pyhd3eb1b0_0
29 | - colorama=0.4.5=pyhd8ed1ab_0
30 | - contextlib2=21.6.0=pyhd8ed1ab_0
31 | - contourpy=1.0.5=py39hf939315_0
32 | - cryptography=37.0.4=py39hd97740a_0
33 | - cudatoolkit=11.6.0=hecad31d_10
34 | - cycler=0.11.0=pyhd8ed1ab_0
35 | - cytoolz=0.11.0=py39h27cfd23_0
36 | - dask-core=2022.7.0=py39h06a4308_0
37 | - dbus=1.13.6=h5008d03_3
38 | - decorator=5.1.1=pyhd3eb1b0_0
39 | - execnet=1.9.0=pyhd8ed1ab_0
40 | - executing=0.8.3=pyhd3eb1b0_0
41 | - expat=2.4.9=h27087fc_0
42 | - ffmpeg=4.3=hf484d3e_0
43 | - fftw=3.3.10=nompi_hf0379b8_105
44 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
45 | - font-ttf-inconsolata=3.000=h77eed37_0
46 | - font-ttf-source-code-pro=2.038=h77eed37_0
47 | - font-ttf-ubuntu=0.83=hab24e00_0
48 | - fontconfig=2.14.0=hc2a2eb6_1
49 | - fonts-conda-ecosystem=1=0
50 | - fonts-conda-forge=1=0
51 | - fonttools=4.37.3=py39hb9d737c_0
52 | - freetype=2.12.1=hca18f0e_0
53 | - fsspec=2022.7.1=py39h06a4308_0
54 | - gettext=0.19.8.1=h73d1719_1008
55 | - glib=2.72.1=h6239696_0
56 | - glib-tools=2.72.1=h6239696_0
57 | - gmp=6.2.1=h58526e2_0
58 | - gnutls=3.6.13=h85f3911_1
59 | - gst-plugins-base=1.20.3=h57caac4_2
60 | - gstreamer=1.20.3=hd4edc92_2
61 | - icu=70.1=h27087fc_0
62 | - idna=3.4=pyhd8ed1ab_0
63 | - imageio=2.22.0=pyhfa7a67d_0
64 | - iniconfig=1.1.1=pyh9f0ad1d_0
65 | - ipdb=0.13.9=pyhd8ed1ab_0
66 | - ipython=8.4.0=py39h06a4308_0
67 | - jack=1.9.18=h8c3723f_1003
68 | - jedi=0.18.1=py39h06a4308_1
69 | - jpeg=9e=h166bdaf_2
70 | - keyutils=1.6.1=h166bdaf_0
71 | - kiwisolver=1.4.4=py39hf939315_0
72 | - krb5=1.19.3=h3790be6_0
73 | - lame=3.100=h7f98852_1001
74 | - lcms2=2.12=hddcbb42_0
75 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
76 | - lerc=4.0.0=h27087fc_0
77 | - libastra=2.1.2=cuda_11.6_0
78 | - libbrotlicommon=1.0.9=h166bdaf_7
79 | - libbrotlidec=1.0.9=h166bdaf_7
80 | - libbrotlienc=1.0.9=h166bdaf_7
81 | - libcap=2.65=ha37c62d_0
82 | - libclang=14.0.6=default_h2e3cab8_0
83 | - libclang13=14.0.6=default_h3a83d3e_0
84 | - libcups=2.3.3=h3e49a29_2
85 | - libdb=6.2.32=h9c3ff4c_0
86 | - libdeflate=1.14=h166bdaf_0
87 | - libedit=3.1.20191231=he28a2e2_2
88 | - libevent=2.1.10=h9b69904_4
89 | - libffi=3.4.2=h7f98852_5
90 | - libflac=1.3.4=h27087fc_0
91 | - libgcc-ng=12.1.0=h8d9b700_16
92 | - libgfortran-ng=12.1.0=h69a702a_16
93 | - libgfortran5=12.1.0=hdcd56e2_16
94 | - libglib=2.72.1=h2d90d5f_0
95 | - libiconv=1.16=h516909a_0
96 | - libllvm14=14.0.6=he0ac6c6_0
97 | - libnsl=2.0.0=h7f98852_0
98 | - libogg=1.3.4=h7f98852_1
99 | - libopus=1.3.1=h7f98852_1
100 | - libpng=1.6.38=h753d276_0
101 | - libpq=14.5=hd77ab85_0
102 | - libsndfile=1.0.31=h9c3ff4c_1
103 | - libsqlite=3.39.3=h753d276_0
104 | - libstdcxx-ng=12.1.0=ha89aaad_16
105 | - libtiff=4.4.0=h55922b4_4
106 | - libtool=2.4.6=h9c3ff4c_1008
107 | - libudev1=249=h166bdaf_4
108 | - libuuid=2.32.1=h7f98852_1000
109 | - libvorbis=1.3.7=h9c3ff4c_0
110 | - libwebp-base=1.2.4=h166bdaf_0
111 | - libxcb=1.13=h7f98852_1004
112 | - libxkbcommon=1.0.3=he3ba5ed_0
113 | - libxml2=2.9.14=h22db469_4
114 | - libzlib=1.2.12=h166bdaf_3
115 | - llvm-openmp=14.0.4=he0ac6c6_0
116 | - locket=1.0.0=py39h06a4308_0
117 | - matplotlib=3.6.0=py39hf3d152e_0
118 | - matplotlib-base=3.6.0=py39hf9fd14e_0
119 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2
120 | - mkl=2021.4.0=h8d4b97c_729
121 | - mkl-service=2.4.0=py39h7e14d7c_0
122 | - mkl_fft=1.3.1=py39h0c7bc48_1
123 | - mkl_random=1.2.2=py39hde0f152_0
124 | - mock=4.0.3=py39hf3d152e_3
125 | - munkres=1.1.4=pyh9f0ad1d_0
126 | - mysql-common=8.0.30=haf5c9bc_1
127 | - mysql-libs=8.0.30=h28c427c_1
128 | - ncurses=6.3=h27087fc_1
129 | - nettle=3.6=he412f7d_0
130 | - networkx=2.8.4=py39h06a4308_0
131 | - nspr=4.32=h9c3ff4c_1
132 | - nss=3.78=h2350873_0
133 | - numpy=1.22.3=py39he7a7128_0
134 | - numpy-base=1.22.3=py39hf524024_0
135 | - openh264=2.1.1=h780b84a_0
136 | - openjpeg=2.5.0=h7d73246_1
137 | - openssl=1.1.1s=h166bdaf_0
138 | - os=0.1.4=0
139 | - packaging=21.3=pyhd8ed1ab_0
140 | - parso=0.8.3=pyhd3eb1b0_0
141 | - partd=1.2.0=pyhd3eb1b0_1
142 | - path=16.4.0=py39hf3d152e_1
143 | - path.py=12.5.0=0
144 | - pcre=8.45=h9c3ff4c_0
145 | - pexpect=4.8.0=pyhd3eb1b0_3
146 | - pickleshare=0.7.5=pyhd3eb1b0_1003
147 | - pillow=9.2.0=py39hd5dbb17_2
148 | - pip=22.2.2=pyhd8ed1ab_0
149 | - pluggy=1.0.0=py39hf3d152e_3
150 | - ply=3.11=py_1
151 | - portaudio=19.6.0=h8e90077_6
152 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0
153 | - pthread-stubs=0.4=h36c2ea0_1001
154 | - ptyprocess=0.7.0=pyhd3eb1b0_2
155 | - pulseaudio=14.0=h0868958_9
156 | - pure_eval=0.2.2=pyhd3eb1b0_0
157 | - py=1.11.0=pyh6c4a22f_0
158 | - pycparser=2.21=pyhd8ed1ab_0
159 | - pygments=2.11.2=pyhd3eb1b0_0
160 | - pyopenssl=22.0.0=pyhd8ed1ab_0
161 | - pyparsing=3.0.9=pyhd8ed1ab_0
162 | - pyqt=5.15.7=py39h18e9c17_0
163 | - pyqt5-sip=12.11.0=py39h5a03fae_0
164 | - pysocks=1.7.1=pyha2e5f31_6
165 | - pytest=7.1.3=py39hf3d152e_0
166 | - pytest-shutil=1.7.0=pyhd8ed1ab_1
167 | - python=3.9.13=h9a8a25e_0_cpython
168 | - python-dateutil=2.8.2=pyhd8ed1ab_0
169 | - python-lmdb=1.3.0=py39h5a03fae_1
170 | - python_abi=3.9=2_cp39
171 | - pytorch=1.12.1=py3.9_cuda11.6_cudnn8.3.2_0
172 | - pytorch-mutex=1.0=cuda
173 | - pywavelets=1.3.0=py39h7f8727e_0
174 | - pyyaml=6.0=py39h7f8727e_1
175 | - qt-main=5.15.6=hc525480_0
176 | - qutil=3.2.1=6
177 | - readline=8.1.2=h0f457ee_0
178 | - requests=2.28.1=pyhd8ed1ab_1
179 | - scikit-image=0.16.2=py39ha9443f7_0
180 | - scipy=1.7.3=py39h6c91a56_2
181 | - setuptools=65.3.0=pyhd8ed1ab_1
182 | - sip=6.6.2=py39h5a03fae_0
183 | - six=1.16.0=pyh6c4a22f_0
184 | - sqlite=3.39.3=h4ff8645_0
185 | - stack_data=0.2.0=pyhd3eb1b0_0
186 | - tbb=2021.5.0=h924138e_2
187 | - termcolor=2.0.1=pyhd8ed1ab_1
188 | - tk=8.6.12=h27826a3_0
189 | - toml=0.10.2=pyhd8ed1ab_0
190 | - tomli=2.0.1=pyhd8ed1ab_0
191 | - toolz=0.11.2=pyhd3eb1b0_0
192 | - torchaudio=0.12.1=py39_cu116
193 | - torchvision=0.13.1=py39_cu116
194 | - tornado=6.2=py39hb9d737c_0
195 | - tqdm=4.64.1=pyhd8ed1ab_0
196 | - traitlets=5.1.1=pyhd3eb1b0_0
197 | - typing_extensions=4.3.0=pyha770c72_0
198 | - tzdata=2022c=h191b570_0
199 | - unicodedata2=14.0.0=py39hb9d737c_1
200 | - urllib3=1.26.11=pyhd8ed1ab_0
201 | - wcwidth=0.2.5=pyhd3eb1b0_0
202 | - wheel=0.37.1=pyhd8ed1ab_0
203 | - xcb-util=0.4.0=h166bdaf_0
204 | - xcb-util-image=0.4.0=h166bdaf_0
205 | - xcb-util-keysyms=0.4.0=h166bdaf_0
206 | - xcb-util-renderutil=0.3.9=h166bdaf_0
207 | - xcb-util-wm=0.4.1=h166bdaf_0
208 | - xorg-libxau=1.0.9=h7f98852_0
209 | - xorg-libxdmcp=1.1.3=h7f98852_0
210 | - xz=5.2.6=h166bdaf_0
211 | - yaml=0.2.5=h7b6447c_0
212 | - zlib=1.2.12=h166bdaf_3
213 | - zstd=1.5.2=h6239696_4
214 | - pip:
215 | - future==0.18.2
216 | - geomloss==0.2.5
217 | - keopscore==2.1
218 | - normflows==1.4
219 | - odl==1.0.0.dev0
220 | - pybind11==2.10.0
221 | - pykeops==2.1
222 | - torch-summary==1.4.5
223 | prefix: /users/staff/dmi-dmi/debarn0000/anaconda3/envs/torch12
224 |
--------------------------------------------------------------------------------
/ops/traveltime_lib.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | EPS = 1e-12
10 |
11 | def sense_2(sensors, img):
12 |
13 | img_size = img.shape[0]
14 |
15 | # make a line
16 | x1 = sensors[0, 0]
17 | y1 = sensors[0, 1]
18 |
19 | x2 = sensors[1, 0]
20 | y2 = sensors[1, 1]
21 |
22 | slope = (y2 - y1) / (x2 - x1)
23 |
24 | per_pixel_width = 1.0/img_size
25 | n_pts_x = int(np.abs(x2-x1)/per_pixel_width)
26 | n_pts_y = int(np.abs(y2-y1)/per_pixel_width)
27 |
28 |
29 | intersect_vert = None
30 | intersect_horz = None
31 |
32 | if n_pts_x > 0:
33 | xs = x1 + np.arange(1, n_pts_x + 1) * per_pixel_width * np.sign(x2-x1)
34 | ys = y2 - slope * (x2 - xs)
35 | intersect_vert = np.stack((xs, ys), axis=-1)
36 |
37 | if n_pts_y > 0:
38 | ys = y1 + np.arange(1, n_pts_y + 1) * per_pixel_width * np.sign(y2-y1)
39 | xs = x2 - (y2 - ys) / slope
40 | intersect_horz = np.stack((xs, ys), axis=-1)
41 |
42 | all_pts = np.concatenate((sensors, intersect_horz, intersect_vert), axis=0)
43 | idx= np.argsort(all_pts, axis=0)[:, 0]
44 | all_pts = all_pts[idx] # Sorts acc. to x coordinate.
45 |
46 | # find line segment length x pixels
47 | midpoints = (all_pts[:-1] + all_pts[1:])/2
48 | lengths = np.linalg.norm(all_pts[:-1] - all_pts[1:], axis=-1)
49 |
50 | # pick correct pixels x line segment length
51 | midpoints = np.clip((midpoints + 1) / 2.0, 0, 1-EPS)
52 | pixel_idx = np.floor(midpoints/per_pixel_width).astype(np.int32)
53 | pixel_intensities = img[pixel_idx[:, 0], pixel_idx[:, 1]]
54 |
55 |
56 | return np.sum(lengths * pixel_intensities)
57 |
58 |
59 | class tt_2sensors(nn.Module):
60 | """Traveltime for 2 sensors"""
61 | def __init__(self, sensors, img_size):
62 | super(tt_2sensors, self).__init__()
63 | self.lengths, self.idx = self.__build(sensors, img_size)
64 |
65 | @staticmethod
66 | def __build(sensors, img_size):
67 | if isinstance(sensors, np.ndarray):
68 | sensors = torch.from_numpy(sensors.astype(np.float32)).cuda()
69 |
70 | # make a line
71 | x1 = sensors[0, 0]
72 | y1 = sensors[0, 1]
73 |
74 | x2 = sensors[1, 0]
75 | y2 = sensors[1, 1]
76 |
77 | slope = (y2 - y1) / (x2 - x1)
78 |
79 | per_pixel_width = 1.0/img_size
80 | n_pts_x = torch.abs(x2-x1)/per_pixel_width
81 | n_pts_x = n_pts_x.type(torch.int)
82 | n_pts_y = torch.abs(y2-y1)/per_pixel_width
83 | n_pts_y = n_pts_y.type(torch.int)
84 |
85 |
86 | intersect_vert = None
87 | intersect_horz = None
88 |
89 | if n_pts_x > 0:
90 | xs = x1 + torch.arange(
91 | 1, n_pts_x + 1, device='cuda') * per_pixel_width * torch.sign(x2-x1)
92 | ys = y2 - slope * (x2 - xs)
93 | intersect_vert = torch.stack((xs, ys), dim=-1)
94 |
95 | if n_pts_y > 0:
96 | ys = y1 + torch.arange(
97 | 1, n_pts_y + 1, device='cuda') * per_pixel_width * torch.sign(y2-y1)
98 | xs = x2 - (y2 - ys) / slope
99 | intersect_horz = torch.stack((xs, ys), dim=-1)
100 |
101 | all_pts = sensors.clone().cuda()
102 | if intersect_horz is not None:
103 | all_pts = torch.cat((sensors, intersect_horz), dim=0)
104 | if intersect_vert is not None:
105 | all_pts = torch.cat((all_pts, intersect_vert), dim=0)
106 |
107 | idx = torch.argsort(all_pts, dim=0)[:, 0]
108 | all_pts = all_pts[idx] # Sorts acc. to x coordinate.
109 |
110 | # find line segment length x pixels
111 | midpoints = (all_pts[:-1] + all_pts[1:])/2
112 | lengths = torch.norm(all_pts[:-1] - all_pts[1:], dim=-1)
113 |
114 | # pick correct pixels x line segment length
115 | midpoints = torch.clip((midpoints + 1) / 2.0, 0, 1-EPS)
116 | pixel_idx = torch.floor(midpoints/per_pixel_width).type(
117 | torch.cuda.LongTensor)
118 |
119 | return lengths, pixel_idx
120 |
121 |
122 | def forward(self, img):
123 | pixel_intensities = img[self.idx[:,0], self.idx[:, 1]]
124 | return torch.sum(pixel_intensities * self.lengths)
125 |
126 |
127 | class TravelTimeOperator(nn.Module):
128 | """Builds the linear traveltime tomography operator."""
129 |
130 | def __init__(self, sensors, img_size):
131 | """Initializes a linear TT operator.
132 | Args:
133 | sensors (torch.Tensor): Locations of sensors [-1, 1]
134 | img_size (int): Size of the image
135 | """
136 | super(TravelTimeOperator, self).__init__()
137 | self.sensors = sensors
138 | self.lengths, self.idx, self.nelems = self._build(sensors, img_size)
139 | self.optimizable_params = sensors
140 |
141 |
142 | @staticmethod
143 | def _build(sensors, img_size):
144 | N = len(sensors)
145 | lengths = []
146 | nelems = []
147 | idx = []
148 |
149 | for i in range(N-1):
150 | for j in range(i+1, N):
151 | sense2 = torch.stack((sensors[i], sensors[j]), axis=0)
152 | row = tt_2sensors(sense2, img_size)
153 | lengths.append(row.lengths)
154 | idx.append(row.idx)
155 | nelems.append(len(row.lengths))
156 |
157 | return torch.cat(lengths), torch.cat(idx, dim=0), nelems
158 |
159 |
160 |
161 | def forward(self, img):
162 | pixel_intensities = img[self.idx[:,0], self.idx[:, 1]]
163 | y_measured = pixel_intensities * self.lengths
164 | measured_chunks = torch.split(y_measured, self.nelems)
165 | y_measured = torch.stack([
166 | torch.sum(c) for c in measured_chunks]
167 | )
168 |
169 | return y_measured
170 |
171 |
172 | def time_op():
173 | print('\n Timing the op')
174 | from time import time
175 | IMG_SIZE = 64
176 | NS = 50
177 | CVAL = 0.3
178 | img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.float32) * CVAL
179 |
180 | sensors = np.random.rand(NS, 2)
181 | sensors = torch.tensor(sensors, requires_grad=True).cuda()
182 |
183 | A = TravelTimeOperator(sensors, IMG_SIZE).cuda()
184 |
185 | print('Operator built.')
186 |
187 | img = torch.from_numpy(img).cuda()
188 | t = time()
189 | for _ in range(500):
190 | y = A(img)
191 |
192 | print(f'Time per forward: {(time()-t)/500}s')
193 |
194 | def unit_test():
195 | print('\n Unit testing the op')
196 | from time import time
197 | IMG_SIZE = 1024
198 | NS = 50
199 | CVAL = 0.3
200 | img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.float32) * CVAL
201 |
202 | sensors = np.random.rand(NS, 2)
203 | sensors = torch.tensor(sensors).cuda()
204 | sensors.requires_grad_(True)
205 |
206 | A = TravelTimeOperator(sensors, IMG_SIZE)
207 | print(f'A length is on {A.lengths.device}')
208 | print(f'A idx is on {A.idx.device}')
209 |
210 | y = A(torch.from_numpy(img).cuda())
211 |
212 | obj = torch.sum(y)
213 | obj.backward()
214 |
215 | y_t = y.cpu().detach().numpy()
216 | print(sensors.grad)
217 |
218 |
219 | gt = np.zeros(int(NS * (NS-1) / 2))
220 | c = 0
221 | for i in range(NS-1):
222 | for j in range(i+1, NS):
223 | gt[c] = torch.norm(sensors[i] - sensors[j]) * CVAL
224 | c += 1
225 |
226 | assert np.allclose(y_t, gt, atol=1e-5), f'Max abs error = {np.abs(y_t - gt).max()}.'
227 |
228 |
229 | def unit_test_2sensor():
230 | """Unit testing the 2 sensor row.
231 |
232 | KNOWN BUG:
233 | - In the numpy version of this code, there is a chance
234 | that two sensors are so close in either x or y that when computing
235 | the intersection points along x or y output can be None. This causes
236 | failure but since this happens rarely it is not fixed.
237 | """
238 |
239 | print('\n Unit testing each row of the op')
240 |
241 | IMG_SIZE = 64
242 | img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.float32) * 0.3
243 | # img = np.random.rand(IMG_SIZE, IMG_SIZE)
244 |
245 | # sensors = np.array([[-1.0, -1.0], [1.0, 1.0]])
246 | sensors = np.random.rand(2, 2).astype(np.float32)
247 | sensors_t = torch.from_numpy(sensors).cuda()
248 | sensors_t.requires_grad_(True)
249 | A = tt_2sensors(sensors_t, IMG_SIZE).cuda()
250 |
251 | print(f'tt_2sensors length is on {A.lengths.device}')
252 | print(f'tt_2sensors idx is on {A.idx.device}')
253 |
254 | y = A(torch.from_numpy(img).cuda())
255 |
256 | y.backward()
257 | print(sensors_t.grad)
258 |
259 | y_t = y.cpu().detach().numpy()
260 |
261 | y_np = sense_2(sensors, img).astype(np.float32)
262 | # gt = np.diag(img).sum()/32.0*np.sqrt(2)
263 | gt = np.linalg.norm(sensors[0] - sensors[1]) * 0.3
264 |
265 | assert np.abs(y_t - y_np) < 1e-6, f'Got {y_t}, expected {y_np}.'
266 |
267 | assert np.isclose(y_t, gt), f'Got {y_t}, expected {gt}.'
268 |
269 |
270 | if __name__ == '__main__':
271 | # unit_test_2sensor()
272 | # unit_test()
273 | time_op()
274 |
275 |
276 |
--------------------------------------------------------------------------------
/normflow/nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from . import utils
4 |
5 | # Try importing ResNet dependencies
6 | try:
7 | from residual_flows.layers.base import Swish, InducedNormLinear, InducedNormConv2d
8 | except:
9 | print('Warning: Dependencies for Residual Networks could '
10 | 'not be loaded. Other models can still be used.')
11 |
12 |
13 | class MLP(nn.Module):
14 | """
15 | A multilayer perceptron with Leaky ReLU nonlinearities
16 | """
17 |
18 | def __init__(self, layers, leaky=0.0, score_scale=None, output_fn=None,
19 | output_scale=None, init_zeros=False, dropout=None):
20 | """
21 | :param layers: list of layer sizes from start to end
22 | :param leaky: slope of the leaky part of the ReLU,
23 | if 0.0, standard ReLU is used
24 | :param score_scale: Factor to apply to the scores, i.e. output before
25 | output_fn.
26 | :param output_fn: String, function to be applied to the output, either
27 | None, "sigmoid", "relu", "tanh", or "clampexp"
28 | :param output_scale: Rescale outputs if output_fn is specified, i.e.
29 | scale * output_fn(out / scale)
30 | :param init_zeros: Flag, if true, weights and biases of last layer
31 | are initialized with zeros (helpful for deep models, see arXiv 1807.03039)
32 | :param dropout: Float, if specified, dropout is done before last layer;
33 | if None, no dropout is done
34 | """
35 | super().__init__()
36 | net = nn.ModuleList([])
37 | for k in range(len(layers)-2):
38 | net.append(nn.Linear(layers[k], layers[k+1]))
39 | net.append(nn.LeakyReLU(leaky))
40 | if dropout is not None:
41 | net.append(nn.Dropout(p=dropout))
42 | net.append(nn.Linear(layers[-2], layers[-1]))
43 | if init_zeros:
44 | nn.init.zeros_(net[-1].weight)
45 | nn.init.zeros_(net[-1].bias)
46 | if output_fn is not None:
47 | if score_scale is not None:
48 | net.append(utils.ConstScaleLayer(score_scale))
49 | if output_fn == "sigmoid":
50 | net.append(nn.Sigmoid())
51 | elif output_fn == "relu":
52 | net.append(nn.ReLU())
53 | elif output_fn == "tanh":
54 | net.append(nn.Tanh())
55 | elif output_fn == "clampexp":
56 | net.append(utils.ClampExp())
57 | else:
58 | NotImplementedError("This output function is not implemented.")
59 | if output_scale is not None:
60 | net.append(utils.ConstScaleLayer(output_scale))
61 | self.net = nn.Sequential(*net)
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | class ConvNet2d(nn.Module):
68 | """
69 | Convolutional Neural Network with leaky ReLU nonlinearities
70 | """
71 |
72 | def __init__(self, channels, kernel_size, leaky=0.0, init_zeros=True,
73 | actnorm=False, weight_std=None):
74 | """
75 | Constructor
76 | :param channels: List of channels of conv layers, first entry is in_channels
77 | :param kernel_size: List of kernel sizes, same for height and width
78 | :param leaky: Leaky part of ReLU
79 | :param init_zeros: Flag whether last layer shall be initialized with zeros
80 | :param scale_output: Flag whether to scale output with a log scale parameter
81 | :param logscale_factor: Constant factor to be multiplied to log scaling
82 | :param actnorm: Flag whether activation normalization shall be done after
83 | each conv layer except output
84 | :param weight_std: Fixed std used to initialize every layer
85 | """
86 | super().__init__()
87 | # Build network
88 | net = nn.ModuleList([])
89 | for i in range(len(kernel_size) - 1):
90 | conv = nn.Conv2d(channels[i], channels[i + 1], kernel_size[i],
91 | padding=kernel_size[i] // 2, bias=(not actnorm))
92 | if weight_std is not None:
93 | conv.weight.data.normal_(mean=0.0, std=weight_std)
94 | net.append(conv)
95 | if actnorm:
96 | net.append(utils.ActNorm((channels[i + 1],) + (1, 1)))
97 | net.append(nn.LeakyReLU(leaky))
98 | i = len(kernel_size)
99 | net.append(nn.Conv2d(channels[i - 1], channels[i], kernel_size[i - 1],
100 | padding=kernel_size[i - 1] // 2))
101 | if init_zeros:
102 | nn.init.zeros_(net[-1].weight)
103 | nn.init.zeros_(net[-1].bias)
104 | self.net = nn.Sequential(*net)
105 |
106 | def forward(self, x):
107 | return self.net(x)
108 |
109 |
110 | # Lipschitz continuous neural nets for residual flow
111 |
112 | class LipschitzMLP(nn.Module):
113 | """
114 | Fully connected neural net which is Lipschitz continuous
115 | with Lipschitz constant L < 1
116 | """
117 | def __init__(self, channels, lipschitz_const=0.97, max_lipschitz_iter=5,
118 | lipschitz_tolerance=None, init_zeros=True):
119 | """
120 | Constructor
121 | :param channels: Integer list with the number of channels of
122 | the layers
123 | :param lipschitz_const: Maximum Lipschitz constant of each layer
124 | :param max_lipschitz_iter: Maximum number of iterations used to
125 | ensure that layers are Lipschitz continuous with L smaller than
126 | set maximum; if None, tolerance is used
127 | :param lipschitz_tolerance: Float, tolerance used to ensure
128 | Lipschitz continuity if max_lipschitz_iter is None, typically 1e-3
129 | :param init_zeros: Flag, whether to initialize last layer
130 | approximately with zeros
131 | """
132 | super().__init__()
133 |
134 | self.n_layers = len(channels) - 1
135 | self.channels = channels
136 | self.lipschitz_const = lipschitz_const
137 | self.max_lipschitz_iter = max_lipschitz_iter
138 | self.lipschitz_tolerance = lipschitz_tolerance
139 | self.init_zeros = init_zeros
140 |
141 | layers = []
142 | for i in range(self.n_layers):
143 | layers += [Swish(),
144 | InducedNormLinear(in_features=channels[i],
145 | out_features=channels[i + 1], coeff=lipschitz_const,
146 | domain=2, codomain=2, n_iterations=max_lipschitz_iter,
147 | atol=lipschitz_tolerance, rtol=lipschitz_tolerance,
148 | zero_init=init_zeros if i == (self.n_layers - 1) else False)]
149 |
150 | self.net = nn.Sequential(*layers)
151 |
152 | def forward(self, x):
153 | return self.net(x)
154 |
155 |
156 | class LipschitzCNN(nn.Module):
157 | """
158 | Convolutional neural network which is Lipschitz continuous
159 | with Lipschitz constant L < 1
160 | """
161 | def __init__(self, channels, kernel_size, lipschitz_const=0.97,
162 | max_lipschitz_iter=5, lipschitz_tolerance=None,
163 | init_zeros=True):
164 | """
165 | Constructor
166 | :param channels: Integer list with the number of channels of
167 | the layers
168 | :param kernel_size: Integer list of kernel sizes of the layers
169 | :param lipschitz_const: Maximum Lipschitz constant of each layer
170 | :param max_lipschitz_iter: Maximum number of iterations used to
171 | ensure that layers are Lipschitz continuous with L smaller than
172 | set maximum; if None, tolerance is used
173 | :param lipschitz_tolerance: Float, tolerance used to ensure
174 | Lipschitz continuity if max_lipschitz_iter is None, typically 1e-3
175 | :param init_zeros: Flag, whether to initialize last layer
176 | approximately with zeros
177 | """
178 | super().__init__()
179 |
180 | self.n_layers = len(kernel_size)
181 | self.channels = channels
182 | self.kernel_size = kernel_size
183 | self.lipschitz_const = lipschitz_const
184 | self.max_lipschitz_iter = max_lipschitz_iter
185 | self.lipschitz_tolerance = lipschitz_tolerance
186 | self.init_zeros = init_zeros
187 |
188 | layers = []
189 | for i in range(self.n_layers):
190 | layers += [Swish(),
191 | InducedNormConv2d(in_channels=channels[i],
192 | out_channels=channels[i + 1], kernel_size=kernel_size[i],
193 | stride=1, padding=kernel_size[i] // 2, bias=True,
194 | coeff=lipschitz_const, domain=2, codomain=2,
195 | n_iterations=max_lipschitz_iter, atol=lipschitz_tolerance,
196 | rtol=lipschitz_tolerance,
197 | zero_init=init_zeros if i == self.n_layers - 1 else False)]
198 |
199 | self.net = nn.Sequential(*layers)
200 |
201 | def forward(self, x):
202 | return self.net(x)
--------------------------------------------------------------------------------
/normflow/distributions/prior.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 |
6 |
7 | class PriorDistribution:
8 | def __init__(self):
9 | raise NotImplementedError
10 |
11 | def log_prob(self, z):
12 | """
13 | :param z: value or batch of latent variable
14 | :return: log probability of the distribution for z
15 | """
16 | raise NotImplementedError
17 |
18 |
19 | class ImagePrior(nn.Module):
20 | """
21 | Intensities of an image determine probability density of prior
22 | """
23 |
24 | def __init__(self, image, x_range=[-3, 3], y_range=[-3, 3], eps=1.e-10):
25 | """
26 | Constructor
27 | :param image: image as np matrix
28 | :param x_range: x range to position image at
29 | :param y_range: y range to position image at
30 | :param eps: small value to add to image to avoid log(0) problems
31 | """
32 | super().__init__()
33 | image_ = np.flip(image, 0).transpose() + eps
34 | self.image_cpu = torch.tensor(image_ / np.max(image_))
35 | self.image_size_cpu = self.image_cpu.size()
36 | self.x_range = torch.tensor(x_range)
37 | self.y_range = torch.tensor(y_range)
38 |
39 | self.register_buffer('image', self.image_cpu)
40 | self.register_buffer('image_size', torch.tensor(self.image_size_cpu).unsqueeze(0))
41 | self.register_buffer('density', torch.log(self.image_cpu / torch.sum(self.image_cpu)))
42 | self.register_buffer('scale', torch.tensor([[self.x_range[1] - self.x_range[0],
43 | self.y_range[1] - self.y_range[0]]]))
44 | self.register_buffer('shift', torch.tensor([[self.x_range[0], self.y_range[0]]]))
45 |
46 | def log_prob(self, z):
47 | """
48 | :param z: value or batch of latent variable
49 | :return: log probability of the distribution for z
50 | """
51 | z_ = torch.clamp((z - self.shift) / self.scale, max=1, min=0)
52 | ind = (z_ * (self.image_size - 1)).long()
53 | return self.density[ind[:, 0], ind[:, 1]]
54 |
55 | def rejection_sampling(self, num_steps=1):
56 | """
57 | Perform rejection sampling on image distribution
58 | :param num_steps: Number of rejection sampling steps to perform
59 | :return: Accepted samples
60 | """
61 | z_ = torch.rand((num_steps, 2), dtype=self.image.dtype, device=self.image.device)
62 | prob = torch.rand(num_steps, dtype=self.image.dtype, device=self.image.device)
63 | ind = (z_ * (self.image_size - 1)).long()
64 | intensity = self.image[ind[:, 0], ind[:, 1]]
65 | accept = intensity > prob
66 | z = z_[accept, :] * self.scale + self.shift
67 | return z
68 |
69 | def sample(self, num_samples=1):
70 | """
71 | Sample from image distribution through rejection sampling
72 | :param num_samples: Number of samples to draw
73 | :return: Samples
74 | """
75 | z = torch.ones((0, 2), dtype=self.image.dtype, device=self.image.device)
76 | while len(z) < num_samples:
77 | z_ = self.rejection_sampling(num_samples)
78 | ind = np.min([len(z_), num_samples - len(z)])
79 | z = torch.cat([z, z_[:ind, :]], 0)
80 | return z
81 |
82 |
83 | class TwoModes(PriorDistribution):
84 | def __init__(self, loc, scale):
85 | """
86 | Distribution 2d with two modes at z[0] = -loc and z[0] = loc
87 | :param loc: distance of modes from the origin
88 | :param scale: scale of modes
89 | """
90 | self.loc = loc
91 | self.scale = scale
92 |
93 | def log_prob(self, z):
94 | """
95 | log(p) = 1/2 * ((norm(z) - loc) / (2 * scale)) ** 2
96 | - log(exp(-1/2 * ((z[0] - loc) / (3 * scale)) ** 2) + exp(-1/2 * ((z[0] + loc) / (3 * scale)) ** 2))
97 | :param z: value or batch of latent variable
98 | :return: log probability of the distribution for z
99 | """
100 | a = torch.abs(z[:, 0])
101 | eps = torch.abs(torch.tensor(self.loc))
102 |
103 | log_prob = - 0.5 * ((torch.norm(z, dim=1) - self.loc) / (2 * self.scale)) ** 2 \
104 | - 0.5 * ((a - eps) / (3 * self.scale)) ** 2 \
105 | + torch.log(1 + torch.exp(-2 * (a * eps) / (3 * self.scale) ** 2))
106 |
107 | return log_prob
108 |
109 |
110 | class Sinusoidal(PriorDistribution):
111 | def __init__(self, scale, period):
112 | """
113 | Distribution 2d with sinusoidal density
114 | :param loc: distance of modes from the origin
115 | :param scale: scale of modes
116 | """
117 | self.scale = scale
118 | self.period = period
119 |
120 | def log_prob(self, z):
121 | """
122 | log(p) = - 1/2 * ((z[1] - w_1(z)) / (2 * scale)) ** 2
123 | w_1(z) = sin(2*pi / period * z[0])
124 | :param z: value or batch of latent variable
125 | :return: log probability of the distribution for z
126 | """
127 | if z.dim() > 1:
128 | z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
129 | else:
130 | z_ = z
131 |
132 | w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0])
133 | log_prob = - 0.5 * ((z_[1] - w_1(z_)) / (self.scale)) ** 2 \
134 | - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4 # add Gaussian envelope for valid p(z)
135 |
136 | return log_prob
137 |
138 |
139 | class Sinusoidal_gap(PriorDistribution):
140 | def __init__(self, scale, period):
141 | """
142 | Distribution 2d with sinusoidal density with gap
143 | :param loc: distance of modes from the origin
144 | :param scale: scale of modes
145 | """
146 | self.scale = scale
147 | self.period = period
148 | self.w2_scale = 0.6
149 | self.w2_amp = 3.0
150 | self.w2_mu = 1.0
151 |
152 | def log_prob(self, z):
153 | """
154 | :param z: value or batch of latent variable
155 | :return: log probability of the distribution for z
156 | """
157 | if z.dim() > 1:
158 | z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
159 | else:
160 | z_ = z
161 |
162 | w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0])
163 | w_2 = lambda x: self.w2_amp * torch.exp(-0.5 * ((z_[0] - self.w2_mu) / self.w2_scale) ** 2)
164 |
165 | eps = torch.abs(w_2(z_) / 2)
166 | a = torch.abs(z_[1] - w_1(z_) + w_2(z_) / 2)
167 |
168 | log_prob = -0.5 * ((a - eps) / self.scale) ** 2 + \
169 | torch.log(1 + torch.exp(-2 * (eps * a) / self.scale ** 2)) \
170 | - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4
171 |
172 | return log_prob
173 |
174 |
175 | class Sinusoidal_split(PriorDistribution):
176 | def __init__(self, scale, period):
177 | """
178 | Distribution 2d with sinusoidal density with split
179 | :param loc: distance of modes from the origin
180 | :param scale: scale of modes
181 | """
182 | self.scale = scale
183 | self.period = period
184 | self.w3_scale = 0.3
185 | self.w3_amp = 3.0
186 | self.w3_mu = 1.0
187 |
188 | def log_prob(self, z):
189 | """
190 | :param z: value or batch of latent variable
191 | :return: log probability of the distribution for z
192 | """
193 | if z.dim() > 1:
194 | z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
195 | else:
196 | z_ = z
197 |
198 | w_1 = lambda x: torch.sin(2 * np.pi / self.period * z_[0])
199 | w_3 = lambda x: self.w3_amp * torch.sigmoid((z_[0] - self.w3_mu) / self.w3_scale)
200 |
201 | eps = torch.abs(w_3(z_) / 2)
202 | a = torch.abs(z_[1] - w_1(z_) + w_3(z_) / 2)
203 |
204 | log_prob = -0.5 * ((a - eps) / (self.scale)) ** 2 + \
205 | torch.log(1 + torch.exp(-2 * (eps * a) / self.scale ** 2)) \
206 | - 0.5 * (torch.norm(z_, dim=0, p=4) / (20 * self.scale)) ** 4
207 |
208 | return log_prob
209 |
210 |
211 | class Smiley(PriorDistribution):
212 | def __init__(self, scale):
213 | """
214 | Distribution 2d of a smiley :)
215 | :param loc: distance of modes from the origin
216 | :param scale: scale of modes
217 | """
218 | self.scale = scale
219 | self.loc = 2.0
220 |
221 | def log_prob(self, z):
222 | """
223 | :param z: value or batch of latent variable
224 | :return: log probability of the distribution for z
225 | """
226 | if z.dim() > 1:
227 | z_ = z.permute((z.dim() - 1,) + tuple(range(0, z.dim() - 1)))
228 | else:
229 | z_ = z
230 |
231 | log_prob = - 0.5 * ((torch.norm(z_, dim=0) - self.loc) / (2 * self.scale)) ** 2 \
232 | - 0.5 * ((torch.abs(z_[1] + 0.8) - 1.2) / (2 * self.scale)) ** 2
233 |
234 | return log_prob
--------------------------------------------------------------------------------
/ops/odl_lib.py:
--------------------------------------------------------------------------------
1 | """Holds all the odl related utilities."""
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import torch
5 |
6 | import odl
7 |
8 | # from odl.contrib.torch import OperatorFunction
9 | # from odl.contrib.torch import OperatorFunction
10 | from ops.ODLHelper import OperatorFunction
11 |
12 | class OperatorAsAutogradFunction(object):
13 | """Dummy class around OperatorFunction.
14 |
15 | Check https://github.com/swing-research/dual_implicit_rep/issues/1 .
16 | """
17 | def __init__(self, odl_op):
18 | self.op = odl_op
19 |
20 | def __call__(self, x):
21 | return OperatorFunction.apply(self.op, x)
22 |
23 |
24 | def apply_angle_noise(angles, noise):
25 | """Applies operator noise in the angles of the operator.
26 |
27 | SNR = 20 log (P_angles/P_noise)
28 | Args:
29 | angles (np.ndarray): 1D array of angles used in the Radon operator.
30 | noise (float): SNR of the noise.
31 | Returns:
32 | noisy_angles (np.ndarray): Noisy angles at `noise` dB.
33 | """
34 | if noise > 200:
35 | return angles
36 |
37 | noise_std_dev = 10**(-noise/20)
38 |
39 | noisy_angles = angles * (1 + np.random.randn(*angles.shape)*noise_std_dev)
40 |
41 | return noisy_angles
42 |
43 |
44 | class ParallelBeamGeometryOp(object):
45 | """Creates an `img_size` mesh parallel geometry tomography operator."""
46 |
47 | def __init__(self, img_size, num_angles, op_snr=500):
48 | self.img_size = img_size
49 | self.num_angles = num_angles
50 | self.reco_space = odl.uniform_discr(
51 | min_pt=[-20, -20],
52 | max_pt=[20, 20],
53 | shape=[img_size, img_size],
54 | dtype='float32'
55 | )
56 |
57 | self.geometry = odl.tomo.parallel_beam_geometry(
58 | self.reco_space, num_angles)
59 |
60 | self.num_detectors = self.geometry.detector.size
61 | self.op_snr = op_snr
62 | self.angles = apply_angle_noise(self.geometry.angles, op_snr)
63 |
64 | self.optimizable_params = torch.tensor(
65 | self.angles, dtype=torch.float32) # Convert to torch.Tensor.
66 |
67 | self.op = odl.tomo.RayTransform(
68 | self.reco_space,
69 | self.geometry,
70 | impl='astra_cuda')
71 |
72 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op)
73 |
74 | def __call__(self, x):
75 | return OperatorFunction.apply(self.op, x)
76 |
77 | def pinv(self, y):
78 | return OperatorFunction.apply(self.fbp, y)
79 |
80 | class ParallelBeamGeometryOpBroken(ParallelBeamGeometryOp):
81 | """Creates a noisy angle instance of ParallelBeamGeometryOp
82 |
83 | # Steps taken from implementation of odl.tomo.parallel_beam_geometry
84 | # https://github.com/odlgroup/odl/blob/master/odl/tomo/geometry/parallel.py#L1471
85 |
86 | Notes
87 | -----
88 | According to [NW2001]_, pages 72--74, a function
89 | :math:`f : \mathbb{R}^2 \to \mathbb{R}` that has compact support
90 | .. math::
91 | | x | > rho implies f(x) = 0,
92 | and is essentially bandlimited
93 | .. math::
94 | | xi | > Omega implies hat{f}(xi) approx 0,
95 | can be fully reconstructed from a parallel beam ray transform
96 | if (1) the projection angles are sampled with a spacing of
97 | :math:`Delta psi` such that
98 | .. math::
99 | Delta psi leq frac{pi}{rho Omega},
100 | and (2) the detector is sampled with an interval :math:`Delta s`
101 | that satisfies
102 | .. math::
103 | Delta s leq frac{pi}{Omega}.
104 | The geometry returned by this function satisfies these conditions exactly.
105 | If the domain is 3-dimensional, the geometry is "separable", in that each
106 | slice along the z-dimension of the data is treated as independed 2d data.
107 | References
108 | ----------
109 | .. [NW2001] Natterer, F and Wuebbeling, F.
110 | *Mathematical Methods in Image Reconstruction*.
111 | SIAM, 2001.
112 | https://dx.doi.org/10.1137/1.9780898718324
113 | """
114 | def __init__(self, clean_operator, op_snr):
115 | super().__init__(clean_operator.img_size, clean_operator.num_angles, op_snr)
116 |
117 | space = self.reco_space
118 |
119 | # Find maximum distance from rotation axis
120 | corners = space.domain.corners()[:, :2]
121 | rho = np.max(np.linalg.norm(corners, axis=1))
122 |
123 | # Find default values according to Nyquist criterion.
124 |
125 | # We assume that the function is bandlimited by a wave along the x or y
126 | # axis. The highest frequency we can measure is then a standing wave with
127 | # period of twice the inter-node distance.
128 | min_side = min(space.partition.cell_sides[:2])
129 | omega = np.pi / min_side
130 | num_px_horiz = 2 * int(np.ceil(rho * omega / np.pi)) + 1
131 | det_min_pt = -rho
132 | det_max_pt = rho
133 | det_shape = num_px_horiz
134 | det_partition = odl.discr.uniform_partition(det_min_pt, det_max_pt, det_shape)
135 |
136 | self.angles = apply_angle_noise(clean_operator.geometry.angles, op_snr)
137 |
138 | self.optimizable_params = torch.tensor(clean_operator.geometry.angles, dtype=torch.float32)
139 |
140 | # angle partition is changed to not be uniform
141 | angle_partition = odl.discr.nonuniform_partition(np.sort(self.angles))
142 |
143 | self.geometry = odl.tomo.Parallel2dGeometry(angle_partition, det_partition)
144 |
145 | self.num_detectors = self.geometry.detector.size
146 |
147 | self.op = odl.tomo.RayTransform(
148 | self.reco_space,
149 | self.geometry,
150 | impl='astra_cuda')
151 |
152 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op)
153 |
154 |
155 | class ParallelBeamGeometryOpNonUniform(ParallelBeamGeometryOp):
156 | """Creates a noisy angle instance of ParallelBeamGeometryOp
157 |
158 | # Steps taken from implementation of odl.tomo.parallel_beam_geometry
159 | # https://github.com/odlgroup/odl/blob/master/odl/tomo/geometry/parallel.py#L1471
160 |
161 | Notes
162 | -----
163 | According to [NW2001]_, pages 72--74, a function
164 | :math:`f : \mathbb{R}^2 \to \mathbb{R}` that has compact support
165 | .. math::
166 | | x | > rho implies f(x) = 0,
167 | and is essentially bandlimited
168 | .. math::
169 | | xi | > Omega implies hat{f}(xi) approx 0,
170 | can be fully reconstructed from a parallel beam ray transform
171 | if (1) the projection angles are sampled with a spacing of
172 | :math:`Delta psi` such that
173 | .. math::
174 | Delta psi leq frac{pi}{rho Omega},
175 | and (2) the detector is sampled with an interval :math:`Delta s`
176 | that satisfies
177 | .. math::
178 | Delta s leq frac{pi}{Omega}.
179 | The geometry returned by this function satisfies these conditions exactly.
180 | If the domain is 3-dimensional, the geometry is "separable", in that each
181 | slice along the z-dimension of the data is treated as independed 2d data.
182 | References
183 | ----------
184 | .. [NW2001] Natterer, F and Wuebbeling, F.
185 | *Mathematical Methods in Image Reconstruction*.
186 | SIAM, 2001.
187 | https://dx.doi.org/10.1137/1.9780898718324
188 | """
189 | def __init__(self, clean_operator, nonuniform_angles):
190 | super().__init__(clean_operator.img_size, clean_operator.num_angles, clean_operator.op_snr)
191 |
192 | space = self.reco_space
193 |
194 | # Find maximum distance from rotation axis
195 | corners = space.domain.corners()[:, :2]
196 | rho = np.max(np.linalg.norm(corners, axis=1))
197 |
198 | # Find default values according to Nyquist criterion.
199 |
200 | # We assume that the function is bandlimited by a wave along the x or y
201 | # axis. The highest frequency we can measure is then a standing wave with
202 | # period of twice the inter-node distance.
203 | min_side = min(space.partition.cell_sides[:2])
204 | omega = np.pi / min_side
205 | num_px_horiz = 2 * int(np.ceil(rho * omega / np.pi)) + 1
206 | det_min_pt = -rho
207 | det_max_pt = rho
208 | det_shape = num_px_horiz
209 | det_partition = odl.discr.uniform_partition(det_min_pt, det_max_pt, det_shape)
210 |
211 | del self.angles
212 |
213 | # angle partition is changed to not be uniform
214 | angle_partition = odl.discr.nonuniform_partition(np.sort(nonuniform_angles))
215 |
216 | self.geometry = odl.tomo.Parallel2dGeometry(angle_partition, det_partition)
217 |
218 | self.num_detectors = self.geometry.detector.size
219 |
220 | self.op = odl.tomo.RayTransform(
221 | self.reco_space,
222 | self.geometry,
223 | impl='astra_cuda')
224 |
225 | self.fbp = odl.tomo.analytic.filtered_back_projection.fbp_op(self.op)
226 |
227 |
228 | def unit_test():
229 | op_64_32 = ParallelBeamGeometryOp(1024, 32)
230 | phantom = odl.phantom.shepp_logan(
231 | op_64_32.reco_space, modified=True)
232 |
233 | x = torch.from_numpy(phantom.data)
234 | y = op_64_32(x)
235 |
236 | print(y.shape)
237 | print(x.shape)
238 |
239 | def unit_test_nonuniform():
240 |
241 | operator1 = ParallelBeamGeometryOp(64, 45)
242 | x = torch.from_numpy(odl.phantom.shepp_logan(operator1.reco_space, modified=True).data)
243 | y1 = operator1(x)
244 |
245 | operator2 = ParallelBeamGeometryOpNonUniform(operator1, operator1.geometry.angles)
246 | y2 = operator2(x)
247 |
248 | print(y1 - y2)
249 |
250 | if __name__ == '__main__':
251 | unit_test()
252 | unit_test_nonuniform()
253 |
--------------------------------------------------------------------------------
/normflow/flows/mixing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .base import Flow
5 |
6 | # Try importing Neural Spline Flow dependencies
7 | try:
8 | from neural_spline_flows.nde.transforms.permutations import RandomPermutation
9 | from neural_spline_flows.nde.transforms.lu import LULinear
10 | except:
11 | print('Warning: Dependencies for Neural Spline Flows could '
12 | 'not be loaded. Other models can still be used.')
13 |
14 |
15 | class Permute(Flow):
16 | """
17 | Permutation features along the channel dimension
18 | """
19 |
20 | def __init__(self, num_channels, mode='shuffle'):
21 | """
22 | Constructor
23 | :param num_channel: Number of channels
24 | :param mode: Mode of permuting features, can be shuffle for
25 | random permutation or swap for interchanging upper and lower part
26 | """
27 | super().__init__()
28 | self.mode = mode
29 | self.num_channels = num_channels
30 | if self.mode == 'shuffle':
31 | perm = torch.randperm(self.num_channels)
32 | inv_perm = torch.empty_like(perm).scatter_(dim=0, index=perm,
33 | src=torch.arange(self.num_channels))
34 | self.register_buffer("perm", perm)
35 | self.register_buffer("inv_perm", inv_perm)
36 |
37 | def forward(self, z):
38 | if self.mode == 'shuffle':
39 | z = z[:, self.perm, ...]
40 | elif self.mode == 'swap':
41 | z1 = z[:, :self.num_channels // 2, ...]
42 | z2 = z[:, self.num_channels // 2:, ...]
43 | z = torch.cat([z2, z1], dim=1)
44 | else:
45 | raise NotImplementedError('The mode ' + self.mode + ' is not implemented.')
46 | log_det = 0
47 | return z, log_det
48 |
49 | def inverse(self, z):
50 | if self.mode == 'shuffle':
51 | z = z[:, self.inv_perm, ...]
52 | elif self.mode == 'swap':
53 | z1 = z[:, :(self.num_channels + 1) // 2, ...]
54 | z2 = z[:, (self.num_channels + 1) // 2:, ...]
55 | z = torch.cat([z2, z1], dim=1)
56 | else:
57 | raise NotImplementedError('The mode ' + self.mode + ' is not implemented.')
58 | log_det = 0
59 | return z, log_det
60 |
61 |
62 | class Invertible1x1Conv(Flow):
63 | """
64 | Invertible 1x1 convolution introduced in the Glow paper
65 | Assumes 4d input/output tensors of the form NCHW
66 | """
67 |
68 | def __init__(self, num_channels, use_lu=False):
69 | """
70 | Constructor
71 | :param num_channels: Number of channels of the data
72 | :param use_lu: Flag whether to parametrize weights through the LU decomposition
73 | """
74 | super().__init__()
75 | self.num_channels = num_channels
76 | self.use_lu = use_lu
77 | Q = torch.qr(torch.randn(self.num_channels, self.num_channels))[0]
78 | if use_lu:
79 | P, L, U = torch.lu_unpack(*Q.lu())
80 | self.register_buffer('P', P) # remains fixed during optimization
81 | self.L = nn.Parameter(L) # lower triangular portion
82 | S = U.diag() # "crop out" the diagonal to its own parameter
83 | self.register_buffer("sign_S", torch.sign(S))
84 | self.log_S = nn.Parameter(torch.log(torch.abs(S)))
85 | self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S
86 | self.register_buffer("eye", torch.diag(torch.ones(self.num_channels)))
87 | else:
88 | self.W = nn.Parameter(Q)
89 |
90 | def _assemble_W(self, inverse=False):
91 | # assemble W from its components (P, L, U, S)
92 | L = torch.tril(self.L, diagonal=-1) + self.eye
93 | U = torch.triu(self.U, diagonal=1) + torch.diag(self.sign_S * torch.exp(self.log_S))
94 | if inverse:
95 | if self.log_S.dtype == torch.float64:
96 | L_inv = torch.inverse(L)
97 | U_inv = torch.inverse(U)
98 | else:
99 | L_inv = torch.inverse(L.double()).type(self.log_S.dtype)
100 | U_inv = torch.inverse(U.double()).type(self.log_S.dtype)
101 | W = U_inv @ L_inv @ self.P.t()
102 | else:
103 | W = self.P @ L @ U
104 | return W
105 |
106 | def forward(self, z):
107 | if self.use_lu:
108 | W = self._assemble_W(inverse=True)
109 | log_det = -torch.sum(self.log_S)
110 | else:
111 | W_dtype = self.W.dtype
112 | if W_dtype == torch.float64:
113 | W = torch.inverse(self.W)
114 | else:
115 | W = torch.inverse(self.W.double()).type(W_dtype)
116 | W = W.view(*W.size(), 1, 1)
117 | log_det = -torch.slogdet(self.W)[1]
118 | W = W.view(self.num_channels, self.num_channels, 1, 1)
119 | z_ = torch.nn.functional.conv2d(z, W)
120 | log_det = log_det * z.size(2) * z.size(3)
121 | return z_, log_det
122 |
123 | def inverse(self, z):
124 | if self.use_lu:
125 | W = self._assemble_W()
126 | log_det = torch.sum(self.log_S)
127 | else:
128 | W = self.W
129 | log_det = torch.slogdet(self.W)[1]
130 | W = W.view(self.num_channels, self.num_channels, 1, 1)
131 | z_ = torch.nn.functional.conv2d(z, W)
132 | log_det = log_det * z.size(2) * z.size(3)
133 | return z_, log_det
134 |
135 |
136 | class InvertibleAffine(Flow):
137 | """
138 | Invertible affine transformation without shift, i.e. one-dimensional
139 | version of the invertible 1x1 convolutions
140 | """
141 |
142 | def __init__(self, num_channels, use_lu=True):
143 | """
144 | Constructor
145 | :param num_channels: Number of channels of the data
146 | :param use_lu: Flag whether to parametrize weights through the
147 | LU decomposition
148 | """
149 | super().__init__()
150 | self.num_channels = num_channels
151 | self.use_lu = use_lu
152 | Q = torch.qr(torch.randn(self.num_channels, self.num_channels))[0]
153 | if use_lu:
154 | P, L, U = torch.lu_unpack(*Q.lu())
155 | self.register_buffer('P', P) # remains fixed during optimization
156 | self.L = nn.Parameter(L) # lower triangular portion
157 | S = U.diag() # "crop out" the diagonal to its own parameter
158 | self.register_buffer("sign_S", torch.sign(S))
159 | self.log_S = nn.Parameter(torch.log(torch.abs(S)))
160 | self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S
161 | self.register_buffer("eye", torch.diag(torch.ones(self.num_channels)))
162 | else:
163 | self.W = nn.Parameter(Q)
164 |
165 | def _assemble_W(self, inverse=False):
166 | # assemble W from its components (P, L, U, S)
167 | L = torch.tril(self.L, diagonal=-1) + self.eye
168 | U = torch.triu(self.U, diagonal=1) + torch.diag(self.sign_S * torch.exp(self.log_S))
169 | if inverse:
170 | if self.log_S.dtype == torch.float64:
171 | L_inv = torch.inverse(L)
172 | U_inv = torch.inverse(U)
173 | else:
174 | L_inv = torch.inverse(L.double()).type(self.log_S.dtype)
175 | U_inv = torch.inverse(U.double()).type(self.log_S.dtype)
176 | W = U_inv @ L_inv @ self.P.t()
177 | else:
178 | W = self.P @ L @ U
179 | return W
180 |
181 | def forward(self, z):
182 | if self.use_lu:
183 | W = self._assemble_W(inverse=True)
184 | log_det = -torch.sum(self.log_S)
185 | else:
186 | W_dtype = self.W.dtype
187 | if W_dtype == torch.float64:
188 | W = torch.inverse(self.W)
189 | else:
190 | W = torch.inverse(self.W.double()).type(W_dtype)
191 | log_det = -torch.slogdet(self.W)[1]
192 | z_ = z @ W
193 | return z_, log_det
194 |
195 | def inverse(self, z):
196 | if self.use_lu:
197 | W = self._assemble_W()
198 | log_det = torch.sum(self.log_S)
199 | else:
200 | W = self.W
201 | log_det = torch.slogdet(self.W)[1]
202 | z_ = z @ W
203 | return z_, log_det
204 |
205 |
206 | class LULinearPermute(Flow):
207 | """
208 | Fixed permutation combined with a linear transformation parametrized
209 | using the LU decomposition, used in https://arxiv.org/abs/1906.04032
210 | """
211 | def __init__(self, num_channels, identity_init=True, reverse=True):
212 | """
213 | Constructor
214 | :param num_channels: Number of dimensions of the data
215 | :param identity_init: Flag, whether to initialize linear
216 | transform as identity matrix
217 | :param reverse: Flag, change forward and inverse transform
218 | """
219 | # Initialize
220 | super().__init__()
221 | self.reverse = reverse
222 |
223 | # Define modules
224 | self.permutation = RandomPermutation(num_channels)
225 | self.linear = LULinear(num_channels, identity_init=identity_init)
226 |
227 | def forward(self, z):
228 | if self.reverse:
229 | z, log_det = self.linear.inverse(z)
230 | z, _ = self.permutation.inverse(z)
231 | else:
232 | z, _ = self.permutation(z)
233 | z, log_det = self.linear(z)
234 | return z, log_det.view(-1)
235 |
236 | def inverse(self, z):
237 | if self.reverse:
238 | z, _ = self.permutation(z)
239 | z, log_det = self.linear(z)
240 | else:
241 | z, log_det = self.linear.inverse(z)
242 | z, _ = self.permutation.inverse(z)
243 | return z, log_det.view(-1)
244 |
--------------------------------------------------------------------------------
/normflow/flows/affine_coupling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 | from .base import Flow
6 | from .reshape import Split, Merge
7 |
8 |
9 |
10 | class AffineConstFlow(Flow):
11 | """
12 | scales and shifts with learned constants per dimension. In the NICE paper there is a
13 | scaling layer which is a special case of this where t is None
14 | """
15 |
16 | def __init__(self, shape, scale=True, shift=True):
17 | """
18 | Constructor
19 | :param shape: Shape of the coupling layer
20 | :param scale: Flag whether to apply scaling
21 | :param shift: Flag whether to apply shift
22 | :param logscale_factor: Optional factor which can be used to control
23 | the scale of the log scale factor
24 | """
25 | super().__init__()
26 | if scale:
27 | self.s = nn.Parameter(torch.zeros(shape)[None])
28 | else:
29 | self.register_buffer('s', torch.zeros(shape)[None])
30 | if shift:
31 | self.t = nn.Parameter(torch.zeros(shape)[None])
32 | else:
33 | self.register_buffer('t', torch.zeros(shape)[None])
34 | self.n_dim = self.s.dim()
35 | self.batch_dims = torch.nonzero(torch.tensor(self.s.shape) == 1, as_tuple=False)[:, 0].tolist()
36 |
37 | def forward(self, z):
38 | z_ = z * torch.exp(self.s) + self.t
39 | if len(self.batch_dims) > 1:
40 | prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]])
41 | else:
42 | prod_batch_dims = 1
43 | log_det = prod_batch_dims * torch.sum(self.s)
44 | return z_, log_det
45 |
46 | def inverse(self, z):
47 | z_ = (z - self.t) * torch.exp(-self.s)
48 | if len(self.batch_dims) > 1:
49 | prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]])
50 | else:
51 | prod_batch_dims = 1
52 | log_det = -prod_batch_dims * torch.sum(self.s)
53 | return z_, log_det
54 |
55 |
56 | class CCAffineConst(Flow):
57 | """
58 | Affine constant flow layer with class-conditional parameters
59 | """
60 |
61 | def __init__(self, shape, num_classes):
62 | super().__init__()
63 | self.shape = shape
64 | self.s = nn.Parameter(torch.zeros(shape)[None])
65 | self.t = nn.Parameter(torch.zeros(shape)[None])
66 | self.s_cc = nn.Parameter(torch.zeros(num_classes, np.prod(shape)))
67 | self.t_cc = nn.Parameter(torch.zeros(num_classes, np.prod(shape)))
68 | self.n_dim = self.s.dim()
69 | self.batch_dims = torch.nonzero(torch.tensor(self.s.shape) == 1, as_tuple=False)[:, 0].tolist()
70 |
71 | def forward(self, z, y):
72 | s = self.s + (y @ self.s_cc).view(-1, *self.shape)
73 | t = self.t + (y @ self.t_cc).view(-1, *self.shape)
74 | z_ = z * torch.exp(s) + t
75 | if len(self.batch_dims) > 1:
76 | prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]])
77 | else:
78 | prod_batch_dims = 1
79 | log_det = prod_batch_dims * torch.sum(s, dim=list(range(1, self.n_dim)))
80 | return z_, log_det
81 |
82 | def inverse(self, z, y):
83 | s = self.s + (y @ self.s_cc).view(-1, *self.shape)
84 | t = self.t + (y @ self.t_cc).view(-1, *self.shape)
85 | z_ = (z - t) * torch.exp(-s)
86 | if len(self.batch_dims) > 1:
87 | prod_batch_dims = np.prod([z.size(i) for i in self.batch_dims[1:]])
88 | else:
89 | prod_batch_dims = 1
90 | log_det = -prod_batch_dims * torch.sum(s, dim=list(range(1, self.n_dim)))
91 | return z_, log_det
92 |
93 |
94 | class AffineCoupling(Flow):
95 | """
96 | Affine Coupling layer as introduced RealNVP paper, see arXiv: 1605.08803
97 | """
98 |
99 | def __init__(self, param_map, scale=True, scale_map='exp'):
100 | """
101 | Constructor
102 | :param param_map: Maps features to shift and scale parameter (if applicable)
103 | :param scale: Flag whether scale shall be applied
104 | :param scale_map: Map to be applied to the scale parameter, can be 'exp' as in
105 | RealNVP or 'sigmoid' as in Glow, 'sigmoid_inv' uses multiplicative sigmoid
106 | scale when sampling from the model
107 | """
108 | super().__init__()
109 | self.add_module('param_map', param_map)
110 | self.scale = scale
111 | self.scale_map = scale_map
112 |
113 | def forward(self, z):
114 | """
115 | z is a list of z1 and z2; z = [z1, z2]
116 | z1 is left constant and affine map is applied to z2 with parameters depending
117 | on z1
118 | """
119 | z1, z2 = z
120 | param = self.param_map(z1)
121 | if self.scale:
122 | shift = param[:, 0::2, ...]
123 | scale_ = param[:, 1::2, ...]
124 | if self.scale_map == 'exp':
125 | z2 = z2 * torch.exp(scale_) + shift
126 | log_det = torch.sum(scale_, dim=list(range(1, shift.dim())))
127 | elif self.scale_map == 'sigmoid':
128 | scale = torch.sigmoid(scale_ + 2)
129 | z2 = z2 / scale + shift
130 | log_det = -torch.sum(torch.log(scale),
131 | dim=list(range(1, shift.dim())))
132 | elif self.scale_map == 'sigmoid_inv':
133 | scale = torch.sigmoid(scale_ + 2)
134 | z2 = z2 * scale + shift
135 | log_det = torch.sum(torch.log(scale),
136 | dim=list(range(1, shift.dim())))
137 | else:
138 | raise NotImplementedError('This scale map is not implemented.')
139 | else:
140 | z2 += param
141 | log_det = 0
142 | return [z1, z2], log_det
143 |
144 | def inverse(self, z):
145 | z1, z2 = z
146 | param = self.param_map(z1)
147 | if self.scale:
148 | shift = param[:, 0::2, ...]
149 | scale_ = param[:, 1::2, ...]
150 | if self.scale_map == 'exp':
151 | z2 = (z2 - shift) * torch.exp(-scale_)
152 | log_det = -torch.sum(scale_, dim=list(range(1, shift.dim())))
153 | elif self.scale_map == 'sigmoid':
154 | scale = torch.sigmoid(scale_ + 2)
155 | z2 = (z2 - shift) * scale
156 | log_det = torch.sum(torch.log(scale),
157 | dim=list(range(1, shift.dim())))
158 | elif self.scale_map == 'sigmoid_inv':
159 | scale = torch.sigmoid(scale_ + 2)
160 | z2 = (z2 - shift) / scale
161 | log_det = -torch.sum(torch.log(scale),
162 | dim=list(range(1, shift.dim())))
163 | else:
164 | raise NotImplementedError('This scale map is not implemented.')
165 | else:
166 | z2 -= param
167 | log_det = 0
168 | return [z1, z2], log_det
169 |
170 |
171 | class MaskedAffineFlow(Flow):
172 | """
173 | RealNVP as introduced in arXiv: 1605.08803
174 | Masked affine flow f(z) = b * z + (1 - b) * (z * exp(s(b * z)) + t)
175 | class AffineHalfFlow(Flow): is MaskedAffineFlow with alternating bit mask
176 | NICE is AffineFlow with only shifts (volume preserving)
177 | """
178 |
179 | def __init__(self, b, t=None, s=None):
180 | """
181 | Constructor
182 | :param b: mask for features, i.e. tensor of same size as latent data point filled with 0s and 1s
183 | :param t: translation mapping, i.e. neural network, where first input dimension is batch dim,
184 | if None no translation is applied
185 | :param s: scale mapping, i.e. neural network, where first input dimension is batch dim,
186 | if None no scale is applied
187 | """
188 | super().__init__()
189 | self.b_cpu = b.view(1, *b.size())
190 | self.register_buffer('b', self.b_cpu)
191 |
192 | if s is None:
193 | self.s = lambda x: torch.zeros_like(x)
194 | else:
195 | self.add_module('s', s)
196 |
197 | if t is None:
198 | self.t = lambda x: torch.zeros_like(x)
199 | else:
200 | self.add_module('t', t)
201 |
202 | def forward(self, z):
203 | z_masked = self.b * z
204 | scale = self.s(z_masked)
205 | nan = torch.tensor(np.nan, dtype=z.dtype, device=z.device)
206 | scale = torch.where(torch.isfinite(scale), scale, nan)
207 | trans = self.t(z_masked)
208 | trans = torch.where(torch.isfinite(trans), trans, nan)
209 | z_ = z_masked + (1 - self.b) * (z * torch.exp(scale) + trans)
210 | log_det = torch.sum((1 - self.b) * scale, dim=list(range(1, self.b.dim())))
211 | return z_, log_det
212 |
213 | def inverse(self, z):
214 | z_masked = self.b * z
215 | scale = self.s(z_masked)
216 | nan = torch.tensor(np.nan, dtype=z.dtype, device=z.device)
217 | scale = torch.where(torch.isfinite(scale), scale, nan)
218 | trans = self.t(z_masked)
219 | trans = torch.where(torch.isfinite(trans), trans, nan)
220 | z_ = z_masked + (1 - self.b) * (z - trans) * torch.exp(-scale)
221 | log_det = -torch.sum((1 - self.b) * scale, dim=list(range(1, self.b.dim())))
222 | return z_, log_det
223 |
224 |
225 | class AffineCouplingBlock(Flow):
226 | """
227 | Affine Coupling layer including split and merge operation
228 | """
229 | def __init__(self, param_map, scale=True, scale_map='exp', split_mode='channel'):
230 | """
231 | Constructor
232 | :param param_map: Maps features to shift and scale parameter (if applicable)
233 | :param scale: Flag whether scale shall be applied
234 | :param scale_map: Map to be applied to the scale parameter, can be 'exp' as in
235 | RealNVP or 'sigmoid' as in Glow
236 | :param split_mode: Splitting mode, for possible values see Split class
237 | """
238 | super().__init__()
239 | self.flows = nn.ModuleList([])
240 | # Split layer
241 | self.flows += [Split(split_mode)]
242 | # Affine coupling layer
243 | self.flows += [AffineCoupling(param_map, scale, scale_map)]
244 | # Merge layer
245 | self.flows += [Merge(split_mode)]
246 |
247 | def forward(self, z):
248 | log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
249 | for flow in self.flows:
250 | z, log_det = flow(z)
251 | log_det_tot += log_det
252 | return z, log_det_tot
253 |
254 | def inverse(self, z):
255 | log_det_tot = torch.zeros(z.shape[0], dtype=z.dtype, device=z.device)
256 | for i in range(len(self.flows) - 1, -1, -1):
257 | z, log_det = self.flows[i].inverse(z)
258 | log_det_tot += log_det
259 | return z, log_det_tot
--------------------------------------------------------------------------------
/train_generative.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from timeit import default_timer
5 | from torch.optim import Adam
6 | import os
7 | import imageio
8 | import matplotlib.pyplot as plt
9 | from autoencoder_model import autoencoder, encoder, decoder
10 | from flow_model import real_nvp
11 | from utils import *
12 | from datasets import *
13 | from laplacian_loss import LaplacianPyramidLoss
14 | import config_generative as config
15 |
16 | torch.manual_seed(0)
17 | np.random.seed(0)
18 |
19 | epochs_flow = config.epochs_flow
20 | epochs_aeder = config.epochs_aeder
21 | flow_depth = config.flow_depth
22 | latent_dim = config.latent_dim
23 | batch_size = config.batch_size
24 | dataset = config.dataset
25 | gpu_num = config.gpu_num
26 | exp_desc = config.exp_desc
27 | image_size = config.image_size
28 | c = config.c
29 | train_aeder = config.train_aeder
30 | train_flow = config.train_flow
31 | restore_flow = config.restore_flow
32 | restore_aeder = config.restore_aeder
33 |
34 | enable_cuda = True
35 | device = torch.device('cuda:' + str(gpu_num) if torch.cuda.is_available() and enable_cuda else 'cpu')
36 |
37 |
38 | all_experiments = 'experiments/'
39 | if os.path.exists(all_experiments) == False:
40 | os.mkdir(all_experiments)
41 |
42 | # experiment path
43 | exp_path = all_experiments + 'generator_' + dataset + '_' \
44 | + str(flow_depth) + '_' + str(latent_dim) + '_' + str(image_size) + '_' + exp_desc
45 |
46 | if os.path.exists(exp_path) == False:
47 | os.mkdir(exp_path)
48 |
49 |
50 | learning_rate = 1e-4
51 | step_size = 50
52 | gamma = 0.5
53 | lam = 0.01
54 |
55 | # Print the experiment setup:
56 | print('Experiment setup:')
57 | print('---> epochs_aeder: {}'.format(epochs_aeder))
58 | print('---> epochs_flow: {}'.format(epochs_flow))
59 | print('---> flow_depth: {}'.format(flow_depth))
60 | print('---> batch_size: {}'.format(batch_size))
61 | print('---> dataset: {}'.format(dataset))
62 | print('---> Learning rate: {}'.format(learning_rate))
63 | print('---> experiment path: {}'.format(exp_path))
64 | print('---> latent dim: {}'.format(latent_dim))
65 | print('---> image size: {}'.format(image_size))
66 |
67 |
68 | # Dataset:
69 | train_dataset = Dataset_loader(dataset = 'train' ,size = (image_size,image_size), c = c, quantize = False)
70 | test_dataset = Dataset_loader(dataset = 'test' ,size = (image_size,image_size), c = c, quantize = False)
71 |
72 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=40, shuffle = True)
73 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=25, num_workers=8)
74 |
75 | ntrain = len(train_loader.dataset)
76 | n_test = len(test_loader.dataset)
77 | print('---> Number of training, test samples: {}, {}'.format(ntrain,n_test))
78 | plot_per_num_epoch = 1 if ntrain > 20000 else 20000//ntrain
79 |
80 | # Loss
81 | dum_samples = next(iter(test_loader)).to(device)
82 | mse_l = F.mse_loss
83 | pyramid_l = LaplacianPyramidLoss(max_levels=3, channels=c, kernel_size=5,
84 | sigma=1, device=device, dtype=dum_samples.dtype)
85 | vgg =Vgg16().to(device)
86 | for param in vgg.parameters():
87 | param.requires_grad = False
88 |
89 | # 1. Training Autoencoder:
90 | enc = encoder(latent_dim = latent_dim, in_res = image_size , c = c).to(device)
91 | dec = decoder(latent_dim = latent_dim, in_res = image_size , c = c).to(device)
92 | aeder = autoencoder(encoder = enc , decoder = dec).to(device)
93 |
94 | num_param_aeder= count_parameters(aeder)
95 | print('---> Number of trainable parameters of Autoencoder: {}'.format(num_param_aeder))
96 |
97 | optimizer_aeder = Adam(aeder.parameters(), lr=learning_rate)
98 | scheduler_aeder = torch.optim.lr_scheduler.StepLR(optimizer_aeder, step_size=step_size, gamma=gamma)
99 |
100 | checkpoint_autoencoder_path = os.path.join(exp_path, 'autoencoder.pt')
101 | if os.path.exists(checkpoint_autoencoder_path) and restore_aeder == True:
102 | checkpoint_autoencoder = torch.load(checkpoint_autoencoder_path)
103 | aeder.load_state_dict(checkpoint_autoencoder['model_state_dict'])
104 | optimizer_aeder.load_state_dict(checkpoint_autoencoder['optimizer_state_dict'])
105 | print('Autoencoder is restored...')
106 |
107 |
108 | if train_aeder:
109 |
110 | if plot_per_num_epoch == -1:
111 | plot_per_num_epoch = epochs_aeder + 1 # only plot in the last epoch
112 |
113 | loss_ae_plot = np.zeros([epochs_aeder])
114 | for ep in range(epochs_aeder):
115 | aeder.train()
116 | t1 = default_timer()
117 | loss_ae_epoch = 0
118 |
119 | # Training 100 rpochs over style and then over combined loss of style and mse
120 | loss_type = 'style' if ep < 100 else 'style_mse'
121 | for image in train_loader:
122 |
123 | batch_size = image.shape[0]
124 | image = image.to(device)
125 |
126 | optimizer_aeder.zero_grad()
127 | image_mat = image.reshape(-1, image_size, image_size, c).permute(0,3,1,2)
128 |
129 | embed = aeder.encoder(image_mat)
130 | image_recon = aeder.decoder(embed)
131 |
132 | recon_loss = aeder_loss(image_mat, image_recon, loss_type = loss_type,
133 | pyramid_l = pyramid_l, mse_l = mse_l, vgg = vgg)
134 | regularization = mse_l(embed, torch.zeros(embed.shape).to(device))
135 | ae_loss = recon_loss + lam * regularization
136 |
137 | ae_loss.backward()
138 | optimizer_aeder.step()
139 | loss_ae_epoch += ae_loss.item()
140 |
141 | scheduler_aeder.step()
142 | t2 = default_timer()
143 | loss_ae_epoch/= ntrain
144 | loss_ae_plot[ep] = loss_ae_epoch
145 |
146 | plt.plot(np.arange(epochs_aeder)[:ep], loss_ae_plot[:ep], 'o-', linewidth=2)
147 | plt.title('AE_loss')
148 | plt.xlabel('epoch')
149 | plt.ylabel('MSE loss')
150 |
151 | plt.savefig(os.path.join(exp_path, 'Autoencoder_loss.jpg'))
152 | np.save(os.path.join(exp_path, 'Autoencoder_loss.npy'), loss_ae_plot[:ep])
153 | plt.close()
154 |
155 | torch.save({
156 | 'model_state_dict': aeder.state_dict(),
157 | 'optimizer_state_dict': optimizer_aeder.state_dict()
158 | }, checkpoint_autoencoder_path)
159 |
160 |
161 | samples_folder = os.path.join(exp_path, 'Results')
162 | if not os.path.exists(samples_folder):
163 | os.mkdir(samples_folder)
164 | image_path_reconstructions = os.path.join(
165 | samples_folder, 'Reconstructions_aeder')
166 |
167 | if not os.path.exists(image_path_reconstructions):
168 | os.mkdir(image_path_reconstructions)
169 |
170 | if (ep + 1) % plot_per_num_epoch == 0 or ep + 1 == epochs_aeder:
171 | sample_number = 25
172 | ngrid = int(np.sqrt(sample_number))
173 |
174 | test_images = next(iter(test_loader)).to(device)[:sample_number]
175 | test_images = test_images.reshape(-1, image_size, image_size, c).permute(0,3,1,2)
176 |
177 | image_np = test_images.permute(0,2,3,1).detach().cpu().numpy()
178 | image_write = image_np[:sample_number].reshape(
179 | ngrid, ngrid,
180 | image_size, image_size,c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0
181 | image_write = image_write.clip(0, 255).astype(np.uint8)
182 | imageio.imwrite(os.path.join(image_path_reconstructions, '%d_gt.png' % (ep,)),image_write)
183 |
184 | embed = aeder.encoder(test_images)
185 | image_recon = aeder.decoder(embed)
186 | image_recon_np = image_recon.detach().cpu().numpy().transpose(0,2,3,1)
187 | image_recon_write = image_recon_np[:sample_number].reshape(
188 | ngrid, ngrid,
189 | image_size, image_size, c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0
190 |
191 | image_recon_write = image_recon_write.clip(0, 255).astype(np.uint8)
192 | imageio.imwrite(os.path.join(image_path_reconstructions, '%d_aeder_recon.png' % (ep,)),
193 | image_recon_write)
194 |
195 | snr_aeder = SNR(image_np , image_recon_np)
196 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
197 | file.write('ep: %03d/%03d | time: %.0f | aeder_loss %.4f | SNR_aeder %.4f' %(ep, epochs_aeder,t2-t1,
198 | loss_ae_epoch, snr_aeder))
199 | file.write('\n')
200 | print('ep: %03d/%03d | time: %.0f | aeder_loss %.4f | SNR_aeder %.4f' %(ep, epochs_aeder,t2-t1,
201 | loss_ae_epoch, snr_aeder))
202 |
203 |
204 | # Training the flow model
205 | nfm = real_nvp(latent_dim = latent_dim, K = flow_depth)
206 | nfm = nfm.to(device)
207 | num_param_nfm = count_parameters(nfm)
208 | print('Number of trainable parametrs of flow: {}'.format(num_param_nfm))
209 |
210 | loss_hist = np.array([])
211 | optimizer_flow = torch.optim.Adam(nfm.parameters(), lr=1e-4, weight_decay=1e-5)
212 | scheduler_flow = torch.optim.lr_scheduler.StepLR(optimizer_flow, step_size=step_size, gamma=gamma)
213 |
214 | # Initialize ActNorm
215 | batch_img = next(iter(train_loader)).to(device)
216 | batch_img = batch_img.reshape(-1, image_size, image_size, c).permute(0,3,1,2)
217 | dummy_samples = aeder.encoder(batch_img)
218 | # dummy_samples = model.reference_latents(torch.tensor(0).to(device))
219 | dummy_samples = dummy_samples.view(-1, latent_dim)
220 | # dummy_samples = torch.tensor(dummy_samples).float().to(device)
221 | likelihood = nfm.log_prob(dummy_samples)
222 |
223 | checkpoint_flow_path = os.path.join(exp_path, 'flow.pt')
224 | if os.path.exists(checkpoint_flow_path) and restore_flow == True:
225 | checkpoint_flow = torch.load(checkpoint_flow_path)
226 | nfm.load_state_dict(checkpoint_flow['model_state_dict'])
227 | optimizer_flow.load_state_dict(checkpoint_flow['optimizer_state_dict'])
228 | print('Flow model is restored...')
229 |
230 | if train_flow:
231 |
232 | for ep in range(epochs_flow):
233 |
234 | nfm.train()
235 | t1 = default_timer()
236 | loss_flow_epoch = 0
237 | for image in train_loader:
238 | optimizer_flow.zero_grad()
239 | image = image.to(device)
240 | image = image.reshape(-1, image_size, image_size, c).permute(0,3,1,2)
241 |
242 | x = aeder.encoder(image)
243 | # Compute loss
244 | loss_flow = nfm.forward_kld(x)
245 |
246 | if ~(torch.isnan(loss_flow) | torch.isinf(loss_flow)):
247 | loss_flow.backward()
248 | optimizer_flow.step()
249 |
250 | # Make layers Lipschitz continuous
251 | # nf.utils.update_lipschitz(nfm, 5)
252 | loss_flow_epoch += loss_flow.item()
253 | # Log loss
254 | loss_hist = np.append(loss_hist, loss_flow.to('cpu').data.numpy())
255 |
256 | scheduler_flow.step()
257 | t2 = default_timer()
258 | loss_flow_epoch /= ntrain
259 |
260 | torch.save({
261 | 'model_state_dict': nfm.state_dict(),
262 | 'optimizer_state_dict': optimizer_flow.state_dict()
263 | }, checkpoint_flow_path)
264 |
265 |
266 | if (ep + 1) % plot_per_num_epoch == 0 or ep + 1 == epochs_flow:
267 | samples_folder = os.path.join(exp_path, 'Results')
268 | if not os.path.exists(samples_folder):
269 | os.mkdir(samples_folder)
270 | image_path_generated = os.path.join(
271 | samples_folder, 'generated')
272 |
273 | if not os.path.exists(image_path_generated):
274 | os.mkdir(image_path_generated)
275 | sample_number = 25
276 | ngrid = int(np.sqrt(sample_number))
277 |
278 | generated_embed, _ = nfm.sample(torch.tensor(sample_number).to(device))
279 |
280 | generated_samples = aeder.decoder(generated_embed)
281 | generated_samples = generated_samples.detach().cpu().numpy().transpose(0,2,3,1)
282 |
283 | generated_samples = generated_samples[:sample_number].reshape(
284 | ngrid, ngrid,
285 | image_size, image_size, c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0
286 | generated_samples = generated_samples.clip(0, 255).astype(np.uint8)
287 |
288 |
289 | imageio.imwrite(os.path.join(image_path_generated, 'epoch %d.png' % (ep,)), generated_samples) # training images
290 |
291 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
292 | file.write('ep: %03d/%03d | time: %.0f | ML_loss %.4f' %(ep, epochs_flow, t2-t1, loss_flow_epoch))
293 | file.write('\n')
294 |
295 | print('ep: %03d/%03d | time: %.0f | ML_loss %.4f' %(ep, epochs_flow, t2-t1, loss_flow_epoch))
296 |
297 |
--------------------------------------------------------------------------------
/results.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import os
5 | import imageio
6 | from utils import *
7 | import config_funknn as config
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | def evaluator(ep, subset, data_loader, model, exp_path):
12 |
13 | samples_folder = os.path.join(exp_path, 'Results')
14 | if not os.path.exists(samples_folder):
15 | os.mkdir(samples_folder)
16 | image_path_reconstructions = os.path.join(
17 | samples_folder, 'Reconstructions')
18 |
19 | if not os.path.exists(image_path_reconstructions):
20 | os.mkdir(image_path_reconstructions)
21 |
22 | max_scale = config.max_scale
23 | recursive = config.recursive
24 | sample_number = config.sample_number
25 | image_size = config.image_size
26 | c = config.c
27 |
28 | if subset == 'ood':
29 | max_scale = 2
30 | recursive = False
31 |
32 | device = model.ws1.device
33 | num_samples_write = sample_number if sample_number < 26 else 25
34 | ngrid = int(np.sqrt(num_samples_write))
35 | num_samples_write = int(ngrid **2)
36 |
37 | images_k = next(iter(data_loader)).to(device)[:sample_number]
38 | images_k = images_k.reshape(-1, max_scale*image_size, max_scale*image_size, c).permute(0,3,1,2)
39 | images = F.interpolate(images_k, size = image_size, antialias = True, mode = 'bilinear')
40 |
41 | scales = [i+1 for i in range(int(np.log2(max_scale)))]
42 | scales = np.power(2, scales)
43 |
44 | print('Evaluation over {} set:'.format(subset))
45 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
46 | file.write('Evaluation over {} set:'.format(subset))
47 | file.write('\n')
48 |
49 | if recursive == True:
50 |
51 | images_down = images
52 |
53 | for i in range(len(scales)):
54 | # Recuirsive image generation starting from factor 2 well-suited for factor training mode
55 | res = scales[i]*image_size
56 | # GT:
57 | images_temp = F.interpolate(images_k, size = res , antialias = True, mode = 'bilinear')
58 | images_np = images_temp.permute(0, 2, 3, 1).detach().cpu().numpy()
59 | image_write = images_np[:num_samples_write].reshape(
60 | ngrid, ngrid,
61 | res, res,c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0
62 | image_write = image_write.clip(0, 255).astype(np.uint8)
63 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_gt_%d.png' % (ep,scales[i])),
64 | image_write)
65 |
66 | # Recon:
67 | coords = get_mgrid(res).reshape(-1, 2)
68 | coords = torch.unsqueeze(coords, dim = 0)
69 | coords = coords.expand(images_k.shape[0] , -1, -1).to(device)
70 | recon_np = batch_sampling(images_down, coords,c, model)
71 | recon_np = np.reshape(recon_np, [-1, res, res, c])
72 | recon_write = recon_np[:num_samples_write].reshape(
73 | ngrid, ngrid, res, res, c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0
74 | recon_write = recon_write.clip(0, 255).astype(np.uint8)
75 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_recursive_FunkNN_%d.png' % (ep,scales[i])),
76 | recon_write)
77 |
78 | # Interpolate:
79 | interpolate = F.interpolate(images, size = res, mode = 'bilinear')
80 | interpolate_np = interpolate.detach().cpu().numpy().transpose(0,2,3,1)
81 | interpolate_write = interpolate_np[:num_samples_write].reshape(
82 | ngrid, ngrid,
83 | res, res, c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0
84 | interpolate_write = interpolate_write.clip(0, 255).astype(np.uint8)
85 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_interpolate_%d.png' % (ep,scales[i])),
86 | interpolate_write) # mesh_based_recon
87 |
88 | snr_recon = SNR(images_np, recon_np)
89 | snr_interpolate = SNR(images_np, interpolate_np)
90 | recon_np = recon_np.transpose([0,3,1,2])
91 | images_down = torch.tensor(recon_np, dtype = images_down.dtype).to(device)
92 |
93 | print('SNR_interpolate_recursive_f{}: {:.1f} | SNR_FunkNN_recursive_f{}: {:.1f}'.format(scales[i],
94 | snr_interpolate, scales[i], snr_recon))
95 |
96 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
97 | file.write('SNR_interpolate_recursive_f{}: {:.1f} | SNR_FunkNN_recursive_f{}: {:.1f} | '.format(scales[i],
98 | snr_interpolate, scales[i], snr_recon))
99 | file.write('\n')
100 | if subset == 'ood':
101 | file.write('\n')
102 |
103 | else:
104 | for i in range(len(scales)):
105 | # Direct image generation well-suited for single and continuous training modes
106 | res = scales[i]*image_size
107 | # GT:
108 | images_temp = F.interpolate(images_k, size = res , antialias = True, mode = 'bilinear')
109 | images_np = images_temp.permute(0, 2, 3, 1).detach().cpu().numpy()
110 | image_write = images_np[:num_samples_write].reshape(
111 | ngrid, ngrid,
112 | res, res,c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0
113 | image_write = image_write.clip(0, 255).astype(np.uint8)
114 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_gt_%d.png' % (ep,scales[i])),
115 | image_write)
116 |
117 | # Recon:
118 | coords = get_mgrid(res).reshape(-1, 2)
119 | coords = torch.unsqueeze(coords, dim = 0)
120 | coords = coords.expand(images_k.shape[0] , -1, -1).to(device)
121 | recon_np = batch_sampling(images, coords,c, model)
122 | recon_np = np.reshape(recon_np, [-1, res, res, c])
123 | recon_write = recon_np[:num_samples_write].reshape(
124 | ngrid, ngrid, res, res, c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0
125 | recon_write = recon_write.clip(0, 255).astype(np.uint8)
126 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_FunkNN_%d.png' % (ep,scales[i])),
127 | recon_write)
128 |
129 | # Interpolate:
130 | interpolate = F.interpolate(images, size = res, mode = 'bilinear')
131 | interpolate_np = interpolate.detach().cpu().numpy().transpose(0,2,3,1)
132 | interpolate_write = interpolate_np[:num_samples_write].reshape(
133 | ngrid, ngrid,
134 | res, res, c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0
135 | interpolate_write = interpolate_write.clip(0, 255).astype(np.uint8)
136 | imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_interpolate_%d.png' % (ep,scales[i])),
137 | interpolate_write) # mesh_based_recon
138 |
139 | snr_recon = SNR(images_np, recon_np)
140 | snr_interpolate = SNR(images_np, interpolate_np)
141 |
142 | print('SNR_interpolate_f{}: {:.1f} | SNR_FunkNN_f{}: {:.1f}'.format(scales[i],
143 | snr_interpolate, scales[i], snr_recon))
144 |
145 | with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
146 | file.write('SNR_interpolate_f{}: {:.1f} | SNR_FunkNN_f{}: {:.1f} | '.format(scales[i],
147 | snr_interpolate, scales[i], snr_recon))
148 | file.write('\n')
149 | if subset == 'ood':
150 | file.write('\n')
151 |
152 |
153 |
154 |
155 | if config.derivatives_evaluation:
156 | # Gradients:
157 | coords_2k = get_mgrid(2*image_size).reshape(-1, 2)
158 | coords_2k = torch.unsqueeze(coords_2k, dim = 0)
159 | coords_2k = coords_2k.expand(images_k.shape[0] , -1, -1).to(device)
160 | coords_2k = coords_2k.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
161 | funknn_grad = batch_grad(images, coords_2k,c, model)
162 | funknn_grad = torch.tensor(funknn_grad, dtype=coords.dtype)
163 | funknn_grad = torch.norm(funknn_grad, dim = 2).cpu().detach().numpy()
164 | funknn_grad = np.reshape(funknn_grad, [-1, 2*image_size, 2*image_size,1])
165 | funknn_grad_write = funknn_grad[:sample_number, :, :].reshape(
166 | ngrid, ngrid,
167 | 2*image_size, 2*image_size, 1).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, 1)*255.0
168 | plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_grad_funknn.png' % (ep,)),
169 | funknn_grad_write[:,:,0], cmap='seismic')
170 |
171 |
172 | coords_2k = get_mgrid(2*image_size).reshape(-1, 2)
173 | coords_2k = torch.unsqueeze(coords_2k, dim = 0)
174 | coords_2k = coords_2k.expand(images_k.shape[0] , -1, -1).to(device)
175 | coords_2k = coords_2k.clone().detach().requires_grad_(True)
176 | funknn = torch.tensor(batch_sampling(images, coords_2k,c, model), dtype = coords_2k.dtype).to(device)
177 | funknn_mat = funknn.reshape([-1, 2*image_size, 2*image_size, c]).permute(0,3,1,2)
178 | true_grad = image_derivative(funknn_mat , c)[1]
179 | true_grad = true_grad.permute(0,2,3,1).detach().cpu().numpy()
180 | true_grad_write = true_grad[:sample_number, :, :].reshape(
181 | ngrid, ngrid,
182 | 2*image_size, 2*image_size, 1).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, 1)*255.0
183 |
184 | plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_grad_finite_diff.png' % (ep,)),
185 | true_grad_write[:,:,0], cmap='seismic')
186 |
187 | coords_2k = get_mgrid(2*image_size).reshape(-1, 2)
188 | coords_2k = torch.unsqueeze(coords_2k, dim = 0)
189 | coords_2k = coords_2k.expand(images_k.shape[0] , -1, -1).to(device)
190 | coords_2k = coords_2k.clone().detach().requires_grad_(True).to(device)
191 | coords_2k_grad = coords_2k.reshape((-1,2*image_size,2*image_size,2))
192 | coords_2k_grad = 2 * torch.flip(coords_2k_grad , dims = [3])
193 | img_tmp = model.grid_sample_customized(images, coords_2k_grad, mode = config.interpolation_kernel)
194 | img_tmp = img_tmp.permute(0,2,3,1).reshape(1 , -1 , c)
195 | st_grad = gradient(img_tmp, coords_2k, grad_outputs=None)
196 | st_grad = torch.norm(st_grad, dim = 2).cpu().detach().numpy()
197 | st_grad = np.reshape(st_grad, [-1, 2*image_size, 2*image_size, 1])
198 | st_grad_write = st_grad[:sample_number, :, :].reshape(
199 | ngrid, ngrid,
200 | 2*image_size, 2*image_size, 1).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, 1)*255.0
201 |
202 | plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_grad_ST.png' % (ep,)),
203 | st_grad_write[:,:,0], cmap='seismic')
204 |
205 |
206 |
207 | ############################################################################################
208 | # Laplacian:
209 | shift = 4
210 | coords_2k = get_mgrid(2*image_size).reshape(-1, 2)
211 | coords_2k = torch.unsqueeze(coords_2k, dim = 0)
212 | coords_2k = coords_2k.expand(images_k.shape[0] , -1, -1).to(device)
213 | coords_2k = coords_2k.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
214 | funknn_laplace= batch_laplace(images, coords_2k,c, model)
215 | funknn_laplace = np.reshape(funknn_laplace, [-1, 2*image_size, 2*image_size,1])
216 |
217 | funknn_laplace_write = funknn_laplace[:images_k.shape[0], shift:2*image_size-shift, shift:2*image_size-shift].reshape(
218 | ngrid, ngrid,
219 | 2*image_size -2*shift, 2*image_size-2*shift, 1).swapaxes(1, 2).reshape(ngrid*(2*image_size-2*shift), -1, 1)
220 |
221 | plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_laplace_funknn.png' % (ep,)),
222 | funknn_laplace_write[:,:,0], cmap='seismic')
223 |
224 | coords_2k = get_mgrid(2*image_size).reshape(-1, 2)
225 | coords_2k = torch.unsqueeze(coords_2k, dim = 0)
226 | coords_2k = coords_2k.expand(images_k.shape[0] , -1, -1).to(device)
227 | coords_2k = coords_2k.clone().detach().requires_grad_(True).to(device)
228 | coords_2k_grad = coords_2k.reshape((-1,2*image_size,2*image_size,2))
229 | coords_2k_grad = 2 * torch.flip(coords_2k_grad , dims = [3])
230 | img_tmp = model.grid_sample_customized(images, coords_2k_grad, mode = config.interpolation_kernel)
231 | img_tmp = img_tmp.permute(0,2,3,1).reshape(1 , -1 , c)
232 | st_laplace = laplace(img_tmp, coords_2k).detach() # Ground truth derivatives
233 | st_laplace = np.reshape(st_laplace.cpu().numpy(), [-1, 2*image_size, 2*image_size, 1])
234 | st_laplace_write = st_laplace[:images_k.shape[0], shift:2*image_size-shift, shift:2*image_size-shift].reshape(
235 | ngrid, ngrid,
236 | 2*image_size -2*shift, 2*image_size-2*shift, 1).swapaxes(1, 2).reshape(ngrid*(2*image_size-2*shift), -1, 1)
237 |
238 | plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_laplace_ST.png' % (ep,)),
239 | st_laplace_write[:,:,0], cmap='seismic')
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torchvision.models import vgg16
5 | from skimage.transform import radon, iradon
6 | from scipy import optimize
7 | from skimage.metrics import peak_signal_noise_ratio as psnr
8 |
9 |
10 | def image_derivative(x , c):
11 | '''x must be a (b,c*h*w) tensor'''
12 |
13 | horiz_derive = np.array([[1, 0, -1],[2, 0, -2],[1,0,-1]], dtype = np.float64)
14 | horiz_derive = horiz_derive[None,None,:]
15 | horiz_derive = np.repeat(horiz_derive,c,axis = 1)
16 |
17 | vert_derive = np.array([[1,2,1],[0,0,0], [-1,-2,-1]])
18 | vert_derive = vert_derive[None,None,:]
19 | vert_derive = np.repeat(vert_derive,c,axis = 1)
20 |
21 | conv_horiz = torch.nn.Conv2d(1, c, kernel_size=3, stride=1, padding='same', padding_mode = 'replicate',bias=False)
22 | conv_horiz.weight.data= torch.from_numpy(horiz_derive).float().to(x.device)
23 |
24 | conv_vert = torch.nn.Conv2d(1, c, kernel_size=3, stride=1, padding='same', padding_mode = 'replicate', bias=False)
25 | conv_vert.weight.data= torch.from_numpy(vert_derive).float().to(x.device)
26 |
27 | G_x = conv_horiz(x)
28 | G_y = conv_vert(x)
29 | G = torch.cat((G_x , G_y) , axis = 1)
30 | G_mag = torch.sqrt(torch.pow(G_x,2)+ torch.pow(G_y,2))
31 |
32 | return G, G_mag
33 |
34 |
35 | def PSNR(x_true , x_pred):
36 |
37 | s = 0
38 | for i in range(np.shape(x_pred)[0]):
39 | s += psnr(x_pred[i],
40 | x_true[i],
41 | data_range=x_true[i].max() - x_true[i].min())
42 |
43 | return s/np.shape(x_pred)[0]
44 |
45 | def SNR(x_true , x_pred):
46 | '''Calculate SNR of a batch of true and their estimations'''
47 |
48 | # x_true = np.reshape(x_true , [np.shape(x_true)[0] , -1])
49 | # x_pred = np.reshape(x_pred , [np.shape(x_pred)[0] , -1])
50 |
51 | snr = 0
52 | for i in range(x_true.shape[0]):
53 | Noise = x_true[i] - x_pred[i]
54 | Noise_power = np.sum(np.square(np.abs(Noise)))
55 | Signal_power = np.sum(np.square(np.abs(x_true[i])))
56 | snr += 10*np.log10(Signal_power/Noise_power)
57 |
58 | return snr/x_true.shape[0]
59 |
60 |
61 |
62 | def SNR_rescale(x_true , x_pred):
63 | '''Calculate SNR rescale of a batch of true and their estimations'''
64 | snr = 0
65 | for i in range(x_true.shape[0]):
66 |
67 | def func(weights):
68 | Noise = x_true[i] - (weights[0]*x_pred[i]+weights[1])
69 | Noise_power = np.sum(np.square(np.abs(Noise)))
70 | Signal_power = np.sum(np.square(np.abs(x_true[i])))
71 | SNR = 10*np.log10(np.mean(Signal_power/(Noise_power+1e-12)))
72 | return SNR
73 | opt = optimize.minimize(lambda x: -func(x),x0=np.array([1,0]))
74 | snr += -opt.fun
75 | weights = opt.x
76 | return snr/x_true.shape[0]
77 |
78 |
79 | def PSNR_rescale(x_true , x_pred):
80 | '''Calculate SNR rescale of a batch of true and their estimations'''
81 | snr = 0
82 | for i in range(x_true.shape[0]):
83 |
84 | def func(weights):
85 | x_pred_rescale= weights[0]*x_pred[i]+weights[1]
86 | s = psnr(x_pred_rescale,
87 | x_true[i],
88 | data_range=x_true[i].max() - x_true[i].min())
89 |
90 | return s
91 | opt = optimize.minimize(lambda x: -func(x),x0=np.array([1,0]))
92 | snr += -opt.fun
93 | weights = opt.x
94 | return snr/x_true.shape[0], weights
95 |
96 |
97 |
98 | def count_parameters(model):
99 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
100 |
101 |
102 | def relative_mse_loss(x_true, x_pred):
103 | noise_power = np.sqrt(np.sum(np.square(x_true - x_pred) , axis = [1,2,3]))
104 | signal_power = np.sqrt(np.sum(np.square(x_true) , axis = [1,2,3]))
105 |
106 | return np.mean(noise_power/signal_power)
107 |
108 |
109 | def batch_sampling(image_recon, coords, c, model):
110 | s = 64
111 |
112 | outs = np.zeros([np.shape(coords)[0], np.shape(coords)[1], c])
113 | for i in range(np.shape(coords)[1]//s):
114 |
115 | batch_coords = coords[:,i*s: (i+1)*s]
116 | out = model(batch_coords, image_recon).detach().cpu().numpy()
117 | outs[:,i*s: (i+1)*s] = out
118 |
119 | return outs
120 |
121 |
122 | def batch_grad(image_recon, coords, c, model):
123 | s = 64
124 |
125 | out_grads = np.zeros([np.shape(coords)[0], np.shape(coords)[1], 2])
126 | for i in range(np.shape(coords)[1]//s):
127 |
128 | batch_coords = coords[:,i*s: (i+1)*s]
129 | out = model(batch_coords, image_recon)
130 | out_grad = gradient(out, batch_coords).detach().cpu().numpy()
131 | out_grads[:,i*s: (i+1)*s] = out_grad
132 |
133 | return out_grads
134 |
135 |
136 | def batch_laplace(image_recon, coords, c, model):
137 | s = 64
138 |
139 | out_laplaces = np.zeros([np.shape(coords)[0], np.shape(coords)[1],1])
140 | for i in range(np.shape(coords)[1]//s):
141 |
142 | batch_coords = coords[:,i*s: (i+1)*s]
143 | out = model(batch_coords, image_recon)
144 | out_laplace = laplace(out, batch_coords).detach().cpu().numpy()
145 | out_laplaces[:,i*s: (i+1)*s] = out_laplace
146 |
147 | return out_laplaces
148 |
149 |
150 | def batch_grad_pde(image_recon, coords, c, model):
151 | s = 64
152 |
153 | out_grads = np.zeros([np.shape(coords)[0], np.shape(coords)[1], 2])
154 | for i in range(np.shape(coords)[1]//s):
155 |
156 | batch_coords = coords[:,i*s: (i+1)*s]
157 | out = model(batch_coords, image_recon)
158 | out_grad = gradient(out, batch_coords).detach().cpu().numpy()
159 | out_grads[:,i*s: (i+1)*s] = out_grad
160 |
161 | return out_grads
162 |
163 |
164 |
165 |
166 |
167 | def simpleaxis(ax):
168 | ax.spines['top'].set_visible(False)
169 | ax.spines['right'].set_visible(False)
170 | ax.get_xaxis().tick_bottom()
171 | ax.get_yaxis().tick_left()
172 |
173 |
174 | def get_mgrid(sidelen):
175 | # Generate 2D pixel coordinates from an image of sidelen x sidelen
176 | pixel_coords = np.stack(np.mgrid[:sidelen,:sidelen], axis=-1)[None,...].astype(np.float32)
177 | pixel_coords /= sidelen
178 | pixel_coords -= 0.5
179 | pixel_coords = torch.Tensor(pixel_coords).view(-1, 2)
180 | return pixel_coords
181 |
182 | def get_mgrid_unbalanced(sidelen1,sidelen2):
183 | # Generate 2D pixel coordinates from an image of sidelen x sidelen
184 | pixel_coords = np.stack(np.mgrid[:sidelen1,:sidelen2], axis=-1)[None,...].astype(np.float32)
185 | pixel_coords = torch.Tensor(pixel_coords).view(-1, 2)
186 | pixel_coords = pixel_coords/(pixel_coords.max(dim = 0)[0]+1)
187 | pixel_coords -= 0.5
188 | return pixel_coords
189 |
190 |
191 | def lin2img(tensor):
192 | batch_size, num_samples, channels = tensor.shape
193 | sidelen = np.sqrt(num_samples).astype(int)
194 | return tensor.view(batch_size, channels, sidelen, sidelen)
195 |
196 |
197 | def plot_sample_image(img_batch, ax):
198 | # plot the first item in batch
199 | img = lin2img(img_batch)[0].detach().cpu().numpy()
200 | img += 1
201 | img /= 2.
202 | img = np.clip(img, 0., 1.)
203 | ax.set_axis_off()
204 | ax.imshow(img)
205 |
206 |
207 | def gradient(y, x, grad_outputs=None):
208 | if grad_outputs is None:
209 | grad_outputs = torch.ones_like(y)
210 | grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
211 | return grad
212 |
213 |
214 | def divergence(y, x):
215 | div = 0.
216 | for i in range(y.shape[-1]):
217 | div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
218 | return div
219 |
220 |
221 | def multiple_gradient(y , x):
222 | out = torch.zeros([y.shape[0] , y.shape[1], 6])
223 | for i in range(y.shape[-1]):
224 | a = torch.autograd.grad(y[...,i], x, torch.ones_like(y[...,i]), create_graph=True)[0]
225 | out[...,i:i +2] = a
226 | return out
227 |
228 | def laplace(y, x):
229 | grad = gradient(y, x)
230 | return divergence(grad, x)
231 |
232 |
233 | def aeder_loss(x_true, x_hat, loss_type = 'mse', pyramid_l = None, mse_l = None, vgg = None):
234 | batch_size = x_true.shape[0]
235 | if loss_type == 'mse':
236 | loss = mse_l(x_hat.reshape(batch_size, -1) , x_true.reshape(batch_size, -1))
237 |
238 | elif loss_type == 'pyramid':
239 | loss = pyramid_l(x_hat , x_true)
240 |
241 | elif loss_type == 'style':
242 |
243 | reg_weight = 1e-6
244 | style_weight = 1
245 | feature_weight = 10
246 | pure_feat = vgg(x_true)
247 | recon_feat = vgg(x_hat)
248 |
249 | loss_style = 0
250 | loss_feature = 0
251 |
252 | for k in range(len(pure_feat)):
253 |
254 | bs, ch, h, w = pure_feat[k].size()
255 |
256 | pure_re_feat= pure_feat[k].view(bs, ch, h*w)
257 | gram_pure = torch.matmul(pure_re_feat, torch.transpose(pure_re_feat,1,2))/(ch*h*w)
258 |
259 | recon_re_feat= recon_feat[k].view(bs, ch, h*w)
260 | gram_recon = torch.matmul(recon_re_feat, torch.transpose(recon_re_feat,1,2))/(ch*h*w)
261 |
262 | loss_style = loss_style+ mse_l(gram_pure.view(batch_size,-1),gram_recon.view(batch_size,-1))
263 |
264 | loss_feature = loss_feature + mse_l(pure_feat[k].reshape(batch_size,-1),recon_feat[k].reshape(batch_size,-1))#/(pure_feat[k].size(1)*pure_feat[k].size(2)*pure_feat[k].size(3))
265 |
266 |
267 | loss_style = style_weight * loss_style
268 | loss_feature = feature_weight * loss_feature
269 |
270 | loss_tv = reg_weight * (
271 | torch.sum(torch.abs(x_hat[:, :, :, :-1] - x_hat[:, :, :, 1:])) +
272 | torch.sum(torch.abs(x_hat[:, :, :-1, :] - x_hat[:, :, 1:, :])))
273 |
274 | loss = loss_feature + loss_style + loss_tv
275 |
276 | elif loss_type == 'style_mse':
277 |
278 | reg_weight = 1e-6
279 | style_weight = 1
280 | feature_weight = 10
281 | mse_weight = 5000
282 | pure_feat = vgg(x_true)
283 | recon_feat = vgg(x_hat)
284 |
285 | loss_style = 0
286 | loss_feature = 0
287 |
288 | for k in range(len(pure_feat)):
289 |
290 | bs, ch, h, w = pure_feat[k].size()
291 |
292 | pure_re_feat= pure_feat[k].view(bs, ch, h*w)
293 | gram_pure = torch.matmul(pure_re_feat, torch.transpose(pure_re_feat,1,2))/(ch*h*w)
294 |
295 | recon_re_feat= recon_feat[k].view(bs, ch, h*w)
296 | gram_recon = torch.matmul(recon_re_feat, torch.transpose(recon_re_feat,1,2))/(ch*h*w)
297 |
298 | loss_style = loss_style+ mse_l(gram_pure.view(batch_size,-1),gram_recon.view(batch_size,-1))
299 |
300 | loss_feature = loss_feature + mse_l(pure_feat[k].reshape(batch_size,-1),recon_feat[k].reshape(batch_size,-1))#/(pure_feat[k].size(1)*pure_feat[k].size(2)*pure_feat[k].size(3))
301 |
302 |
303 | loss_style = style_weight * loss_style
304 | loss_feature = feature_weight * loss_feature
305 |
306 | loss_tv = reg_weight * (
307 | torch.sum(torch.abs(x_hat[:, :, :, :-1] - x_hat[:, :, :, 1:])) +
308 | torch.sum(torch.abs(x_hat[:, :, :-1, :] - x_hat[:, :, 1:, :])))
309 |
310 | loss_mse = mse_weight * mse_l(x_hat.reshape(batch_size, -1) , x_true.reshape(batch_size, -1))
311 |
312 | loss = loss_feature + loss_style + loss_tv + loss_mse
313 |
314 |
315 | return loss
316 |
317 |
318 |
319 |
320 |
321 | def training_strategy(x_true, image_size, factor = None , mode = 'continuous', image_recon = None):
322 |
323 | image_recon = x_true if image_recon == None else image_recon
324 |
325 | if mode == 'continuous':
326 |
327 | image_size_random = np.random.randint(low = image_size//4, high = image_size//2, size = 1)[0]
328 | x_low = F.interpolate(image_recon, size = image_size_random, antialias = True, mode = 'bilinear')
329 | x_high = x_true
330 | image_size_high = image_size
331 |
332 |
333 | elif mode == 'factor':
334 |
335 | image_size_low = np.random.randint(low = image_size//8, high = image_size//2, size = 1)[0]
336 | image_size_high = 2 * image_size_low
337 | x_high = F.interpolate(x_true, size = image_size_high, antialias = True, mode = 'bilinear')
338 | x_low = F.interpolate(image_recon, size = image_size_low, antialias = True, mode = 'bilinear')
339 |
340 | elif mode == 'single':
341 | x_high = x_true
342 | x_low = F.interpolate(image_recon, size = image_size//2, antialias = True, mode = 'bilinear')
343 | image_size_high = image_size
344 |
345 | return x_high, x_low, image_size_high
346 |
347 |
348 | class Vgg16(torch.nn.Module):
349 | def __init__(self):
350 | super(Vgg16, self).__init__()
351 | features = list(vgg16(pretrained = True).features)[:23]
352 | self.features = torch.nn.ModuleList(features).eval()
353 |
354 | def forward(self, x):
355 | results = []
356 | for ii,model in enumerate(self.features):
357 | x = model(x)
358 | if ii in {3,8,15,22}:
359 | results.append(x)
360 |
361 | return results
362 |
363 |
364 |
365 | def fbp_batch(x):
366 | n_measure = x.shape[2]
367 | theta = np.linspace(0., 180., n_measure, endpoint=False)
368 |
369 | fbps = []
370 | for i in range(x.shape[0]):
371 | fbps.append(iradon(x[i], theta=theta, circle = False))
372 |
373 | fbps = np.array(fbps)
374 | return fbps
--------------------------------------------------------------------------------
/funknn_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | import config_funknn as config
5 | import numpy as np
6 |
7 |
8 | def squeeze(x , f):
9 | x = x.permute(0,2,3,1)
10 | b, N1, N2, nch = x.shape
11 | x = torch.reshape(
12 | torch.permute(
13 | torch.reshape(x, shape=[b, N1//f, f, N2//f, f, nch]),
14 | [0, 1, 3, 2, 4, 5]),
15 | [b, N1//f, N2//f, nch*f*f])
16 | x = x.permute(0,3,1,2)
17 | return x
18 |
19 |
20 | def reflect_coords(ix, min_val, max_val):
21 |
22 | pos_delta = ix[ix>max_val] - max_val
23 |
24 | neg_delta = min_val - ix[ix < min_val]
25 |
26 | ix[ix>max_val] = ix[ix>max_val] - 2*pos_delta
27 | ix[ix