├── __init__.py ├── .gitignore ├── pytorch_probgraph ├── __init__.py ├── README.md ├── dbn.py ├── utils.py ├── hm.py ├── unitlayer.py ├── interaction.py ├── rbm.py └── dbm.py ├── examples ├── evaluate_models_emnist_conv.sh ├── evaluate_models_emnist.sh ├── evaluate_models_emnist_test.sh ├── evaluate_models_emnist_hm.sh ├── Model_DBN_Conv.py ├── Model_DBN_Conv_ProbMaxPool.py ├── Model_HM_WS.py ├── Model_HM_RWS.py ├── Model_DBN.py ├── Model_DBN_IntModule.py ├── Model_RBM_CD.py ├── Model_RBM_PCD.py ├── Model_DBM_CD.py ├── Model_DBM_PCD.py └── evaluate_emnist.py ├── docs ├── Makefile ├── api.rst ├── conf.py ├── howto.rst └── index.rst ├── setup.py ├── LICENSE └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | docs/_build 3 | pytorch_probgraph/__pycache__ 4 | build 5 | dist 6 | PyTorch_ProbGraph.egg-info 7 | -------------------------------------------------------------------------------- /pytorch_probgraph/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | PyTorch-ProbGraph library 3 | ''' 4 | 5 | from .utils import * 6 | from .unitlayer import * 7 | from .interaction import * 8 | from .rbm import * 9 | from .dbn import * 10 | from .hm import * 11 | from .dbm import * 12 | -------------------------------------------------------------------------------- /examples/evaluate_models_emnist_conv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # RBM 4 | echo 'Evaluate DBN-Conv-ProbMaxPool' 5 | python3 evaluate_emnist.py --directory Eval_DBN_Conv_ProbMaxPool --file Model_DBN_Conv_ProbMaxPool.py --model Model_DBN_Conv_ProbMaxPool --tqdm 6 | echo '\n\n\n' 7 | 8 | #echo 'Evaluate DBN-Conv-ProbMaxProol' 9 | #python3 evaluate_emnist.py --directory Eval_DBN --file Model_DBN.py --model Model_DBN --tqdm 10 | #echo '\n\n\n' 11 | 12 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setup(name='PyTorch-ProbGraph', 7 | version='0.0.1', 8 | description='Hierarchical Probabilistic Graphical Models in PyTorch', 9 | long_description=long_description, 10 | author='Korbinian Poeppel, Hendrik Elvers', 11 | author_email='korbinian.poeppel@tum.de, hendrik.elvers@tum.de', 12 | url='https://github.com/kpoeppel/pytorch_probgraph/', 13 | packages=['pytorch_probgraph'], 14 | install_requires=['torch', 'numpy', 'matplotlib', 'tqdm', 15 | 'sphinx_rtd_theme', 'sphinx', 'setuptools'], 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: BSD License", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires='>=3.6', 22 | ) 23 | -------------------------------------------------------------------------------- /examples/evaluate_models_emnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # RBM 4 | echo 'Evaluate RBM' 5 | python3 evaluate_emnist.py --directory Eval_RBM --file Model_RBM.py --model Model_RBM --tqdm 6 | echo '\n\n\n' 7 | 8 | echo 'Evaluate DBN' 9 | python3 evaluate_emnist.py --directory Eval_DBN --file Model_DBN.py --model Model_DBN --tqdm 10 | echo '\n\n\n' 11 | 12 | echo 'Evaluate DBM_PCD' 13 | python3 evaluate_emnist.py --directory Eval_DBM_PCD --file Model_DBM_PCD.py --model Model_DBM_PCD --tqdm 14 | echo '\n\n\n' 15 | 16 | echo 'Evaluate DBM_CD' 17 | python3 evaluate_emnist.py --directory Eval_DBM_CD --file Model_DBM_CD.py --model Model_DBM_CD --tqdm 18 | echo '\n\n\n' 19 | 20 | echo 'Evaluate HW_WS' 21 | python3 evaluate_emnist.py --directory Eval_HM_WS --file Model_HM_WS.py --model Model_HM_WS --tqdm 22 | echo '\n\n\n' 23 | 24 | echo 'Evaluate HM_RWS' 25 | python3 evaluate_emnist.py --directory Eval_HM_RWS --file Model_HM_RWS.py --model Model_HM_RWS --tqdm 26 | echo '\n\n\n' 27 | 28 | -------------------------------------------------------------------------------- /examples/evaluate_models_emnist_test.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | 4 | # RBM 5 | echo 'Evaluate RBM' 6 | python3 evaluate_emnist.py --directory Check_RBM --file Model_RBM_CD.py --model Model_RBM_CD --tqdm --maxepochs 5 --testing 7 | echo '\n\n\n' 8 | 9 | echo 'Evaluate DBM' 10 | python3 evaluate_emnist.py --directory Check_DBM --file Model_DBM_PCD.py --model Model_DBM_PCD --tqdm --maxepochs 5 --testing 11 | echo '\n\n\n' 12 | 13 | echo 'Evaluate DBN' 14 | python3 evaluate_emnist.py --directory Check_DBN_Test_IntModule --file Model_DBN_IntModule.py --model Model_DBN_IntModule --tqdm --testing --maxepochs 5 15 | echo '\n\n\n' 16 | 17 | echo 'Evaluate DBN' 18 | python3 evaluate_emnist.py --directory Check_DBN_Test --file Model_DBN.py --model Model_DBN --tqdm --testing --maxepochs 5 19 | echo '\n\n\n' 20 | 21 | echo 'Evaluate HW_RWS' 22 | python3 evaluate_emnist.py --directory Check_HM_RWS --file Model_HM_RWS.py --model Model_HM_RWS --tqdm --testing --maxepochs 5 23 | echo '\n\n\n' 24 | -------------------------------------------------------------------------------- /examples/evaluate_models_emnist_hm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # RBM 4 | #echo 'Evaluate RBM' 5 | #python3 evaluate_emnist.py --directory Eval_RBM --file Model_RBM.py --model Model_RBM --tqdm 6 | #echo '\n\n\n' 7 | 8 | #echo 'Evaluate DBN' 9 | #python3 evaluate_emnist.py --directory Eval_DBN --file Model_DBN.py --model Model_DBN --tqdm 10 | #echo '\n\n\n' 11 | 12 | #echo 'Evaluate DBM_PCD' 13 | #python3 evaluate_emnist.py --directory Eval_DBM_PCD --file Model_DBM_PCD.py --model Model_DBM_PCD --tqdm 14 | #echo '\n\n\n' 15 | 16 | #echo 'Evaluate DBM_CD' 17 | #python3 evaluate_emnist.py --directory Eval_DBM_CD --file Model_DBM_CD.py --model Model_DBM_CD --tqdm 18 | #echo '\n\n\n' 19 | 20 | echo 'Evaluate HW_WS' 21 | python3 evaluate_emnist.py --directory Eval_HM_WS --file Model_HM_WS.py --model Model_HM_WS --tqdm 22 | echo '\n\n\n' 23 | 24 | echo 'Evaluate HM_RWS' 25 | python3 evaluate_emnist.py --directory Eval_HM_RWS --file Model_HM_RWS.py --model Model_HM_RWS --tqdm 26 | echo '\n\n\n' 27 | 28 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | =================================== 2 | The PyTorch-ProbGraph API reference 3 | =================================== 4 | 5 | .. _unitlayer-module: 6 | 7 | ------------------ 8 | “unitlayer”-module 9 | ------------------ 10 | 11 | .. automodule:: pytorch_probgraph.unitlayer 12 | :members: 13 | 14 | -------------------- 15 | “interaction”-module 16 | -------------------- 17 | 18 | .. automodule:: pytorch_probgraph.interaction 19 | :members: 20 | 21 | ------------ 22 | “rbm”-module 23 | ------------ 24 | 25 | .. automodule:: pytorch_probgraph.rbm 26 | :members: 27 | 28 | ------------ 29 | “dbn”-module 30 | ------------ 31 | 32 | .. automodule:: pytorch_probgraph.dbn 33 | :members: 34 | 35 | ------------ 36 | “dbm”-module 37 | ------------ 38 | 39 | .. automodule:: pytorch_probgraph.dbm 40 | :members: 41 | 42 | ----------- 43 | “hm”-module 44 | ----------- 45 | 46 | .. automodule:: pytorch_probgraph.hm 47 | :members: 48 | 49 | -------------- 50 | “utils”-module 51 | -------------- 52 | 53 | .. automodule:: pytorch_probgraph.utils 54 | :members: 55 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, kpoeppel 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /examples/Model_DBN_Conv.py: -------------------------------------------------------------------------------- 1 | import site 2 | site.addsitedir('..') 3 | 4 | import torch 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from pytorch_probgraph import BernoulliLayer, DiracDeltaLayer, CategoricalLayer 8 | from pytorch_probgraph import GaussianLayer 9 | from pytorch_probgraph import InteractionLinear, InteractionModule, InteractionPoolMapIn2D, InteractionPoolMapOut2D 10 | from pytorch_probgraph import InteractionSequential 11 | from pytorch_probgraph import RestrictedBoltzmannMachinePCD 12 | from pytorch_probgraph import DeepBeliefNetwork 13 | from itertools import chain 14 | from tqdm import tqdm 15 | 16 | 17 | 18 | class Model_DBN_Conv(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | layer0 = BernoulliLayer(torch.zeros([1, 1, 28, 28], requires_grad=True)) 22 | layer1 = BernoulliLayer(torch.zeros([1, 40, 17, 17], requires_grad=True)) 23 | 24 | interaction0 = InteractionModule(torch.nn.Conv2d(1,40,12)) 25 | 26 | rbm1 = RestrictedBoltzmannMachinePCD(layer0, layer1, interaction0, fantasy_particles=10) 27 | opt = torch.optim.Adam(rbm1.parameters(), lr=1e-3) 28 | self.model = DeepBeliefNetwork([rbm1], opt) 29 | #self.model = self.model.to(device) 30 | #print(interaction.weight.shape) 31 | 32 | def train(self, data, epochs=1, device=None): 33 | self.model.train(data, epochs=epochs, device=device) 34 | 35 | def loglikelihood(self, data): 36 | return -self.model.free_energy_estimate(data) 37 | 38 | def generate(self, N=1): 39 | return self.model.sample(N=N, gibbs_steps=100).cpu() 40 | -------------------------------------------------------------------------------- /examples/Model_DBN_Conv_ProbMaxPool.py: -------------------------------------------------------------------------------- 1 | import site 2 | site.addsitedir('..') 3 | 4 | import torch 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from pytorch_probgraph import BernoulliLayer, DiracDeltaLayer, CategoricalLayer 8 | from pytorch_probgraph import GaussianLayer 9 | from pytorch_probgraph import InteractionLinear, InteractionModule, InteractionPoolMapIn2D, InteractionPoolMapOut2D 10 | from pytorch_probgraph import InteractionSequential 11 | from pytorch_probgraph import RestrictedBoltzmannMachinePCD 12 | from pytorch_probgraph import DeepBeliefNetwork 13 | from itertools import chain 14 | from tqdm import tqdm 15 | 16 | 17 | 18 | class Model_DBN_Conv_ProbMaxPool(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | layer0 = BernoulliLayer(torch.zeros([1, 1, 28, 28], requires_grad=True)) 22 | layer1 = CategoricalLayer(torch.zeros([1, 40, 8, 8, 5], requires_grad=True)) 23 | layer2 = CategoricalLayer(torch.zeros([1, 40, 1, 1, 5], requires_grad=True)) 24 | 25 | interaction0 = InteractionSequential(InteractionModule(torch.nn.Conv2d(1,40,12)), InteractionPoolMapIn2D(2, 2)) 26 | interaction1 = InteractionSequential(InteractionPoolMapOut2D(2,2), InteractionModule(torch.nn.Conv2d(40,40,6)), InteractionPoolMapIn2D(2, 2)) 27 | 28 | rbm1 = RestrictedBoltzmannMachinePCD(layer0, layer1, interaction0, fantasy_particles=10) 29 | rbm2 = RestrictedBoltzmannMachinePCD(layer1, layer2, interaction1, fantasy_particles=10) 30 | opt = torch.optim.Adam(chain(rbm1.parameters(), rbm2.parameters()), lr=1e-3) 31 | self.model = DeepBeliefNetwork([rbm1, rbm2], opt) 32 | #self.model = self.model.to(device) 33 | #print(interaction.weight.shape) 34 | 35 | def train(self, data, epochs=1, device=None): 36 | self.model.train(data, epochs=epochs, device=device) 37 | 38 | def loglikelihood(self, data): 39 | data = data.reshape(-1, 1, 1, 28, 28) 40 | return -self.model.free_energy_estimate(data) 41 | 42 | def generate(self, N=1): 43 | return self.model.sample(N=N, gibbs_steps=100).cpu() 44 | -------------------------------------------------------------------------------- /examples/Model_HM_WS.py: -------------------------------------------------------------------------------- 1 | 2 | import site 3 | site.addsitedir('..') 4 | 5 | import torch 6 | from pytorch_probgraph import BernoulliLayer 7 | from pytorch_probgraph import InteractionLinear 8 | from pytorch_probgraph import HelmholtzMachine 9 | from itertools import chain 10 | from tqdm import tqdm 11 | 12 | class Model_HM_WS(torch.nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | layer0 = BernoulliLayer(torch.nn.Parameter(torch.zeros([1, 1, 28, 28]), requires_grad=True)) 16 | layer1 = BernoulliLayer(torch.nn.Parameter(torch.zeros([1, 200]), requires_grad=True)) 17 | layer2 = BernoulliLayer(torch.nn.Parameter(torch.zeros([1, 200]), requires_grad=True)) 18 | 19 | interactionUp1 = InteractionLinear(layer0.bias.shape[1:], layer1.bias.shape[1:]) 20 | interactionDown1 = InteractionLinear(layer1.bias.shape[1:], layer0.bias.shape[1:]) 21 | interactionUp2 = InteractionLinear(layer1.bias.shape[1:], layer2.bias.shape[1:]) 22 | interactionDown2 = InteractionLinear(layer2.bias.shape[1:], layer1.bias.shape[1:]) 23 | 24 | parameters = chain(*[m.parameters() for m in [layer0, layer1, layer2, interactionUp1, interactionUp2, interactionDown1, interactionDown2]]) 25 | opt = torch.optim.Adam(parameters) 26 | 27 | self.model = HelmholtzMachine([layer0, layer1, layer2], 28 | [interactionUp1, interactionUp2], 29 | [interactionDown1, interactionDown2], 30 | optimizer=opt) 31 | #print(interaction.weight.shape) 32 | 33 | def train(self, data, epochs=1, device=None): 34 | for epoch in range(epochs): 35 | for dat in data: 36 | self.model.trainWS(dat.to(device)) 37 | if isinstance(data, tqdm): 38 | data = tqdm(data) 39 | #print(torch.sum(self.model.interaction.weight)) 40 | 41 | def loglikelihood(self, data): 42 | return self.model.loglikelihood(data, ksamples=100).cpu().detach() 43 | 44 | def generate(self, N=1): 45 | return self.model.sampleAll(N=N)[0][0].cpu() 46 | -------------------------------------------------------------------------------- /examples/Model_HM_RWS.py: -------------------------------------------------------------------------------- 1 | 2 | import site 3 | site.addsitedir('..') 4 | 5 | import torch 6 | from pytorch_probgraph import BernoulliLayer 7 | from pytorch_probgraph import InteractionLinear 8 | from pytorch_probgraph import HelmholtzMachine 9 | from itertools import chain 10 | from tqdm import tqdm 11 | 12 | class Model_HM_RWS(torch.nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | layer0 = BernoulliLayer(torch.nn.Parameter(torch.zeros([1, 1, 28, 28]), requires_grad=True)) 16 | layer1 = BernoulliLayer(torch.nn.Parameter(torch.zeros([1, 200]), requires_grad=True)) 17 | layer2 = BernoulliLayer(torch.nn.Parameter(torch.zeros([1, 200]), requires_grad=True)) 18 | 19 | interactionUp1 = InteractionLinear(layer0.bias.shape[1:], layer1.bias.shape[1:]) 20 | interactionDown1 = InteractionLinear(layer1.bias.shape[1:], layer0.bias.shape[1:]) 21 | interactionUp2 = InteractionLinear(layer1.bias.shape[1:], layer2.bias.shape[1:]) 22 | interactionDown2 = InteractionLinear(layer2.bias.shape[1:], layer1.bias.shape[1:]) 23 | 24 | parameters = chain(*[m.parameters() for m in [layer0, layer1, layer2, interactionUp1, interactionUp2, interactionDown1, interactionDown2]]) 25 | opt = torch.optim.Adam(parameters) 26 | 27 | self.model = HelmholtzMachine([layer0, layer1, layer2], 28 | [interactionUp1, interactionUp2], 29 | [interactionDown1, interactionDown2], 30 | optimizer=opt) 31 | #print(interaction.weight.shape) 32 | 33 | def train(self, data, epochs=1, device=None): 34 | for epoch in range(epochs): 35 | for dat in data: 36 | self.model.trainReweightedWS(dat.to(device), ksamples=5) 37 | if isinstance(data, tqdm): 38 | data = tqdm(data) 39 | #print(torch.sum(self.model.interaction.weight)) 40 | 41 | def loglikelihood(self, data): 42 | return self.model.loglikelihood(data, ksamples=100).cpu().detach() 43 | 44 | def generate(self, N=1): 45 | return self.model.sampleAll(N=N)[0][0].cpu() 46 | -------------------------------------------------------------------------------- /examples/Model_DBN.py: -------------------------------------------------------------------------------- 1 | import site 2 | site.addsitedir('..') 3 | 4 | import torch 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from pytorch_probgraph import BernoulliLayer, DiracDeltaLayer, CategoricalLayer 8 | from pytorch_probgraph import GaussianLayer 9 | from pytorch_probgraph import InteractionLinear, InteractionModule, InteractionPoolMapIn2D, InteractionPoolMapOut2D 10 | from pytorch_probgraph import InteractionPoolMapIn1D, InteractionPoolMapOut1D 11 | from pytorch_probgraph import RestrictedBoltzmannMachinePCD 12 | from pytorch_probgraph import DeepBeliefNetwork 13 | from itertools import chain 14 | from typing import Iterable 15 | from tqdm import tqdm 16 | 17 | 18 | 19 | class Model_DBN(torch.nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | layer0 = BernoulliLayer(torch.zeros([1, 1, 28, 28], requires_grad=True)) 23 | layer1 = BernoulliLayer(torch.zeros([1, 200], requires_grad=True)) 24 | layer2 = BernoulliLayer(torch.zeros([1, 200], requires_grad=True)) 25 | interaction0 = InteractionLinear(layer0.bias.shape[1:], layer1.bias.shape[1:]) 26 | interaction1 = InteractionLinear(layer1.bias.shape[1:], layer2.bias.shape[1:]) 27 | rbm1 = RestrictedBoltzmannMachinePCD(layer0, layer1, interaction0, fantasy_particles=10) 28 | rbm2 = RestrictedBoltzmannMachinePCD(layer1, layer2, interaction1, fantasy_particles=10) 29 | opt = torch.optim.Adam(chain(rbm1.parameters(), rbm2.parameters()), lr=1e-3) 30 | self.model = DeepBeliefNetwork([rbm1, rbm2], opt) 31 | #self.model = self.model.to(device) 32 | #print(interaction.weight.shape) 33 | 34 | def train(self, 35 | data: Iterable[torch.tensor], 36 | epochs: int=1, 37 | device: torch.device=None 38 | ) -> None: 39 | self.model.train(data, epochs=epochs, device=device) 40 | 41 | def loglikelihood(self, 42 | data: torch.Tensor 43 | ) -> torch.Tensor: 44 | return -self.model.free_energy_estimate(data) 45 | 46 | def generate(self, N=1): 47 | return self.model.sample(N=N, gibbs_steps=100).cpu() 48 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # http://www.sphinx-doc.org/en/master/config 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('..')) 16 | # sys.path.append(os.path.abspath('sphinxext')) 17 | # import sphinx.ext.autodoc 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'PyTorch-ProbGraph' 22 | copyright = '2020, Korbinian Poeppel, Hendrik Elvers' 23 | author = 'Korbinian Poeppel, Hendrik Elvers' 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = '0.0.1' 27 | 28 | 29 | # -- General configuration --------------------------------------------------- 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon' 35 | ] 36 | 37 | # Add any paths that contain templates here, relative to this directory. 38 | templates_path = ['_templates'] 39 | 40 | # List of patterns, relative to source directory, that match files and 41 | # directories to ignore when looking for source files. 42 | # This pattern also affects html_static_path and html_extra_path. 43 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 44 | 45 | 46 | # -- Options for HTML output ------------------------------------------------- 47 | 48 | # The theme to use for HTML and HTML Help pages. See the documentation for 49 | # a list of builtin themes. 50 | # 51 | html_theme = 'classic' 52 | 53 | # Add any paths that contain custom static files (such as style sheets) here, 54 | # relative to this directory. They are copied after the builtin static files, 55 | # so a file named "default.css" will overwrite the builtin "default.css". 56 | html_static_path = ['_static'] 57 | -------------------------------------------------------------------------------- /examples/Model_DBN_IntModule.py: -------------------------------------------------------------------------------- 1 | import site 2 | site.addsitedir('..') 3 | 4 | import torch 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from pytorch_probgraph import BernoulliLayer, DiracDeltaLayer, CategoricalLayer 8 | from pytorch_probgraph import GaussianLayer 9 | from pytorch_probgraph import InteractionLinear, InteractionModule, InteractionPoolMapIn2D, InteractionPoolMapOut2D 10 | from pytorch_probgraph import InteractionPoolMapIn1D, InteractionPoolMapOut1D 11 | from pytorch_probgraph import RestrictedBoltzmannMachinePCD 12 | from pytorch_probgraph import DeepBeliefNetwork 13 | from itertools import chain 14 | from tqdm import tqdm 15 | 16 | 17 | 18 | class Model_DBN_IntModule(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | layer0 = BernoulliLayer(torch.zeros([1, 784], requires_grad=True)) 22 | layer1 = BernoulliLayer(torch.zeros([1, 200], requires_grad=True)) 23 | layer2 = BernoulliLayer(torch.zeros([1, 200], requires_grad=True)) 24 | interaction0 = InteractionModule(torch.nn.Linear(layer0.bias.shape[1], layer1.bias.shape[1]), inputShape=layer0.bias.shape) 25 | interaction1 = InteractionModule(torch.nn.Linear(layer1.bias.shape[1], layer2.bias.shape[1]), inputShape=layer1.bias.shape) 26 | rbm1 = RestrictedBoltzmannMachinePCD(layer0, layer1, interaction0, fantasy_particles=10) 27 | rbm2 = RestrictedBoltzmannMachinePCD(layer1, layer2, interaction1, fantasy_particles=10) 28 | opt = torch.optim.Adam(chain(rbm1.parameters(), rbm2.parameters()), lr=1e-3) 29 | self.model = DeepBeliefNetwork([rbm1, rbm2], opt) 30 | #self.model = self.model.to(device) 31 | #print(interaction.weight.shape) 32 | 33 | def train(self, data, epochs=1, device=None): 34 | datnew = [dat.reshape(-1, 784) for dat in data] 35 | if isinstance(data, tqdm): 36 | datnew = tqdm(datnew) 37 | self.model.train(datnew, epochs=epochs, device=device) 38 | 39 | def loglikelihood(self, data): 40 | if data.shape[0] == 1: 41 | dataresh = data.reshape(-1, 784) 42 | else: 43 | dataresh = data.reshape(-1, 784) 44 | return -self.model.free_energy_estimate(dataresh) 45 | 46 | def generate(self, N=1): 47 | return self.model.sample(N=N, gibbs_steps=100).cpu().reshape(-1,28,28) 48 | -------------------------------------------------------------------------------- /pytorch_probgraph/README.md: -------------------------------------------------------------------------------- 1 | # README of "PyTorch-ProbGraph" 2 | 3 | ## What is PyTorch-ProbGraph? 4 | 5 | PyTorch-ProbGraph is a library based on amazing PyTorch (https://pytorch.org) 6 | to easily use and adapt directed and undirected Hierarchical Probabilistic 7 | Graphical Models. These include Restricted Boltzmann Machines, 8 | Deep Belief Networks, Deep Boltzmann Machines and Helmholtz 9 | Machines (Sigmoid Belief Networks). 10 | Models can be set up in a modular fashion, using UnitLayers, layers of Random Units and Interactions between these UnitLayers. 11 | Currently, only Gaussian, Categorical and Bernoulli units are available, but an extension can be made to allow all kinds of distributions from the Exponential family. 12 | (see https://en.wikipedia.org/wiki/Exponential_family) 13 | The Interactions are usually only linear for undirected models, but can be built 14 | from arbitrary PyTorch torch.nn.Modules (using forward and the backward gradient). 15 | There is a pre-implemented fully-connected InteractionLinear, one for using 16 | existing torch.nn.Modules and some custom Interactions / Mappings to enable 17 | Probabilistic Max-Pooling. Interactions can also be connected without intermediate 18 | Random UnitLayers with InteractionSequential. 19 | 20 | This library was built by Korbinian Poeppel and Hendrik Elvers during a Practical Course "Beyond Deep Learning - Uncertainty Aware Models" at TU Munich. 21 | Disclaimer: It is built as an extension to PyTorch and not directly affiliated. 22 | 23 | ## Documentation 24 | A more detailed documentation is included, using the Sphinx framework. 25 | Go inside directory 'docs' and run 'make html' (having Sphinx installed). 26 | The documentation can then be found inside the _build sub-directory. 27 | 28 | ## Examples 29 | There are some example models, as well as an evaluation script in the `examples` 30 | folder. 31 | 32 | ## License 33 | This library is distributed in a ... license. 34 | 35 | ## References 36 | Ian Goodfellow and Yoshua Bengio and Aaron Courville, 37 | http://www.deeplearningbook.org 38 | 39 | Jörg Bornschein, Yoshua Bengio Reweighted Wake-Sleep 40 | https://arxiv.org/abs/1406.2751 41 | 42 | Geoffrey Hinton, A Practical Guide to Training Restricted Boltzmann Machines 43 | https://www.cs.toronto.edu/~hinton/absps/guideTR.pdf 44 | 45 | Ruslan Salakhutdinov, Learning Deep Generative Models 46 | https://tspace.library.utoronto.ca/handle/1807/19226 47 | 48 | Honglak Lee et al., Convolutional Deep Belief Networks for Scalable Unsupervised Learning of Hierarchical 49 | Representations, ICML09 50 | 51 | G.Hinton, S. Osindero A fast learning algorithm for deep belief nets 52 | -------------------------------------------------------------------------------- /docs/howto.rst: -------------------------------------------------------------------------------- 1 | ================================== 2 | A small HowTo on PyTorch-ProbGraph 3 | ================================== 4 | 5 | ------------ 6 | Introduction 7 | ------------ 8 | 9 | PyTorch-ProbGraph is a library bringing the Probabilistic Graphical Framework 10 | to PyTorch (``_), orthogonal to (``_), 11 | with a narrow focus on the traditional Restricted Boltzmann Machine, 12 | Deep Boltzmann Machine, Deep Belief Network and Helmholtz Machine as well 13 | as their convolutional variants. 14 | 15 | The core modules (all torch.nn.Modules) are the UnitLayer, representing some 16 | random distributed variables and the Interaction, representing (directed or 17 | undirected) interactions/links between these UnitLayers. 18 | 19 | A hierarchical graphical model is now built combining these in one of the 20 | following models or their variants: 21 | - Restricted Boltzmann Machine (Contrastive Divergence / Persistent Contrastive Divergence) 22 | - Deep Boltzmann Machine 23 | - Deep Belief Network 24 | - Helmholtz Machine (Wake-Sleep / Reweighted Wake Sleep) 25 | 26 | -------------- 27 | A simple Model 28 | -------------- 29 | .. code-block:: python 30 | 31 | from pytorch_probgraph.unitlayer import BernoulliLayer, CategoricalLayer 32 | from pytorch_probgraph.interaction import InteractionLinear 33 | from pytorch_probgraph.rbm import RestrictedBoltzmannMachineCD 34 | from pytorch_probgraph.dbn import DeepBeliefNetwork 35 | 36 | ## Load data as some iterator over training batches 37 | data = get_data_from_somewhere() 38 | 39 | # Define the layers (always take 1 for the first=batch dimension) 40 | blayer0 = BernoulliLayer(torch.zeros([1, 20], requires_grad=False)) 41 | blayer1 = CategoricalLayer(torch.zeros([1, 5], requires_grad=True)) 42 | blayer2 = BernoulliLayer(torch.zeros([1, 30], requires_grad=True)) 43 | 44 | ## Define interactions between layers 45 | interaction0 = InteractionLinear(blayer0.bias.shape[1:], blayer1.bias.shape[1:]) 46 | interaction1 = InteractionLinear(blayer0.bias.shape[1:], blayer1.bias.shape[1:]) 47 | 48 | ## Define Restricted Boltzmann Machines to be stacked 49 | rbm0 = RestrictedBoltzmannMachineCD(blayer0, blayer1, interaction0) 50 | rbm1 = RestrictedBoltzmannMachineCD(blayer1, blayer2, interaction1) 51 | 52 | ## Define the optimizer and the Deep Belief Network 53 | opt = torch.opt.Adam(chain(rbm0.parameters(), rbm1.parameters())) 54 | dbn = DeepBeliefNetwork([rbm0, rbm1], opt) 55 | 56 | ## Train on data 57 | dbn.train(data, epochs=10) 58 | 59 | ## Generate a batch of 10 samples 60 | dbn.sample(N=10) 61 | 62 | ## Estimate some free energy of the setting 63 | dbn.free_energy_estimate(data) 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README of "PyTorch-ProbGraph" 2 | 3 | ## What is PyTorch-ProbGraph? 4 | 5 | PyTorch-ProbGraph is a library based on amazing PyTorch (https://pytorch.org) 6 | to easily use and adapt directed and undirected Hierarchical Probabilistic 7 | Graphical Models. These include Restricted Boltzmann Machines, 8 | Deep Belief Networks, Deep Boltzmann Machines and Helmholtz 9 | Machines (Sigmoid Belief Networks). 10 | 11 | Models can be set up in a modular fashion, using UnitLayers, layers of Random Units and Interactions between these UnitLayers. 12 | Currently, only Gaussian, Categorical and Bernoulli units are available, but an extension can be made to allow all kinds of distributions from the Exponential family. 13 | (see https://en.wikipedia.org/wiki/Exponential_family) 14 | 15 | The Interactions are usually only linear for undirected models, but can be built 16 | from arbitrary PyTorch torch.nn.Modules (using forward and the backward gradient). 17 | 18 | There is a pre-implemented fully-connected InteractionLinear, one for using 19 | existing torch.nn.Modules and some custom Interactions / Mappings to enable 20 | Probabilistic Max-Pooling. Interactions can also be connected without intermediate 21 | Random UnitLayers with InteractionSequential. 22 | 23 | This library was built by Korbinian Poeppel and Hendrik Elvers during a Practical Course "Beyond Deep Learning - Uncertainty Aware Models" at TU Munich. 24 | Disclaimer: It is built as an extension to PyTorch and not directly affiliated. 25 | 26 | ## Documentation 27 | A more detailed documentation is included, using the Sphinx framework. 28 | Go inside directory 'docs' and run 'make html' (having Sphinx installed). 29 | The documentation can then be found inside the _build sub-directory. 30 | 31 | ## Examples 32 | There are some example models, as well as an evaluation script using the EMNIST dataset in the `examples` 33 | folder. 34 | 35 | ## License 36 | This library is distributed in a BSD 3-clause license. 37 | 38 | ## Setup 39 | The library is accessible via the PyPi repository and can be install by: 40 | pip install pytorch_probgraph 41 | 42 | ## References 43 | Ian Goodfellow and Yoshua Bengio and Aaron Courville, 44 | http://www.deeplearningbook.org 45 | 46 | Jörg Bornschein, Yoshua Bengio Reweighted Wake-Sleep 47 | https://arxiv.org/abs/1406.2751 48 | 49 | Geoffrey Hinton, A Practical Guide to Training Restricted Boltzmann Machines 50 | https://www.cs.toronto.edu/~hinton/absps/guideTR.pdf 51 | 52 | Ruslan Salakhutdinov, Learning Deep Generative Models 53 | https://tspace.library.utoronto.ca/handle/1807/19226 54 | 55 | Honglak Lee et al., Convolutional Deep Belief Networks for Scalable Unsupervised Learning of Hierarchical 56 | Representations, ICML09 57 | 58 | G.Hinton, S. Osindero A fast learning algorithm for deep belief nets 59 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. PyHolo documentation master file, created by 2 | sphinx-quickstart on Mon Jul 2 13:50:42 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to PyTorch-ProbGraph's documentation! 7 | ============================================= 8 | 9 | What is PyTorch-ProbGraph? 10 | -------------------------- 11 | 12 | PyTorch-ProbGraph is a library based on amazing `PyTorch `_ 13 | to easily use and adapt directed and undirected Hierarchical Probabilistic 14 | Graphical Models. These include Restricted Boltzmann Machines, Deep Belief 15 | Networks, Deep Boltzmann Machines and Helmholtz Machines (Sigmoid Belief Networks). 16 | 17 | Models can be set up in a modular fashion, using UnitLayers, layers of 18 | Random Units and Interactions between these UnitLayers. 19 | Currently, only Gaussian, Categorical and Bernoulli units are available, 20 | but an extension can be made to allow all kinds of distributions 21 | from the Exponential family. 22 | (see ``_) 23 | 24 | The Interactions are usually only linear for undirected models, but can be built 25 | from arbitrary PyTorch torch.nn.Modules (using forward and the backward gradient). 26 | There is a pre-implemented fully-connected InteractionLinear, one for using 27 | existing torch.nn.Modules and some custom Interactions / Mappings to enable 28 | Probabilistic Max-Pooling. Interactions can also be connected without intermediate 29 | Random UnitLayers with InteractionSequential. 30 | 31 | Using these UnitLayers and Interactions, Restricted Boltzmann Machines, 32 | Deep Belief Networks, Deep Boltzmann Machines and Helmholtz Machines can be 33 | defined. Undirected models can be trained using Contrastive Divergence / Persistent Contrastive Divergence 34 | learning or Greedy Layerwise Learning / Pretraining (for deep models). 35 | The directed Helmholtz Machine can be trained using either traditional Wake-Sleep 36 | Learning or Reweighted Wake-Sleep. 37 | 38 | This library was built by Korbinian Poeppel and Hendrik Elvers during a 39 | Practical Course "Beyond Deep Learning - Uncertainty Aware Models" at TU Munich. 40 | Disclaimer: It is built as an extension to PyTorch and not directly affiliated. 41 | 42 | References 43 | ---------- 44 | 45 | Ian Goodfellow and Yoshua Bengio and Aaron Courville, 46 | ``_ 47 | 48 | Jörg Bornschein, Yoshua Bengio Reweighted Wake-Sleep 49 | ``_ 50 | 51 | Geoffrey Hinton, A Practical Guide to Training Restricted Boltzmann Machines 52 | ``_ 53 | 54 | Ruslan Salakhutdinov, Learning Deep Generative Models 55 | ``_ 56 | 57 | Honglak Lee et al., Convolutional Deep Belief Networks for Scalable Unsupervised Learning of Hierarchical 58 | Representations, ICML09 59 | 60 | G.Hinton, S. Osindero A fast learning algorithm for deep belief nets 61 | 62 | 63 | .. toctree:: 64 | :maxdepth: 3 65 | :caption: Contents: 66 | 67 | api 68 | 69 | howto 70 | 71 | 72 | Indices and tables 73 | ================== 74 | 75 | * :ref:`genindex` 76 | * :ref:`modindex` 77 | * :ref:`search` 78 | -------------------------------------------------------------------------------- /examples/Model_RBM_CD.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path = [".."] + sys.path 3 | 4 | import torch 5 | from pytorch_probgraph import BernoulliLayer, GaussianLayer 6 | from pytorch_probgraph import InteractionLinear, InteractionModule 7 | from pytorch_probgraph import RestrictedBoltzmannMachineCD_Smooth 8 | from pytorch_probgraph import DeepBoltzmannMachineLS 9 | 10 | 11 | class Model_RBM_CD(torch.nn.Module): 12 | ''' 13 | Constructs a RBM. 14 | ''' 15 | 16 | def __init__(self): 17 | ''' 18 | Builds the RBM with hidden dimension = 200. 19 | ''' 20 | 21 | super().__init__() 22 | # initialize the bias with zeros 23 | l0bias = torch.zeros([1, 1, 28, 28]) 24 | l0bias.requires_grad = True 25 | l1bias = torch.zeros([1, 200]) 26 | l1bias.requires_grad = True 27 | 28 | # initialize the bernoulli layers 29 | l0 = BernoulliLayer(l0bias) 30 | l1 = BernoulliLayer(l1bias) 31 | 32 | # initialize the interaction layer 33 | i0 = InteractionLinear(l0.bias.shape[1:], l1.bias.shape[1:]) 34 | 35 | # build the RBM 36 | rbm0 = RestrictedBoltzmannMachineCD_Smooth(l0, l1, i0, ksteps=1) 37 | 38 | # get all parameters of the RBM for the optimizer 39 | params = list(rbm0.parameters()) 40 | 41 | # set up the optimizer and the scheduler 42 | self.opt = torch.optim.SGD(params, lr=1e-2, weight_decay=1e-5) 43 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=2000, gamma=0.94) 44 | self.model = rbm0 45 | 46 | #The DBM has a general AIS implementation, which can be used to calculate the log likelihood of the RBM. 47 | self.dbm = DeepBoltzmannMachineLS(rbms=[rbm0], optimizer=self.opt, scheduler=self.scheduler, learning='CD', 48 | nFantasy=100) 49 | 50 | def train(self, data, epochs=1, device=None): 51 | ''' 52 | Function to train the model 53 | :param data: tqdm object 54 | :param epochs: [int], number of epochs to train 55 | :return: None 56 | ''' 57 | 58 | self.model.train(data, epochs=epochs, optimizer=self.opt, scheduler=self.scheduler, device=device) 59 | 60 | def loglikelihood(self, data, log_Z = None): 61 | ''' 62 | Calculates the log likelihood 63 | :param data: tqdm object 64 | :param log_Z: [float], log of the partitioning sum 65 | :return: [torch.tensor(batch size, float)], log likelihood per image 66 | ''' 67 | 68 | return self.dbm.loglikelihood(data, log_Z = log_Z) 69 | 70 | def generate(self, N=1): 71 | ''' 72 | Generates new images according to the model distribution 73 | :param N: number of images to be generated 74 | :return: [torch.tensor(N,28,28)], generated images 75 | ''' 76 | 77 | return self.model.reconstruct(N=N, gibbs_steps=10, mean=True).cpu() 78 | 79 | def get_log_Z(self, steps, samples): 80 | ''' 81 | Calculates the partitioning sum. 82 | :param steps: [int], number of steps for the AIS algorithm 83 | :param samples: [int], number of samples to compute the mean 84 | :return: [torch.tensor(1, float)] log of the partitioning sum 85 | ''' 86 | 87 | return self.dbm.ais(steps, samples) 88 | -------------------------------------------------------------------------------- /examples/Model_RBM_PCD.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path = [".."] + sys.path 3 | 4 | import torch 5 | from pytorch_probgraph import BernoulliLayer, GaussianLayer 6 | from pytorch_probgraph import InteractionLinear, InteractionModule 7 | from pytorch_probgraph import RestrictedBoltzmannMachinePCD 8 | from pytorch_probgraph import DeepBoltzmannMachineLS 9 | 10 | 11 | class Model_RBM_PCD(torch.nn.Module): 12 | ''' 13 | Constructs a RBM. 14 | ''' 15 | 16 | def __init__(self): 17 | ''' 18 | Builds the RBM with hidden dimension = 200. 19 | ''' 20 | 21 | super().__init__() 22 | # initialize the bias with zeros 23 | l0bias = torch.zeros([1, 1, 28, 28]) 24 | l0bias.requires_grad = True 25 | l1bias = torch.zeros([1, 200]) 26 | l1bias.requires_grad = True 27 | 28 | # initialize the bernoulli layers 29 | l0 = BernoulliLayer(l0bias) 30 | l1 = BernoulliLayer(l1bias) 31 | 32 | # initialize the interaction layer 33 | i0 = InteractionLinear(l0.bias.shape[1:], l1.bias.shape[1:]) 34 | 35 | # build the RBM 36 | rbm0 = RestrictedBoltzmannMachinePCD(l0, l1, i0, fantasy_particles=10) 37 | 38 | # get all parameters of the RBM for the optimizer 39 | params = list(rbm0.parameters()) 40 | 41 | # set up the optimizer and the scheduler 42 | self.opt = torch.optim.SGD(params, lr=1e-2, weight_decay=1e-5) 43 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=2000, gamma=0.94) 44 | self.model = rbm0 45 | 46 | #The DBM has a general AIS implementation, which can be used to calculate the log likelihood of the RBM. 47 | self.dbm = DeepBoltzmannMachineLS(rbms=[rbm0], optimizer=self.opt, scheduler=self.scheduler, learning='CD', 48 | nFantasy=100) 49 | 50 | def train(self, data, epochs=1, device=None): 51 | ''' 52 | Function to train the model 53 | :param data: tqdm object 54 | :param epochs: [int], number of epochs to train 55 | :return: None 56 | ''' 57 | 58 | self.model.train(data, epochs=epochs, optimizer=self.opt, scheduler=self.scheduler, device=device) 59 | 60 | def loglikelihood(self, data, log_Z = None): 61 | ''' 62 | Calculates the log likelihood 63 | :param data: tqdm object 64 | :param log_Z: [float], log of the partitioning sum 65 | :return: [torch.tensor(batch size, float)], log likelihood per image 66 | ''' 67 | 68 | return self.dbm.loglikelihood(data, log_Z = log_Z) 69 | 70 | def generate(self, N=1): 71 | ''' 72 | Generates new images according to the model distribution 73 | :param N: number of images to be generated 74 | :return: [torch.tensor(N,28,28)], generated images 75 | ''' 76 | 77 | return self.model.reconstruct(N=N, gibbs_steps=10, mean=True).cpu() 78 | 79 | def get_log_Z(self, steps, samples): 80 | ''' 81 | Calculates the partitioning sum. 82 | :param steps: [int], number of steps for the AIS algorithm 83 | :param samples: [int], number of samples to compute the mean 84 | :return: [torch.tensor(1, float)] log of the partitioning sum 85 | ''' 86 | 87 | return self.dbm.ais(steps, samples) 88 | -------------------------------------------------------------------------------- /examples/Model_DBM_CD.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path = [".."] + sys.path 3 | 4 | from pytorch_probgraph import BernoulliLayer 5 | from pytorch_probgraph import InteractionLinear 6 | from pytorch_probgraph import RestrictedBoltzmannMachineCD_Smooth 7 | from pytorch_probgraph import DeepBoltzmannMachineLS 8 | 9 | 10 | import torch 11 | 12 | class Model_DBM_CD(torch.nn.Module): # 13 | 14 | 15 | def __init__(self): 16 | ''' 17 | Builds the DBM with hidden dimension = 200, 200. Uses PCD for learning. 18 | ''' 19 | super().__init__() 20 | 21 | # initialize the bias with zeros and lock them. 22 | # This is because the DBM approximates the bias distribution with higher layers. 23 | l0bias = torch.zeros([1, 1, 28, 28]) 24 | l0bias.requires_grad = False 25 | l1bias = torch.zeros([1, 200]) 26 | l1bias.requires_grad = False 27 | l2bias = torch.zeros([1, 200]) 28 | l2bias.requires_grad = False 29 | 30 | # initialize the bernoulli layers 31 | l0 = BernoulliLayer(l0bias) 32 | l1 = BernoulliLayer(l1bias) 33 | l2 = BernoulliLayer(l2bias) 34 | 35 | # initialize the interaction layers 36 | i0 = InteractionLinear(l0.bias.shape[1:], l1.bias.shape[1:]) 37 | i1 = InteractionLinear(l1.bias.shape[1:], l2.bias.shape[1:]) 38 | 39 | # build two RBMs 40 | rbm0 = RestrictedBoltzmannMachineCD_Smooth(l0, l1, i0, ksteps=1) 41 | rbm1 = RestrictedBoltzmannMachineCD_Smooth(l1, l2, i1, ksteps=1) 42 | 43 | # get all parameters of the RBM for the optimizer 44 | params = list(rbm0.parameters()) + list(rbm1.parameters()) 45 | 46 | # set up the optimizer and the scheduler 47 | opt = torch.optim.SGD(params, lr=1e-3, weight_decay=1e-4) 48 | scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=700, gamma=0.94) 49 | 50 | # build the DBM 51 | self.model = DeepBoltzmannMachineLS(rbms=[rbm0, rbm1], optimizer=opt, scheduler=scheduler, learning='CD', 52 | nFantasy=100, ksteps=1) 53 | 54 | def train(self, data, epochs=1, device=None): 55 | ''' 56 | Function to train the model 57 | :param data: tqdm object 58 | :param epochs: [int], number of epochs to train 59 | :return: None 60 | ''' 61 | self.model.train_model(data, epochs=epochs, device=device) 62 | 63 | def loglikelihood(self, data, log_Z = None): 64 | ''' 65 | Calculates the log likelihood 66 | :param data: tqdm object 67 | :param log_Z: [float], log of the partitioning sum 68 | :return: [torch.tensor(batch size, float)], log likelihood per image 69 | ''' 70 | return self.model.loglikelihood(data, log_Z = log_Z) 71 | 72 | def generate(self, N=1): 73 | ''' 74 | Generates new images according to the model distribution 75 | :param N: number of images to be generated 76 | :return: [torch.tensor(N,28,28)], generated images 77 | ''' 78 | 79 | return self.model.generate(N=N, gibbs_steps=10).cpu() 80 | 81 | def get_log_Z(self, steps, samples): 82 | ''' 83 | Calculates the partitioning sum. 84 | :param steps: [int], number of steps for the AIS algorithm 85 | :param samples: [int], number of samples to compute the mean 86 | :return: [torch.tensor(1, float)] log of the partitioning sum 87 | ''' 88 | 89 | return self.model.ais(steps, samples) 90 | -------------------------------------------------------------------------------- /examples/Model_DBM_PCD.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path = [".."] + sys.path 3 | 4 | import torch 5 | from pytorch_probgraph import BernoulliLayer, GaussianLayer 6 | from pytorch_probgraph import InteractionLinear, InteractionModule 7 | from pytorch_probgraph import RestrictedBoltzmannMachineCD_Smooth 8 | from pytorch_probgraph import DeepBoltzmannMachineLS 9 | 10 | class Model_DBM_PCD(torch.nn.Module): 11 | 12 | 13 | def __init__(self): 14 | ''' 15 | Builds the DBM with hidden dimension = 200, 200. 16 | ''' 17 | super().__init__() 18 | 19 | # initialize the bias with zeros and lock them. 20 | # This is because the DBM approximates the bias distribution with higher layers. 21 | l0bias = torch.zeros([1, 1, 28, 28]) 22 | l0bias.requires_grad = False 23 | l1bias = torch.zeros([1, 200]) 24 | l1bias.requires_grad = False 25 | l2bias = torch.zeros([1, 200]) 26 | l2bias.requires_grad = False 27 | 28 | # initialize the bernoulli layers 29 | l0 = BernoulliLayer(l0bias) 30 | l1 = BernoulliLayer(l1bias) 31 | l2 = BernoulliLayer(l2bias) 32 | 33 | # initialize the interaction layers 34 | i0 = InteractionLinear(l0.bias.shape[1:], l1.bias.shape[1:]) 35 | i1 = InteractionLinear(l1.bias.shape[1:], l2.bias.shape[1:]) 36 | 37 | # build two RBMs 38 | rbm0 = RestrictedBoltzmannMachineCD_Smooth(l0, l1, i0, ksteps=1) 39 | rbm1 = RestrictedBoltzmannMachineCD_Smooth(l1, l2, i1, ksteps=1) 40 | 41 | # get all parameters of the RBM for the optimizer 42 | params = list(rbm0.parameters()) + list(rbm1.parameters()) 43 | 44 | # set up the optimizer and the scheduler 45 | opt = torch.optim.SGD(params, lr=1e-3, weight_decay=1e-4) 46 | scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=700, gamma=0.94) 47 | 48 | # build the DBM 49 | self.model = DeepBoltzmannMachineLS(rbms=[rbm0, rbm1], optimizer=opt, scheduler=scheduler, learning='PCD', 50 | nFantasy=10, ksteps=1) 51 | 52 | def train(self, data, epochs=1, device=None): 53 | ''' 54 | Function to train the model 55 | :param data: tqdm object 56 | :param epochs: [int], number of epochs to train 57 | :return: None 58 | ''' 59 | self.model.train_model(data, epochs=epochs, device=device) 60 | 61 | def loglikelihood(self, data, log_Z = None): 62 | ''' 63 | Calculates the log likelihood 64 | :param data: tqdm object 65 | :param log_Z: [float], log of the partitioning sum 66 | :return: [torch.tensor(batch size, float)], log likelihood per image 67 | ''' 68 | return self.model.loglikelihood(data, log_Z = log_Z) 69 | 70 | def generate(self, N=1): 71 | ''' 72 | Generates new images according to the model distribution 73 | :param N: number of images to be generated 74 | :return: [torch.tensor(N,28,28)], generated images 75 | ''' 76 | 77 | return self.model.generate(N=N, gibbs_steps=10).cpu() 78 | 79 | def get_log_Z(self, steps, samples): 80 | ''' 81 | Calculates the partitioning sum. 82 | :param steps: [int], number of steps for the AIS algorithm 83 | :param samples: [int], number of samples to compute the mean 84 | :return: [torch.tensor(1, float)] log of the partitioning sum 85 | ''' 86 | 87 | return self.model.ais(steps, samples) 88 | -------------------------------------------------------------------------------- /pytorch_probgraph/dbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from typing import List, Iterable, Optional 5 | from .rbm import RestrictedBoltzmannMachine 6 | from .utils import ListModule 7 | 8 | 9 | class DeepBeliefNetwork(torch.nn.Module): 10 | ''' 11 | From: "On the Quantitative Analysis of Deep Belief Networks" (Salakhutdinov, Murray) 12 | Using greedy learning on RBMs 13 | ''' 14 | def __init__(self, 15 | rbms: List[RestrictedBoltzmannMachine], 16 | optimizer: torch.optim.Optimizer): 17 | ''' 18 | Pass a list of rbms with fitting input and output sizes 19 | The optimizer should be an optimizer on all the rbm's parameters. 20 | :param rbms: [List[RestrictedBoltzmannMachine]] List of RBMs 21 | :param optimizer: Optimizer to train the RBMs 22 | ''' 23 | super().__init__() 24 | self.rbms = ListModule(*rbms) 25 | self.optimizer = optimizer 26 | 27 | def train_layer(self, 28 | rbm_num: int, 29 | data: Iterable[torch.Tensor], 30 | epochs: int, 31 | device: torch.device=None 32 | ) -> None: 33 | ''' 34 | Train an RBM using some data (usually sampled from below) 35 | ''' 36 | for epoch in range(epochs): 37 | # set all gradients zero first! 38 | for rbm in self.rbms: 39 | rbm.zero_grad() 40 | 41 | for bdat in data: 42 | self.rbms[rbm_num].zero_grad() 43 | self.rbms[rbm_num].step(bdat.to(device)) 44 | self.optimizer.step() 45 | 46 | if isinstance(data, tqdm): 47 | data = tqdm(data) 48 | data.set_description("Epoch {}, RBM {}".format(epoch, rbm_num)) 49 | 50 | def sample_layer_hidden(self, 51 | rbm_num: int, 52 | visible: Optional[torch.Tensor]=None, 53 | N: int=1 54 | ) -> torch.Tensor: 55 | if visible is not None: 56 | return self.rbms[rbm_num].sample_hidden(visible=visible) 57 | else: 58 | return self.rbms[rbm_num].sample_hidden(N=N) 59 | 60 | def train(self, 61 | data : Iterable[torch.Tensor], 62 | epochs: int=1, 63 | skip_rbm: List[int]=list(), 64 | device: torch.device=None 65 | ) -> None: 66 | ''' 67 | Train the Deep Belief Network, data should be an iterator on training 68 | batches. 69 | 70 | :param data: Iterator of Training Batches 71 | :param epochs: Number of epochs to learn each Restricted Boltzmann Machine 72 | :param skip_rbm: Skip Learning for some RBMs (indexes) 73 | :param device: The torch device to move the data to before training 74 | ''' 75 | for rbm_num in range(len(self.rbms)): 76 | if rbm_num not in skip_rbm: 77 | self.train_layer(rbm_num, data, epochs, device=device) 78 | newdat = [] 79 | for bdat in data: 80 | #newdat.append(self.sample_layer_hidden(rbm_num, visible=bdat.to(device)).detach().cpu()) 81 | next_data = self.sample_layer_hidden(rbm_num, visible=bdat.to(device)).detach().cpu() 82 | next_data.requires_grad = True 83 | newdat.append(next_data) 84 | if isinstance(data, tqdm): 85 | data = tqdm(newdat) 86 | else: 87 | data = newdat 88 | 89 | def free_energy_estimate(self, 90 | data: torch.Tensor, 91 | skip_rbm: List[int]=list(), 92 | ) -> torch.Tensor: 93 | ''' 94 | Calculate the sum of the free energies of the RBMs. Upward RBMs are \ 95 | fed with samples from below. 96 | 97 | :param data: Data batch 98 | :param skip_rbm: Skip some RBMs (indices) 99 | ''' 100 | free_energy = 0. 101 | for rbm_num in range(len(self.rbms)): 102 | if rbm_num not in skip_rbm: 103 | free_energy = free_energy + self.rbms[rbm_num].free_energy(data) 104 | data = self.sample_layer_hidden(rbm_num, visible=data).detach() 105 | return free_energy 106 | 107 | 108 | def sample(self, 109 | N: int, 110 | gibbs_steps: int=1 111 | ) -> torch.Tensor: 112 | ''' 113 | Sample first from deepest rbm, then sample all the conditionals to visible layer 114 | 115 | :param N: Number of samples (batch size) 116 | :param gibbs_steps: Number of Gibbs steps to use in each RBM 117 | ''' 118 | visible_sample = self.rbms[-1].reconstruct(N=N, gibbs_steps=gibbs_steps) 119 | for i in reversed(range(len(self.rbms)-1)): 120 | visible_sample = self.rbms[i].sample_visible(visible_sample) 121 | return visible_sample 122 | -------------------------------------------------------------------------------- /pytorch_probgraph/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, List 3 | 4 | class ListModule(torch.nn.Module): 5 | ''' 6 | Implements a List Module effectively taking a list of Modules and storing 7 | them. The modules can be indexed as for a usual list. 8 | ''' 9 | def __init__(self, 10 | *args: List[torch.nn.Module]): 11 | ''' 12 | :param *args: List of Modules 13 | ''' 14 | super().__init__() 15 | idx = 0 16 | for module in args: 17 | self.add_module(str(idx), module) 18 | idx += 1 19 | 20 | def __getitem__(self, idx: int): 21 | if idx < 0: 22 | return self.__getitem__(self.__len__() + idx) 23 | if idx < 0 or idx >= len(self._modules): 24 | raise IndexError('index {} is out of range'.format(idx)) 25 | it = iter(self._modules.values()) 26 | for i in range(idx): 27 | next(it) 28 | return next(it) 29 | 30 | def __iter__(self): 31 | return iter(self._modules.values()) 32 | 33 | def __len__(self): 34 | return len(self._modules) 35 | 36 | 37 | class Reverse(torch.nn.Module): 38 | ''' 39 | Reverts any module, with forward replacing backward and 40 | backward replacing forward. 41 | ''' 42 | def __init__(self, module): 43 | ''' 44 | :param module: the module to be reverted 45 | ''' 46 | super().__init__() 47 | self.module = module 48 | 49 | def forward(self, input: torch.Tensor) -> torch.Tensor: 50 | ''' 51 | :param input: input data 52 | :returns: output data 53 | ''' 54 | return self.module.backward(input) 55 | 56 | def backward(self, output: torch.Tensor) -> torch.Tensor: 57 | ''' 58 | :param output: output data 59 | :returns: input data / gradient 60 | ''' 61 | return self.module.forward(input) 62 | 63 | 64 | class Projection(torch.nn.Module): 65 | ''' 66 | | A class for a projection of an input to a different shape \ 67 | effectively mapping from 68 | | [..., inshape[1] .. inshape[-1]] -> [..., outshape[1] .. outshape[-1]] 69 | | only going over the subelements. 70 | | Example input (4,6) to (4,5) (shapes): 71 | | with instart (0, 1) inend (4, 5) outstart (0, 0), outend (4, 4) \ 72 | maps essentially input[:, 1:5] to a new tensor output[:4, 0:4] with shape \ 73 | (4, 5) 74 | | Non-indexed elements in the output are set to zero. 75 | 76 | ''' 77 | def __init__(self, 78 | instart: Tuple[int], 79 | inend: Tuple[int], 80 | inshape: Tuple[int], 81 | outstart: Tuple[int], 82 | outend: Tuple[int], 83 | outshape: Tuple[int]): 84 | ''' 85 | :param instart: List of start indices of different dimensions in input 86 | :param inend: End indices (exclusive) in input 87 | :param inshape: Real input shapes (dimension sizes) 88 | :param outstart: List of start indices of different dimensions in output 89 | :param outend: End indices (exclusive) in output 90 | :param outshape: Real output shapes (dimension sizes) 91 | ''' 92 | super().__init__() 93 | self.inindex = tuple([slice(instart[i], inend[i], 1) for i in range(len(inshape))]) 94 | self.outindex = tuple([slice(outstart[i], outend[i], 1) for i in range(len(outshape))]) 95 | self.inshape = inshape 96 | self.outshape = outshape 97 | 98 | def forward(self, input: torch.Tensor) -> torch.Tensor: 99 | ''' 100 | :param input: Input tensor 101 | :returns: output tensor 102 | ''' 103 | inindex = [slice(None, None, 1) for _ in range(len(input.shape) - len(self.inshape))] 104 | outindex = inindex 105 | inindex = tuple(inindex + list(self.inindex)) 106 | outindex = tuple(outindex + list(self.outindex)) 107 | outshape = [input.shape[i] for i in range(len(input.shape) - len(self.inshape))] 108 | outshape += self.outshape 109 | output = torch.zeros(outshape, device=input.device, requires_grad=False) 110 | output[outindex] += input[inindex] 111 | #print(self.inshape, self.outshape, input.shape, outshape) 112 | #print("Projection", output.shape) 113 | return output 114 | 115 | def backward(self, output: torch.Tensor) -> torch.Tensor: 116 | ''' 117 | :param output: output tensor to backward through module 118 | :returns: input gradient 119 | ''' 120 | outindex = [slice(None, None, 1) for _ in range(len(output.shape) - len(self.outshape))] 121 | inindex = outindex 122 | outindex = tuple(outindex + list(self.outindex)) 123 | inindex = tuple(inindex + list(self.inindex)) 124 | inshape = [output.shape[i] for i in range(len(output.shape) - len(self.inshape))] 125 | inshape += self.inshape 126 | input = torch.zeros(inshape, device=output.device, requires_grad=input.requires_grad) 127 | input[inindex] += output[outindex] 128 | #print("ProjectionBack", input.shape) 129 | return input 130 | 131 | class Expansion1D(torch.nn.Module): 132 | ''' 133 | Adds a dimension to the tensor with certain size, dividing the now second 134 | last dimension. 135 | ''' 136 | def __init__(self, expsize: int): 137 | ''' 138 | :param expsize: Size of new last dimension 139 | ''' 140 | super().__init__() 141 | self.expsize = expsize 142 | 143 | def forward(self, input: torch.Tensor) -> torch.Tensor: 144 | ''' 145 | :param input: input tensor 146 | :returns: output tensor 147 | ''' 148 | shape = list(input.shape) 149 | newshape = shape[:-1] + [shape[-1]/self.expsize, self.expsize] 150 | return input.reshape(newshape) 151 | 152 | class Truncation1D(torch.nn.Module): 153 | ''' 154 | Merges the two last dimensions to one. Last dimension shape is needed for 155 | backward operation. 156 | ''' 157 | def __init__(self, shape: int): 158 | ''' 159 | :param shape: size of the last dimension 160 | ''' 161 | super().__init__() 162 | self.shape = shape 163 | 164 | def forward(self, input: torch.Tensor) -> torch.Tensor: 165 | ''' 166 | :param input: input tensor 167 | :returns: output tensor 168 | ''' 169 | newshape = list(input.shape[:-2]) + [input.shape[-2] * input.shape[-1]] 170 | #print("TrFor", input.shape, newshape) 171 | return input.reshape(newshape) 172 | 173 | def backward(self, output: torch.Tensor) -> torch.Tensor: 174 | ''' 175 | :param output: output to put backwards 176 | :returns: input gradient 177 | ''' 178 | #print("TrBack", input.shape, list(input.shape[:-2]) + [self.shape1, self.shape2]) 179 | return output.reshape(list(output.shape[:-2]) + [output.shape[-1]/self.shape, self.shape]) 180 | 181 | class Expansion2D(torch.nn.Module): 182 | ''' 183 | Expands a tensor in the last two dimensions, effectively to a coarse grid 184 | of smaller grids. 185 | ''' 186 | def __init__(self, expsize1: int, expsize2: int): 187 | ''' 188 | :param expsize1: size of the second last dimension to be created 189 | :param expsize2: size of the last dimension to be created 190 | ''' 191 | super().__init__() 192 | self.expsize1 = expsize1 193 | self.expsize2 = expsize2 194 | 195 | def forward(self, input: torch.Tensor) -> torch.Tensor: 196 | ''' 197 | :param input: input tensor 198 | :returns: output tensor 199 | ''' 200 | shape = list(input.shape) 201 | # print(shape) 202 | newshape = shape[:-2] + \ 203 | [shape[-2]//self.expsize1, 204 | shape[-1]//self.expsize2, 205 | self.expsize1, 206 | self.expsize2] 207 | sliceshape = list(newshape) 208 | sliceshape[-4] = 1 209 | sliceshape[-3] = 1 210 | output = torch.zeros(newshape, device=input.device) 211 | baseslice = [slice(None, None, 1) for _ in range(len(shape)-2)] 212 | for i in range(shape[-2]//self.expsize1): 213 | for j in range(shape[-1]//self.expsize2): 214 | inslice = tuple(baseslice + \ 215 | [slice(self.expsize1*i, self.expsize1*(i+1)), 216 | slice(self.expsize2*j, self.expsize2*(j+1))]) 217 | outslice = tuple(baseslice + \ 218 | [i, 219 | j, 220 | slice(None, None, 1), 221 | slice(None, None, 1)]) 222 | #print(inslice, outslice, input.shape, output.shape) 223 | #print(input[inslice].shape) 224 | #print(outslice) 225 | #print(output[outslice].shape) 226 | output[outslice] += input[inslice] #.view(sliceshape) 227 | return output 228 | 229 | class Truncation2D(torch.nn.Module): 230 | ''' 231 | A module merging the last two dimensions, merging coarse scale in grid 232 | of dimensions -4, -3 and finer resolution in dimensions -2, -1 to 233 | one fine grained grid with two dimensions less. 234 | ''' 235 | def __init__(self): 236 | super().__init__() 237 | 238 | def forward(self, input: torch.Tensor) -> torch.Tensor: 239 | ''' 240 | :param input: input tensor 241 | :returns: output tensor 242 | ''' 243 | shape = input.shape 244 | outputshape = list(input.shape[:-2]) 245 | expsize1 = input.shape[-2] 246 | expsize2 = input.shape[-1] 247 | outputshape[-2] *= input.shape[-2] 248 | outputshape[-1] *= input.shape[-1] 249 | baseslice = [slice(None, None, 1) for _ in range(len(outputshape)-2)] 250 | output = torch.zeros(outputshape, device=input.device, requires_grad=False) 251 | for i in range(shape[-4]): 252 | for j in range(shape[-3]): 253 | outslice = tuple(baseslice + \ 254 | [slice(expsize1*i, expsize1*(i+1)), 255 | slice(expsize2*j, expsize2*(j+1))]) 256 | inslice = tuple(baseslice + \ 257 | [i, 258 | j, 259 | slice(None, None, 1), 260 | slice(None, None, 1)]) 261 | output[outslice] += input[inslice] 262 | #print("Trunc2D", input.shape, output.shape) 263 | return output 264 | 265 | class InvertGrayscale(torch.nn.Module): 266 | ''' 267 | Invert the input around 1, sensible for grayscale images in [0,1] 268 | distributions. 269 | ''' 270 | def __init__(self): 271 | super().__init__() 272 | def forward(self, input: torch.Tensor) -> torch.Tensor: 273 | ''' 274 | :param input: input tensor 275 | :returns: output tensor 276 | ''' 277 | return 1. - input 278 | -------------------------------------------------------------------------------- /pytorch_probgraph/hm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A library implementing a generic sigmoid belief network aka Helmholtz Machine. 3 | 4 | ''' 5 | 6 | from typing import List, Tuple, Union 7 | from .interaction import Interaction 8 | from .unitlayer import UnitLayer 9 | from itertools import chain 10 | import torch 11 | import numpy as np 12 | from .utils import ListModule 13 | 14 | def logsumexp(x, dim=0, keepdim=False): 15 | maxval = torch.max(x, dim=dim, keepdim=True).values 16 | return torch.log(torch.sum(torch.exp(x - maxval), dim=dim, keepdim=keepdim))\ 17 | + torch.sum(maxval, dim, keepdim=keepdim) 18 | 19 | class HelmholtzMachine(torch.nn.Module): 20 | ''' 21 | A multilayer sigmoid belief network with (reweighted) wake-sleep learning. 22 | Using asymmetric conditional probabilities (interaction weights). 23 | From: 24 | 25 | [1] G.Hinton et al. "The wake-sleep algorithm for unsupervised 26 | neural networks" 27 | 28 | [2] Peter Dayan: Helmholtz Machines and Wake-Sleep Learning 29 | http://www.gatsby.ucl.ac.uk/~dayan/papers/d2000a.pdf 30 | Note that this implementation uses tied biases for generative and 31 | reconstructed probabilities. 32 | 33 | [3] https://arxiv.org/pdf/1406.2751.pdf 34 | 35 | [4] https://github.com/jbornschein/reweighted-ws 36 | 37 | ''' 38 | 39 | def __init__(self, 40 | layers: List[UnitLayer], 41 | interactionsUp: List[Interaction], 42 | interactionsDown: List[Interaction], 43 | optimizer: torch.optim.Optimizer): 44 | ''' 45 | :param layers: UnitLayers of Random Units 46 | :param interactionsUp: List of Interactions upwards 47 | :param interactionsDown: List of Interactions downwards 48 | :param optimizer: Optimizer for training 49 | ''' 50 | super().__init__() 51 | if len(interactionsUp) != len(interactionsDown) or \ 52 | len(layers)-1 != len(interactionsUp): 53 | raise ValueError('Non fitting layers') 54 | self.layers = ListModule(*layers) 55 | self.intsUp = ListModule(*interactionsUp) 56 | self.intsDown = ListModule(*interactionsDown) 57 | 58 | self.optim = optimizer 59 | 60 | def sampleQ(self, 61 | data: torch.Tensor 62 | ) -> Tuple[List[torch.Tensor], 63 | List[torch.Tensor], 64 | List[torch.Tensor], 65 | torch.Tensor]: 66 | ''' 67 | :param data: Data to sample Q (reconstruction model) from. 68 | :return: Samples/Means/Logprobs from reconstruction distribution (for all layers) + Total LogProb 69 | ''' 70 | samplesUp = [data] 71 | meansUp = [None] 72 | logprobsUp = [0.] 73 | logprobsUp_total = 0. 74 | nlayers = len(self.layers) 75 | for i in range(nlayers-1): 76 | intterm = self.intsUp[i].gradOutput(self.layers[i].transform(samplesUp[i])) 77 | mean = self.layers[i+1].mean_cond(interaction=intterm) 78 | samp = self.layers[i+1].sample_cond(interaction=intterm) 79 | logprob = self.layers[i+1].logprob_cond(samp, intterm) 80 | samplesUp.append(samp) 81 | meansUp.append(mean) 82 | logprobsUp.append(logprob) 83 | logprobsUp_total += logprob 84 | return samplesUp, meansUp, logprobsUp, logprobsUp_total 85 | 86 | def logprobP(self, 87 | total_samples: List[torch.Tensor] 88 | ) -> Tuple[List[torch.Tensor], torch.Tensor]: 89 | ''' 90 | :param total_samples: Samples from all layers 91 | :return: logprob P of generative model of these samples 92 | ''' 93 | logprob = [self.layers[-1].logprob_cond(total_samples[-1], interaction=0.)] 94 | logprob_total = logprob[0] 95 | for n in reversed(range(len(self.layers)-1)): 96 | interaction = self.intsDown[n].gradOutput(self.layers[n+1].transform(total_samples[n+1])) 97 | logprobn = self.layers[n].logprob_cond(total_samples[n], interaction=interaction) 98 | logprob = [logprobn] + logprob 99 | logprob_total += logprobn 100 | return logprob, logprob_total 101 | 102 | def wakePhaseReweighted(self, 103 | data: torch.Tensor, 104 | ksamples: int=1, 105 | kmax_parallel: int=1000, 106 | train: bool=True, 107 | wakePhaseQ: bool=True 108 | ) -> torch.Tensor: 109 | ''' 110 | According to https://github.com/jbornschein/reweighted-ws/blob/master/learning/models/rws.py 111 | So k samples are drawn with each data point in batch. 112 | 113 | :param data: training batch 114 | :param ksamples: number of samples for reweighting 115 | :param kmax_parallel: max number of samples to run in parallel (for lower memory footprint) 116 | :param train: actually modifiy weights / apply gradients (as this function also returns likelihood) 117 | :param wakePhaseQ: use also wake phase for learning reconstruction model Q 118 | :return: log likelihood of data in the generative model 119 | ''' 120 | 121 | nthrun = 0 122 | left = ksamples 123 | logprobP_total = None 124 | logprobQ_total = None 125 | while left > 0: 126 | take = min(kmax_parallel, left) 127 | left -= take 128 | shape = list(data.shape) 129 | shape_exp = [take] + shape 130 | shape[0] *= take # data is expanded to ksamples*batchsize in dim 0 131 | # print("Nth Run {}, Take: {}".format(nthrun, take)) 132 | nthrun+=1 133 | # sample upward pass q(h | x) 134 | dataExp = data.expand(shape_exp).transpose(0,1).reshape(shape) 135 | samplesUp, meansUp, logprobQ, logprobQ_total_take = self.sampleQ(dataExp) 136 | # 137 | logprobP, logprobP_total_take = self.logprobP(samplesUp) 138 | 139 | logprobP_total_take = logprobP_total_take.reshape((-1, take)) 140 | logprobQ_total_take = logprobQ_total_take.reshape((-1, take)) 141 | if logprobP_total is None: 142 | logprobP_total = logprobP_total_take.detach() 143 | logprobQ_total = logprobQ_total_take.detach() 144 | else: 145 | logprobP_total = torch.cat([logprobP_total, logprobP_total_take.detach()], dim=1) 146 | logprobQ_total = torch.cat([logprobQ_total, logprobQ_total_take.detach()], dim=1) 147 | # loglikelihood 148 | 149 | # calculate sampling weights 150 | if train: 151 | nlayers = len(self.layers)-1 152 | logPQ = (logprobP_total_take - logprobQ_total_take - np.log(take)) 153 | wnorm = logsumexp(logPQ, dim=1) 154 | logw = logPQ - wnorm.reshape(-1, 1) 155 | w = torch.exp(logw).flatten().reshape(-1,1) 156 | # downward pass, taking same batch size 157 | samplesDown = [None]*nlayers + [self.layers[nlayers].sample_cond(N=data.shape[0])] 158 | meansDown = [None]*nlayers + [self.layers[nlayers].mean_cond(N=data.shape[0])] 159 | for i in reversed(range(nlayers)): 160 | intterm = self.intsDown[i].gradOutput(self.layers[i].transform(samplesUp[i+1])) 161 | mean = self.layers[i].mean_cond(interaction=intterm) 162 | samp = self.layers[i].sample_cond(interaction=intterm) 163 | samplesDown[i] = samp 164 | meansDown[i] = mean 165 | # add stochastic batch gradients, ksamples needed because of internal normalziation 166 | for i in range(len(self.layers)-1): 167 | self.layers[i].backward(samplesUp[i] - meansDown[i], factor=-w.view(-1, *([1]*(len(meansDown[i].shape)-1)))*take) 168 | for i in range(len(self.layers)-1): 169 | self.intsDown[i].backward(self.layers[i+1].transform(samplesUp[i+1]), 170 | self.layers[i].transform(samplesUp[i]), 171 | factor=-w*take) 172 | self.intsDown[i].backward(self.layers[i+1].transform(samplesUp[i+1]), 173 | self.layers[i].transform(meansDown[i]), 174 | factor=w*take) 175 | 176 | logPX = logsumexp(logprobP_total - logprobQ_total, dim=1) - np.log(ksamples) 177 | 178 | return logPX 179 | 180 | def sleepPhase(self, 181 | N: int=1, 182 | train: bool=False 183 | ) -> torch.Tensor: 184 | ''' 185 | Learning Q in the sleep phase, generating samples. 186 | 187 | :param N: number of samples to generate 188 | :param train: actually train weights 189 | :return: (samples, means) N samples and their means generating downwards 190 | ''' 191 | nlayers = len(self.layers)-1 192 | samplesDown = [None]*nlayers + [self.layers[nlayers].sample_cond(N=N)] 193 | meansDown = [None]*nlayers + [self.layers[nlayers].mean_cond(N=N)] 194 | # downward pass 195 | for i in reversed(range(nlayers)): 196 | intterm = self.intsDown[i].gradOutput(self.layers[i+1].transform(samplesDown[i+1])) 197 | mean = self.layers[i].mean_cond(interaction=intterm) 198 | samp = self.layers[i].sample_cond(interaction=intterm) 199 | samplesDown[i] = samp 200 | meansDown[i] = mean 201 | 202 | # upward pass 203 | samplesUp = [None]*(nlayers+1) 204 | meansUp = [None]*(nlayers+1) 205 | for i in range(nlayers): 206 | intterm = self.intsUp[i].gradOutput(self.layers[i].transform(samplesDown[i])) 207 | mean = self.layers[i+1].mean_cond(interaction=intterm) 208 | samp = self.layers[i+1].sample_cond(interaction=intterm) 209 | samplesUp[i+1] = samp 210 | meansUp[i+1] = mean 211 | # add stochastic batch gradients 212 | if train: 213 | for i in range(1, len(self.layers)): 214 | self.layers[i].backward(samplesDown[i] - meansUp[i], factor=-1) 215 | for i in range(len(self.layers)-1): 216 | self.intsUp[i].backward(self.layers[i].transform(samplesDown[i]), 217 | self.layers[i+1].transform(samplesDown[i+1]), 218 | factor=-1) 219 | self.intsUp[i].backward(self.layers[i].transform(samplesDown[i]), 220 | self.layers[i+1].transform(meansUp[i+1]), 221 | factor=1) 222 | # self.interaction.backward() 223 | 224 | return samplesDown, meansDown 225 | 226 | def trainReweightedWS(self, 227 | data: torch.Tensor, 228 | ksamples: int = 1, 229 | sleepPhaseQ: bool = True, 230 | wakePhaseQ: bool = False 231 | ) -> torch.Tensor: 232 | ''' 233 | Reweighted Wake-Sleep following https://arxiv.org/pdf/1406.2751.pdf 234 | 235 | :param data: training batch 236 | :param ksamples: number of samples for reweighting 237 | :param sleepPhaseQ: use sleep phase for learning Q 238 | :param wakePhaseQ: use wake phase for learning Q 239 | :return: (estimated) loglikelihood of data 240 | ''' 241 | 242 | self.zero_grad() 243 | loglik = self.wakePhaseReweighted(data, ksamples=ksamples, train=True, wakePhaseQ=wakePhaseQ) 244 | if sleepPhaseQ: 245 | self.sleepPhase(N=data.shape[0], train=True) 246 | self.optim.step() 247 | return loglik 248 | 249 | def trainWS(self, 250 | data: torch.Tensor 251 | ) -> torch.Tensor: 252 | ''' 253 | Traditional wake sleep-algorithm, using only one sample (no reweighting) 254 | and no wake phase Q learning. 255 | :param data: training data batch 256 | ''' 257 | return self.trainReweightedWS(data, ksamples=1, sleepPhaseQ=True, wakePhaseQ=False) 258 | 259 | def loglikelihood(self, 260 | data: torch.Tensor, 261 | ksamples: int=1, 262 | kmax_parallel: int=1000 263 | ) -> torch.Tensor: 264 | ''' 265 | Estimate log likelihood as a byproduct of reweighting. 266 | 267 | :param data: data batch 268 | :param ksamples: number of reweighting samples 269 | :param kmax_parallel: maximal number of parallel samples (memory footprint) 270 | :return: loglikelihood of each batch sample 271 | ''' 272 | return self.wakePhaseReweighted(data, ksamples=ksamples, kmax_parallel=kmax_parallel, train=False) 273 | 274 | def sampleAll(self, 275 | N: int=1 276 | )-> Tuple[List[torch.Tensor], 277 | List[torch.Tensor], 278 | List[torch.Tensor], 279 | torch.Tensor]: 280 | ''' 281 | Sample all layers from generative P, (list of samples). 282 | 283 | :param N: number of samples 284 | :return: batch of N generated data samples and their means for each layer 285 | ''' 286 | return self.sleepPhase(N=N, train=False) 287 | 288 | def sample(self, 289 | N: int = 1 290 | ) -> torch.Tensor: 291 | ''' 292 | Sample only visible layer from generative P. 293 | 294 | :param N: number of samples 295 | :return: batch of N generated data samples 296 | ''' 297 | return self.sleepPhase(N=N, train=False)[0][0] 298 | -------------------------------------------------------------------------------- /examples/evaluate_emnist.py: -------------------------------------------------------------------------------- 1 | ''' 2 | An Evaluation Script for Graphical Models with likelihood estimation trained 3 | on EMNIST Digits and tested on EMNIST characters. 4 | 5 | For tests use in conjuction with a simple DBN model (testing limits things to few data points): 6 | 7 | python3 evaluate_emnist.py --directory test --file ModelDBN.py --model ModelDBN --tqdm --maxepochs 5 --testing 8 | 9 | ''' 10 | 11 | 12 | 13 | import torch 14 | import torchvision 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | 18 | import site 19 | 20 | site.addsitedir('..') 21 | 22 | #from dbm import DeepBoltzmannMachine 23 | from itertools import chain 24 | from tqdm import tqdm 25 | from os import mkdir 26 | import json 27 | import argparse 28 | from os.path import basename, dirname, join 29 | from shutil import copyfile 30 | from scipy.integrate import simps 31 | from tqdm import tqdm 32 | 33 | device = None 34 | 35 | def identity(x): 36 | return x 37 | 38 | def build_roc_graph(model, positive, negative, ais=False): 39 | ''' 40 | This function returns the TPR and FPR for 100 discriminator values. 41 | :param model: the model to be tested. 42 | :param positive: positive examples 43 | :param negative: negative examples 44 | :param ais: if the model uses ais to evaluate the probability, set this parameter to True 45 | :return: FPR [torch.tensor(100, float)], TPR [torch.tensor(100, float)] 46 | ''' 47 | 48 | 49 | # For models with AIS, the partitioning sum only has to be evaluated once. 50 | if ais: 51 | log_z = model.get_log_Z(1000, 1000) 52 | points = 100 53 | TP_R = torch.zeros([points]) 54 | FP_R = torch.zeros([points]) 55 | u_positive = [] 56 | u_negative = [] 57 | 58 | # Evaluate the log likelihood for the positive and negative examples. 59 | if ais: 60 | for pos in positive: 61 | u_positive.append(model.loglikelihood(data = pos, log_Z = log_z).cpu()) 62 | for neg in negative: 63 | u_negative.append(model.loglikelihood(data = neg, log_Z = log_z).cpu()) 64 | else: 65 | for pos in positive: 66 | u_positive.append(model.loglikelihood(data = pos).cpu()) 67 | for neg in negative: 68 | u_negative.append(model.loglikelihood(data = neg).cpu()) 69 | u_positive = torch.cat(u_positive, dim=0) 70 | u_negative = torch.cat(u_negative, dim=0) 71 | 72 | # Throw away samples with underflow / overflow. 73 | inf_positive = torch.zeros_like(u_positive) 74 | inf_positive[torch.isinf(u_positive)] = 1 75 | u_positive = u_positive[inf_positive == 0] 76 | 77 | inf_negative = torch.zeros_like(u_negative) 78 | inf_negative[torch.isinf(u_negative)] = 1 79 | u_negative = u_negative[inf_negative == 0] 80 | 81 | # Calculate the step size for the discriminator. 82 | min_like = min(torch.min(u_positive), torch.min(u_negative)) 83 | max_like = max(torch.max(u_positive), torch.max(u_negative)) 84 | diff = (max_like - min_like) / points 85 | 86 | # Count, how many samples were classified true positive or false positive. 87 | for i in range(points): 88 | TP_R[i] = (u_positive[u_positive >= min_like + i*diff]).size(0) / u_positive.size(0) 89 | FP_R[i] = (u_negative[u_negative >= min_like + i*diff]).size(0) / u_negative.size(0) 90 | return TP_R, FP_R 91 | 92 | 93 | def plot_roc_graph(fig, TP_R, FP_R, directory): 94 | ''' 95 | This function plots the ROC curve for TPR and FPR with common discriminator values and stores it in the path directory. 96 | :param fig: 97 | :param TP_R: TPR for different discriminator values 98 | :param FP_R: FPR for different discriminator values 99 | :param directory: directory to save the plot 100 | :return: ROC curve 101 | ''' 102 | 103 | # Set up the plot. 104 | plt.clf() 105 | ax = fig.subplots() 106 | ax.plot(FP_R, TP_R) 107 | 108 | ax.set_xlabel('False Positive Rate') 109 | ax.set_ylabel('True Positive Rate') 110 | 111 | # Calculate the Yuoden index 112 | youden_index_list = TP_R - FP_R 113 | youden_index = youden_index_list.max() 114 | 115 | # Delete not strictly monotonic values for the integration. 116 | mask = torch.zeros([TP_R.size(0)], dtype=torch.long) 117 | for i in range(TP_R.size(0) - 1): 118 | if FP_R[i] == FP_R[i + 1]: 119 | mask[i + 1] = 1 120 | 121 | # Integrate the curve. 122 | integral = simps(TP_R[mask == 0], FP_R[mask == 0]) 123 | 124 | # Plot the curve. 125 | fig.suptitle('Youden index = ' + str(round(float(youden_index), 2)) + ' Integral = ' + str(round(abs(integral), 2))) 126 | fig.savefig(directory + "/roc.png") 127 | 128 | def main(): 129 | description = ''' 130 | Evaluate a hierarchical graphical model on emnist. 131 | Models are learnt via a train() method, samples are generated using 132 | generate() and the log likelihood (per sample) is estimated using 133 | loglikelihood(). 134 | This script takes an file defining the model, the model class name, 135 | some model arguments and whether it should be trained anew. 136 | It trains the models using a predefined scheme and uses a validation set 137 | to be able to stop early. 138 | Finally a test likelihood is calculated, the model is saved, some samples 139 | are generated and a discriminator between EMNIST digits and characters 140 | is evaluated, which is based on the loglikelihood() estimation. 141 | This is done via an ROC-curve, which is stored as well. 142 | Note that the likelihood-estimator does not have to be normalized, for 143 | the model to work also as a discriminator. Just the results might not be 144 | as interpretable. 145 | ''' 146 | 147 | parser = argparse.ArgumentParser(description=description) 148 | parser.add_argument('--directory', dest='directory', type=str, 149 | help='directory to store the results', 150 | required=True) 151 | parser.add_argument('--file', dest='file', type=str, 152 | help='file to read the model from', 153 | required=True) 154 | parser.add_argument('--model', type=str, dest='model', 155 | help='model class name') 156 | parser.add_argument('--retrain', dest='retrain', 157 | const=True, default=False, action='store_const', 158 | help='retrain previously trained model') 159 | parser.add_argument('--ais', dest='ais', 160 | const=True, default=False, action='store_const', 161 | help='use ais partition sum precalculation') 162 | parser.add_argument('--minepochs', dest='minepochs', type=int, default=20, 163 | help='minimal number of epochs') 164 | parser.add_argument('--maxepochs', dest='maxepochs', type=int, default=1000, 165 | help='maximal number of epochs') 166 | parser.add_argument('--valepochs', dest='valepochs', type=int, default=5, 167 | help='validation loglik after every $ epochs') 168 | parser.add_argument('--reeval', dest='reeval', 169 | const=True, default=False, action='store_const', 170 | help='reevaluate loglikelihood on test and valid data') 171 | parser.add_argument('--tqdm', dest='tqdm', 172 | const=tqdm, default=identity, action='store_const', 173 | help='use tqdm to show progress') 174 | parser.add_argument('--no-binarize', dest='binarize', 175 | const=False, default=True, action='store_const', 176 | help='don\'t binarize EMNIST data') 177 | parser.add_argument('--copy', dest='copy', 178 | const=False, default=True, action='store_const', 179 | help='copy model file to eval directory') 180 | parser.add_argument('--store-intermediate', dest='storeinterm', 181 | const=False, default=True, action='store_const', 182 | help='copy model file to eval directory') 183 | parser.add_argument('--testing', dest='testing', 184 | const=True, default=False, action='store_const', 185 | help='use only 2000 samples from EMNIST to test') 186 | parser.add_argument('--use-labels', dest='uselabels', 187 | const=True, default=False, action='store_const', 188 | help='use labels for training the model (classifier)') 189 | 190 | #parser.add_argument('') 191 | 192 | args = parser.parse_args() 193 | if not args.model: 194 | args.model = basename(args.file).split('.')[0] 195 | 196 | # import model from file 197 | site.addsitedir(dirname(args.file)) 198 | #Model = None 199 | try: 200 | print("from " + basename(args.file).split('.')[0] + " import " + args.model + " as NewModel") 201 | exec("from " + basename(args.file).split('.')[0] + " import " + args.model + " as NewModel") 202 | Model = locals()['NewModel'] 203 | except: 204 | parser.print_help() 205 | exit(-1) 206 | 207 | try: 208 | mkdir(args.directory) 209 | except FileExistsError: 210 | pass 211 | 212 | if args.copy: 213 | copyfile(args.file, join(args.directory, basename(args.file))) 214 | 215 | torch.random.manual_seed(42) 216 | try: 217 | torch.cuda.init() 218 | device = torch.cuda.current_device() 219 | except: 220 | device = torch.device("cpu") 221 | 222 | model = Model().to(device) 223 | 224 | ### Define likelihood averaging 225 | log_Z = 0. 226 | def average_loglikelihood(data, func, log_Z=0.): 227 | res = 0. 228 | n = 0. 229 | if args.ais: 230 | for dat in data: 231 | res = res + torch.sum(func(dat.to(device), log_Z=log_Z)).detach() 232 | n += dat.shape[0] 233 | else: 234 | for dat in data: 235 | res = res + torch.sum(func(dat.to(device))).detach() 236 | n += dat.shape[0] 237 | return float(res)/n 238 | 239 | ### Load EMNIST dataset 240 | 241 | from emnist import extract_training_samples 242 | numbers, n_labels = extract_training_samples('digits') 243 | characters, c_labels = extract_training_samples('letters') 244 | 245 | numbers = torch.tensor(numbers/255., dtype=torch.float).cpu().clone().reshape(-1, 1, 28, 28) 246 | characters = torch.tensor(characters/255., dtype=torch.float).cpu().clone().reshape(-1, 1, 28, 28) 247 | 248 | if args.testing: 249 | numbers = numbers[:2000] 250 | characters = characters[:2000] 251 | ### Binarization of the data 252 | 253 | if args.binarize: 254 | numbers[numbers >= 0.5] = 1. 255 | numbers[numbers < 0.5] = 0. 256 | 257 | 258 | ### Split into training, validation and test set 259 | 260 | batch_size_train = 10 261 | batch_size_valid = 10 262 | batch_size_test = 10 263 | train_size = int(3/4*numbers.shape[0]) 264 | valid_size = int(1/8*numbers.shape[0]) 265 | test_size = int(1/8*numbers.shape[0]) 266 | 267 | train_data = numbers[:train_size].reshape(train_size//batch_size_train, batch_size_train, 1, 28, 28) 268 | train_labels = torch.tensor(n_labels[:train_size].reshape(train_size//batch_size_train, batch_size_train)) 269 | valid_data = numbers[train_size:train_size+valid_size].reshape(valid_size//batch_size_valid, batch_size_valid, 1, 28, 28) 270 | test_data = numbers[train_size+valid_size:].reshape(test_size//batch_size_valid, batch_size_valid, 1, 28, 28) 271 | 272 | # train the model if needed 273 | 274 | trained = True 275 | try: 276 | state_dict = torch.load(args.directory + "/model.pt") 277 | model.load_state_dict(state_dict) 278 | except FileNotFoundError: 279 | trained = False 280 | if not trained or args.retrain: 281 | valid_loglikelihoods = [] 282 | valid_loglikelihood = -np.inf 283 | last_valid_loglikelihood = -np.inf 284 | epoch = 0 285 | while valid_loglikelihood >= last_valid_loglikelihood or \ 286 | epoch < args.minepochs: 287 | last_valid_loglikelihood = valid_loglikelihood 288 | if epoch >= args.maxepochs: 289 | break 290 | # Train 291 | if args.uselabels: 292 | model.train(data=args.tqdm(train_data), 293 | labels=train_labels, 294 | epochs=args.valepochs, 295 | device=device) 296 | else: 297 | model.train(data=args.tqdm(train_data), 298 | epochs=args.valepochs, 299 | device=device) 300 | 301 | epoch += args.valepochs 302 | # Evaluate likelihood 303 | if args.ais: 304 | log_Z = model.get_log_Z(1000, 1000) 305 | valid_loglikelihood = average_loglikelihood(args.tqdm(valid_data), model.loglikelihood, log_Z=log_Z) 306 | valid_loglikelihoods.append(valid_loglikelihood) 307 | 308 | if args.storeinterm: 309 | torch.save(model.state_dict(), args.directory + "/model_intermediate_{}.pt".format(epoch)) 310 | 311 | torch.save(model.state_dict(), args.directory + "/model.pt") 312 | 313 | try: 314 | with open(directory + "/results.json", "r") as fp: 315 | resultdict = json.load(fp) 316 | except: 317 | 318 | if args.ais: 319 | log_Z = model.get_log_Z(1000, 1000) 320 | 321 | test_loglikelihood = average_loglikelihood(args.tqdm(test_data), model.loglikelihood, log_Z=log_Z) 322 | train_loglikelihood = average_loglikelihood(args.tqdm(train_data), model.loglikelihood, log_Z=log_Z) 323 | valid_loglikelihood = average_loglikelihood(args.tqdm(valid_data), model.loglikelihood, log_Z=log_Z) 324 | 325 | resultdict = {"valid_loglikelihood": valid_loglikelihood, 326 | "test_loglikelihood": test_loglikelihood, 327 | "train_loglikelihood": train_loglikelihood} 328 | 329 | with open(args.directory + "/results.json", "w") as fp: 330 | json.dump(resultdict, fp) 331 | 332 | samples = model.generate(N=32) 333 | fig = plt.figure(figsize=[16,8]) 334 | for i in range(32): 335 | plt.subplot(4, 8, i+1) 336 | plt.imshow(samples[i].detach().cpu().numpy().reshape(28, 28), cmap='gray', interpolation='none') 337 | plt.xticks([]) 338 | plt.yticks([]) 339 | fig.savefig(args.directory + "/generated.png") 340 | 341 | fig = plt.figure(figsize=[16,8]) 342 | 343 | TP_R, FP_R = build_roc_graph(model, numbers[-1000:].reshape(-1,10,1,28,28).to(device), characters[:1000].reshape(-1,10,1,28,28).to(device), ais = args.ais) 344 | 345 | #print(TP_R, FP_R) 346 | 347 | plot_roc_graph(fig, TP_R, FP_R, directory=args.directory) 348 | 349 | exit(0) 350 | 351 | if __name__ == '__main__': 352 | main() 353 | -------------------------------------------------------------------------------- /pytorch_probgraph/unitlayer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A module describing a layer of units and a bias energy (in terms of 3 | probabilistic energy terms). 4 | ''' 5 | import torch 6 | from typing import Tuple, Union 7 | from numpy import pi 8 | import numpy as np 9 | 10 | epsilon = 1e-6 11 | 12 | class UnitLayer(torch.nn.Module): 13 | ''' 14 | Abstract Class for representing layers of random variables of various shape. 15 | ''' 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def logprob_cond(self, 20 | input: torch.Tensor, 21 | interaction: torch.Tensor 22 | ) -> torch.Tensor: 23 | ''' 24 | Returns the conditional logprobability of a sample input given some 25 | interaction term of the same shape. 26 | 27 | :param input: the input sample/batch 28 | :param interaction: the exponential interaction term 29 | :return: the logprobability of an input given some interaction 30 | ''' 31 | return NotImplementedError 32 | 33 | def sample_cond(self, 34 | interaction: torch.Tensor=None, 35 | N: int=1 36 | ) -> torch.Tensor: 37 | ''' 38 | Samples from the conditional probability given some interaction term. 39 | With A being the current unit this is the term B in 40 | the energy exponential :math:`e^{transf(A) * B}`. 41 | 42 | :param interaction: the exponential interaction term (None == 0) 43 | :param N: the number of samples to be drawn in case interaction == None 44 | :return: batch of samples either of size of the interaction batch or N 45 | ''' 46 | return NotImplementedError 47 | 48 | def mean_cond(self, 49 | interaction: torch.Tensor=None, 50 | N: int=1 51 | ) -> torch.Tensor: 52 | ''' 53 | Returns the mean of the conditional probability given some interaction tensor. 54 | 55 | :param interaction: the exponential interaction term (None == 0) 56 | :param N: number of batch copies in case interaction == None 57 | :return: batch of means either of size of the interaction batch or N 58 | ''' 59 | return NotImplementedError 60 | 61 | def transform(self, 62 | input: torch.Tensor 63 | ) -> torch.Tensor: 64 | ''' 65 | Transforms a value such that the result leads to a linear interaction. 66 | Random variable x :math:`\\rightarrow` exponential family term :math:`e^{transf(x) * y}` 67 | 68 | :param input: input data to be transformed 69 | :return: transformed input 70 | ''' 71 | return NotImplementedError 72 | 73 | def transform_invert(self, 74 | transformed_input: torch.Tensor 75 | ) -> torch.Tensor: 76 | ''' 77 | Transforms a value such that the result leads to a linear interaction. 78 | Random variable exponential family term :math:`e^{x' * y} \\rightarrow x = transinv(x')` 79 | 80 | :param transformed_input: Some transformed variable(s) 81 | :return: The original (batch of) variable(s) x 82 | ''' 83 | return NotImplementedError 84 | 85 | def logprob_joint(self, 86 | input : torch.Tensor 87 | ) -> torch.Tensor: 88 | ''' 89 | Returns an unnormalized probability weight given some input. 90 | This is only an unnormalized probability since there are not interactions. 91 | An interaction can simply be added by + input @ interaction 92 | 93 | :param input: some variable samples 94 | :return: their unnormalized loglikelihood (no interaction terms only bias) 95 | ''' 96 | return NotImplementedError 97 | 98 | def free_energy(self, 99 | interaction: torch.Tensor 100 | ) -> torch.Tensor: 101 | ''' 102 | Computes the partition sum given an interaction term. 103 | Example: binary 104 | 105 | .. math:: 106 | 107 | -\log(1 + e^{bias + interaction}) 108 | 109 | :param interaction: interaction term 110 | :return: partition sum / normalizing factor of Gibbs distribution 111 | ''' 112 | return NotImplementedError 113 | 114 | def backward(self, 115 | input: torch.Tensor, 116 | factor: Union[torch.Tensor, float]=1. 117 | ) -> None: 118 | ''' 119 | Computes the gradient of the internal parameters wrt to the input data. 120 | :param input: input data 121 | :return: None 122 | ''' 123 | pass 124 | 125 | class BernoulliLayer(UnitLayer): 126 | ''' 127 | A UnitLayer of bernoulli units modelled with probabilities as a sigmoid. 128 | ''' 129 | def __init__(self, 130 | bias: torch.Tensor 131 | ) -> torch.Tensor: 132 | ''' 133 | :param bias: Bias for the sigmoid modelling the bernoulli probability. 134 | ''' 135 | super().__init__() 136 | self.register_parameter("bias", torch.nn.Parameter(bias)) 137 | 138 | def mean_cond(self, 139 | interaction: Union[torch.Tensor, None] = None, 140 | N: int=1 141 | ) -> torch.Tensor: 142 | if interaction is not None: 143 | weight = self.bias + interaction 144 | else: 145 | weight = torch.zeros([N] + list(self.bias.shape[1:]), device=self.bias.device) + self.bias 146 | return torch.sigmoid(weight) 147 | 148 | def sample_cond(self, 149 | interaction: Union[torch.Tensor, None]=None, 150 | N: int=1 151 | ) -> torch.Tensor: 152 | return torch.bernoulli(self.mean_cond(interaction=interaction, N=N)) 153 | 154 | def transform(self, 155 | input: torch.Tensor 156 | ) -> torch.Tensor: 157 | return input 158 | 159 | def transform_invert(self, 160 | transformed_input: torch.Tensor 161 | ) -> torch.Tensor: 162 | return transformed_input 163 | 164 | def logprob_cond(self, 165 | input: torch.Tensor, 166 | interaction: Union[torch.Tensor, float]=0. 167 | ) -> torch.Tensor: 168 | newdim = list(range(1, len(input.shape))) 169 | return torch.sum(torch.log(input*torch.sigmoid(interaction + self.bias) + (1.-input)*(1.-torch.sigmoid(interaction + self.bias))), dim=newdim) 170 | 171 | def logprob_joint(self, 172 | input: torch.Tensor 173 | ) -> torch.Tensor: 174 | return torch.sum(input * self.bias, dim=list(range(1, len(input.shape)))) 175 | 176 | def free_energy(self, 177 | interaction: torch.Tensor 178 | ) -> torch.Tensor: 179 | return -torch.sum(torch.log1p(torch.exp(self.bias + interaction)), dim=list(range(1, len(interaction.shape)))) 180 | 181 | def backward(self, 182 | input: torch.Tensor, 183 | factor: Union[torch.Tensor, float]=1. 184 | ) -> None: 185 | if self.bias.requires_grad: 186 | self.bias.backward((factor*input).sum(dim=0, keepdim=True).detach()/input.shape[0]) 187 | 188 | 189 | class GaussianLayer(UnitLayer): 190 | ''' 191 | A UnitLayer of Gaussian distributed variables. 192 | A layer with :math:`-\\frac{x^2}{2 \sigma^2} + x (bias + interaction)` energy. 193 | ''' 194 | def __init__(self, 195 | bias: torch.Tensor, 196 | logsigma: torch.Tensor): 197 | ''' 198 | :param bias: bias for the Gaussian 199 | :param logsigma: logarithm of the standard deviation sigma 200 | ''' 201 | super().__init__() 202 | self.register_parameter("bias", torch.nn.Parameter(bias)) 203 | self.register_parameter("logsigma", torch.nn.Parameter(logsigma)) 204 | 205 | def transform(self, 206 | input: torch.Tensor 207 | ) -> torch.Tensor: 208 | return input 209 | 210 | def transform_invert(self, 211 | transformed_input: torch.Tensor 212 | ) -> torch.Tensor: 213 | return transformed_input 214 | 215 | def mean_cond(self, 216 | interaction: Union[torch.Tensor, None] = None, 217 | N: int=1 218 | ) -> torch.Tensor: 219 | if interaction is not None: 220 | return torch.exp(self.logsigma*2)*(interaction + self.bias) 221 | else: 222 | return torch.exp(self.logsigma*2)*self.bias.expand(*([N] + list(self.bias.shape[1:]))) 223 | 224 | def sample_cond(self, 225 | interaction: Union[torch.Tensor, None]=None, 226 | N: int=1 227 | ) -> torch.Tensor: 228 | mean = self.mean_cond(interaction=interaction, N=N) 229 | return mean + torch.normal(torch.zeros_like(mean), torch.ones_like(mean))*torch.exp(self.logsigma) 230 | 231 | def free_energy(self, 232 | interaction: torch.Tensor 233 | ) -> torch.Tensor: 234 | return -torch.sum(0.5*(interaction + self.bias)*torch.exp(2*self.logsigma), dim=list(range(1, len(interaction.shape)))) 235 | 236 | def logprob_cond(self, 237 | input: torch.Tensor, 238 | interaction: Union[torch.Tensor, float]=0. 239 | ) -> torch.Tensor: 240 | norm = -0.5*torch.log(2*pi*torch.ones_like(self.logsigma)).sum() - torch.sum(self.logsigma) 241 | exp = input * (interaction + self.bias) - (input)**2/2/torch.exp(2*self.logsigma) 242 | return norm + exp.sum(dim=list(range(1, len(exp.shape)))) 243 | 244 | def logprob_joint(self, 245 | input: torch.Tensor 246 | ) -> torch.Tensor: 247 | norm = -0.5*torch.log(2*pi*torch.ones_like(self.logsigma)).sum() - self.logsigma.sum() 248 | exp = -input**2/2/torch.exp(self.logsigma*2) + input*self.bias - 0.5*torch.log(2 * pi) - 2*self.logsigma 249 | return exp.sum(dim=list(range(1, len(exp.shape)))) + norm 250 | 251 | def backward(self, 252 | input: torch.Tensor, 253 | factor: Union[torch.Tensor, float]=1. 254 | ) -> None: 255 | if self.bias.requires_grad: 256 | grad_bias = (factor*input).sum(dim=0, keepdim=True) / input.shape[0] 257 | self.bias.backward(grad_bias.detach()) 258 | if self.logsigma.requires_grad: 259 | var = (input**2).sum(dim=0, keepdim=True)/input.shape[0] 260 | grad_logsigma = factor* var* torch.exp(-2*self.logsigma) - 1. 261 | self.logsigma.backward(grad_logsigma.detach()) 262 | 263 | 264 | class DiracDeltaLayer(UnitLayer): 265 | ''' 266 | A Layer to model a dirac delta == copying input-interaction to output 267 | ''' 268 | def __init__(self, base_shape=(1,), deltafactor=1000., device=None): 269 | ''' 270 | :param base_shape: shape of the variables (1st is batch dimension) 271 | :param deltafactor: the inverse variance of the approximating gaussian 272 | :param device: the device to operate on 273 | ''' 274 | super().__init__() 275 | self.deltafactor = deltafactor 276 | self.base_shape = base_shape 277 | self.device=device 278 | 279 | def free_energy(self, 280 | interaction: torch.Tensor 281 | ) -> torch.Tensor: 282 | return 0. 283 | 284 | def backward(self, 285 | input: torch.Tensor, 286 | factor: Union[torch.Tensor, float]=1. 287 | ) -> None: 288 | pass 289 | 290 | def transform(self, 291 | input: torch.Tensor 292 | ) -> torch.Tensor: 293 | return input 294 | 295 | def logprob_joint(self, 296 | input: torch.Tensor 297 | ) -> torch.Tensor: 298 | return torch.zeros([input.shape[0]] + list(self.base_shape[1:])) 299 | 300 | def transform_invert(self, 301 | transformed_input: torch.Tensor 302 | ) -> torch.Tensor: 303 | return transformed_input 304 | 305 | def logprob_cond(self, 306 | input: torch.Tensor, 307 | interaction: Union[torch.Tensor, float]=0. 308 | ) -> torch.Tensor: 309 | return -self.deltafactor*(input - interaction)**2 310 | 311 | def mean_cond(self, 312 | interaction: Union[torch.Tensor, None] = None, 313 | N: int=1 314 | ) -> torch.Tensor: 315 | if interaction is not None: 316 | return interaction 317 | else: 318 | return torch.zeros([N] + list(self.base_shape[1:])) 319 | 320 | def sample_cond(self, 321 | interaction: Union[torch.Tensor, None]=None, 322 | N: int=1 323 | ) -> torch.Tensor: 324 | if interaction is None: 325 | return torch.zeros([N] + list(self.base_shape[1:]), device=self.device) 326 | else: 327 | return interaction 328 | 329 | class CategoricalLayer(UnitLayer): 330 | ''' 331 | Essentially a multinomial distribution in the last layer as a one-hot encoding. 332 | The samples are generated via Gumbel softmax samples. 333 | Mean values via softmax. 334 | ''' 335 | def __init__(self, bias : torch.Tensor): 336 | ''' 337 | :param bias: Categorical Bias, Last dim represents categories 338 | ''' 339 | #if biasIn.shape[:-1] != biasOut.shape: 340 | # raise DimensionError('') 341 | 342 | super().__init__() 343 | self.register_parameter("bias", torch.nn.Parameter(bias)) 344 | 345 | def mean_cond(self, 346 | interaction: Union[torch.Tensor, None] = None, 347 | N: int=1 348 | ) -> torch.Tensor: 349 | logprobs = self.bias.expand(*([N] + list(self.bias.shape[1:]))) 350 | if interaction is not None: 351 | #print(logprobs.device, interaction.device) 352 | logprobs = logprobs + interaction 353 | return torch.softmax(logprobs, dim=-1) 354 | 355 | def sample_cond(self, 356 | interaction: Union[torch.Tensor, None]=None, 357 | N: int=1 358 | ) -> torch.Tensor: 359 | # print(self.biasIn.shape, self.biasOut.shape) 360 | shape = self.bias.shape 361 | logprobs = self.bias.expand(*([N] + list(shape[1:]))) 362 | if interaction is not None: 363 | logprobs = logprobs + interaction 364 | return torch.nn.functional.gumbel_softmax(logprobs, dim=-1, hard=True).detach() 365 | 366 | def logprob_cond(self, 367 | input: torch.Tensor, 368 | interaction: Union[torch.Tensor, float]=0. 369 | ) -> torch.Tensor: 370 | res = torch.zeros(input.shape[:-1], device=input.device) 371 | acc = input.sum(dim=-1) 372 | res[acc < 1. + epsilon] = -np.inf 373 | res[acc > 1. + epsilon] = -np.inf 374 | logprobs = torch.logsoftmax(self.bias + interaction, dim=-1, keepdim=True) 375 | logprob = torch.sum(input * logprobs, dim=list(range(1, len(input.shape)))) 376 | logprob += res.sum(dim=list(range(1, len(res)))) 377 | 378 | def transform(self, 379 | input: torch.Tensor 380 | ) -> torch.Tensor: 381 | return input 382 | 383 | def transform_invert(self, 384 | transformed_input: torch.Tensor 385 | ) -> torch.Tensor: 386 | return transformed_input 387 | 388 | def logprob_joint(self, 389 | input: torch.Tensor 390 | ) -> torch.Tensor: 391 | logprob = torch.sum(input * self.bias, dim=list(range(1, len(input.shape)))) 392 | res = torch.zeros(input.shape[:-1], device=input.device) 393 | acc = input.sum(dim=-1) 394 | res[acc > 1. + epsilon] = -np.inf 395 | res[acc < 1. - epsilon] = -np.inf 396 | logprob += res.sum(dim=list(range(1, len(res)))) 397 | return logprob 398 | 399 | def free_energy(self, 400 | interaction: torch.Tensor 401 | ) -> torch.Tensor: 402 | logprobs = torch.nn.functional.log_softmax(self.bias + interaction, dim=-1) 403 | return torch.sum(logprobs, dim=list(range(1, len(interaction)))) 404 | 405 | def backward(self, 406 | input: torch.Tensor, 407 | factor: Union[torch.Tensor, float]=1. 408 | ) -> None: 409 | if self.bias.requires_grad: 410 | self.bias.backward((input*factor).sum(dim=0, keepdim=True)/input.shape[0]) 411 | -------------------------------------------------------------------------------- /pytorch_probgraph/interaction.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A Module providing some Interactions between unitlayers (in the sense of 3 | probabilistic energy terms). 4 | ''' 5 | from operator import mul 6 | from functools import reduce 7 | import torch 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import torch.nn as nn 11 | from typing import Union, Tuple 12 | from .utils import ListModule 13 | from .utils import Projection, Expansion1D, Truncation1D 14 | from .utils import Expansion2D, InvertGrayscale 15 | 16 | class Interaction(nn.Module): 17 | ''' 18 | | A General class for Interactions between input and output as an energy. \ 19 | Could be anything even non-linear interactions. 20 | 21 | | E.g. one could use a torch-module all that is needed that things \ 22 | are differentiable for both input and output (energy is something \ 23 | like a negative loss). 24 | | Note that input and output have to be transformed for some distributions \ 25 | from the Exponential Family \ 26 | (https://en.wikipedia.org/wiki/Exponential_family#Interpretation). 27 | | For all methods, this transformation is assumed to be done previously. 28 | ''' 29 | def __init__(self): 30 | super(Interaction, self).__init__() 31 | 32 | def negenergy(self, 33 | input: torch.Tensor, 34 | output: torch.Tensor, 35 | factor: Union[torch.Tensor, float] = 1., 36 | **kwargs 37 | ) -> torch.Tensor: 38 | ''' 39 | Defines the (negative) interaction energy, e.g. something like 40 | input @ W @ output for a fully connected interaction, 41 | with @ being a matrix multiplication. 42 | 43 | :param input: Input values () 44 | :param output: Output values 45 | :param factor: general factor on top (can also be a tensor of batch_dim) 46 | :param **kwargs: additional optional parameters 47 | ''' 48 | return 0. 49 | 50 | def gradInput(self, 51 | output: torch.Tensor, 52 | input: Union[torch.Tensor, None] = None, 53 | factor: Union[torch.Tensor, float] = 1., 54 | **kwargs 55 | ) -> torch.Tensor: 56 | ''' 57 | The negenergy gradient wrt the input. Can be calculated analytically or by 58 | using autograd. For non-linear interactions, the input is needed. 59 | 60 | :param output: the output value to interact with 61 | :param input: an input value for which the gradient is calculated 62 | :param factor: general factor on top (can also be a tensor of batch_dim) 63 | :param **kwargs: additional optional parameters 64 | :return: gradient of negenergy wrt. input transformed variable 65 | ''' 66 | return NotImplementedError 67 | 68 | def gradOutput(self, 69 | input: torch.Tensor, 70 | factor: Union[torch.Tensor, float] = 1., 71 | **kwargs 72 | ) -> torch.Tensor: 73 | ''' 74 | The negenergy gradient wrt the output. Can be calculated analytically or 75 | by using autograd. 76 | 77 | :param input: the input values 78 | :param factor: general factor on top (can also be a tensor of batch_dim) 79 | :param **kwargs: additional optional parameters 80 | :return: gradient of negenergy wrt. output transformed variable 81 | ''' 82 | return NotImplementedError 83 | 84 | def backward(self, 85 | input: torch.Tensor, 86 | output: torch.Tensor, 87 | factor: Union[torch.Tensor, float] = 1., 88 | **kwargs 89 | ) -> torch.Tensor: 90 | ''' 91 | The energy gradient wrt the internal parameters. It is added to the 92 | parameters in their .grad part. 93 | Only computes the gradient and doesn't return anything. 94 | 95 | :param input: input values 96 | :param output: output values 97 | :param factor: general factor on top (can also be a tensor of batch_dim) 98 | :param **kwargs: additional optional parameters 99 | :return: None 100 | ''' 101 | return NotImplementedError 102 | 103 | 104 | class InteractionLinear(Interaction): 105 | ''' 106 | A simple linear Interaction. 107 | ''' 108 | 109 | def __init__(self, 110 | inputShape: Union[Tuple[int], int] = 1, 111 | outputShape: Union[Tuple[int], int] = 1, 112 | weight: Union[torch.Tensor, None] = None, 113 | dev_factor: float = 1., 114 | batch_norm: bool = False): 115 | ''' 116 | Init of a InteractionLinear between two layers of inputShape and outputShape. 117 | 118 | :param inputShape: Size or Shape (Tuple of Ints) of input UnitLayer 119 | :param outputShape: Size or Shape of output UnitLayer 120 | :param weight: externally define weight matrix 121 | :param dev_factor: Deviation factor for Xavier initialization 122 | :param batch_norm: If batch_norm should be computed for negenergy 123 | ''' 124 | 125 | super().__init__() 126 | self.batch_norm = batch_norm 127 | if weight is not None: 128 | self.weight = nn.Parameter(weight) 129 | else: 130 | if isinstance(inputShape, int): 131 | self.inputShape = (inputShape,) 132 | self.outputShape = (outputShape,) 133 | self.inputSize = inputShape 134 | self.outputSize = outputShape 135 | else: 136 | self.inputShape = inputShape 137 | self.outputShape = outputShape 138 | self.inputSize = reduce(mul, inputShape, 1) 139 | self.outputSize = reduce(mul, outputShape, 1) 140 | weight = dev_factor * 6. / (self.inputSize + self.outputSize) * torch.randn([self.inputSize, self.outputSize]) 141 | self.weight = nn.Parameter(weight) 142 | 143 | 144 | def negenergy(self, 145 | input: torch.Tensor, 146 | output: torch.Tensor, 147 | factor: Union[torch.Tensor, float] = 1., 148 | **kwargs 149 | ) -> torch.Tensor: 150 | negenergy = factor * (output.reshape(-1, self.outputSize) * (input.reshape(-1, self.inputSize) @ self.weight)).sum(1) 151 | if self.batch_norm: 152 | negenergy = negenergy.sum()/ output.shape[0] 153 | return negenergy 154 | 155 | def gradInput(self, 156 | output: torch.Tensor, 157 | input: Union[torch.Tensor, None] = None, 158 | factor: Union[torch.Tensor, float] = 1., 159 | **kwargs 160 | ) -> torch.Tensor: 161 | x = factor*output.reshape(-1, self.outputSize) @ self.weight.t() 162 | return x.reshape(*tuple([-1] + list(self.inputShape))) 163 | 164 | def gradOutput(self, 165 | input: torch.Tensor, 166 | factor: Union[torch.Tensor, float] = 1., 167 | **kwargs 168 | ) -> torch.Tensor: 169 | x = factor * input.reshape(-1, self.inputSize) @ self.weight 170 | return x.reshape(*tuple([-1] + list(self.outputShape))) 171 | 172 | def backward(self, 173 | input: torch.Tensor, 174 | output: torch.Tensor, 175 | factor: Union[torch.Tensor, float] = 1., 176 | **kwargs 177 | ) -> torch.Tensor: 178 | gw = (factor*input.reshape(-1, self.inputSize)).t() @ output.reshape(-1, self.outputSize) / input.shape[0] 179 | #print("GW:", gw) 180 | if self.weight.grad is None: 181 | self.weight.grad = gw 182 | else: 183 | self.weight.grad += gw 184 | 185 | def zero_grad(self) -> None: 186 | self.weight.grad = None 187 | 188 | def plot_weight(self, mesh_size=None): 189 | if mesh_size == None: 190 | mesh_size = int(np.sqrt(self.inputSize)) 191 | for i in range(10): 192 | plt.subplot(2, 5, i + 1) 193 | plt.tight_layout() 194 | plt.imshow(self.weight[:,i].reshape([mesh_size, mesh_size]).detach().cpu(), cmap='gray', interpolation='none') 195 | plt.xticks([]) 196 | plt.yticks([]) 197 | plt.show() 198 | 199 | class InteractionModule(Interaction): 200 | ''' 201 | A class taking a torch module as interaction between two layers. Note that only 202 | for linear interactions (Linear + Conv) layers this makes sense. For nonlinear 203 | models only gradients are used, leading to potentially wrong results. 204 | ''' 205 | 206 | def __init__(self, module: torch.nn.Module, inputShape=None): 207 | ''' 208 | :param module: The torch Module to be used. Usually only linear ones make sense 209 | :param inputShape: The shape of the input UnitLayer, needed for gradInput 210 | ''' 211 | super().__init__() 212 | self.module = module 213 | if inputShape is not None: 214 | self.lastInputShape = inputShape 215 | else: 216 | self.lastInputShape = torch.Size([1,1]) 217 | 218 | def enableModuleGrad(self, enable: bool=True) -> None: 219 | ''' 220 | Enables/Disables the internal gradient calculation inside the module. 221 | :param enable: [bool] If internal module gradients are enabled. 222 | ''' 223 | if enable: 224 | for p in self.module.parameters(): 225 | p.requires_grad = True 226 | else: 227 | for p in self.module.parameters(): 228 | p.requires_grad = False 229 | 230 | def negenergy(self, 231 | input: torch.Tensor, 232 | output: torch.Tensor, 233 | factor: Union[torch.Tensor, float] = 1., 234 | **kwargs 235 | ) -> torch.Tensor: 236 | self.lastInputShape = input.shape 237 | negenergy = (factor * output * self.module.forward(input)).sum(list(range(1, len(output.shape)))) 238 | return negenergy.detach() 239 | 240 | def gradInput(self, 241 | output: torch.Tensor, 242 | input: Union[torch.Tensor, None] = None, 243 | factor: Union[torch.Tensor, float] = 1., 244 | **kwargs 245 | ) -> torch.Tensor: 246 | ''' 247 | In the case of variable layer shapes the input is needed. 248 | ''' 249 | if input is None: 250 | inp = torch.ones([output.shape[0]] + list(self.lastInputShape[1:]), device=output.device, requires_grad=True) 251 | else: 252 | inp = input 253 | output = torch.tensor(output.data, requires_grad=False, device=output.device) 254 | if isinstance(factor, torch.Tensor): 255 | factor = factor.detach() 256 | # input.requires_grad = True 257 | with torch.enable_grad(): 258 | self.enableModuleGrad(False) 259 | outprime = self.module.forward(inp) 260 | #self.enableModuleGrad(False) 261 | outprime.backward(output * factor) 262 | del output 263 | return inp.grad.detach() 264 | 265 | def gradOutput(self, 266 | input: torch.Tensor, 267 | factor: Union[torch.Tensor, float] = 1., 268 | **kwargs 269 | ) -> torch.Tensor: 270 | # print(self.module.forward(input).shape) 271 | self.lastInputShape = input.shape 272 | self.enableModuleGrad(False) 273 | #print("MGO:", self.module.forward(input).shape, input.shape) 274 | return factor*self.module.forward(input).detach() 275 | 276 | def backward(self, 277 | input: torch.Tensor, 278 | output: torch.Tensor, 279 | factor: Union[torch.Tensor, float] = 1., 280 | **kwargs 281 | ) -> torch.Tensor: 282 | with torch.enable_grad(): 283 | self.enableModuleGrad(True) 284 | outprime = self.module.forward(input) 285 | outprime.backward(output * factor) 286 | 287 | class InteractionPoolMapIn1D(InteractionModule): 288 | ''' 289 | A class for a Mapping of a tensor to a tensor of different shape 290 | [ ... , N ] -> [ ... , N / m , m + 1] 291 | where only the first m elements are filled in the last dim 292 | ''' 293 | def __init__(self, poolsize: int): 294 | ''' 295 | :param poolsize: Pooling size in 1D 296 | ''' 297 | module = torch.nn.Sequential( 298 | Expansion1D(poolsize), 299 | Projection((0, ), (poolsize, ), (poolsize,), (0,), (poolsize, ), (poolsize+1,)) 300 | ) 301 | super().__init__(module) 302 | 303 | def backward(self, 304 | input: torch.Tensor, 305 | output: torch.Tensor, 306 | factor: Union[torch.Tensor, float] = 1., 307 | **kwargs 308 | ) -> torch.Tensor: 309 | ''' 310 | No internal parameters here. 311 | ''' 312 | pass 313 | 314 | class InteractionPoolMapOut1D(InteractionModule): 315 | ''' 316 | A class for a Mapping between tensors 317 | [ ..., N, m+1] -> [ ..., N] 318 | where only the last 1 element is taken 319 | ''' 320 | def __init__(self, poolsize: int): 321 | ''' 322 | :param poolsize: Pooling size in 1D 323 | ''' 324 | module = torch.nn.Sequential( 325 | Projection((poolsize,), (poolsize+1,), (poolsize+1,), (0,), (1,), (1,)), 326 | Truncation1D(poolsize), 327 | InvertGrayscale() 328 | ) 329 | 330 | super().__init__(module) 331 | 332 | def backward(self, 333 | input: torch.Tensor, 334 | output: torch.Tensor, 335 | factor: Union[torch.Tensor, float] = 1., 336 | **kwargs 337 | ) -> torch.Tensor: 338 | ''' 339 | No internal parameters here. 340 | ''' 341 | pass 342 | 343 | 344 | class InteractionReversed(Interaction): 345 | ''' 346 | A class for reverting an interaction (exchanging input and output). 347 | ''' 348 | def __init__(self, interaction): 349 | ''' 350 | :param interaction: The Interaction to be reversed. 351 | ''' 352 | super().__init__() 353 | self.interaction = interaction 354 | 355 | def gradInput(self, 356 | output: torch.Tensor, 357 | input: Union[torch.Tensor, None] = None, 358 | factor: Union[torch.Tensor, float] = 1., 359 | **kwargs 360 | ) -> torch.Tensor: 361 | return self.interaction.gradOutput(output, factor=factor) 362 | 363 | def gradOutput(self, 364 | input: torch.Tensor, 365 | factor: Union[torch.Tensor, float] = 1., 366 | **kwargs 367 | ) -> torch.Tensor: 368 | return self.interaction.gradInput(input, factor=factor) 369 | 370 | def backward(self, 371 | input: torch.Tensor, 372 | output: torch.Tensor, 373 | factor: Union[torch.Tensor, float] = 1., 374 | **kwargs 375 | ) -> torch.Tensor: 376 | return self.interaction.backward(output, input, factor=factor) 377 | 378 | def negenergy(self, 379 | input: torch.Tensor, 380 | output: torch.Tensor, 381 | factor: Union[torch.Tensor, float] = 1., 382 | **kwargs 383 | ) -> torch.Tensor: 384 | return self.interaction.negenergy(output, input, factor, **kwargs) 385 | 386 | 387 | class InteractionPoolMapIn2D(InteractionModule): 388 | ''' 389 | Interaction merging pooling input to a common tensor of 390 | pooling input and output. 391 | ''' 392 | def __init__(self, poolsize1: int, poolsize2: int): 393 | module = torch.nn.Sequential( 394 | Expansion2D(poolsize1, poolsize2), 395 | Truncation1D(poolsize2), 396 | Projection((0,), (poolsize1*poolsize2,), (poolsize1*poolsize2,), 397 | (0,), (poolsize1*poolsize2,), (poolsize1*poolsize2+1,)), 398 | ) 399 | super().__init__(module) 400 | 401 | def backward(self, 402 | input: torch.Tensor, 403 | output: torch.Tensor, 404 | factor: Union[torch.Tensor, float] = 1., 405 | **kwargs 406 | ) -> torch.Tensor: 407 | ''' 408 | No internal parameters here. 409 | ''' 410 | pass 411 | 412 | 413 | class InteractionPoolMapOut2D(InteractionModule): 414 | ''' 415 | Interaction mapping a ProbMaxPool layer back to an image. 416 | ''' 417 | def __init__(self, poolsize1: int, poolsize2: int): 418 | p1p2 = poolsize1 * poolsize2 419 | module = torch.nn.Sequential( 420 | Projection((p1p2,), (p1p2+1,), (p1p2+1,), (0,), (1,), (1,)), 421 | Truncation1D(1), 422 | InvertGrayscale() 423 | ) 424 | super().__init__(module) 425 | 426 | def backward(self, 427 | input: torch.Tensor, 428 | output: torch.Tensor, 429 | factor: Union[torch.Tensor, float] = 1., 430 | **kwargs 431 | ) -> torch.Tensor: 432 | ''' 433 | No internal parameters here. 434 | ''' 435 | pass 436 | 437 | class InteractionSequential(Interaction): 438 | ''' 439 | Combines Interactions sequentially with no random UnitLayers 440 | (i.e. DiracDeltaLayers) in between. 441 | ''' 442 | def __init__(self, *interactions): 443 | ''' 444 | :param interaction: List of Interactions to be concatenated. 445 | ''' 446 | super().__init__() 447 | self.interactions = ListModule(*interactions) 448 | 449 | def negenergy(self, 450 | input: torch.Tensor, 451 | output: torch.Tensor, 452 | factor: Union[torch.Tensor, float] = 1., 453 | **kwargs 454 | ) -> torch.Tensor: 455 | inputs = [input] 456 | outputs = [output] 457 | for inter in self.interactions: 458 | inputs.append(inter.gradOutput(inputs[-1], factor=factor, **kwargs)) 459 | for inter in reversed(self.interactions): 460 | outputs = [inter.gradInput(outputs[0])] + outputs 461 | # forward backward style algorithm 462 | negen = 0. 463 | for n, inter in enumerate(self.interactions): 464 | negen = negen + inter.negenergy(inputs[n], outputs[n+1], factor=factor, **kwargs) 465 | return negen 466 | 467 | def gradInput(self, 468 | output: torch.Tensor, 469 | input: Union[torch.Tensor, None] = None, 470 | factor: Union[torch.Tensor, float] = 1., 471 | **kwargs 472 | ) -> torch.Tensor: 473 | if input: 474 | inputs = [input] 475 | for inter in self.interactions: 476 | inputs.append(inter.gradOutput(inputs[-1], factor=1., **kwargs)) 477 | else: 478 | inputs = [None]*len(self.interactions) 479 | for n, inter in enumerate(reversed(self.interactions)): 480 | inp = inputs[len(self.interactions)-n-1] 481 | if inp is not None: 482 | inp2 = torch.tensor(inp.data, requires_grad=True, device=inp.device) 483 | else: 484 | inp2 = None 485 | output = inter.gradInput(output, input=inp2, factor=1., **kwargs) 486 | if inp2 is not None: 487 | del inp2 488 | return factor*output 489 | 490 | def gradOutput(self, 491 | input: torch.Tensor, 492 | factor: Union[torch.Tensor, float] = 1., 493 | **kwargs 494 | ) -> torch.Tensor: 495 | for inter in self.interactions: 496 | input = inter.gradOutput(input, factor=1., **kwargs) 497 | return factor*input 498 | 499 | def backward(self, 500 | input: torch.Tensor, 501 | output: torch.Tensor, 502 | factor: Union[torch.Tensor, float] = 1., 503 | **kwargs 504 | ) -> torch.Tensor: 505 | inputs = [input] 506 | outputs = [output] 507 | for inter in self.interactions: 508 | inputs.append(inter.gradOutput(inputs[-1], factor = 1., **kwargs)) 509 | for n, inter in enumerate(reversed(self.interactions)): 510 | device = inputs[len(self.interactions)-n-1].device 511 | inp = torch.tensor(inputs[len(self.interactions)-n-1].data, requires_grad=True, device=device) 512 | outputs = [inter.gradInput(outputs[0], input=inp, factor=1.)] + outputs 513 | del inp 514 | # forward backward style algorithm 515 | for n, inter in enumerate(self.interactions): 516 | inter.backward(inputs[n], outputs[n+1], factor=factor) 517 | -------------------------------------------------------------------------------- /pytorch_probgraph/rbm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from itertools import chain 4 | import matplotlib.pyplot as plt 5 | import torch.nn as nn 6 | from .unitlayer import UnitLayer 7 | from .interaction import Interaction 8 | from typing import Iterable, Union, Optional, List 9 | from tqdm import tqdm 10 | 11 | class RestrictedBoltzmannMachine(nn.Module): 12 | ''' 13 | A two layer undirected hierarchical probabilistic model 14 | 15 | Sources: 16 | 17 | [1] R. Salakhutdinov "LEARNING DEEP GENERATIVE MODELS", 18 | https://tspace.library.utoronto.ca/bitstream/1807/19226/3/Salakhutdinov_Ruslan_R_200910_PhD_thesis.pdf 19 | 20 | [2] G. Hinton "A Practical Guide to Training Restricted Boltzmann Machines", 2010 21 | https://www.cs.toronto.edu/~hinton/absps/guideTR.pdf 22 | ''' 23 | 24 | def __init__(self, 25 | visible: UnitLayer, 26 | hidden: UnitLayer, 27 | interaction: Interaction): 28 | ''' 29 | :param visible: [UnitLayer], visible layer 30 | :param hidden: [UnitLayer], hidden layer 31 | :param interaction: [Interaction], the interaction unit 32 | ''' 33 | 34 | super(RestrictedBoltzmannMachine, self).__init__() 35 | self.add_module('visible', visible) 36 | self.add_module('hidden', hidden) 37 | self.add_module('interaction', interaction) 38 | 39 | def get_visible_layer(self): 40 | ''' 41 | :return: visible layer 42 | ''' 43 | return self.visible 44 | 45 | def get_hidden_layer(self): 46 | ''' 47 | :return: hidden layer 48 | ''' 49 | return self.hidden 50 | 51 | def get_interaction(self): 52 | ''' 53 | :return: interaction 54 | ''' 55 | return self.interaction 56 | 57 | def logprob_joint(self, 58 | visible: torch.Tensor, 59 | hidden: torch.Tensor 60 | ) -> torch.Tensor: 61 | ''' 62 | Computes the unnormalized log probability of the RBM state 63 | 64 | :param visible: visible unit activations 65 | :param hidden: hidden unit activations 66 | :return: unnormalized log probability 67 | ''' 68 | 69 | # transform activations according the the layer (can be the identity for linear layer) 70 | intvisible = self.visible.transform(visible) 71 | inthidden = self.hidden.transform(hidden) 72 | 73 | # get the joint energy from the interaction term 74 | joint_energy = self.interaction.negenergy(intvisible, inthidden) 75 | # get the joint energy from the layer term (unnormalized log probability contribution of the layers) 76 | lpv = self.visible.logprob_joint(visible) 77 | lph = self.hidden.logprob_joint(hidden) 78 | return joint_energy + lpv + lph 79 | 80 | def free_energy(self, 81 | visible: torch.Tensor 82 | ) -> torch.Tensor: 83 | ''' 84 | The free energy is computed as: 85 | $$ F = - ln Z = - ln ( \sum_{h} \exp^{ - E(h, v) } ) $$ 86 | (see, Hinton: A Practical Guide to Training Restricted Boltzmann Machines) 87 | It can be used to track the learning process and overfitting. 88 | Note that here an averaged free energy is used. 89 | 90 | :param visible: visible units activation 91 | :return: unnormalized log probability 92 | ''' 93 | free_en = -self.visible.logprob_joint(visible) 94 | hidden_int = self.interaction.gradOutput(self.visible.transform(visible)) 95 | free_en += self.hidden.free_energy(hidden_int) 96 | return free_en 97 | 98 | def mean_visible(self, 99 | hidden: torch.Tensor, 100 | interaction_factor: Union[torch.Tensor, float]=1. 101 | ) -> torch.Tensor: 102 | ''' 103 | Get the mean of the visible conditional distribution 104 | If the factor is set to 0, the RBM follows a factorized distribution. 105 | 106 | :param hidden: hidden activation 107 | :param interaction_factor: a factor to weight the interaction energy. 108 | :return: mean of the visible activation 109 | ''' 110 | hidden_int = self.hidden.transform(hidden) 111 | visible_int = interaction_factor*self.interaction.gradInput(hidden_int) 112 | return self.visible.mean_cond(visible_int) 113 | 114 | def mean_hidden(self, 115 | visible: torch.Tensor, 116 | interaction_factor: Union[torch.Tensor, float]=1. 117 | ) -> torch.Tensor: 118 | ''' 119 | Get the mean of the hidden conditional distribution 120 | If the factor is set to 0, the RBM follows a factorizing distribution. 121 | 122 | :param visible: visible activation 123 | :param interaction_factor: a factor to weight the interaction energy. 124 | :return: mean of hidden activation 125 | ''' 126 | visible_int = self.visible.transform(visible) 127 | hidden_int = interaction_factor*self.interaction.gradOutput(visible_int) 128 | return self.hidden.mean_cond(hidden_int) 129 | 130 | def sample_visible(self, 131 | hidden: Union[torch.Tensor, None]=None, 132 | N: int=1, 133 | interaction_factor: Union[torch.Tensor, float]=1. 134 | ) -> torch.Tensor: 135 | ''' 136 | Sample from the visible conditional distribution. 137 | If the factor is set to 0, the RBM follows a factorizing distribution. 138 | 139 | :param hidden: hidden activation 140 | :param N: batch size 141 | :param interaction_factor: a factor to weight the interaction energy. 142 | :return: visible activation according to the probability distribution 143 | ''' 144 | 145 | # if hidden is none, the visible state is randomly sampled 146 | if hidden is None: 147 | return self.visible.sample_cond(N=N) 148 | else: 149 | hidden_int = self.hidden.transform(hidden) 150 | interaction_hidden = interaction_factor*self.interaction.gradInput(hidden_int) 151 | return self.visible.sample_cond(interaction_hidden) 152 | 153 | def sample_hidden(self, 154 | visible: Union[torch.Tensor, None]=None, 155 | N: int=1, 156 | interaction_factor: Union[torch.Tensor, float]=1. 157 | ) -> torch.Tensor: 158 | ''' 159 | Sample from the hidden conditional distribution 160 | If the factor is set to 0, the RBM follows a factorizing distribution. 161 | 162 | :param visible: visible activation 163 | :param N: batch size 164 | :param interaction_factor: a factor to weight the interaction energy. 165 | :return: hidden activation according to the probability distribution 166 | ''' 167 | if visible is None: 168 | return self.hidden.sample_cond(N=N) 169 | else: 170 | visible_int = self.visible.transform(visible) 171 | interaction_visible = interaction_factor*self.interaction.gradOutput(visible_int) 172 | return self.hidden.sample_cond(interaction_visible) 173 | 174 | def reconstruct(self, 175 | N: int=1, 176 | visible_input: Union[torch.Tensor, None]=None, 177 | gibbs_steps: int=1, 178 | visible_interaction_factor: Union[torch.Tensor, float]=1., 179 | hidden_interaction_factor: Union[torch.Tensor, float]=1., 180 | mean: bool=False, 181 | all_states: bool=False 182 | ) -> torch.Tensor: 183 | ''' 184 | Take N visible samples of an RBM, either preconditioned or 185 | by Gibbs sampling with random start 186 | 187 | :param N: batch size 188 | :param visible_input: visible activation 189 | :param gibbs_steps: number of Gibbs steps 190 | :param visible_interaction_factor: a factor to weight the visible layer energy. 191 | :param hidden_interaction_factor: a factor to weight the hidden layer energy. 192 | :param mean: if true, returns the mean activation of the visible units, \ 193 | else returns a sample 194 | :param all_states: if true, returns a sample for both the hidden and the \ 195 | visible layer, else returns a sample of the visible layer 196 | :return: sample according to the RBM probability distribution 197 | ''' 198 | if visible_input is not None: 199 | hidden_mean = self.mean_hidden(visible_input, interaction_factor=hidden_interaction_factor) 200 | hidden_sample = self.sample_hidden(visible_input, interaction_factor=hidden_interaction_factor) 201 | else: 202 | hidden_sample = self.sample_hidden(N=N) 203 | for _ in range(gibbs_steps): 204 | visible_mean = self.mean_visible(hidden_sample, interaction_factor=visible_interaction_factor) 205 | visible_sample = self.sample_visible(hidden_sample, interaction_factor=visible_interaction_factor) 206 | hidden_mean = self.mean_hidden(visible_sample, interaction_factor=hidden_interaction_factor) 207 | hidden_sample = self.sample_hidden(visible_sample, interaction_factor=hidden_interaction_factor) 208 | if all_states: 209 | if mean: 210 | return visible_mean, hidden_mean 211 | else: 212 | return visible_sample, hidden_sample 213 | else: 214 | if mean: 215 | return visible_mean 216 | else: 217 | return visible_sample 218 | 219 | def train(self, 220 | data: Iterable[torch.Tensor], 221 | epochs: int, 222 | optimizer: torch.optim.Optimizer, 223 | scheduler = None, 224 | visible_interaction_factor: Union[torch.Tensor, float]=1., 225 | hidden_interaction_factor: Union[torch.Tensor, float]=1., 226 | device: Union[torch.device, None]=None): 227 | ''' 228 | Training method for the RBM 229 | 230 | :param data: training data 231 | :param epochs: number of epochs 232 | :param optimizer: torch optimizer 233 | :param scheduler: torch scheduler 234 | :param visible_interaction_factor: factor for the visible activation units 235 | :param hidden_interaction_factor: factor for the hidden activation units 236 | :return: None 237 | ''' 238 | for epoch in range(epochs): 239 | for j, bdat in enumerate(data): 240 | self.zero_grad() 241 | self.step(bdat.to(device), 242 | visible_interaction_factor=visible_interaction_factor, 243 | hidden_interaction_factor=hidden_interaction_factor) 244 | optimizer.step() 245 | if scheduler is not None: 246 | scheduler.step() 247 | if isinstance(data, tqdm): 248 | data = tqdm(data) 249 | data.set_description('Epoch {}'.format(epoch)) 250 | 251 | def plot_reconstruction(self, 252 | X_rec : torch.Tensor, 253 | X_train : torch.Tensor): 254 | ''' 255 | Plots the reconstruction 256 | 257 | :param X_rec: reconstructed image 258 | :param X_train: original image 259 | :return: None 260 | ''' 261 | 262 | mesh_size = int(np.sqrt(X_rec.size()[1])) 263 | plt.subplot(2, 1, 1) 264 | plt.imshow(X_rec.detach().cpu().numpy().reshape([mesh_size, mesh_size]), cmap='gray', interpolation='none') 265 | plt.subplot(2, 1, 2) 266 | plt.imshow(X_train[0].detach().cpu().numpy().reshape([mesh_size, mesh_size]), cmap='gray', interpolation='none') 267 | plt.xticks([]) 268 | plt.yticks([]) 269 | plt.show() 270 | 271 | def plot_reconstruction_list(self, 272 | pictures: List[torch.Tensor]): 273 | ''' 274 | Plots a list of reconstructed images 275 | 276 | :param pictures: list of images 277 | :return: None 278 | ''' 279 | 280 | mesh_size = int(np.sqrt(pictures[0].size()[1])) 281 | fig, axs = plt.subplots(4, int(len(pictures)/4) + 1) 282 | for i in range(len(pictures)): 283 | j = i % 4 284 | k = int(i/4) 285 | axs[j, k].imshow(pictures[i].detach().cpu().numpy().reshape([mesh_size, mesh_size]), cmap='gray', interpolation='none') 286 | plt.show() 287 | 288 | class RestrictedBoltzmannMachineCD_Smooth(RestrictedBoltzmannMachine): 289 | ''' 290 | This class handles the training of the RBM with CD. Also the last sampled 291 | hidden states for the negative and positive phases are mean activations. 292 | ''' 293 | def __init__(self, 294 | visible: UnitLayer, 295 | hidden: UnitLayer, 296 | interaction: Interaction, 297 | ksteps: int=1): 298 | ''' 299 | :param visible: visible UnitLayer 300 | :param hidden: hidden UnitLayer 301 | :param interaction: Interaction 302 | :param ksteps: number of Gibbs steps 303 | ''' 304 | super().__init__(visible, hidden, interaction) 305 | self.ksteps = ksteps 306 | 307 | def step(self, 308 | data : torch.Tensor, 309 | visible_interaction_factor: Union[torch.Tensor, float]=1., 310 | hidden_interaction_factor: Union[torch.Tensor, float]=1. 311 | ) -> None: 312 | ''' 313 | One update step for a batch 314 | 315 | :param data: batch data 316 | :param visible_interaction_factor: factor for the visible activation units 317 | :param hidden_interaction_factor: factor for the hidden activation units 318 | :return: 319 | ''' 320 | visible_sample = data.clone() 321 | hidden_mean = self.mean_hidden(visible_sample, interaction_factor=hidden_interaction_factor) 322 | hidden_sample = self.sample_hidden(visible_sample, interaction_factor=visible_interaction_factor) 323 | hidden_mean_positive = hidden_mean.clone() 324 | # Do gibbs sampling for ksteps 325 | for _ in range(self.ksteps): 326 | visible_mean = self.mean_visible(hidden_sample, interaction_factor=visible_interaction_factor) 327 | visible_sample = self.sample_visible(hidden_sample, interaction_factor=visible_interaction_factor) 328 | hidden_mean = self.mean_hidden(visible_sample, interaction_factor=hidden_interaction_factor) 329 | hidden_sample = self.sample_hidden(visible_sample, interaction_factor=hidden_interaction_factor) 330 | 331 | # pos phase, biases first, then interaction weights 332 | # we are using negative gradients here because of gradient descent 333 | # replaced by ascent 334 | self.visible.backward(data, factor=-1) 335 | # the mean activation for the hidden units are used 336 | self.hidden.backward(hidden_mean_positive, factor=-1) 337 | self.interaction.backward(self.visible.transform(data), 338 | self.hidden.transform(hidden_mean_positive), 339 | factor=-1) 340 | 341 | # neg phase, biases first, then interaction weights 342 | self.visible.backward(visible_mean, factor=1) 343 | self.hidden.backward(hidden_mean, factor=1) 344 | self.interaction.backward(self.visible.transform(visible_mean), 345 | self.visible.transform(hidden_mean), 346 | factor=1) 347 | 348 | class RestrictedBoltzmannMachineCD(RestrictedBoltzmannMachine): 349 | ''' 350 | This class handles the training of the RBM with CD. Also the last sampled hidden states for the negative and 351 | positive phases are mean activations. 352 | ''' 353 | def __init__(self, 354 | visible: UnitLayer, 355 | hidden: UnitLayer, 356 | interaction: Interaction, 357 | ksteps: int=1): 358 | ''' 359 | :param visible: visible layer 360 | :param hidden: hidden layer 361 | :param interaction: interaction 362 | :param ksteps: number of Gibbs steps 363 | ''' 364 | super().__init__(visible, hidden, interaction) 365 | self.ksteps = ksteps 366 | 367 | def step(self, 368 | data: torch.Tensor, 369 | visible_interaction_factor: Union[torch.Tensor, float]=1., 370 | hidden_interaction_factor: Union[torch.Tensor, float]=1. 371 | ) -> None: 372 | ''' 373 | One update step for a batch 374 | 375 | :param data: batch data 376 | :param visible_interaction_factor: factor for the visible activation units 377 | :param hidden_interaction_factor: factor for the hidden activation units 378 | :return: 379 | ''' 380 | visible_sample = data 381 | hidden_mean = self.mean_hidden(visible_sample, interaction_factor=hidden_interaction_factor) 382 | hidden_sample = self.sample_hidden(visible_sample, interaction_factor=hidden_interaction_factor) 383 | hidden_sample_positive = hidden_sample 384 | 385 | # Do gibbs sampling for ksteps 386 | for _ in range(self.ksteps): 387 | visible_mean = self.mean_visible(hidden_sample, interaction_factor=visible_interaction_factor) 388 | visible_sample = self.sample_visible(hidden_sample, interaction_factor=visible_interaction_factor) 389 | 390 | hidden_mean = self.mean_hidden(visible_sample, interaction_factor=hidden_interaction_factor) 391 | hidden_sample = self.sample_hidden(visible_sample, interaction_factor=hidden_interaction_factor) 392 | 393 | # pos phase, biases first, then interaction weights 394 | # we are using negative gradients here because of gradient descent 395 | # replaced by ascent 396 | trans_dat = self.visible.transform(data) 397 | trans_dat.require_grad = True 398 | trans_hidden = self.hidden.transform(hidden_sample_positive) 399 | trans_hidden.require_grad = True 400 | 401 | self.visible.backward(data, factor=-1) 402 | self.hidden.backward(hidden_sample_positive, factor=-1) 403 | self.interaction.backward(trans_dat, 404 | trans_hidden, 405 | factor=-1) 406 | 407 | # neg phase, biases first, then interaction weights 408 | self.visible.backward(visible_sample, factor=1) 409 | self.hidden.backward(hidden_sample, factor=1) 410 | self.interaction.backward(self.visible.transform(visible_sample), 411 | self.hidden.transform(hidden_sample), 412 | factor=1) 413 | 414 | class RestrictedBoltzmannMachinePCD(RestrictedBoltzmannMachine): 415 | ''' 416 | This class handles the training of the RBM with PCD. 417 | ''' 418 | def __init__(self, 419 | visible: UnitLayer, 420 | hidden: UnitLayer, 421 | interaction: Interaction, 422 | fantasy_particles: int=1): 423 | super().__init__(visible, hidden, interaction) 424 | ''' 425 | :param visible: visible layer 426 | :param hidden: hidden layer 427 | :param interaction: interaction 428 | :param ksteps: number of Gibbs steps 429 | ''' 430 | self.register_parameter("visible_fantasy", torch.nn.Parameter(self.sample_visible(N=fantasy_particles), requires_grad = False)) 431 | 432 | def step(self, 433 | data: torch.Tensor, 434 | visible_interaction_factor: Union[torch.Tensor, float]=1., 435 | hidden_interaction_factor: Union[torch.Tensor, float]=1. 436 | ) -> None: 437 | ''' 438 | One update step for a batch 439 | 440 | :param data: batch data 441 | :param visible_interaction_factor: factor for the visible activation units 442 | :param hidden_interaction_factor: factor for the hidden activation units 443 | :return: 444 | ''' 445 | with torch.no_grad(): 446 | # Sample hidden from visible data 447 | visible_sample = data 448 | hidden_mean = self.mean_hidden(visible_sample, interaction_factor=hidden_interaction_factor) 449 | # hidden_sample = self.rbm.sample_hidden(hidden_mean) 450 | hidden_mean_positive = hidden_mean 451 | # Sample hidden from visible "fantasy particles" 452 | visible_sample_neg = self.visible_fantasy 453 | hidden_mean_neg = self.mean_hidden(self.visible_fantasy, interaction_factor=hidden_interaction_factor) 454 | hidden_sample_neg = self.sample_hidden(self.visible_fantasy, interaction_factor=hidden_interaction_factor) 455 | visible_mean_neg = self.mean_visible(hidden_sample_neg, interaction_factor=visible_interaction_factor) 456 | self.visible_fantasy = torch.nn.Parameter(self.sample_visible(hidden_sample_neg, interaction_factor=visible_interaction_factor), requires_grad = False) 457 | #self.visible_fantasy = self.visible_fantasy.detach() # don't propagate autograd through everything 458 | 459 | # pos phase, biases first, then interaction weights 460 | # we are using negative gradients here because of gradient descent 461 | # replaced by ascent 462 | self.visible.backward(data, factor=-1) 463 | self.hidden.backward(hidden_mean_positive, factor=-1) 464 | self.interaction.backward(self.visible.transform(data), 465 | self.hidden.transform(hidden_mean_positive), 466 | factor=-1) 467 | 468 | # negative phase using fantasy particles 469 | self.visible.backward(visible_sample_neg, factor=1) 470 | self.hidden.backward(hidden_mean_neg, factor=1) 471 | self.interaction.backward(self.visible.transform(visible_sample_neg), 472 | self.hidden.transform(hidden_mean_neg), 473 | factor=1) 474 | -------------------------------------------------------------------------------- /pytorch_probgraph/dbm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from typing import List, Union, Optional, Iterable 4 | from tqdm import tqdm 5 | from .rbm import RestrictedBoltzmannMachine 6 | 7 | class DeepBoltzmannMachine(nn.Module): 8 | ''' 9 | | A deep undirected hierarchical probabilistic model 10 | | From: 11 | 12 | [1] R. Salakhutdinov "LEARNING DEEP GENERATIVE MODELS", 13 | https://tspace.library.utoronto.ca/bitstream/1807/19226/3/Salakhutdinov_Ruslan_R_200910_PhD_thesis.pdf 14 | ''' 15 | 16 | def __init__(self, 17 | rbms: List[RestrictedBoltzmannMachine], 18 | optimizer: torch.optim.Optimizer, 19 | scheduler): 20 | ''' 21 | Constructs the DBM 22 | 23 | :param rbms: list of the RBMs, which will be stacked to form the DBM. 24 | :param optimizer: torch optimizer 25 | :param scheduler: torch scheduler 26 | ''' 27 | super().__init__() 28 | 29 | # build a ModuleList containing the RBMs 30 | self.rbms = nn.ModuleList() 31 | for rbm in rbms: 32 | self.rbms.append(rbm) 33 | 34 | # build a list containing all layers 35 | self.layers = [r.get_visible_layer() for r in self.rbms] + [self.rbms[-1].get_hidden_layer()] 36 | 37 | # build a list containing all interactions 38 | self.interactions = [r.get_interaction() for r in self.rbms] 39 | 40 | self.optimizer = optimizer 41 | self.scheduler = scheduler 42 | self.nlayers = len(self.layers) 43 | 44 | def zero_grad(self) -> None: 45 | for r in self.rbms: 46 | r.zero_grad() 47 | 48 | def backward(self, 49 | layer_data: List[torch.Tensor], 50 | factor: Union[torch.Tensor, float]=1. 51 | ) -> None: 52 | ''' 53 | Given all the layers' data calculate the gradient wrt latent variables 54 | 55 | :param layer_data: list of unit activations of the layers in one batch 56 | ''' 57 | 58 | for i in range(self.nlayers - 1): 59 | self.layers[i].backward(layer_data[i], factor=factor) 60 | self.interactions[i].backward(layer_data[i], layer_data[i + 1], factor=factor) 61 | self.layers[-1].backward(layer_data[-1], factor=factor) 62 | 63 | def conditional_sample_general_meanfield(self, 64 | layer_data: Optional[ 65 | List[ 66 | Optional[ 67 | torch.Tensor] 68 | ] 69 | ] = None, 70 | N: int=1, 71 | invtemperature: Union[torch.Tensor, float]=1., 72 | iterations: int=10, 73 | sample_state: List[bool]=None, 74 | mean: bool=False 75 | ) -> List[torch.Tensor]: 76 | ''' 77 | Produces a general sample using an ELBO with a mean field approx. 78 | If the factor is set to 0, the DBM follows a factorizing distribution. 79 | 80 | :param layer_data: list of unit activations of the layers in one batch 81 | :param N: batch size 82 | :param invtemperature: a factor to weight the interaction factor. 83 | :param iterations: number of gibbs steps following mean field approximation 84 | :param sample_state: list of bools to indicate, if the layer should be sampled or kept during gibbs steps 85 | :param mean: if false, the returned state is sampled according to the activation probability. \ 86 | If true, the returned state is the activation probability. 87 | ''' 88 | 89 | # if a state is given in layer data, the batch size can be implied 90 | nbatch = N 91 | if layer_data is not None: 92 | for s in layer_data: 93 | if s is not None: 94 | nbatch = s.shape[0] 95 | 96 | 97 | # build the starting state 98 | # for every layer, for which no starting state is defined, sample a new state 99 | state = [None] * self.nlayers 100 | if not layer_data == None: 101 | for i, layer in enumerate(layer_data): 102 | if layer is not None: 103 | state[i] = layer.clone() 104 | else: 105 | state[i] = self.layers[i].sample_cond(N=nbatch) 106 | else: 107 | for i in range(self.nlayers): 108 | state[i] = self.layers[i].sample_cond(N=nbatch) 109 | 110 | 111 | # if no explicit starting sample states are given, sample all states, which had no staring state 112 | if sample_state == None: 113 | sample_state = [True] * self.nlayers 114 | 115 | # perform the iterative mean field approximation 116 | for i in range(iterations): 117 | # update odd layers 118 | for j in range(1, self.nlayers, 2): 119 | if sample_state[j]: 120 | intterm = self.interactions[j - 1].gradOutput(self.layers[j - 1].transform(state[j - 1])) 121 | intterm *= invtemperature 122 | if j + 1 < self.nlayers: 123 | intterm += invtemperature * self.interactions[j].gradInput(self.layers[j+1].transform(state[j + 1])) 124 | if mean: 125 | state[j] = self.layers[j].mean_cond(interaction=intterm) 126 | else: 127 | state[j] = self.layers[j].sample_cond(interaction=intterm) 128 | 129 | # update even layers 130 | for j in range(0, self.nlayers, 2): 131 | if sample_state[j]: 132 | intterm = 0. 133 | if j > 0: 134 | intterm += invtemperature * self.interactions[j - 1].gradOutput(self.layers[j-1].transform(state[j - 1])) 135 | if j + 1 < self.nlayers: 136 | intterm += invtemperature * self.interactions[j].gradInput(self.layers[j+1].transform(state[j + 1])) 137 | if mean: 138 | state[j] = self.layers[j].mean_cond(interaction=intterm) 139 | else: 140 | state[j] = self.layers[j].sample_cond(interaction=intterm) 141 | return state 142 | 143 | def joint_sample(self, 144 | N: int=1, 145 | iterations: int=10, 146 | invtemperature: Union[torch.Tensor, float]=1., 147 | mean: bool=False 148 | ) -> List[torch.Tensor]: 149 | ''' 150 | Gets a joint sample of all layers with batch size N after some Gibbs iterations. 151 | If the factor is set to 0, the DBM follows a factorizing distribution. 152 | 153 | :param N: [int], batch size 154 | :param iterations: [int], number of gibbs steps 155 | :param invtemperature: [float], a factor to weight the interaction factor. 156 | :param mean: [bool], if false, the returned state is sampled according \ 157 | to the activation probability. If true, the returned state is the \ 158 | activation probability. This can be interpreted as taking the mean \ 159 | over infinitely many samples. 160 | ''' 161 | return self.conditional_sample_general_meanfield(N=N, 162 | iterations=iterations, 163 | invtemperature=invtemperature, 164 | mean=mean) 165 | 166 | def conditional_sample(self, 167 | data: torch.Tensor, 168 | iterations: int=10, 169 | invtemperature: Union[torch.Tensor, float]=1. 170 | ) -> List[torch.Tensor]: 171 | ''' 172 | Gets a sample of all layers conditioned on given visible data. 173 | If the invtemperature is set to 0, the DBM follows a factorizing distribution. 174 | 175 | :param data: [torch.tensor(batch_size, image size)], the images in one batch 176 | :param iterations: [int], number of meanfield approximation steps 177 | :param invtemperature: [float], a factor to weight the interaction factor. 178 | :return: [List[torch.Tensor]] conditional sample of all layers 179 | ''' 180 | 181 | # initialize a state of the DBM, where the visible units are according to the data 182 | # and the hidden units are not initialized 183 | layer_data = [data] + len(self.rbms) * [None] 184 | 185 | # allow the DBM to sample all states except the visible state 186 | sample_state = [True] * self.nlayers 187 | sample_state[0] = False 188 | 189 | # pass the prepared data to the helper function 190 | return self.conditional_sample_general_meanfield(layer_data=layer_data, 191 | sample_state=sample_state, 192 | iterations=iterations, 193 | invtemperature=invtemperature) 194 | 195 | def ais(self, 196 | steps: int, 197 | M: int, 198 | data: Optional[torch.Tensor]=None, 199 | log_Z: Optional[torch.Tensor]=None 200 | ) -> torch.Tensor: 201 | ''' 202 | Following [1] chapter 4.2.1f. 203 | AIS has two operation modi. It can either compute the partitioning sum \ 204 | of the DBM, if no data is provided. It can also calculate the \ 205 | probability of a visible state, if data and the partitioning sum \ 206 | are provided. 207 | 208 | :param steps: number of intermediate probability distributions 209 | :param M: number of samples to take the mean 210 | :param data: the visible state of the DBM 211 | :param log_Z: the partitioning sum of the DBM 212 | :return: Partition sum / Loglikelihood of data 213 | ''' 214 | 215 | with torch.no_grad(): 216 | 217 | #stepsize is the change per iteration of the invtemperature 218 | step_size = 1. / steps 219 | # beta_k is the current invtemperature 220 | beta_k = 0 221 | 222 | # if data is provided, don't sample the visible state 223 | sample_state = [True] * self.nlayers 224 | if data is not None: 225 | sample_state[0] = False 226 | 227 | # if no data is provided, initialize with a joint sample of all states, else keep the visible state. 228 | if data is None: 229 | state = self.joint_sample(invtemperature=0, iterations=1, N = M) 230 | else: 231 | state = self.conditional_sample(data, invtemperature=0, iterations=1) 232 | 233 | log_p_k = None 234 | # iterativly increase the invtemperature, to change the trivial distribution to the actual. 235 | for step in range(steps - 1): 236 | 237 | # calculate the unnormalized probability according to the last and the current invtemperature 238 | log_1 = self.log_free_energy_joint(state, 1) 239 | log_2 = self.log_free_energy_joint(state, 0) 240 | log_p_last = beta_k * log_1 + (1 - beta_k) * log_2 241 | 242 | beta_k += step_size 243 | 244 | 245 | log_1 = self.log_free_energy_joint(state, 1) 246 | log_2 = self.log_free_energy_joint(state, 0) 247 | 248 | log_p_curr = beta_k * log_1 + (1 - beta_k) * log_2 249 | 250 | log_p = log_p_curr - log_p_last 251 | 252 | # sample according to the new invtemperature 253 | state = self.conditional_sample_general_meanfield(state, invtemperature=beta_k, iterations=1, 254 | sample_state=sample_state) 255 | 256 | # add the differences of p_k together 257 | if log_p_k is None: 258 | log_p_k = log_p 259 | else: 260 | log_p_k += log_p 261 | 262 | # for numerical stability, subtract the mean before taking the exponent 263 | normalize = log_p_k.mean() 264 | exponent = log_p_k - normalize 265 | w_ais = torch.exp(exponent) 266 | 267 | # ignore samples with underflow/overflow 268 | w_ais[torch.isinf(w_ais) == 1] = w_ais[torch.isinf(w_ais) != 1].mean() 269 | 270 | # if the partitioning sum has to be computed, the mean over the batch can be taken 271 | if data is None: 272 | log_r_ais = torch.log(w_ais.mean()) + normalize 273 | else: 274 | log_r_ais = torch.log(w_ais) + normalize 275 | 276 | # for the partitioning sum: 277 | if data is None: 278 | num_states = 0 279 | 280 | # calculate the trivial partitioning sum: 281 | for i in range(0, len(self.layers)): 282 | num_states += torch.numel(self.layers[i].bias) 283 | log_Z_A = num_states * torch.log(torch.tensor(2).float()) 284 | 285 | # add the accumulated fractions 286 | log_Z = log_r_ais + log_Z_A 287 | return log_Z 288 | 289 | # for the log probability 290 | else: 291 | num_states = 0 292 | 293 | # calculate the trivial partitioning sum except for the lowest layer. 294 | # This is like summing out the not visible units for the trivial distribution. 295 | for i in range(1, len(self.layers)): 296 | num_states += torch.numel(self.layers[i].bias) 297 | log_Z_A = num_states * torch.log(torch.tensor(2).float()) 298 | 299 | # add the accumulated fraction to the trivial distribution and subtract the partitioning sum 300 | log_p = log_r_ais + log_Z_A - log_Z 301 | log_p[torch.isinf(log_p) == 1] = log_p[torch.isinf(log_p) != 1].mean() 302 | return log_p 303 | 304 | def generate(self, 305 | N: int=32, 306 | gibbs_steps: int=1, 307 | mean: bool=True 308 | ) -> torch.Tensor: 309 | ''' 310 | Compatibility function for the generalized evaluataion method. Returns the visible layer of a joint sample. 311 | 312 | :param N: number of samples to generate 313 | :param gibbs_steps: number of gibbs steps 314 | :param mean: if false, the returned state is sampled according to the \ 315 | activation probability. If true, the returned state is the \ 316 | activation probability. 317 | :return: generated samples 318 | ''' 319 | return self.joint_sample(N=N, mean=mean, iterations=gibbs_steps)[0].reshape([-1,28,28]) 320 | 321 | def loglikelihood(self, 322 | data: torch.Tensor, 323 | log_Z: Optional[torch.Tensor]=None 324 | ) -> torch.Tensor: 325 | ''' 326 | Compute the log-likelihood of a batch of data. In case log_Z 327 | (estimator of log of partition sum) is known, do not recompute it. 328 | 329 | :param data: data batch 330 | :param log_Z: known log partition sum estimate 331 | ''' 332 | if log_Z is None: 333 | log_Z = self.ais(100,100) 334 | log_p = self.ais(100, -1, data = data, log_Z = log_Z) 335 | return log_p 336 | 337 | def log_free_energy_joint(self, 338 | state: List[torch.Tensor], 339 | invtemperature: Union[torch.Tensor, float]=1. 340 | ) -> torch.Tensor: 341 | ''' 342 | Return an unnormalized logarithmic probability of a hidden+visible 343 | variable state. 344 | If the invtemperature is set to 0, the DBM follows a factorizing distribution. 345 | 346 | :param state: The model state to calculate the free energy for 347 | :param invtemperature: a factor to weight the interaction factor. 348 | ''' 349 | 350 | # the log energy of the DBM can be computed from the energies of the layers and interactions 351 | # according to [1] chapter 5.3. 352 | log_en = 0. 353 | for i in range(self.nlayers): 354 | log_en += self.layers[i].logprob_joint(state[i]) 355 | for i in range(self.nlayers - 1): 356 | int_state_in = self.layers[i].transform(state[i]) 357 | int_state_out = self.layers[i+1].transform(state[i+1]) 358 | log_en += invtemperature * self.interactions[i].negenergy(int_state_in, int_state_out) 359 | return log_en 360 | 361 | def greedy_pretraining(self, 362 | data: Iterable[torch.Tensor], 363 | epochs: int 364 | ) -> None: 365 | ''' 366 | Implementation of the greedy pretraining algorithm from [1] chapter 5, algorithm 6. 367 | Nonte, that we do not use pretraining for our evaluation script, as it has a tendency to overfit, 368 | thus the quality of the results depends heavily on the initialisation. 369 | Also note, that this implementation can only be used, if every second layer has the same dimension. 370 | 371 | :param data: tqdm object, training data 372 | :param epochs: [int] number of training epochs per RBM 373 | ''' 374 | for i in range(len(self.rbms)): 375 | self.greedy_pretraining_layer(i, data, epochs) 376 | # the output of the lower RBM becomes the input to the higher 377 | newdat = [] 378 | for bdat in data: 379 | newdat.append(self.sample_rbm_layer_hidden(i, bdat)) 380 | data = newdat 381 | 382 | def greedy_pretraining_layer(self, 383 | rbm_num: int, 384 | data: Iterable[torch.Tensor], 385 | epochs: int 386 | ) -> None: 387 | ''' 388 | A helper function to train the individual RBMs. 389 | 390 | :param rbm_num: [int] the number of the layer in the DBM. 391 | :param data: [torch.tensor(batch_size, image size)], the images in one batch 392 | :param epochs: [int] number of training epochs for the RBM 393 | :return: None 394 | ''' 395 | 396 | # according to the position of the RBM, the missing layer below or above must be compensated. 397 | if rbm_num == 0: 398 | visible_interaction_factor = 1. 399 | hidden_interaction_factor = 2. 400 | elif rbm_num == len(self.rbms) - 1: 401 | visible_interaction_factor = 2. 402 | hidden_interaction_factor = 1. 403 | else: 404 | visible_interaction_factor = 1. 405 | hidden_interaction_factor = 1. 406 | 407 | # weights are mirrored for initialisation. 408 | if rbm_num > 0: 409 | self.rbms[rbm_num].get_interaction().weight = self.rbms[rbm_num].get_interaction().weight.T.clone() 410 | self.rbms[rbm_num].train(data, epochs, self.optimizer, 411 | self.scheduler, visible_interaction_factor, 412 | hidden_interaction_factor) 413 | 414 | def sample_rbm_layer_hidden(self, 415 | rbm_num: int, 416 | visible: Optional[torch.Tensor]=None, 417 | N: int=1 418 | ) -> torch.Tensor: 419 | ''' 420 | A helper function for the greedy pretraining. 421 | Sample the hidden state of the RBM at position i. If no visible state is given, a random visible state is used. 422 | :param rbm_num: [int], position of the RBM 423 | :param visible: [torch.tensor(batch_size, visible size)], the visible state of the RBM 424 | :param N: [int], if no visible state is given, N must be set to the batch size 425 | :return: [torch.tensor(batch size, hidden size)], sampled hidden state 426 | ''' 427 | 428 | if visible is not None: 429 | return self.rbms[rbm_num].sample_hidden(visible=visible) 430 | else: 431 | return self.rbms[rbm_num].sample_hidden(N=N) 432 | 433 | class DeepBoltzmannMachineLS(DeepBoltzmannMachine): 434 | ''' 435 | A class handling the training strategy for the DBM 436 | ''' 437 | 438 | def __init__(self, 439 | rbms: List[RestrictedBoltzmannMachine], 440 | optimizer: torch.optim.Optimizer, 441 | scheduler, 442 | ksteps: int=1, 443 | learning: str = 'CD', 444 | nFantasy: int=100): 445 | ''' 446 | Builds the DBM and initializes the learning strategy 447 | :param rbms: a list of RBMs 448 | :param optimizer: torch optimizer 449 | :param scheduler: torch scheduler 450 | :param ksteps: number of gibbs steps 451 | :param learning: the learning strategy 452 | :param nFantasy: the number of fantasy particles 453 | ''' 454 | super().__init__(rbms, optimizer, scheduler) 455 | self.fantasy_state = None 456 | self.ksteps = ksteps 457 | self.ls = learning 458 | self.CD = learning 459 | self.nFantasy = nFantasy 460 | 461 | def train_model(self, 462 | data: Iterable[torch.Tensor], 463 | epochs: int=5, 464 | device: Optional[torch.device]=None 465 | ) -> None: 466 | ''' 467 | Implements the training module to train the DBM 468 | :param data: [tqdm object] training data 469 | :param epochs: number of epochs to train 470 | :return: None 471 | ''' 472 | 473 | for epoch in range(epochs): 474 | for i, bdat in enumerate(data): 475 | if i % 10000 == 0 and isinstance(data, tqdm): 476 | self.train_batch(bdat.to(device), iterations=self.ksteps, verbose=True) 477 | else: 478 | self.train_batch(bdat.to(device), iterations=self.ksteps) 479 | if isinstance(data, tqdm): 480 | data = tqdm(data) 481 | #if isinstance(data, tqdm): 482 | # data = tqdm(data) 483 | 484 | def train_batch(self, 485 | data: torch.Tensor, 486 | iterations: int=10, 487 | verbose: bool=False): 488 | ''' 489 | A helper function that takes single batches in trains the model 490 | :param data: training data batch 491 | :param iterations: [int], number of gibbs steps 492 | :param verbose: [bool], if true, training progress is printed. 493 | :return: None 494 | ''' 495 | 496 | # conditional sample is for the positive phase 497 | conditional_sample = self.conditional_sample(data, iterations) 498 | with torch.no_grad(): 499 | if self.CD == 'CD': 500 | # for CD, a starting state according to the training data is chosen 501 | starting_state = [None] * self.nlayers 502 | starting_state[0] = data 503 | negative_sample = self.conditional_sample_general_meanfield(layer_data= starting_state, iterations=iterations, mean=True) 504 | elif self.CD == 'PCD': 505 | # for PCD, a starting state according to the last negative example is chosen 506 | sample_state = [True] * self.nlayers 507 | if self.fantasy_state is None: 508 | self.fantasy_state = self.conditional_sample_general_meanfield(iterations=iterations, 509 | sample_state = sample_state, 510 | mean=True) 511 | else: 512 | self.fantasy_state = self.conditional_sample_general_meanfield(layer_data=self.fantasy_state, 513 | iterations=iterations, 514 | mean=True, 515 | sample_state = sample_state) 516 | negative_sample = self.fantasy_state 517 | 518 | self.zero_grad() 519 | # positive step 520 | self.backward(conditional_sample, factor=-1) 521 | # negative step 522 | self.backward(negative_sample, factor=1) 523 | # update weights 524 | self.optimizer.step() 525 | self.scheduler.step() 526 | 527 | # if verbose, print the training progress 528 | if verbose: 529 | #print('learning rate:') 530 | for param_group in self.optimizer.param_groups: 531 | print(param_group['lr']) 532 | conditions = (len(self.rbms) + 1) * [None] 533 | random_state = [] 534 | #for st in conditional_sample: 535 | #random_state.append(torch.randn_like(st)) 536 | #unconditional_sample = self.conditional_sample_general_meanfield(conditions, 537 | # iterations=iterations, 538 | # mean=True) 539 | #self.rbms[0].plot_reconstruction(negative_sample[0][0].reshape([1, -1]), 540 | # unconditional_sample[0][0].reshape([1, -1])) 541 | log_z = self.ais(100,100) 542 | print('log_z') 543 | print(log_z) 544 | log_p = self.ais(100, -1, data = data, log_Z = log_z) 545 | print('log_p: ') 546 | print(log_p.mean()) 547 | --------------------------------------------------------------------------------