├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── apps └── nn_fit.py ├── build.sh ├── docs ├── Makefile ├── requirements.txt └── source │ ├── cheatsheet.txt │ ├── conf.py │ ├── diagram.rst │ ├── ens.rst │ ├── examples.rst │ ├── func.rst │ ├── ind.rst │ ├── index.rst │ ├── intro.rst │ ├── math.rst │ ├── mcmc.rst │ ├── nns.rst │ ├── references.bib │ ├── refs.rst │ ├── rvar.rst │ ├── solvers.rst │ ├── tests.rst_ │ ├── utils.rst │ └── vi.rst ├── examples ├── ex_fit.py ├── ex_fit_2d.py ├── ex_loss.py ├── ex_lreg_mcmc.py └── ex_ufit.py ├── quinn ├── __init__.py ├── ens │ ├── __init__.py │ └── learner.py ├── func │ ├── __init__.py │ └── funcs.py ├── mcmc │ ├── __init__.py │ ├── admcmc.py │ ├── hmc.py │ ├── mala.py │ └── mcmc.py ├── nns │ ├── __init__.py │ ├── losses.py │ ├── mlp.py │ ├── nnbase.py │ ├── nnfit.py │ ├── nns.py │ ├── nnwrap.py │ ├── rnet.py │ └── tchutils.py ├── rvar │ ├── __init__.py │ └── rvs.py ├── solvers │ ├── __init__.py │ ├── nn_ens.py │ ├── nn_laplace.py │ ├── nn_mcmc.py │ ├── nn_rms.py │ ├── nn_swag.py │ ├── nn_vi.py │ └── quinn.py ├── utils │ ├── __init__.py │ ├── maps.py │ ├── plotting.py │ ├── stats.py │ └── xutils.py └── vi │ ├── __init__.py │ └── bnet.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # User-added 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | # formats: 24 | # - pdf 25 | # - epub 26 | 27 | # Optional but recommended, declare the Python requirements required 28 | # to build your documentation 29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 30 | python: 31 | install: 32 | - requirements: docs/requirements.txt 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Sandia National Laboratories 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Quantification of Uncertainties in Neural Networks (QUiNN) is a python library centered around various probabilistic wrappers over PyTorch modules in order to provide uncertainty estimation in Neural Network (NN) predictions. 2 | 3 | # Build the library 4 | ./build.sh 5 | or 6 | ./setup.py build; setup.py install 7 | 8 | # Requirements 9 | numpy, scipy, matplotlib, pytorch 10 | 11 | # Examples 12 | examples/ex_fit.py 13 | examples/ex_fit_2d.py 14 | examples/ex_ufit.py # where method=mcmc, ens or vi. 15 | 16 | # Authors 17 | Khachik Sargsyan 18 | Javier Murgoitio-Esandi 19 | Oscar Diaz-Ibarra 20 | 21 | # Acknowledgements 22 | This work is supported by 23 | - U.S. Department of Energy, Office of Fusion Energy Sciences (OFES) under Field Work Proposal Number 20-023149. 24 | - Laboratory Directed Research and Development (LDRD) program of Sandia National Laboratories. 25 | 26 | -------------------------------------------------------------------------------- /apps/nn_fit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Command line utility for NN fit.""" 3 | import sys 4 | import torch 5 | import argparse 6 | import numpy as np 7 | 8 | 9 | from quinn.solvers.nn_vi import NN_VI 10 | from quinn.solvers.nn_ens import NN_Ens 11 | from quinn.solvers.nn_rms import NN_RMS 12 | from quinn.solvers.nn_mcmc import NN_MCMC 13 | from quinn.solvers.nn_swag import NN_SWAG 14 | from quinn.solvers.nn_laplace import NN_Laplace 15 | 16 | from quinn.utils.plotting import myrc 17 | from quinn.utils.stats import get_domain 18 | from quinn.utils.maps import scale01ToDom 19 | from quinn.utils.xutils import read_textlist 20 | 21 | from quinn.nns.rnet import RNet, Poly 22 | 23 | torch.set_default_dtype(torch.double) 24 | 25 | myrc() 26 | 27 | 28 | usage_str = 'A command line app to build NN surrogate given x-y training dataset.' 29 | parser = argparse.ArgumentParser(description=usage_str) 30 | #parser.add_argument('ind_show', type=int, nargs='*', 31 | # help="indices of requested parameters (count from 0)") 32 | parser.add_argument("-x", "--xdata", dest="xdata", type=str, default='ptrain.txt', 33 | help="Xdata file") 34 | parser.add_argument("-y", "--ydata", dest="ydata", type=str, default='ytrain.txt', 35 | help="Ydata file") 36 | # parser.add_argument("-r", "--xrange", dest="xrange", type=str, default=None, 37 | # help="Xrange file") 38 | parser.add_argument("-q", "--outnames_file", dest="outnames_file", type=str, default='outnames.txt', 39 | help="Output names file") 40 | parser.add_argument("-p", "--pnames_file", dest="pnames_file", type=str, default='pnames.txt', help="Param names file") 41 | parser.add_argument("-m", "--method", dest="method", type=str, default='ens', help="Method") 42 | parser.add_argument("-t", "--trnfactor", dest="trnfactor", type=float, default=0.8, 43 | help="Factor of data used for training") 44 | parser.add_argument("-v", "--valfactor", dest="valfactor", type=float, default=0.1, 45 | help="Factor of data used for validation") 46 | 47 | args = parser.parse_args() 48 | 49 | method = args.method 50 | all_uq_options = ['amcmc', 'hmc', 'vi', 'ens', 'rms', 'laplace', 'swag'] 51 | assert method in all_uq_options, f'Pick among {all_uq_options}' 52 | 53 | trnfactor = args.trnfactor 54 | valfactor = args.valfactor 55 | assert(trnfactor+valfactor<=1.0) 56 | 57 | x = np.loadtxt(args.xdata) 58 | y = np.loadtxt(args.ydata) 59 | 60 | 61 | 62 | if len(x.shape)==1: 63 | x = x[:, np.newaxis] 64 | if len(y.shape)==1: 65 | y = y[:, np.newaxis] 66 | 67 | nsam, ndim = x.shape 68 | nsam_, nout = y.shape 69 | 70 | assert(nsam == nsam_) 71 | 72 | outnames = read_textlist(args.outnames_file, nout, names_prefix='out') 73 | pnames = read_textlist(args.pnames_file, ndim, names_prefix='par') 74 | 75 | ntrn = int(trnfactor * nsam) 76 | nval = int(valfactor * nsam) 77 | ntst = nsam - ntrn - nval 78 | assert(ntst>=0) 79 | print(f"Number of training points : {ntrn}") 80 | print(f"Number of validation points : {nval}") 81 | print(f"Number of testing points : {ntst}") 82 | 83 | rperm = np.random.permutation(nsam) 84 | indtrn = rperm[:ntrn] 85 | indval = rperm[ntrn:ntrn+nval] 86 | indtst = rperm[ntrn+nval:] 87 | 88 | ################################################################################ 89 | ################################################################################ 90 | 91 | 92 | # Plot quantiles or st.dev. 93 | plot_qt = False 94 | 95 | 96 | # Function domain 97 | domain = get_domain(x) 98 | 99 | # Get x data 100 | xsc = scale01ToDom(x, domain) 101 | 102 | 103 | xtrn, ytrn = xsc[indtrn], y[indtrn] 104 | xval, yval = xsc[indval], y[indval] 105 | xtst, ytst = xsc[indtst], y[indtst] 106 | 107 | 108 | # Model to fit 109 | #nnet = TwoLayerNet(1, 4, 1) #Constant() #MLP_simple((ndim, 5, 5, 5, nout)) #Polynomial(4) #Polynomial3() #TwoLayerNet(1, 4, 1) #torch.nn.Linear(1,1, bias=False) 110 | nnet = RNet(3, 3, wp_function=Poly(0), 111 | indim=ndim, outdim=nout, 112 | layer_pre=True, layer_post=True, 113 | biasorno=True, nonlin=True, 114 | mlp=False, final_layer=None) 115 | 116 | # nnet = MLP(ndim, nout, (11,11,11), biasorno=True, 117 | # activ='relu', bnorm=False, bnlearn=True, dropout=0.0) 118 | 119 | 120 | if method == 'amcmc': 121 | nmc = 100 122 | uqnet = NN_MCMC(nnet, verbose=True) 123 | sampler_params = {'gamma': 0.01} 124 | uqnet.fit(xtrn, ytrn, zflag=False, datanoise=0.01, nmcmc=10000, sampler='amcmc', sampler_params=sampler_params) 125 | elif method == 'hmc': 126 | nmc = 100 127 | uqnet = NN_MCMC(nnet, verbose=True) 128 | sampler_params = {'L': 3, 'epsilon': 0.0001} 129 | uqnet.fit(xtrn, ytrn, zflag=False, datanoise=0.01, nmcmc=10000, sampler='hmc', sampler_params=sampler_params) 130 | elif method == 'vi': 131 | nmc = 111 132 | uqnet = NN_VI(nnet, verbose=True) 133 | uqnet.fit(xtrn, ytrn, val=[xval, yval], datanoise=0.01, lrate=0.01, batch_size=None, nsam=1, nepochs=5000) 134 | elif method == 'ens': 135 | nmc = 13 136 | uqnet = NN_Ens(nnet, nens=nmc, dfrac=0.8, verbose=True) 137 | uqnet.fit(xtrn, ytrn, val=[xval, yval], lrate=0.01, batch_size=2, nepochs=1000) 138 | elif method == 'rms': 139 | nmc = 7 140 | uqnet = NN_RMS(nnet, nens=nmc, dfrac=0.8, verbose=True, datanoise=0.1, priorsigma=0.1) 141 | uqnet.fit(xtrn, ytrn, val=[xval, yval], lrate=0.01, batch_size=2, nepochs=1000) 142 | elif method == 'laplace': 143 | nmc = 3 144 | uqnet = NN_Laplace(nnet, nens=nmc, dfrac=1.0, verbose=True, la_type='full') 145 | uqnet.fit(xtrn, ytrn, val=[xval, yval], lrate=0.01, batch_size=2, nepochs=1000) 146 | elif method == 'swag': 147 | nmc = 3 148 | uqnet = NN_SWAG(nnet, nens=nmc, dfrac=1.0, verbose=True, k=10, n_steps=12, c=1, cov_type="lowrank", lr_swag=0.01) 149 | uqnet.fit(xtrn, ytrn, val=[xval, yval], lrate=0.01, batch_size=2, nepochs=1000) 150 | else: 151 | print(f"UQ Method {method} is unknown. Exiting.") 152 | sys.exit() 153 | 154 | 155 | # Prepare lists of inputs and outputs for plotting 156 | xx_list = [xtrn, xval] 157 | yy_list = [ytrn, yval] 158 | ll_list = ['Training', 'Validation'] 159 | if ntst > 0: 160 | xx_list.append(xtst) 161 | yy_list.append(ytst) 162 | ll_list.append('Testing') 163 | 164 | if ndim==1: 165 | uqnet.plot_1d_fits(xx_list, yy_list, nmc=100, labels=ll_list, name_postfix=str(method)) 166 | uqnet.predict_plot(xx_list, yy_list, nmc=100, plot_qt=False, labels=ll_list) 167 | 168 | 169 | 170 | ## This can go inside uqnet and be improved 171 | # for isam in range(nsam): 172 | # f = plt.figure(figsize=(12,4)) 173 | # plt.plot(range(nout), y[isam,:], 'b-', label='Data') 174 | # plt.plot(range(nout), ypred[isam,:], 'g-', label='NN apprx.') 175 | # plt.title(f'Sample #{isam+1}') 176 | # plt.xlabel('x') 177 | # plt.ylabel('y') 178 | # plt.legend() 179 | # plt.tight_layout() 180 | # plt.savefig(f'fit_s{str(isam).zfill(3)}.png') 181 | # plt.close() 182 | 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | python setup.py build 4 | python setup.py install develop --user -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | torch 4 | matplotlib 5 | sphinx_rtd_theme 6 | sphinxcontrib-bibtex 7 | -------------------------------------------------------------------------------- /docs/source/cheatsheet.txt: -------------------------------------------------------------------------------- 1 | .. _cheat-sheet: 2 | 3 | ****************** 4 | Sphinx cheat sheet 5 | ****************** 6 | 7 | Here is a quick and dirty cheat sheet for some common stuff you want 8 | to do in sphinx and ReST. You can see the literal source for this 9 | file at :ref:`cheatsheet-literal`. 10 | 11 | 12 | .. _formatting-text: 13 | 14 | Formatting text 15 | =============== 16 | 17 | You use inline markup to make text *italics*, **bold**, or ``monotype``. 18 | 19 | You can represent code blocks fairly easily:: 20 | 21 | import numpy as np 22 | x = np.random.rand(12) 23 | 24 | Or literally include code: 25 | 26 | .. literalinclude:: ../../functions/func.py 27 | 28 | .. _making-a-list: 29 | 30 | Making a list 31 | ============= 32 | 33 | It is easy to make lists in rest 34 | 35 | Bullet points 36 | ------------- 37 | 38 | This is a subsection making bullet points 39 | 40 | * point A 41 | 42 | * point B 43 | 44 | * point C 45 | 46 | 47 | Enumerated points 48 | ------------------ 49 | 50 | This is a subsection making numbered points 51 | 52 | #. point A 53 | 54 | #. point B 55 | 56 | #. point C 57 | 58 | 59 | .. _making-a-table: 60 | 61 | Making a table 62 | ============== 63 | 64 | This shows you how to make a table -- if you only want to make a list see :ref:`making-a-list`. 65 | 66 | ================== ============ 67 | Name Age 68 | ================== ============ 69 | John D Hunter 40 70 | Cast of Thousands 41 71 | And Still More 42 72 | ================== ============ 73 | 74 | .. _making-links: 75 | 76 | Making links 77 | ============ 78 | 79 | It is easy to make a link to `yahoo `_ or to some 80 | section inside this document (see :ref:`making-a-table`) or another 81 | document. 82 | 83 | You can also reference classes, modules, functions, etc that are 84 | documented using the sphinx `autodoc 85 | `_ facilites. For example, 86 | see the module :mod:`matplotlib.backend_bases` documentation, or the 87 | class :class:`~matplotlib.backend_bases.LocationEvent`, or the method 88 | :meth:`~matplotlib.backend_bases.FigureCanvasBase.mpl_connect`. 89 | 90 | 91 | 92 | .. _cheatsheet-literal: 93 | 94 | This file 95 | ========= 96 | 97 | .. literalinclude:: cheatsheet.rst 98 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('../..')) 18 | #sys.path.insert(0, os.path.abspath('../../examples')) 19 | # sys.path.insert(0, os.path.abspath('../../tests')) 20 | # sys.path.insert(0, os.path.abspath('../../errors')) 21 | 22 | print(sys.path) 23 | # -- Project information ----------------------------------------------------- 24 | 25 | project = 'QUiNN' 26 | copyright = 'TBD' 27 | author = '2022, Khachik Sargsyan and co' 28 | 29 | # The short X.Y version 30 | version = '' 31 | # The full version, including alpha/beta/rc tags 32 | release = '' 33 | 34 | 35 | # -- General configuration --------------------------------------------------- 36 | 37 | # If your documentation needs a minimal Sphinx version, state it here. 38 | # 39 | # needs_sphinx = '1.0' 40 | 41 | 42 | # Add any Sphinx extension module names here, as strings. They can be 43 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 44 | # ones. 45 | extensions = [ 46 | 'sphinx.ext.autodoc', 47 | 'sphinx.ext.doctest', 48 | 'sphinx.ext.todo', 49 | 'sphinx.ext.coverage', 50 | 'sphinx.ext.mathjax', 51 | 'sphinx.ext.ifconfig', 52 | 'sphinx.ext.viewcode', 53 | 'sphinx.ext.graphviz', 54 | 'sphinx.ext.inheritance_diagram', 55 | 'sphinx.ext.napoleon', 56 | 'sphinxcontrib.bibtex' 57 | # 'sphinx.ext.pngmath' 58 | ] 59 | # Add any paths that contain templates here, relative to this directory. 60 | templates_path = ['_templates'] 61 | 62 | # The suffix(es) of source filenames. 63 | # You can specify multiple suffix as a list of string: 64 | # 65 | # source_suffix = ['.rst', '.md'] 66 | source_suffix = '.rst' 67 | 68 | # The master toctree document. 69 | master_doc = 'index' 70 | 71 | # The language for content autogenerated by Sphinx. Refer to documentation 72 | # for a list of supported languages. 73 | # 74 | # This is also used if you do content translation via gettext catalogs. 75 | # Usually you set "language" from the command line for these cases. 76 | language = None 77 | 78 | # List of patterns, relative to source directory, that match files and 79 | # directories to ignore when looking for source files. 80 | # This pattern also affects html_static_path and html_extra_path. 81 | exclude_patterns = [] 82 | 83 | # The name of the Pygments (syntax highlighting) style to use. 84 | pygments_style = None 85 | 86 | 87 | # -- Options for HTML output ------------------------------------------------- 88 | 89 | # The theme to use for HTML and HTML Help pages. See the documentation for 90 | # a list of builtin themes. 91 | # 92 | # html_theme = 'classic' 93 | html_theme = 'sphinx_rtd_theme' 94 | 95 | # Theme options are theme-specific and customize the look and feel of a theme 96 | # further. For a list of options available for each theme, see the 97 | # documentation. 98 | # 99 | # html_theme_options = {} 100 | 101 | # Add any paths that contain custom static files (such as style sheets) here, 102 | # relative to this directory. They are copied after the builtin static files, 103 | # so a file named "default.css" will overwrite the builtin "default.css". 104 | html_static_path = ['_static'] 105 | 106 | # Custom sidebar templates, must be a dictionary that maps document names 107 | # to template names. 108 | # 109 | # The default sidebars (for documents that don't match any pattern) are 110 | # defined by theme itself. Builtin themes are using these templates by 111 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 112 | # 'searchbox.html']``. 113 | # 114 | # html_sidebars = {} 115 | 116 | 117 | # -- Options for HTMLHelp output --------------------------------------------- 118 | 119 | # Output file base name for HTML help builder. 120 | htmlhelp_basename = 'quinndoc' 121 | 122 | 123 | # -- Options for LaTeX output ------------------------------------------------ 124 | 125 | latex_elements = { 126 | # The paper size ('letterpaper' or 'a4paper'). 127 | # 128 | # 'papersize': 'letterpaper', 129 | 130 | # The font size ('10pt', '11pt' or '12pt'). 131 | # 132 | # 'pointsize': '10pt', 133 | 134 | # Additional stuff for the LaTeX preamble. 135 | # 136 | # 'preamble': '', 137 | 138 | # Latex figure (float) alignment 139 | # 140 | # 'figure_align': 'htbp', 141 | } 142 | 143 | # Grouping the document tree into LaTeX files. List of tuples 144 | # (source start file, target name, title, 145 | # author, documentclass [howto, manual, or own class]). 146 | latex_documents = [ 147 | (master_doc, 'QUiNN.tex', 'QUiNN Documentation', 148 | 'Khachik Sargsyan', 'manual'), 149 | ] 150 | 151 | 152 | # -- Options for manual page output ------------------------------------------ 153 | 154 | # One entry per manual page. List of tuples 155 | # (source start file, name, description, authors, manual section). 156 | man_pages = [ 157 | (master_doc, 'QUiNN', 'QUiNN Documentation', 158 | [author], 1) 159 | ] 160 | 161 | 162 | # -- Options for Texinfo output ---------------------------------------------- 163 | 164 | # Grouping the document tree into Texinfo files. List of tuples 165 | # (source start file, target name, title, author, 166 | # dir menu entry, description, category) 167 | texinfo_documents = [ 168 | (master_doc, 'QUiNN', 'QUiNN Documentation', 169 | author, 'QUiNN', 'One line description of project.', 170 | 'Miscellaneous'), 171 | ] 172 | 173 | 174 | # -- Options for Epub output ------------------------------------------------- 175 | 176 | # Bibliographic Dublin Core info. 177 | epub_title = project 178 | 179 | # The unique identifier of the text. This can be a ISBN number 180 | # or the project homepage. 181 | # 182 | # epub_identifier = '' 183 | 184 | # A unique identification for the text. 185 | # 186 | # epub_uid = '' 187 | 188 | # A list of files that should not be packed into the epub file. 189 | epub_exclude_files = ['search.html'] 190 | 191 | 192 | # -- Extension configuration ------------------------------------------------- 193 | 194 | # -- Options for todo extension ---------------------------------------------- 195 | 196 | # If true, `todo` and `todoList` produce output, else they produce nothing. 197 | todo_include_todos = True 198 | 199 | # -- Options for Napoleon extension ------------------------------------------- 200 | 201 | napoleon_google_docstring = True 202 | napoleon_numpy_docstring = True 203 | napoleon_include_init_with_doc = True 204 | napoleon_include_private_with_doc = True 205 | napoleon_include_special_with_doc = True 206 | napoleon_use_admonition_for_examples = False 207 | napoleon_use_admonition_for_notes = False 208 | napoleon_use_admonition_for_references = False 209 | napoleon_use_ivar = False 210 | napoleon_use_param = True 211 | napoleon_use_rtype = True 212 | napoleon_use_keyword = True 213 | napoleon_custom_sections = None 214 | 215 | # see https://sphinx-rtd-theme.readthedocs.io/en/latest/configuring.html 216 | html_theme_options = { 217 | 'canonical_url': '', 218 | 'analytics_id': 'UA-XXXXXXX-1', 219 | 'logo_only': False, 220 | 'display_version': True, 221 | 'prev_next_buttons_location': 'bottom', 222 | 'style_external_links': False, 223 | # Toc options 224 | 'collapse_navigation': False, 225 | 'sticky_navigation': True, 226 | 'navigation_depth': -1, 227 | 'includehidden': True, 228 | 'titles_only': False 229 | } 230 | 231 | inheritance_node_attrs = dict(shape='ellipse', fontsize=14, height=0.75, 232 | color='dodgerblue1', style='filled') 233 | 234 | #inheritance_graph_attrs = dict(rankdir="TB", size='""') 235 | inheritance_graph_attrs = dict(rankdir="LR", size='"6.0, 8.0"', 236 | fontsize=14, color='dodgerblue1', ratio='compress', style='filled') 237 | 238 | autodoc_member_order = 'bysource' 239 | 240 | bibtex_bibfiles = ['references.bib'] 241 | bibtex_default_style = 'unsrt' 242 | bibtex_encoding = 'latin' 243 | 244 | graphviz_output_format = 'svg' 245 | 246 | -------------------------------------------------------------------------------- /docs/source/diagram.rst: -------------------------------------------------------------------------------- 1 | Class inheritance diagram 2 | ========================= 3 | 4 | .. digraph:: foo 5 | 6 | "bar" -> "baz" -> "quux"; 7 | 8 | .. inheritance-diagram:: quinn.solvers.quinn quinn.solvers.nn_ens quinn.solvers.nn_mcmc quinn.solvers.nn_vi quinn.solvers.nn_laplace quinn.solvers.nn_swag quinn.solvers.nn_rms 9 | 10 | .. inheritance-diagram:: quinn.nns.nns quinn.nns.nnbase quinn.nns.mlp quinn.nns.nnfit quinn.nns.tchutils quinn.nns.rnet 11 | 12 | .. inheritance-diagram:: quinn.nns.losses 13 | 14 | .. inheritance-diagram:: quinn.mcmc.mcmc quinn.mcmc.admcmc quinn.mcmc.hmc quinn.mcmc.mala 15 | 16 | .. inheritance-diagram:: quinn.rvar.rvs 17 | 18 | .. inheritance-diagram:: quinn.utils.maps quinn.utils.stats quinn.utils.plotting quinn.utils.xutils 19 | 20 | .. inheritance-diagram:: quinn.nns.nnwrap 21 | 22 | .. inheritance-diagram:: quinn.ens.learner 23 | 24 | .. inheritance-diagram:: quinn.vi.bnet 25 | 26 | .. inheritance-diagram:: quinn.func.funcs -------------------------------------------------------------------------------- /docs/source/ens.rst: -------------------------------------------------------------------------------- 1 | ens 2 | ------------------ 3 | 4 | learner 5 | ======= 6 | .. automodule:: quinn.ens.learner 7 | :members: 8 | :undoc-members: 9 | :no-private-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | ex_fit 2 | ========= 3 | .. automodule:: examples.ex_fit 4 | :members: 5 | :undoc-members: 6 | :no-private-members: 7 | :show-inheritance: 8 | 9 | ex_fit2d 10 | ========= 11 | .. automodule:: examples.ex_fit_2d 12 | :members: 13 | :undoc-members: 14 | :no-private-members: 15 | :show-inheritance: 16 | 17 | ex_ufit 18 | ======== 19 | .. automodule:: examples.ex_ufit 20 | :members: 21 | :undoc-members: 22 | :no-private-members: 23 | :show-inheritance: 24 | -------------------------------------------------------------------------------- /docs/source/func.rst: -------------------------------------------------------------------------------- 1 | func 2 | ------------------ 3 | 4 | func 5 | ==== 6 | .. automodule:: quinn.func.funcs 7 | :members: 8 | :undoc-members: 9 | :no-private-members: 10 | :show-inheritance: 11 | 12 | -------------------------------------------------------------------------------- /docs/source/ind.rst: -------------------------------------------------------------------------------- 1 | Index 2 | ================= 3 | * :ref:`genindex` 4 | * :ref:`modindex` 5 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | QUiNN's documentation 2 | ================================= 3 | 4 | .. include:: intro.rst 5 | 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | :caption: List of modules: 10 | 11 | solvers.rst 12 | nns.rst 13 | mcmc.rst 14 | utils.rst 15 | func.rst 16 | rvar.rst 17 | vi.rst 18 | ens.rst 19 | .. tests.rst 20 | 21 | .. toctree:: 22 | :maxdepth: 4 23 | :caption: Examples: 24 | 25 | .. examples.rst 26 | 27 | 28 | .. toctree:: 29 | :maxdepth: 4 30 | :caption: Theory: 31 | 32 | .. math.rst 33 | 34 | .. toctree:: 35 | :maxdepth: 4 36 | :caption: Misc: 37 | 38 | ind.rst 39 | refs.rst 40 | .. diagram.rst 41 | 42 | 43 | -------------------------------------------------------------------------------- /docs/source/intro.rst: -------------------------------------------------------------------------------- 1 | Hello, QUiNN is a library for quantifying uncertainties in NN predictions. 2 | -------------------------------------------------------------------------------- /docs/source/math.rst: -------------------------------------------------------------------------------- 1 | MCMC 2 | ====== 3 | 4 | .. math:: 5 | W^3 \sum 6 | 7 | VI 8 | ====== 9 | 10 | .. math:: 11 | W^3 \sum 12 | 13 | ENS 14 | ====== 15 | 16 | .. math:: 17 | W^3 \sum -------------------------------------------------------------------------------- /docs/source/mcmc.rst: -------------------------------------------------------------------------------- 1 | mcmc 2 | ------------------ 3 | 4 | mcmc 5 | ==== 6 | .. automodule:: quinn.mcmc.mcmc 7 | :members: 8 | :undoc-members: 9 | :no-private-members: 10 | :show-inheritance: 11 | 12 | admcmc 13 | ====== 14 | .. automodule:: quinn.mcmc.admcmc 15 | :members: 16 | :undoc-members: 17 | :no-private-members: 18 | :show-inheritance: 19 | 20 | hmc 21 | === 22 | .. automodule:: quinn.mcmc.hmc 23 | :members: 24 | :undoc-members: 25 | :no-private-members: 26 | :show-inheritance: 27 | 28 | mala 29 | ==== 30 | .. automodule:: quinn.mcmc.mala 31 | :members: 32 | :undoc-members: 33 | :no-private-members: 34 | :show-inheritance: 35 | -------------------------------------------------------------------------------- /docs/source/nns.rst: -------------------------------------------------------------------------------- 1 | nns 2 | ------------------ 3 | 4 | nns 5 | ========== 6 | .. automodule:: quinn.nns.nns 7 | :members: 8 | :undoc-members: 9 | :no-private-members: 10 | :show-inheritance: 11 | 12 | nnbase 13 | ========== 14 | .. automodule:: quinn.nns.nnbase 15 | :members: 16 | :undoc-members: 17 | :no-private-members: 18 | :show-inheritance: 19 | 20 | nnwrap 21 | ========== 22 | .. automodule:: quinn.nns.nnwrap 23 | :members: 24 | :undoc-members: 25 | :no-private-members: 26 | :show-inheritance: 27 | 28 | mlp 29 | ========== 30 | .. automodule:: quinn.nns.mlp 31 | :members: 32 | :undoc-members: 33 | :no-private-members: 34 | :show-inheritance: 35 | 36 | rnet 37 | ========== 38 | .. automodule:: quinn.nns.rnet 39 | :members: 40 | :undoc-members: 41 | :no-private-members: 42 | :show-inheritance: 43 | 44 | losses 45 | ========== 46 | .. automodule:: quinn.nns.losses 47 | :members: 48 | :undoc-members: 49 | :no-private-members: 50 | :show-inheritance: 51 | 52 | tchutils 53 | ========== 54 | .. automodule:: quinn.nns.tchutils 55 | :members: 56 | :undoc-members: 57 | :no-private-members: 58 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/references.bib: -------------------------------------------------------------------------------- 1 | @article{blundell:2015, 2 | author = {Blundell, Charles and Cornebise, Julien and Kavukcuoglu, Koray and Wierstra, Daan}, 3 | title = {Weight Uncertainty in Neural Networks}, 4 | keywords = {Machine Learning (stat.ML), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, 5 | doi = {10.48550/ARXIV.1505.05424}, 6 | url = {https://arxiv.org/abs/1505.05424}, 7 | publisher = {arXiv}, 8 | year = {2015}, 9 | journal = {}, 10 | note = {http://proceedings.mlr.press/v37/blundell15.pdf}, 11 | copyright = {arXiv.org perpetual, non-exclusive license} 12 | } 13 | 14 | 15 | @inproceedings{pearce:2018, 16 | title={Uncertainty in Neural Networks: Approximately Bayesian Ensembling}, 17 | author={Tim Pearce and Felix Leibfried and Alexandra Brintrup and Mohamed H. Zaki and A. D. Neely}, 18 | booktitle={International Conference on Artificial Intelligence and Statistics}, 19 | year={2018}, 20 | url={https://api.semanticscholar.org/CorpusID:209984372} 21 | } 22 | 23 | @ARTICLE{haario:2001, 24 | author = {H. Haario and E. Saksman and J. Tamminen}, 25 | title = {An adaptive {M}etropolis algorithm}, 26 | journal = {Bernoulli}, 27 | year = {2001}, 28 | volume = {7}, 29 | pages = {223-242}, 30 | doi = {10.2307/3318737} 31 | } 32 | 33 | @book{brooks:2011, 34 | title={Handbook of Markov Chain Monte Carlo}, 35 | ISBN={9780429138508}, 36 | url={http://dx.doi.org/10.1201/b10905}, 37 | DOI={10.1201/b10905}, 38 | publisher={Chapman and Hall/CRC}, 39 | author={Brooks, Steve and Gelman, Andrew and Jones, Galin and Meng, Xiao-Li}, 40 | year={2011}, 41 | month={May} 42 | } 43 | 44 | @article{girolami:2011, 45 | author = {Mark Girolami and Ben Calderhead}, 46 | title = {Riemann manifold Langevin and Hamiltonian Monte Carlo methods}, 47 | journal = {Journal of the Royal Statistical Society: Series B (Statistical Methodology)}, 48 | volume = {73}, 49 | number = {2}, 50 | pages = {123-214}, 51 | year={2011}, 52 | keywords = {Bayesian inference, Geometry in statistics, Hamiltonian Monte Carlo methods, Langevin diffusion, Markov chain Monte Carlo methods, Riemann manifolds}, 53 | doi = {10.1111/j.1467-9868.2010.00765.x}, 54 | url = {https://rss.onlinelibrary.wiley.com/doi/abs/10.1111/j.1467-9868.2010.00765.x} 55 | } -------------------------------------------------------------------------------- /docs/source/refs.rst: -------------------------------------------------------------------------------- 1 | References 2 | ========== 3 | .. bibliography:: references.bib 4 | :cited: -------------------------------------------------------------------------------- /docs/source/rvar.rst: -------------------------------------------------------------------------------- 1 | rvar 2 | ------------------ 3 | 4 | rvs 5 | === 6 | .. automodule:: quinn.rvar.rvs 7 | :members: 8 | :undoc-members: 9 | :no-private-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/solvers.rst: -------------------------------------------------------------------------------- 1 | solvers 2 | ------------------ 3 | 4 | quinn 5 | ===== 6 | .. automodule:: quinn.solvers.quinn 7 | :members: 8 | :undoc-members: 9 | :no-private-members: 10 | :show-inheritance: 11 | 12 | nn_ens 13 | ====== 14 | .. automodule:: quinn.solvers.nn_ens 15 | :members: 16 | :undoc-members: 17 | :no-private-members: 18 | :show-inheritance: 19 | 20 | nn_mcmc 21 | ======= 22 | .. automodule:: quinn.solvers.nn_mcmc 23 | :members: 24 | :undoc-members: 25 | :no-private-members: 26 | :show-inheritance: 27 | 28 | nn_vi 29 | ===== 30 | .. automodule:: quinn.solvers.nn_vi 31 | :members: 32 | :undoc-members: 33 | :no-private-members: 34 | :show-inheritance: 35 | 36 | nn_laplace 37 | ========== 38 | .. automodule:: quinn.solvers.nn_laplace 39 | :members: 40 | :undoc-members: 41 | :no-private-members: 42 | :show-inheritance: 43 | 44 | nn_swag 45 | ======= 46 | .. automodule:: quinn.solvers.nn_swag 47 | :members: 48 | :undoc-members: 49 | :no-private-members: 50 | :show-inheritance: 51 | 52 | nn_rms 53 | ====== 54 | .. automodule:: quinn.solvers.nn_rms 55 | :members: 56 | :undoc-members: 57 | :no-private-members: 58 | :show-inheritance: 59 | -------------------------------------------------------------------------------- /docs/source/tests.rst_: -------------------------------------------------------------------------------- 1 | tests 2 | ------------------ 3 | 4 | test_func 5 | ========= 6 | .. automodule:: tests.test_mi 7 | :members: 8 | :undoc-members: 9 | :no-private-members: 10 | :show-inheritance: 11 | 12 | -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | utils 2 | ------------------ 3 | 4 | maps 5 | ======= 6 | .. automodule:: quinn.utils.maps 7 | :members: 8 | :undoc-members: 9 | :no-private-members: 10 | :show-inheritance: 11 | 12 | 13 | plotting 14 | ======== 15 | .. automodule:: quinn.utils.plotting 16 | :members: 17 | :undoc-members: 18 | :no-private-members: 19 | :show-inheritance: 20 | 21 | stats 22 | ======= 23 | .. automodule:: quinn.utils.stats 24 | :members: 25 | :undoc-members: 26 | :no-private-members: 27 | :show-inheritance: 28 | 29 | xutils 30 | ======= 31 | .. automodule:: quinn.utils.xutils 32 | :members: 33 | :undoc-members: 34 | :no-private-members: 35 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/vi.rst: -------------------------------------------------------------------------------- 1 | vi 2 | ------------------ 3 | 4 | bnet 5 | ======= 6 | .. automodule:: quinn.vi.bnet 7 | :members: 8 | :undoc-members: 9 | :no-private-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /examples/ex_fit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """An example of a 1d function approximation.""" 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from quinn.nns.mlp import MLP 8 | from quinn.utils.plotting import myrc 9 | from quinn.utils.maps import scale01ToDom 10 | from quinn.func.funcs import Sine, Sine10, blundell 11 | 12 | 13 | 14 | def main(): 15 | 16 | torch.set_default_dtype(torch.double) 17 | myrc() 18 | # defaults to cuda:0 if available 19 | device_id='cuda:0' 20 | device = torch.device(device_id if torch.cuda.is_available() else 'cpu') 21 | print("Using device",device) 22 | 23 | 24 | ######################################################################################## 25 | ################################################################################# 26 | 27 | 28 | nall = 22 # total number of points 29 | trn_factor = 0.8 # which fraction of nall goes to training 30 | ntst = 13 # separate test set 31 | ndim = 1 # input dimensionality 32 | datanoise = 0.00 # Noise in the generated data 33 | 34 | # Plot 1d fits or not 35 | plot_1d = (ndim==1) 36 | 37 | 38 | true_model, nout = blundell, ndim 39 | #true_model, nout = Sine10, 10 40 | 41 | # Function domain 42 | domain = np.tile(np.array([-0.5, 0.5]), (ndim, 1)) 43 | 44 | 45 | # Get x data 46 | xall = scale01ToDom(np.random.rand(nall, ndim), domain) 47 | if true_model is not None: 48 | yall = true_model(xall, datanoise=datanoise) 49 | 50 | # Sample test 51 | if ntst > 0: 52 | xtst = scale01ToDom(np.random.rand(ntst, ndim), domain) 53 | if true_model is not None: 54 | ytst = true_model(xtst, datanoise=datanoise) 55 | 56 | # Model to fit 57 | nnet = MLP(ndim, nout, (11,11,11), biasorno=True, 58 | activ='tanh', 59 | bnorm=False, bnlearn=True, 60 | dropout=0.0, 61 | device=device) 62 | 63 | # Data split to training and validation 64 | ntrn = int(trn_factor * nall) 65 | indperm = np.random.permutation(range(nall)) 66 | indtrn = indperm[:ntrn] 67 | indval = indperm[ntrn:] 68 | xtrn, xval = xall[indtrn, :], xall[indval, :] 69 | ytrn, yval = yall[indtrn, :], yall[indval, :] 70 | 71 | nnet.fit(xtrn, ytrn, val=[xval, yval], 72 | lrate=0.01, 73 | batch_size=None, nepochs=2000) 74 | 75 | print("=======================================") 76 | # print("Best Parameters : ") 77 | # uqnet.print_params() 78 | 79 | # Prepare lists of inputs and outputs for plotting 80 | xx_list = [xtrn, xval] 81 | yy_list = [ytrn, yval] 82 | ll_list = ['Training', 'Validation'] 83 | if ntst > 0: 84 | xx_list.append(xtst) 85 | yy_list.append(ytst) 86 | ll_list.append('Testing') 87 | 88 | if plot_1d: 89 | nnet.plot_1d_fits(xx_list, yy_list, labels=ll_list, true_model=true_model) 90 | nnet.predict_plot(xx_list, yy_list, labels=ll_list) 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /examples/ex_fit_2d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """An example of a 2d function approximation, and an example of a periodic boundary regularization.""" 3 | 4 | import torch 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | from quinn.func.funcs import Ackley 10 | from quinn.nns.mlp import MLP 11 | from quinn.nns.tchutils import tch 12 | from quinn.nns.nnwrap import NNWrap 13 | from quinn.nns.losses import PeriodicLoss 14 | 15 | from quinn.utils.plotting import myrc, plot_fcn_2d_slice 16 | from quinn.utils.maps import scale01ToDom 17 | 18 | def main(): 19 | torch.set_default_dtype(torch.double) 20 | myrc() 21 | 22 | # defaults to cuda:0 if available 23 | device_id='cuda:0' 24 | device = torch.device(device_id if torch.cuda.is_available() else 'cpu') 25 | print("Using device",device) 26 | 27 | ########################################################################### 28 | ########################################################################### 29 | 30 | ## Set up 31 | ndim = 2 # input dimensionality 32 | nall = 25 # total number of points 33 | trn_factor = 0.8 # which fraction of nall goes to training 34 | ntst = 13 # separate test set 35 | datanoise = 0.02 # Noise in the generated data 36 | 37 | # True model is Ackley function with one output 38 | true_model, nout = Ackley, 1 39 | 40 | # Function domain 41 | domain = np.tile(np.array([-1.5, 1.5]), (ndim, 1)) 42 | 43 | # Get x data 44 | xall = scale01ToDom(np.random.rand(nall, ndim), domain) 45 | if true_model is not None: 46 | yall = true_model(xall, datanoise=datanoise) 47 | 48 | # Sample test 49 | if ntst > 0: 50 | xtst = scale01ToDom(np.random.rand(ntst, ndim), domain) 51 | if true_model is not None: 52 | ytst = true_model(xtst, datanoise=datanoise) 53 | 54 | # Model to fit 55 | nnet = MLP(ndim, nout, (11,11,11), biasorno=True, 56 | activ='tanh', 57 | bnorm=False, bnlearn=True, dropout=0.0, 58 | device=device) 59 | 60 | # Data split to training and validation 61 | ntrn = int(trn_factor * nall) 62 | indperm = np.random.permutation(range(nall)) 63 | indtrn = indperm[:ntrn] 64 | indval = indperm[ntrn:] 65 | xtrn, xval = xall[indtrn, :], xall[indval, :] 66 | ytrn, yval = yall[indtrn, :], yall[indval, :] 67 | 68 | 69 | # Set up a periodic boundary 70 | ngr = 11 71 | bdry1 = -1.5*np.ones((ngr, ndim)) 72 | bdry1[:,1] = np.linspace(-1.5, 1.5, ngr) 73 | bdry2 = 1.5*np.ones((ngr, ndim)) 74 | bdry2[:,1] = np.linspace(-1.5, 1.5, ngr) 75 | # pass input tensors with device 76 | loss = PeriodicLoss(nnet.nnmodel, 10.1, [tch(bdry1,device=device), tch(bdry2,device=device)]) #None 77 | nnet.fit(xtrn, ytrn, val=[xval, yval], lrate=0.01, batch_size=10, nepochs=1000, loss_xy=loss) 78 | print("=======================================") 79 | 80 | 81 | # Plot the true model and the NN approximation 82 | figs, axarr = plt.subplots(1, 2, figsize=(18, 7)) 83 | plot_fcn_2d_slice(true_model, domain, idim=0, jdim=1, nom=None, ngr=33, ax=axarr[0]) 84 | plot_fcn_2d_slice(NNWrap(nnet), domain, idim=0, jdim=1, nom=None, ngr=33, ax=axarr[1]) 85 | axarr[0].plot(xtrn[:,0], xtrn[:,1], 'bo', ms=11, markeredgecolor='w') 86 | axarr[0].plot(xval[:,0], xval[:,1], 'go', ms=11, markeredgecolor='w') 87 | for ax in axarr: 88 | ax.set_xlabel(r'$x_1$') 89 | ax.set_ylabel(r'$x_2$') 90 | axarr[0].set_title('True Model') 91 | axarr[1].set_title('NN Apprx.') 92 | plt.savefig('fcn2d.png') 93 | 94 | # Prepare lists of inputs and outputs for plotting 95 | xx_list = [xtrn, xval] 96 | yy_list = [ytrn, yval] 97 | ll_list = ['Training', 'Validation'] 98 | if ntst > 0: 99 | xx_list.append(xtst) 100 | yy_list.append(ytst) 101 | ll_list.append('Testing') 102 | # A diagonal plot to check the approximation 103 | nnet.predict_plot(xx_list, yy_list, labels=ll_list) 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /examples/ex_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import argparse 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from quinn.nns.nnwrap import NNWrap, nn_p, SNet 9 | from quinn.nns.tchutils import npy, tch 10 | from quinn.nns.rnet import RNet 11 | 12 | from quinn.utils.stats import get_domain 13 | from quinn.utils.plotting import myrc, plot_1d_anchored, plot_2d_anchored, plot_2d_anchored_single, plot_1d_anchored_single 14 | from quinn.utils.maps import standardize, scaleDomTo01 15 | 16 | 17 | torch.set_default_dtype(torch.double) 18 | 19 | ##################################################################### 20 | ##################################################################### 21 | ##################################################################### 22 | 23 | 24 | def nnloss(p, modelpars): 25 | nnmodel, loss_fn, xtrn, ytrn = modelpars 26 | 27 | npt = p.shape[0] 28 | fval = np.empty((npt,)) 29 | for ip, pp in enumerate(p): 30 | fval[ip] = loss_fn(tch(nn_p(pp, xtrn, nnmodel).reshape(ytrn.shape)), tch(ytrn)).item() 31 | 32 | # print(fval) 33 | 34 | return fval 35 | 36 | ##################################################################### 37 | ##################################################################### 38 | ##################################################################### 39 | 40 | 41 | myrc() 42 | 43 | ##################################################################### 44 | ##################################################################### 45 | ##################################################################### 46 | 47 | usage_str = 'Script to plot 2d slices of loss function.' 48 | parser = argparse.ArgumentParser(description=usage_str) 49 | # parser.add_argument('ind_show', type=int, nargs='*', 50 | # help="indices of requested parameters (count from 0)") 51 | parser.add_argument("-x", "--xdata", dest="xdata", type=str, default='ptrain.txt', 52 | help="Xdata file") 53 | parser.add_argument("-y", "--ydata", dest="ydata", type=str, default='ytrain.txt', 54 | help="Ydata file") 55 | parser.add_argument("-v", "--valfactor", dest="valfactor", type=float, default=0.1, 56 | help="Factor of data used for validation") 57 | 58 | args = parser.parse_args() 59 | 60 | ##################################################################### 61 | ##################################################################### 62 | ##################################################################### 63 | 64 | scale = 10.0 65 | 66 | 67 | ##################################################################### 68 | ##################################################################### 69 | ##################################################################### 70 | 71 | x = np.loadtxt(args.xdata) 72 | domain = get_domain(x) 73 | x = scaleDomTo01(x, domain) 74 | 75 | y = np.loadtxt(args.ydata) 76 | y = standardize(y) 77 | 78 | valfactor = args.valfactor 79 | 80 | if len(x.shape) == 1: 81 | x = x[:, np.newaxis] 82 | if len(y.shape) == 1: 83 | y = y[:, np.newaxis] 84 | 85 | nsam, ndim = x.shape 86 | nsam_, nout = y.shape 87 | 88 | assert(nsam == nsam_) 89 | 90 | 91 | nval = int(valfactor * nsam) 92 | ntrn = nsam - nval 93 | print(f"Number of training points : {ntrn}") 94 | print(f"Number of validation points : {nval}") 95 | 96 | rperm = np.random.permutation(nsam) 97 | indtrn = rperm[:ntrn] 98 | indval = rperm[ntrn:] 99 | 100 | xtrn, ytrn = x[indtrn], y[indtrn] 101 | xval, yval = x[indval], y[indval] 102 | 103 | ##################################################################### 104 | ##################################################################### 105 | ##################################################################### 106 | 107 | 108 | ##################################################################### 109 | ##################################################################### 110 | ##################################################################### 111 | 112 | # hdl = (11,11,11) 113 | # nnet = MLP(ndim, nout, hdl) 114 | 115 | # nnet = torch.nn.Linear(ndim, nout) 116 | 117 | nnet_orig = RNet(13, 3, wp_function=None, 118 | indim=ndim, outdim=nout, 119 | layer_pre=True, layer_post=True, 120 | biasorno=True, nonlin=True, 121 | mlp=True, final_layer=None) 122 | 123 | # nnet = Polynomial(4) 124 | 125 | 126 | loss_fcn = torch.nn.MSELoss(reduction='mean') 127 | 128 | models = [nnloss, nnloss] 129 | modelpars = [[nnet_orig, loss_fcn, xtrn, ytrn], [nnet_orig, loss_fcn, xval, yval]] 130 | 131 | pdim = sum(p.numel() for p in nnet_orig.parameters()) 132 | 133 | ntry = 3 134 | centers = np.empty((ntry, pdim)) 135 | for itry in range(ntry): 136 | nnet = RNet(13, 3, wp_function=None, 137 | indim=ndim, outdim=nout, 138 | layer_pre=True, layer_post=True, 139 | biasorno=True, nonlin=True, 140 | mlp=True, final_layer=None) 141 | 142 | snet = SNet(nnet, ndim, nout) 143 | def loss_xy(x, y): return loss_fcn(snet(x), y) 144 | 145 | snet.fit(xtrn, ytrn, val=[xval, yval], lrate=0.01, 146 | batch_size=None, nepochs=2000, loss_xy=loss_xy) 147 | 148 | nn_best = snet.best_model 149 | centers[itry] = npy(NNWrap(nn_best).p_flatten()).reshape(-1,) 150 | 151 | labels = [f'Training (N={ntrn})', f'Validation (N={nval})'] 152 | print(centers.shape) 153 | print(centers) 154 | 155 | # Prepare lists of inputs and outputs for plotting 156 | xx_list = [xtrn, xval] 157 | yy_list = [ytrn, yval] 158 | ll_list = ['Training', 'Validation'] 159 | 160 | if ndim == 1: 161 | snet.plot_1d_fits(xx_list, yy_list, labels=ll_list) 162 | snet.predict_plot(xx_list, yy_list, labels=ll_list) 163 | 164 | plot_1d_anchored(models, modelpars, centers[0], scale=scale, 165 | ngr=111, modellabels=labels, ncolrow=(8, 5)) 166 | plt.clf() 167 | 168 | plt.figure(figsize=(8, 8)) 169 | plot_1d_anchored_single(models, modelpars, centers[0], 170 | anchor2=centers[1], 171 | scale=scale, ngr=111, modellabels=labels, 172 | figname='fcn_1d_allcenters_12.png') 173 | plt.clf() 174 | 175 | plt.figure(figsize=(8, 8)) 176 | plot_1d_anchored_single(models, modelpars, centers[0], 177 | anchor2=centers[2], 178 | scale=scale, ngr=111, modellabels=labels, 179 | figname='fcn_1d_allcenters_13.png') 180 | plt.clf() 181 | 182 | plt.figure(figsize=(8, 8)) 183 | plot_1d_anchored_single(models, modelpars, centers[1], 184 | anchor2=centers[2], 185 | scale=scale, ngr=111, modellabels=labels, 186 | figname='fcn_1d_allcenters_23.png') 187 | plt.clf() 188 | 189 | plt.figure(figsize=(10, 10)) 190 | plot_2d_anchored_single(models, modelpars, centers[0], 191 | anchor2=centers[1], anchor3=centers[2], 192 | scale=scale, ngr=55, squished=False, modellabels=labels, 193 | figname='fcn_2d_allcenters.png') 194 | plt.clf() 195 | 196 | plot_2d_anchored(models, modelpars, centers[0], anchor2=centers[1], 197 | scale=scale, ngr=55, squished=False, 198 | modellabels=labels, ncolrow=(4, 3)) 199 | 200 | -------------------------------------------------------------------------------- /examples/ex_lreg_mcmc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """An example of linear regression via MCMC.""" 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from quinn.func.funcs import Sine 8 | from quinn.utils.maps import scale01ToDom 9 | from quinn.solvers.nn_mcmc import NN_MCMC 10 | from quinn.utils.plotting import myrc 11 | from quinn.utils.plotting import plot_xrv, plot_yx 12 | from quinn.utils.plotting import plot_tri, plot_pdfs 13 | 14 | 15 | def main(): 16 | """Main function.""" 17 | torch.set_default_dtype(torch.double) 18 | myrc() 19 | 20 | ################################################################################# 21 | ################################################################################# 22 | 23 | # defaults to cuda:0 24 | device='cpu' 25 | print("Using device",device) 26 | 27 | 28 | nall = 12 # total number of points 29 | trn_factor = 0.9 # which fraction of nall goes to training 30 | ntst = 13 # separate test set 31 | ndim = 1 # input dimensionality 32 | datanoise = 0.1 # Noise in the generated data 33 | 34 | # One output example 35 | true_model, nout = Sine, 1 36 | 37 | # Function domain 38 | domain = np.tile(np.array([-2., 2.]), (ndim, 1)) 39 | 40 | # Get x data 41 | xall = scale01ToDom(np.random.rand(nall, ndim), domain) 42 | if true_model is not None: 43 | yall = true_model(xall, datanoise=datanoise) 44 | 45 | # Sample test 46 | if ntst > 0: 47 | np.random.seed(100) 48 | xtst = scale01ToDom(np.random.rand(ntst, ndim), domain) 49 | if true_model is not None: 50 | ytst = true_model(xtst, datanoise=datanoise) 51 | 52 | # Neural net is a linear function 53 | nnet = torch.nn.Linear(1,1, bias=True) 54 | ## Polynomial example 55 | #nnet = Polynomial(4) 56 | 57 | 58 | 59 | # Data split to training and validation 60 | ntrn = int(trn_factor * nall) 61 | indperm = range(nall)# np.random.permutation(range(nall)) 62 | indtrn = indperm[:ntrn] 63 | indval = indperm[ntrn:] 64 | xtrn, xval = xall[indtrn, :], xall[indval, :] 65 | ytrn, yval = yall[indtrn, :], yall[indval, :] 66 | 67 | 68 | nmcmc = 10000 69 | uqnet = NN_MCMC(nnet, verbose=True) 70 | #sampler, sampler_params = 'hmc', {'L': 3, 'epsilon': 0.0025} 71 | sampler, sampler_params = 'amcmc', {'gamma': 0.1} 72 | uqnet.fit(xtrn, ytrn, zflag=False, datanoise=datanoise, nmcmc=nmcmc, sampler=sampler, sampler_params=sampler_params) 73 | 74 | 75 | # Prepare lists of inputs and outputs for plotting 76 | xx_list = [xtrn, xval] 77 | yy_list = [ytrn, yval] 78 | ll_list = ['Training', 'Validation'] 79 | if ntst > 0: 80 | xx_list.append(xtst) 81 | yy_list.append(ytst) 82 | ll_list.append('Testing') 83 | 84 | uqnet.plot_1d_fits(xx_list, yy_list, nmc=100, labels=ll_list, true_model=true_model, name_postfix='mcmc') 85 | uqnet.predict_plot(xx_list, yy_list, nmc=100, plot_qt=False, labels=ll_list) 86 | np.savetxt('chain.txt', uqnet.samples) 87 | plot_xrv(uqnet.samples, prefix='chain') 88 | plot_yx(np.arange(uqnet.samples.shape[0])[:,np.newaxis], 89 | uqnet.samples, 90 | rowcols=(1,1), ylabel='', xlabels='Chain Id', 91 | log=False, filename='chain.png', 92 | xpad=0.3, ypad=0.3, gridshow=False, ms=4, labelsize=18) 93 | plot_tri(uqnet.samples, names=None, msize=3, figname='chain_tri.png') 94 | plot_pdfs(plot_type='tri', pdf_type='kde', 95 | samples_=uqnet.samples, burnin=nmcmc//10, every=10, 96 | names_=None, nominal_=None, prange_=None, 97 | show_2dsamples=True, 98 | lsize=13, zsize=13, xpad=0.3, ypad=0.3) 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /examples/ex_ufit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """An example of running all the available UQ-for-NN methods.""" 3 | 4 | import sys 5 | import torch 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from quinn.solvers.nn_vi import NN_VI 10 | from quinn.solvers.nn_ens import NN_Ens 11 | from quinn.solvers.nn_rms import NN_RMS 12 | from quinn.solvers.nn_mcmc import NN_MCMC 13 | from quinn.solvers.nn_swag import NN_SWAG 14 | from quinn.solvers.nn_laplace import NN_Laplace 15 | 16 | 17 | from quinn.nns.rnet import RNet, Poly 18 | from quinn.utils.plotting import myrc, plot_vars 19 | from quinn.utils.maps import scale01ToDom 20 | from quinn.func.funcs import Sine, Sine10, blundell 21 | 22 | 23 | def main(): 24 | """Main function.""" 25 | torch.set_default_dtype(torch.double) 26 | myrc() 27 | 28 | ################################################################################# 29 | ################################################################################# 30 | meth = sys.argv[1] 31 | all_uq_options = ['amcmc', 'hmc', 'vi', 'ens', 'rms', 'laplace', 'swag'] 32 | assert meth in all_uq_options, f'Pick among {all_uq_options}' 33 | 34 | # defaults to cuda:0 if available 35 | device_id='cuda:0' 36 | device = torch.device(device_id if torch.cuda.is_available() else 'cpu') 37 | print("Using device",device) 38 | 39 | 40 | nall = 15 # total number of points 41 | trn_factor = 0.9 # which fraction of nall goes to training 42 | ntst = 13 # separate test set 43 | ndim = 1 # input dimensionality 44 | datanoise = 0.02 # Noise in the generated data 45 | 46 | # Plot 1d fits or not 47 | plot_1d = (ndim==1) 48 | # Plot quantiles or st.dev. 49 | plot_qt = False 50 | 51 | # One output example 52 | true_model, nout = Sine, 1 53 | # 10 output example 54 | # true_model, nout = Sine10, 10 55 | 56 | # Function domain 57 | domain = np.tile(np.array([-np.pi, np.pi]), (ndim, 1)) 58 | #np.random.seed(111) 59 | 60 | # Get x data 61 | xall = scale01ToDom(np.random.rand(nall, ndim), domain) 62 | if true_model is not None: 63 | yall = true_model(xall, datanoise=datanoise) 64 | 65 | # Sample test 66 | if ntst > 0: 67 | np.random.seed(100) 68 | xtst = scale01ToDom(np.random.rand(ntst, ndim), domain) 69 | if true_model is not None: 70 | ytst = true_model(xtst, datanoise=datanoise) 71 | 72 | # Model to fit 73 | #nnet = TwoLayerNet(1, 4, 1) #Constant() #MLP_simple((ndim, 5, 5, 5, nout)) #Polynomial(4) #Polynomial3() #TwoLayerNet(1, 4, 1) #torch.nn.Linear(1,1, bias=False) 74 | nnet = RNet(3, 3, wp_function=Poly(0), 75 | indim=ndim, outdim=nout, 76 | layer_pre=True, layer_post=True, 77 | biasorno=True, nonlin=True, 78 | mlp=False, final_layer=None, 79 | device=device) 80 | # nnet = Polynomial(4) 81 | 82 | # nnet = MLP(ndim, nout, (11,11,11), biasorno=True, 83 | # activ='relu', bnorm=False, bnlearn=True, dropout=0.0) 84 | 85 | 86 | 87 | # Data split to training and validation 88 | ntrn = int(trn_factor * nall) 89 | indperm = range(nall)# np.random.permutation(range(nall)) 90 | indtrn = indperm[:ntrn] 91 | indval = indperm[ntrn:] 92 | xtrn, xval = xall[indtrn, :], xall[indval, :] 93 | ytrn, yval = yall[indtrn, :], yall[indval, :] 94 | 95 | 96 | 97 | 98 | if meth == 'amcmc': 99 | nmc = 100 100 | uqnet = NN_MCMC(nnet, verbose=True) 101 | sampler_params = {'gamma': 0.01} 102 | uqnet.fit(xtrn, ytrn, zflag=False, datanoise=datanoise, nmcmc=10000, sampler='amcmc', sampler_params=sampler_params) 103 | elif meth == 'hmc': 104 | nmc = 100 105 | uqnet = NN_MCMC(nnet, verbose=True) 106 | sampler_params = {'L': 3, 'epsilon': 0.0025} 107 | uqnet.fit(xtrn, ytrn, zflag=False, datanoise=datanoise, nmcmc=10000, sampler='hmc', sampler_params=sampler_params) 108 | elif meth == 'vi': 109 | nmc = 111 110 | uqnet = NN_VI(nnet, verbose=True) 111 | uqnet.fit(xtrn, ytrn, val=[xval, yval], datanoise=datanoise, lrate=0.01, batch_size=None, nsam=1, nepochs=5000) 112 | elif meth == 'ens': 113 | nmc = 3 114 | uqnet = NN_Ens(nnet, nens=nmc, dfrac=0.8, verbose=True) 115 | uqnet.fit(xtrn, ytrn, val=[xval, yval], lrate=0.01, batch_size=2, nepochs=1000) 116 | elif meth == 'rms': 117 | nmc = 7 118 | uqnet = NN_RMS(nnet, nens=nmc, dfrac=1.0, verbose=True, datanoise=datanoise, priorsigma=0.1) 119 | uqnet.fit(xtrn, ytrn, val=[xval, yval], lrate=0.01, batch_size=2, nepochs=1000) 120 | elif meth == 'laplace': 121 | nmc = 3 122 | uqnet = NN_Laplace(nnet, nens=nmc, dfrac=1.0, verbose=True, la_type='full') 123 | uqnet.fit(xtrn, ytrn, val=[xval, yval], lrate=0.01, batch_size=2, nepochs=1000) 124 | elif meth == 'swag': 125 | nmc = 3 126 | uqnet = NN_SWAG(nnet, nens=nmc, dfrac=1.0, verbose=True, k=10, 127 | n_steps=12, c=1, cov_type="lowrank", lr_swag=0.01) 128 | uqnet.fit(xtrn, ytrn, val=[xval, yval], lrate=0.01, batch_size=2, nepochs=1000) 129 | else: 130 | print(f"UQ Method {meth} is unknown. Exiting.") 131 | sys.exit() 132 | 133 | 134 | # Prepare lists of inputs and outputs for plotting 135 | xx_list = [xtrn, xval] 136 | yy_list = [ytrn, yval] 137 | ll_list = ['Training', 'Validation'] 138 | if ntst > 0: 139 | xx_list.append(xtst) 140 | yy_list.append(ytst) 141 | ll_list.append('Testing') 142 | 143 | if plot_1d: 144 | uqnet.plot_1d_fits(xx_list, yy_list, nmc=100, labels=ll_list, true_model=true_model, name_postfix=str(meth)) 145 | uqnet.predict_plot(xx_list, yy_list, nmc=100, plot_qt=False, labels=ll_list) 146 | 147 | # Another way of plotting, by explicitly using predict function 148 | if plot_1d: 149 | assert(ndim==1) 150 | xgrid = scale01ToDom(np.linspace(0.0, 1.0, 111), domain) 151 | 152 | ygrid_mean, ygrid_var, ygrid_cov = uqnet.predict_mom_sample(xgrid[:, np.newaxis], msc=2, nsam=1000) 153 | nout = ygrid_mean.shape[-1] 154 | for iout in range(nout): 155 | plt.figure(figsize=(12, 9)) 156 | plot_vars(xgrid, ygrid_mean[:, iout], 157 | variances=ygrid_var[:, iout][:, np.newaxis], 158 | varcolors=['gray'], varlabels=['Std. deviation'], 159 | grid_show=True, connected=True, 160 | interp=None, ax=plt.gca()) 161 | plt.plot(xtrn, ytrn, 'bo', zorder=10000, label='Training points') 162 | plt.plot(xval, yval, 'go', zorder=10000, label='Validation points') 163 | if ntst>0: 164 | plt.plot(xtst, ytst, 'ro', zorder=10000, label='Testing points') 165 | plt.xlabel('Input') 166 | plt.ylabel(f'Output #{iout+1}') 167 | plt.legend() 168 | plt.savefig(f'fit_var_o{iout}_{meth}.png') 169 | plt.clf() 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /quinn/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | 4 | -------------------------------------------------------------------------------- /quinn/ens/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from . import learner 4 | -------------------------------------------------------------------------------- /quinn/ens/learner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for a Learner class that is a wrapper with basic training/prediction functionality.""" 3 | 4 | import math 5 | import copy 6 | 7 | from ..nns.tchutils import npy, tch, print_nnparams 8 | from ..nns.nnfit import nnfit 9 | 10 | class Learner(): 11 | """A learner class that holds PyTorch NN module and helps train it. 12 | 13 | Attributes: 14 | nnmodel (torch.nn.Module): Main PyTorch NN module. 15 | best_model (torch.nn.Module): The best trained PyTorch NN module. 16 | trained (bool): Whether the module is trained or not. 17 | verbose (bool): Whether to be verbose or not. 18 | """ 19 | 20 | def __init__(self, nnmodel, verbose=False): 21 | """Initialization. 22 | 23 | Args: 24 | nnmodel (torch.nn.Module): Main PyTorch NN module. 25 | verbose (bool): Whether to be verbose or not. 26 | """ 27 | super().__init__() 28 | self.nnmodel = copy.deepcopy(nnmodel) 29 | self.trained = False 30 | self.verbose = verbose 31 | self.best_model = None 32 | 33 | if self.verbose: 34 | self.print_params(names_only=True) 35 | 36 | def print_params(self, names_only=False): 37 | """Print parameters of the learner's model. 38 | 39 | Args: 40 | names_only (bool, optional): Whether to print the parameter names only or not. 41 | """ 42 | if self.trained: 43 | print_nnparams(self.best_model, names_only=names_only) 44 | else: 45 | print_nnparams(self.nnmodel, names_only=names_only) 46 | 47 | def init_params(self): 48 | """An example of random initialization of parameters. 49 | 50 | .. todo:: we can and should enrich this. 51 | """ 52 | for p in self.nnmodel.parameters(): 53 | try: 54 | stdv = 1. / math.sqrt(p.size(1)) 55 | except IndexError: 56 | stdv = 1. 57 | p.data.uniform_(-stdv, stdv) 58 | 59 | def fit(self, xtrn, ytrn, **kwargs): 60 | """Fitting function for this learner. 61 | 62 | Args: 63 | xtrn (np.ndarray): Input array of size `(N,d)`. 64 | ytrn (np.ndarray): Output array of size `(N,o)`. 65 | **kwargs (dict): Keyword arguments. 66 | """ 67 | if hasattr(self.nnmodel, 'fit') and callable(getattr(self.nnmodel, 'fit')): 68 | self.best_model = self.nnmodel.fit(xtrn, ytrn, **kwargs) 69 | else: 70 | fit_info = nnfit(self.nnmodel, xtrn, ytrn, **kwargs) 71 | self.best_model = fit_info['best_nnmodel'] 72 | 73 | self.trained = True 74 | 75 | def predict(self, x): 76 | """Prediction of the learner. 77 | 78 | Args: 79 | x (np.ndarray): Input array of size `(N,d)`. 80 | 81 | Returns: 82 | np.ndarray: Output array of size `(N,o)`. 83 | """ 84 | assert(self.trained) 85 | 86 | try: 87 | device = self.best_model.device 88 | except AttributeError: 89 | device = 'cpu' 90 | 91 | y = self.best_model(tch(x, rgrad=False, device=device)) 92 | 93 | return npy(y) 94 | -------------------------------------------------------------------------------- /quinn/func/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from . import funcs 4 | -------------------------------------------------------------------------------- /quinn/func/funcs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Various analytical functions for testing the methods.""" 3 | 4 | import numpy as np 5 | 6 | 7 | 8 | def blundell(xx, datanoise=0.0): 9 | r"""Classical example from :cite:t:`blundell:2015`. 10 | 11 | .. math:: 12 | f(x)=x+0.3 \sin(2\pi(x+\sigma\:{\cal N}(0,1)))+0.3 \sin(4\pi(x+\sigma\:{\cal N}(0,1)))+\sigma\:{\cal N}(0,1) 13 | 14 | Args: 15 | xx (np.ndarray): Input array :math:`x` of size `(N,d)`. 16 | datanoise (float, optional): Standard deviation :math:`\sigma` of i.i.d. gaussian noise, both on the input and output. 17 | 18 | Returns: 19 | np.ndarray: Output array of size `(N,d)`. 20 | Note: 21 | This function is typically used in `d=1` setting. 22 | """ 23 | noise = datanoise * np.random.randn(xx.shape[0], xx.shape[1]) 24 | yy = xx + 0.3 * np.sin(2. * np.pi * (xx + noise)) + 0.3 * \ 25 | np.sin(4. * np.pi * (xx + noise)) + noise 26 | return yy 27 | 28 | 29 | def Sine(xx, datanoise=0.0): 30 | r"""Simple sum of sines function 31 | 32 | .. math:: 33 | f(x)=\sin(x_1)+...+\sin(x_d) + \sigma \: {\cal N} (0,1) 34 | 35 | Args: 36 | xx (np.ndarray): Input array :math:`x` of size `(N,d)`. 37 | datanoise (float, optional): Standard deviation :math:`\sigma` of i.i.d. gaussian noise. 38 | 39 | Returns: 40 | np.ndarray: Output array of size `(N,1)`. 41 | """ 42 | yy = datanoise * np.random.randn(xx.shape[0], 1) 43 | yy += np.sum(np.sin(xx), axis=1).reshape(-1, 1) 44 | 45 | return yy 46 | 47 | 48 | def Summation(xx, datanoise=0.0): 49 | r"""Summation function. 50 | 51 | .. math:: 52 | f(x)=x_1 + x_2 + \dots + x_d + \sigma \: {\cal N} (0,1) 53 | 54 | Args: 55 | xx (np.ndarray): Input array :math:`x` of size `(N,d)`. 56 | datanoise (float, optional): Standard deviation :math:`\sigma` of i.i.d. gaussian noise, both on the input and output. 57 | 58 | Returns: 59 | np.ndarray: Output array of size `(N,d)`. 60 | """ 61 | 62 | yy = datanoise * np.random.randn(xx.shape[0], 1) 63 | yy += np.sum(xx, axis=1).reshape(-1, 1) 64 | 65 | return yy 66 | 67 | 68 | def Sine10(xx, datanoise=0.02): 69 | r"""Sum of sines function with 10 outputs 70 | 71 | .. math:: 72 | f_1(x)=\sin(x_1)+...+\sin(x_d) + \sigma \: {\cal N} (0,1)\\ 73 | \dots \qquad\qquad\qquad\qquad\\ 74 | f_{10}(x)=\sin(x_1)+...+\sin(x_d) + \sigma \: {\cal N} (0,1) 75 | 76 | 77 | Args: 78 | xx (np.ndarray): Input array :math:`x` of size `(N,d)`. 79 | datanoise (float, optional): Standard deviation :math:`\sigma` of i.i.d. gaussian noise. 80 | 81 | Returns: 82 | np.ndarray: Output array of size `(N,10)`. 83 | """ 84 | yy = datanoise * np.random.randn(xx.shape[0], 10) 85 | yy += np.sum(np.sin(xx), axis=1).reshape(-1, 1) 86 | 87 | return yy 88 | 89 | 90 | def Ackley(x, datanoise=0.02): 91 | r"""Ackley4 or Modified Ackley function from https://arxiv.org/pdf/1308.4008v1.pdf. 92 | 93 | .. math:: 94 | f(x)=\sum_{i=1}^{d-1} \left(\exp(-0.2)\sqrt{x_i^2+x_{i+1}^2} + 3 (\cos{2x_i}+\sin{2x_{i+1}})\right) + \sigma \: {\cal N} (0,1) 95 | 96 | Args: 97 | xx (np.ndarray): Input array :math:`x` of size `(N,d)`. 98 | datanoise (float, optional): Standard deviation :math:`\sigma` of i.i.d. gaussian noise. 99 | 100 | Returns: 101 | np.ndarray: Output array of size `(N,1)`. 102 | """ 103 | yy = datanoise * np.random.randn(x.shape[0],) 104 | ndim = x.shape[1] 105 | 106 | for i in range(ndim - 1): 107 | yy += np.exp(-0.2) * np.sqrt(x[:, i]**2 + x[:, i + 1]**2) + \ 108 | 3 * (np.cos(2 * x[:, i]) + np.sin(2 * x[:, i + 1])) 109 | return yy.reshape(-1, 1) 110 | 111 | 112 | def x5(xx, datanoise=0.0): 113 | r"""Fifth power function. 114 | 115 | Args: 116 | xx (np.ndarray): Input array :math:`x` of size `(N,d)`. 117 | datanoise (float, optional): Standard deviation :math:`\sigma` of i.i.d. gaussian noise. 118 | """ 119 | yy = datanoise * np.random.randn(xx.shape[0], 1) 120 | yy += xx[:, 0].reshape(-1, 1)**5 121 | 122 | return yy 123 | -------------------------------------------------------------------------------- /quinn/mcmc/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from . import mcmc 4 | from . import admcmc 5 | from . import hmc 6 | from . import mala 7 | -------------------------------------------------------------------------------- /quinn/mcmc/admcmc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | r"""Module for Adaptive MCMC (AMCMC) sampling. For details of the method, see :cite:t:`haario:2001`""" 3 | 4 | import numpy as np 5 | from .mcmc import MCMCBase 6 | 7 | class AMCMC(MCMCBase): 8 | """Adaptive MCMC class. 9 | 10 | Attributes: 11 | cov_ini (np.ndarray): Initial covariance array of size `(p,p)`. 12 | gamma (float): Proposal jump size factor :math:`\gamma`. 13 | _propcov (np.ndarray): A 2d array of size `(p,p)` for proposal covariance. 14 | t0 (int): Step where adaptivity begins. 15 | tadapt (int): Frequency for adapting/updating the covariance. 16 | """ 17 | def __init__(self, cov_ini=None, gamma=0.1, t0=100, tadapt=1000): 18 | """Initialization. 19 | 20 | Args: 21 | cov_ini (np.ndarray, optional): Initial covariance array of size `(p,p)`. Defaults to None which sets the initial covariance as some fraction of the chain state. 22 | gamma (float, optional): Proposal jump size factor :math:`\gamma`. Defaults to None. 23 | t0 (int, optional): Step where adaptivity begins. Defaults to 100. 24 | tadapt (int): Frequency for adapting/updating the covariance. Defaults to 1000. 25 | """ 26 | super().__init__() 27 | 28 | self.cov_ini = cov_ini 29 | self.t0 = t0 30 | self.tadapt = tadapt 31 | self.gamma = gamma 32 | 33 | # Working attributes 34 | self._Xm = None 35 | self._cov = None 36 | self._propcov = None 37 | 38 | def sampler(self, current, imcmc): 39 | """Sampler method. 40 | 41 | Args: 42 | current (np.ndarray): Current chain state. 43 | imcmc (int): Current chain step number. 44 | 45 | Returns: 46 | tuple(current_proposal, current_K, proposed_K): A tuple containing the current proposal sample, and two zeros irrelevant for AMCMC. 47 | """ 48 | current_proposal = current.copy() 49 | cdim = len(current) 50 | 51 | # Compute covariance matrix 52 | if imcmc == 0: 53 | self._Xm = current.copy() 54 | self._cov = np.zeros((cdim, cdim)) 55 | else: 56 | self._Xm = (imcmc * self._Xm + current) / (imcmc + 1.0) 57 | rt = (imcmc - 1.0) / imcmc 58 | st = (imcmc + 1.0) / imcmc**2 59 | self._cov = rt * self._cov + st * np.dot(np.reshape(current - self._Xm, (cdim, 1)), np.reshape(current - self._Xm, (1, cdim))) 60 | 61 | if imcmc == 0: 62 | if self.cov_ini is not None: 63 | self._propcov = self.cov_ini 64 | else: 65 | self._propcov = 0.01 + np.diag(0.09*np.abs(current)) 66 | elif (imcmc > self.t0) and (imcmc % self.tadapt == 0): 67 | self._propcov = (self.gamma * 2.4**2 / cdim) * (self._cov + 10**(-8) * np.eye(cdim)) 68 | 69 | # Generate proposal candidate 70 | current_proposal += np.random.multivariate_normal(np.zeros(cdim,), self._propcov) 71 | proposed_K = 0.0 72 | current_K = 0.0 73 | 74 | return current_proposal, current_K, proposed_K 75 | 76 | 77 | -------------------------------------------------------------------------------- /quinn/mcmc/hmc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | r"""Module for Hamiltonian MCMC (HMC) sampling. For details of the method, see Chapter 5 of :cite:t:`brooks:2011` or https://arxiv.org/pdf/1206.1901.pdf""" 3 | 4 | import numpy as np 5 | from .mcmc import MCMCBase 6 | 7 | 8 | class HMC(MCMCBase): 9 | """Hamiltonian MCMC class. 10 | 11 | Attributes: 12 | epsilon (float): Step size of the method. 13 | L (int): Number of steps in the Hamiltonian integrator. 14 | """ 15 | 16 | def __init__(self, epsilon=0.05, L=3): 17 | """Initialization. 18 | 19 | Args: 20 | epsilon (float, optional): Step size of the method. Defaults to 0.05. 21 | L (int, optional): Number of steps in the Hamiltonian integrator. Defaults to 3. 22 | """ 23 | super().__init__() 24 | self.epsilon = epsilon 25 | self.L = L 26 | 27 | def sampler(self, current, imcmc): 28 | """Sampler method. 29 | 30 | Args: 31 | current (np.ndarray): Current chain state. 32 | imcmc (int): Current chain step number. 33 | 34 | Returns: 35 | tuple(current_proposal, current_K, proposed_K): A tuple containing the current proposal sample, current and proposed kinetic energies. 36 | """ 37 | assert(self.logPostGrad is not None) 38 | 39 | current_proposal = current.copy() 40 | cdim = len(current) 41 | 42 | 43 | p = np.random.randn(cdim) 44 | current_K = np.sum(np.square(p)) / 2 45 | 46 | # Make a half step for momentum at the beginning (Leapfrog Method step starts here) 47 | 48 | p += self.epsilon * self.logPostGrad(current_proposal, **self.postInfo) / 2 49 | 50 | for jj in range(self.L): 51 | # Make a full step for the position 52 | current_proposal += self.epsilon * p 53 | 54 | # Make a full step for the momentum, expecpt at the end of the trajectory 55 | 56 | if jj != self.L - 1: 57 | p += self.epsilon * self.logPostGrad(current_proposal, **self.postInfo) 58 | 59 | # Make a half step for momentum at the end (Leapfrog Method step ends here) 60 | p += self.epsilon* self.logPostGrad(current_proposal, **self.postInfo) / 2 61 | 62 | 63 | # Negate momentum to make proposal symmetric 64 | # This is really not necessary, but we kept it per original paper 65 | p = -p 66 | 67 | # Evaluate kinetic and potential energies 68 | proposed_K = np.sum(np.square(p)) / 2 69 | 70 | return current_proposal, current_K, proposed_K 71 | 72 | -------------------------------------------------------------------------------- /quinn/mcmc/mala.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | r"""Module for Metropolis Adjusted Langevin (MALA) MCMC sampling. For details of the method, see :cite:t:`girolami:2011` or https://arxiv.org/pdf/1206.1901.pdf""" 3 | 4 | import numpy as np 5 | from .mcmc import MCMCBase 6 | 7 | 8 | class MALA(MCMCBase): 9 | """MALA MCMC class. 10 | 11 | Attributes: 12 | epsilon (float): Step size of the method. 13 | """ 14 | 15 | def __init__(self, epsilon=0.05): 16 | """Initialization. 17 | 18 | Args: 19 | epsilon (float, optional): Step size of the method. Defaults to 0.05. 20 | """ 21 | super().__init__() 22 | self.epsilon = epsilon 23 | 24 | def sampler(self, current, imcmc): 25 | """Sampler method. 26 | 27 | Args: 28 | current (np.ndarray): Current chain state. 29 | imcmc (int): Current chain step number. 30 | 31 | Returns: 32 | tuple(current_proposal, current_K, proposed_K): A tuple containing the current proposal sample, current and proposed kinetic energies. 33 | 34 | Note: When the dust settles, MALA is actually exactly HMC with L=1. 35 | """ 36 | assert(self.logPostGrad is not None) 37 | 38 | current_proposal = current.copy() 39 | cdim = len(current) 40 | 41 | 42 | p = np.random.randn(cdim) 43 | 44 | grad_current = self.logPostGrad(current, **self.postInfo) 45 | current_proposal += 0.5*self.epsilon**2 * grad_current + self.epsilon * p 46 | 47 | grad_prop = self.logPostGrad(current_proposal, **self.postInfo) 48 | current_K = np.sum(np.square(p)) / 2 49 | 50 | p += self.epsilon * (grad_current+grad_prop)/ 2 51 | proposed_K = np.sum(np.square(p)) / 2 52 | 53 | return current_proposal, current_K, proposed_K 54 | -------------------------------------------------------------------------------- /quinn/mcmc/mcmc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Base module for MCMC methods.""" 3 | 4 | import sys 5 | import copy 6 | import numpy as np 7 | 8 | 9 | class MCMCBase(object): 10 | """Base class for MCMC. 11 | 12 | Attributes: 13 | logPost (callable): Log-posterior evaluator function. It has a singature of logPost(model_parameters, \**postInfo) returning a float. 14 | logPostGrad (callable): Log-posterior gradient evaluator function. It has a singature of logPostGrad(model_parameters, \**postInfo) returning an np.ndarray of size model_parameters. Defaults to None, for non-gradient based methods. 15 | postInfo (dict): Dictionary that holds auxiliary parameters for posterior evaluations. 16 | """ 17 | 18 | def __init__(self): 19 | """Dummy instantiation.""" 20 | self.logPost = None 21 | self.logPostGrad = None 22 | self.postInfo = {} 23 | 24 | 25 | def setLogPost(self, logPost, logPostGrad, **postInfo): 26 | """Setting logposterior and its gradient functions. 27 | 28 | Args: 29 | logPost (callable): Log-posterior evaluator function. It has a singature of logPost(model_parameters, \**postInfo) returning a float. 30 | logPostGrad (callable): Log-posterior gradient evaluator function. It has a singature of logPostGrad(model_parameters, \**postInfo) returning an np.ndarray of size model_parameters. Can be None, for non-gradient based methods. 31 | postInfo (dict): Dictionary that holds auxiliary parameters for posterior evaluations. 32 | """ 33 | self.logPost = logPost 34 | self.logPostGrad = logPostGrad 35 | self.postInfo = postInfo 36 | 37 | 38 | 39 | def run(self, nmcmc, param_ini): 40 | """Markov chain Monte Carlo running function. 41 | 42 | Args: 43 | nmcmc (int): Number of steps. 44 | param_ini (np.ndarray): Initial state of the chain. 45 | 46 | Returns: 47 | dict: Dictionary of results. Keys are 'chain' (chain samples array), 'mapparams' (MAP parameters array), 'maxpost' (maximal log-post value), 'accrate' (acceptance rate), 'logpost' (log-posterior array), 'alphas' (array of Metropolis-Hastings probability ratios). 48 | """ 49 | assert(self.logPost is not None) 50 | samples = [] # MCMC samples 51 | alphas = [] # Store alphas (posterior ratios) 52 | logposts = [] # Log-posterior values] 53 | na = 0 # counter for accepted steps 54 | 55 | current = param_ini.copy() # first step 56 | current_U = -self.logPost(current, **self.postInfo) # Negative logposterior 57 | cmode = current # MAP value (maximum a posteriori) 58 | pmode = -current_U # record MCMC mode, which is where the current MAP value is achieved 59 | 60 | samples.append(current) 61 | logposts.append(-current_U) 62 | alphas.append(0.0) 63 | 64 | # Loop over MCMC steps 65 | for imcmc in range(nmcmc): 66 | current_proposal, current_K, proposed_K = self.sampler(current, imcmc) 67 | 68 | proposed_U = -self.logPost(current_proposal, **self.postInfo) 69 | proposed_H = proposed_U + proposed_K 70 | current_H = current_U + current_K 71 | 72 | mh_prob = np.exp(current_H - proposed_H) 73 | 74 | # Accept block 75 | if np.random.random_sample() < mh_prob: 76 | na += 1 # Acceptance counter 77 | current = current_proposal + 0.0 78 | current_U = proposed_U + 0.0 79 | if -current_U >= pmode: 80 | pmode = -current_U 81 | cmode = current + 0.0 82 | 83 | samples.append(current) 84 | alphas.append(mh_prob) 85 | logposts.append(-current_U) 86 | 87 | acc_rate = float(na) / (imcmc+1) 88 | 89 | if((imcmc + 2) % (nmcmc / 10) == 0) or imcmc == nmcmc - 2: 90 | print('%d / %d completed, acceptance rate %lg' % (imcmc + 2, nmcmc, acc_rate)) 91 | 92 | results = { 93 | 'chain' : np.array(samples), 94 | 'mapparams' : cmode, 95 | 'maxpost' : pmode, 96 | 'accrate' : acc_rate, 97 | 'logpost' : np.array(logposts), 98 | 'alphas' : np.array(alphas) 99 | } 100 | 101 | return results 102 | 103 | 104 | def sampler(self, current, imcmc): 105 | """Sampler method. 106 | 107 | Args: 108 | current (np.ndarray): Current chain state. 109 | imcmc (int): Current chain step number. 110 | 111 | Raises: 112 | NotImplementedError: Not implemented in the base class. 113 | """ 114 | raise NotImplementedError("sampler method not implemented in the base class and should be implemented in children.") 115 | -------------------------------------------------------------------------------- /quinn/nns/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | 4 | from . import nnbase 5 | from . import nns 6 | from . import nnwrap 7 | from . import nnfit 8 | 9 | from . import mlp 10 | from . import rnet 11 | 12 | from . import losses 13 | 14 | from . import tchutils 15 | -------------------------------------------------------------------------------- /quinn/nns/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import numpy as np 5 | import torch.autograd.functional as F 6 | 7 | from .tchutils import tch 8 | 9 | class PeriodicLoss(torch.nn.Module): 10 | r"""Example of a periodic loss regularization. 11 | 12 | Attributes: 13 | model (callable): NN model evaluator. 14 | lam (float): Penalty strength. 15 | bdry1 (torch.Tensor): First boundary. 16 | bdry2 (torch.Tensor): Second boundary. 17 | 18 | The loss function has a form 19 | 20 | .. math:: 21 | \frac{1}{N}||y_{\text{pred}}-y_{\text{target}}||^2 + \frac{\lambda}{N}||M\text{(boundary1)}-M\text{(boundary2)}||^2. 22 | """ 23 | 24 | def __init__(self, nnmodel, lam, boundary): 25 | """Initialization. 26 | 27 | Args: 28 | nnmodel (torch.nn.Module): NN model. 29 | lam (float, optional): Penalty strength. Defaults to 0. 30 | boundary (tuple): A tuple of form (boundary1, boundary2). 31 | """ 32 | super().__init__() 33 | self.nnmodel = nnmodel 34 | self.lam = lam 35 | self.bdry1, self.bdry2 = boundary 36 | 37 | def forward(self, inputs, targets): 38 | """Forward function. 39 | 40 | Args: 41 | inputs (torch.Tensor): Input tensor. 42 | targets (torch.Tensor): Targets tensor. 43 | 44 | Returns: 45 | float: Loss value. 46 | """ 47 | 48 | predictions = self.nnmodel(inputs) 49 | 50 | fit = torch.mean((predictions-targets)**2) 51 | 52 | penalty = self.lam * torch.mean((self.nnmodel(self.bdry1)-self.nnmodel(self.bdry2))**2) 53 | loss = fit + penalty 54 | 55 | return loss 56 | 57 | 58 | ######################################################## 59 | ######################################################## 60 | ######################################################## 61 | 62 | class GradLoss(torch.nn.Module): 63 | r"""Example of grad loss function, including derivative contraints. 64 | 65 | Attributes: 66 | lam (float): Penalty strength. 67 | nnmodel (callable): NN model evaluator. 68 | 69 | The loss function has a form 70 | 71 | .. math:: 72 | \frac{1}{N}||M(x_{\text{train}})-y_{\text{train}}||^2 + \frac{\lambda}{Nd}||\nabla M(x_{\text{train}})-G_{\text{train}}||_F^2. 73 | """ 74 | 75 | def __init__(self, nnmodel, lam=0.0, xtrn=None, gtrn=None): 76 | """Initialization. 77 | 78 | Args: 79 | nnmodel (torch.nn.Module): NN model. 80 | lam (float, optional): Penalty strength. Defaults to 0. 81 | xtrn (np.ndarray, optional): Input array of size `(N,d)`. Needs to be user-provided: default produces assertion error. 82 | gtrn (np.ndarray, optional): Gradient array of size `(N,d)`. Needs to be user-provided: default produces assertion error. 83 | """ 84 | super().__init__() 85 | self.nnmodel = nnmodel 86 | self.lam = lam 87 | assert(xtrn is not None) 88 | assert(gtrn is not None) 89 | 90 | self._xtrn = tch(xtrn, rgrad=True) 91 | self._gtrn = tch(gtrn, rgrad=True) 92 | 93 | def forward(self, inputs, targets): 94 | """Forward function. 95 | 96 | Args: 97 | inputs (torch.Tensor): Input tensor. 98 | targets (torch.Tensor): Target tensor. 99 | 100 | Returns: 101 | float: Loss value. 102 | """ 103 | 104 | predictions = self.nnmodel(inputs) 105 | 106 | loss = torch.mean((predictions-targets)**2) 107 | 108 | 109 | # outputs = self.nnmodel(self.xtrn) 110 | # outputs.requires_grad_() 111 | # der = torch.autograd.grad(outputs=outputs, inputs=self.xtrn, 112 | # grad_outputs=torch.ones_like(outputs), 113 | # create_graph=True, retain_graph=True, allow_unused=True)[0] 114 | # if der is not None: 115 | # #print(der1.shape, self.gtrn.shape, der.shape) 116 | # loss += self.lam*torch.mean((der-self.gtrn)**2) 117 | 118 | der = torch.vstack( [ F.jacobian(self.nnmodel, state, create_graph=True, strict=True).squeeze() for state in self.xtrn ] ) 119 | 120 | 121 | loss += self.lam*torch.mean((der-self.gtrn)**2) 122 | 123 | return loss 124 | 125 | 126 | ######################################################## 127 | ######################################################## 128 | ######################################################## 129 | 130 | class NegLogPost(torch.nn.Module): 131 | r"""Negative log-posterior loss function. 132 | 133 | Attributes: 134 | nnmodel (callable): Model evaluator. 135 | priorparams (float): Dictionary of parameters of prior. 136 | sigma (float): Likelihood data noise standard deviation. 137 | fulldatasize (int): Full datasize. Important for weighting in case likelihood is computed on a batch. 138 | pi (float): 3.1415... 139 | 140 | The negative log-posterior has the form: 141 | 142 | .. math:: 143 | \frac{N}{2}\log{(2\pi\sigma^2)} + \frac{1}{2\sigma^2}||M(x_{\text{train}})-y_{\text{train}}||^2 + 144 | .. math:: 145 | +\frac{N}{N_{\text{full}}} \left(\frac{1}{2\sigma_{\text{prior}}^2}||w-w_{\text{anchor}}||^2 + \frac{K}{2} \log{(2\pi\sigma_{\text{prior}}^2)}\right). 146 | """ 147 | 148 | def __init__(self, nnmodel, fulldatasize, sigma, priorparams): 149 | """Initialization. 150 | 151 | Args: 152 | nnmodel (callable): Model evaluator. 153 | fulldatasize (int): Full datasize. Important for weighting in case likelihood is computed on a batch. 154 | sigma (float): Likelihood data noise standard deviation. 155 | priorparams (float): Dictionary of parameters of prior. If None, there will be no prior. 156 | """ 157 | super().__init__() 158 | self.nnmodel = nnmodel 159 | self.sigma = tch(float(sigma), rgrad=False) 160 | self.priorparams = priorparams 161 | self.pi = tch(np.pi, rgrad=False) 162 | self.fulldatasize = fulldatasize 163 | 164 | def forward(self, inputs, targets): 165 | """Forward function. 166 | 167 | Args: 168 | inputs (torch.Tensor): Input tensor. 169 | targets (torch.Tensor): Target tensor. 170 | 171 | Returns: 172 | float: Loss value. 173 | """ 174 | 175 | predictions = self.nnmodel(inputs) 176 | neglogpost = 0.5 * torch.sum(torch.pow(targets - predictions, 2)) / self.sigma**2 177 | neglogpost += (len(predictions) / 2) * torch.log(2 * self.pi) 178 | neglogpost += len(predictions) * torch.log(self.sigma) 179 | 180 | if self.priorparams is not None: 181 | neglogprior_fcn = NegLogPrior(self.priorparams['sigma'], self.priorparams['anchor']) 182 | neglogpost += len(predictions)*neglogprior_fcn(self.nnmodel)/self.fulldatasize 183 | 184 | return neglogpost 185 | 186 | ######################################################## 187 | ######################################################## 188 | ######################################################## 189 | 190 | class NegLogPrior(torch.nn.Module): 191 | r"""Calculates a Gaussian negative log-prior. 192 | 193 | Attributes: 194 | anchor (torch.Tensor): Anchor, i.e. center vector of the gaussian prior. 195 | sigma (float): The standard deviation of the gaussian prior (same for all parameters). 196 | pi (float): 3.1415.. 197 | 198 | The negative log-prior has the form: 199 | 200 | .. math:: 201 | \frac{1}{2\sigma_{\text{prior}}^2}||w-w_{\text{anchor}}||^2 + \frac{K}{2} \log{(2\pi\sigma_{\text{prior}}^2)}. 202 | """ 203 | 204 | def __init__(self, sigma, anchor): 205 | """ 206 | Args: 207 | sigma (float): The standard deviation of the gaussian prior (same for all parameters). 208 | anchor (torch.Tensor): Anchor, i.e. center vector of the gaussian prior. 209 | """ 210 | super().__init__() 211 | self.sigma = tch(float(sigma), rgrad=False) 212 | self.pi = tch(np.pi, rgrad=False) 213 | self.anchor = anchor 214 | 215 | 216 | def forward(self, model): 217 | """Forward evaluator 218 | 219 | Args: 220 | model (torch.nn.Module): The corresponding NN module. 221 | 222 | Returns: 223 | float: Negative log-prior value. 224 | """ 225 | neglogprior = 0 226 | i = 0 227 | for p in model.parameters(): 228 | cur_len = p.flatten().size()[0] 229 | neglogprior += torch.sum( 230 | torch.pow(p.flatten() - self.anchor[i : i + cur_len], 2) ) / 2 / self.sigma**2 231 | 232 | i += cur_len 233 | neglogprior += (i / 2) * torch.log(2 * self.pi * self.sigma**2) 234 | return neglogprior 235 | 236 | ######################################################## 237 | ######################################################## 238 | ######################################################## 239 | 240 | class CustomLoss(torch.nn.Module): 241 | r"""Example of custom one-dimensional loss function, including derivative and periodicity contraints. Quite experimental, but a base for developing problem-specific loss functions. 242 | 243 | Attributes: 244 | model (callable): Model evaluator. 245 | lam1 (float): Penalty strength for the periodicity constraint. 246 | lam2 (float): Penalty strength for the derivative constraint. 247 | 248 | The loss function has a form: 249 | 250 | .. math:: 251 | \frac{1}{N}||y_{\text{pred}}-y_{\text{target}}||^2 + \lambda_1 (M(0.5)-M(-0.5))^2 + \lambda_2 (M'(0.5)-M'(-0.5))^2 252 | """ 253 | 254 | def __init__(self, loss_params): 255 | """Initialization. 256 | 257 | Args: 258 | loss_params (tuple): (model, penalty1, penalty2) pair. 259 | """ 260 | super().__init__() 261 | self.model, self.lam1, self.lam2 = loss_params 262 | 263 | def forward(self, predictions, targets): 264 | """Forward function. 265 | 266 | Args: 267 | predictions (torch.Tensor): Input tensor. 268 | targets (torch.Tensor): Target tensor. 269 | 270 | Returns: 271 | float: Loss value. 272 | """ 273 | loss = torch.mean((predictions-targets)**2) 274 | 275 | loss += self.lam1 * (self.model(torch.Tensor([0.5]))-self.model(torch.Tensor([-0.5])))**2 276 | 277 | x = torch.Tensor([-0.5, 0.5]).view(-1,1) 278 | x.requires_grad_() 279 | 280 | outputs = self.model(x) 281 | outputs.requires_grad_() 282 | der = torch.autograd.grad(outputs=outputs, inputs=x, 283 | grad_outputs=torch.ones_like(outputs), 284 | create_graph=True, retain_graph=True, allow_unused=True)[0] 285 | 286 | if der is not None: # in testing regimes, der is None 287 | reg = (der[0]-der[1])**2 288 | else: 289 | reg = 0.0 290 | 291 | loss += self.lam2*reg 292 | 293 | return loss 294 | 295 | -------------------------------------------------------------------------------- /quinn/nns/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | 5 | from .nns import Expon, Sine 6 | from .nnbase import MLPBase 7 | 8 | class MLP(MLPBase): 9 | """Multilayer perceptron class. 10 | 11 | Attributes: 12 | hls (tuple): Tuple of hidden layer widths. 13 | biasorno (bool): Whether biases are included or not. 14 | bnorm (bool): Whether batch normalization is implemented or not. 15 | bnlearn (bool): Whether batch normalization is learnable or not. 16 | dropout (float): Dropout fraction. 17 | final_transform (str): Final transformation. Currently only 'exp' is implemented for Exponential. 18 | nlayers (int): Number of layers. 19 | nnmodel (torch.nn.Module): Underlying model evaluator. 20 | """ 21 | 22 | def __init__(self, indim, outdim, hls, biasorno=True, 23 | activ='relu', bnorm=False, bnlearn=True, dropout=0.0, 24 | final_transform=None, device='cpu'): 25 | """Initialization. 26 | 27 | Args: 28 | indim (int): Input dimensionality. 29 | outdim (int): Output dimensionality. 30 | hls (tuple): Tuple of hidden layer widths. 31 | biasorno (bool): Whether biases are included or not. 32 | activ (str, optional): Activation function. Options are 'tanh', 'relu', 'sin' or else identity is used. 33 | bnorm (bool): Whether batch normalization is implemented or not. 34 | bnlearn (bool): Whether batch normalization is learnable or not. 35 | dropout (float, optional): Dropout fraction. Default is 0.0. 36 | final_transform (str, optional): Final transform, if any (onle 'exp' is implemented). Default is None. 37 | device (str): It represents where computations are performed and tensors are allocated. Default is cpu. 38 | """ 39 | super(MLP, self).__init__(indim, outdim, device=device) 40 | 41 | self.nlayers = len(hls) 42 | assert(self.nlayers > 0) 43 | self.hls = hls 44 | self.biasorno = biasorno 45 | self.dropout = dropout 46 | self.bnorm = bnorm 47 | self.bnlearn = bnlearn 48 | self.final_transform = final_transform 49 | 50 | if activ == 'tanh': 51 | activ_fcn = torch.nn.Tanh() 52 | elif activ == 'relu': 53 | activ_fcn = torch.nn.ReLU() 54 | elif activ == 'sin': 55 | activ_fcn = Sine() 56 | else: 57 | activ_fcn = torch.nn.Identity() 58 | 59 | modules = [] 60 | modules.append(torch.nn.Linear(self.indim, self.hls[0], self.biasorno)) 61 | if self.dropout > 0.0: 62 | modules.append(torch.nn.Dropout(p=self.dropout)) 63 | 64 | if self.bnorm: 65 | modules.append(torch.nn.BatchNorm1d(self.hls[0], affine=self.bnlearn)) 66 | for i in range(1, self.nlayers): 67 | modules.append(activ_fcn) 68 | modules.append(torch.nn.Linear(self.hls[i - 1], self.hls[i], self.biasorno)) 69 | if self.dropout > 0.0: 70 | modules.append(torch.nn.Dropout(p=self.dropout)) 71 | if self.bnorm: 72 | modules.append(torch.nn.BatchNorm1d(self.hls[i], affine=self.bnlearn)) 73 | 74 | 75 | modules.append(activ_fcn) 76 | modules.append(torch.nn.Linear(self.hls[-1], self.outdim, bias=self.biasorno)) 77 | if self.dropout > 0.0: 78 | modules.append(torch.nn.Dropout(p=self.dropout)) 79 | if self.bnorm: 80 | modules.append(torch.nn.BatchNorm1d(self.outdim, affine=self.bnlearn)) 81 | 82 | if self.final_transform=='exp': 83 | modules.append(Expon()) 84 | 85 | 86 | self.nnmodel = torch.nn.Sequential(*modules) 87 | # sync model to device 88 | self.to(device) 89 | 90 | 91 | 92 | def forward(self, x): 93 | """Forward function. 94 | 95 | Args: 96 | x (torch.Tensor): Input tensor. 97 | 98 | Returns: 99 | torch.Tensor: Output tensor. 100 | """ 101 | return self.nnmodel(x) 102 | 103 | -------------------------------------------------------------------------------- /quinn/nns/nnbase.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for the MLP NN base class.""" 3 | 4 | import torch 5 | import functools 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from .tchutils import npy, tch 10 | 11 | from ..nns.nnfit import nnfit 12 | from ..utils.stats import get_domain 13 | from ..utils.plotting import plot_dm 14 | from ..utils.maps import scale01ToDom 15 | 16 | torch.set_default_dtype(torch.double) 17 | 18 | 19 | class MLPBase(torch.nn.Module): 20 | """Base class for an MLP architecture. 21 | 22 | Attributes: 23 | best_model (torch.nn.Module): Best trained instance, if any. 24 | device (str): Device this object's model will live in. 25 | history (list[np.ndarray]): List containing training history, namely, [fepoch, loss_trn, loss_trn_full, loss_val] 26 | indim (int): Input dimensionality. 27 | outdim (int): Output dimensionality. 28 | trained (bool): Whether the NN is already trained. 29 | """ 30 | 31 | def __init__(self, indim, outdim, device='cpu'): 32 | """Initialization. 33 | 34 | Args: 35 | indim (int): Input dimensionality, `d`. 36 | outdim (int): Output dimensionality, `o`. 37 | device (str): Indicates where computations are performed and tensors are allocated. Default to 'cpu'. 38 | """ 39 | super().__init__() 40 | self.indim = indim 41 | self.outdim = outdim 42 | self.best_model = None 43 | self.trained = False 44 | self.history = None 45 | self.device = device 46 | 47 | 48 | def forward(self, x): 49 | """Forward function is not implemented in base class. 50 | 51 | Args: 52 | x (torch.Tensor): Input of the function. 53 | 54 | Raises: 55 | NotImplementedError: Needs to be implemented in children. 56 | """ 57 | raise NotImplementedError 58 | 59 | def predict(self, x): 60 | """Prediction of the NN. 61 | 62 | Args: 63 | x (np.ndarray): Input array of size `(N,d)`. 64 | 65 | Returns: 66 | np.ndarray: Output array of size `(N,o)`. 67 | 68 | Note: 69 | Both input and outputs are numpy arrays. 70 | 71 | Note: 72 | If trained, it uses the best trained model, otherwise it will use the current weights. 73 | """ 74 | try: 75 | device = self.best_model.device 76 | except AttributeError: 77 | device = 'cpu' 78 | 79 | if self.trained: 80 | y = npy(self.best_model(tch(x, device=device))) 81 | else: 82 | y = npy(self.forward(tch(x, device=device))) 83 | 84 | return y 85 | 86 | def numpar(self): 87 | """Get the number of parameters of NN. 88 | 89 | Returns: 90 | int: Number of parameters, trainable or not. 91 | """ 92 | pdim = sum(p.numel() for p in self.parameters()) 93 | return pdim 94 | 95 | def fit(self, xtrn, ytrn, **kwargs): 96 | """Fit function. 97 | 98 | Args: 99 | xtrn (np.ndarray): Input array of size `(N,d)`. 100 | ytrn (np.ndarray): Output array of size `(N,o)`. 101 | **kwargs (dict): Keyword arguments. 102 | 103 | Returns: 104 | torch.nn.Module: Best trained instance. 105 | 106 | """ 107 | #self.fitdict = locals() 108 | fit_info = nnfit(self, xtrn, ytrn, **kwargs) 109 | self.best_model = fit_info['best_nnmodel'] 110 | self.history = fit_info['history'] 111 | self.trained = True 112 | 113 | return self.best_model 114 | 115 | 116 | def printParams(self): 117 | """Print parameter names and values.""" 118 | for name, param in self.named_parameters(): 119 | if param.requires_grad: 120 | print(name, param.data) 121 | 122 | 123 | def printParamNames(self): 124 | """Print parameter names and shapes.""" 125 | for name, param in self.named_parameters(): 126 | if param.requires_grad: 127 | print(name, param.data.shape) 128 | 129 | 130 | def predict_plot(self, xx_list, yy_list, labels=None, colors=None, iouts=None): 131 | """Plots the diagonal comparison figures. 132 | 133 | Args: 134 | xx_list (list[np.ndarray]): List of `(N,d)` inputs (e.g., training, validation, testing). 135 | yy_list (list[np.ndarray]): List of `(N,o)` outputs. 136 | labels (list[str], optional): List of labels. If None, set label internally. 137 | colors (list[str], optional): List of colors. If None, sets colors internally. 138 | iouts (list[int], optional): List of outputs to plot. If None, plot all. 139 | 140 | Note: 141 | There is a similar function for probabilistic NN in :class:`..solvers.quinn.QUiNNBase`. 142 | """ 143 | nlist = len(xx_list) 144 | assert(nlist==len(yy_list)) 145 | 146 | 147 | yy_pred_list = [] 148 | for xx in xx_list: 149 | yy_pred = self.predict(xx) 150 | yy_pred_list.append(yy_pred) 151 | 152 | nout = yy_pred.shape[1] 153 | if iouts is None: 154 | iouts = range(nout) 155 | 156 | if labels is None: 157 | labels = [f'Set {i+1}' for i in range(nlist)] 158 | assert(len(labels)==nlist) 159 | 160 | if colors is None: 161 | colors = ['b', 'g', 'r', 'c', 'm', 'y']*nlist 162 | colors = colors[:nlist] 163 | assert(len(colors)==nlist) 164 | 165 | for iout in iouts: 166 | x1 = [yy[:, iout] for yy in yy_list] 167 | x2 = [yy[:, iout] for yy in yy_pred_list] 168 | 169 | plot_dm(x1, x2, labels=labels, colors=colors, 170 | axes_labels=[f'Model output # {iout+1}', f'Fit output # {iout+1}'], 171 | figname='fitdiag_o'+str(iout)+'.png', 172 | legendpos='in', msize=13) 173 | 174 | def plot_1d_fits(self, xx_list, yy_list, domain=None, ngr=111, true_model=None, labels=None, colors=None): 175 | """Plotting one-dimensional slices, with the other dimensions at the nominal, of the fit. 176 | 177 | Args: 178 | xx_list (list[np.ndarray]): List of `(N,d)` inputs (e.g., training, validation, testing). 179 | yy_list (list[np.ndarray]): List of `(N,o)` outputs. 180 | domain (np.ndarray, optional): Domain of the function, `(d,2)` array. If None, sets it automatically based on data. 181 | ngr (int, optional): Number of grid points in the 1d plot. 182 | true_model (callable, optional): Optionally, plot the true function. 183 | labels (list[str], optional): List of labels. If None, set label internally. 184 | colors (list[str], optional): List of colors. If None, sets colors internally. 185 | 186 | Note: 187 | There is a similar function for probabilistic NN in :class:`..solvers.quinn.QUiNNBase`. 188 | """ 189 | 190 | nlist = len(xx_list) 191 | assert(nlist==len(yy_list)) 192 | 193 | if labels is None: 194 | labels = [f'Set {i+1}' for i in range(nlist)] 195 | assert(len(labels)==nlist) 196 | 197 | if colors is None: 198 | colors = ['b', 'g', 'r', 'c', 'm', 'y']*nlist 199 | colors = colors[:nlist] 200 | assert(len(colors)==nlist) 201 | 202 | if domain is None: 203 | xall = functools.reduce(lambda x,y: np.vstack((x,y)), xx_list) 204 | domain = get_domain(xall) 205 | 206 | mlabel = 'Mean Pred.' 207 | 208 | ndim = xx_list[0].shape[1] 209 | nout = yy_list[0].shape[1] 210 | for idim in range(ndim): 211 | xgrid_ = 0.5 * np.ones((ngr, ndim)) 212 | xgrid_[:, idim] = np.linspace(0., 1., ngr) 213 | 214 | xgrid = scale01ToDom(xgrid_, domain) 215 | ygrid_pred = self.predict(xgrid) 216 | 217 | for iout in range(nout): 218 | 219 | for j in range(nlist): 220 | xx = xx_list[j] 221 | yy = yy_list[j] 222 | 223 | plt.plot(xx[:, idim], yy[:, iout], colors[j]+'o', markersize=13, markeredgecolor='w', label=labels[j]) 224 | 225 | if true_model is not None: 226 | true = true_model(xgrid, 0.0) 227 | plt.plot(xgrid[:, idim], true[:, iout], 'k-', label='Truth', alpha=0.5) 228 | 229 | 230 | p, = plt.plot(xgrid[:, idim], ygrid_pred[:, iout], 'm-', linewidth=5, label=mlabel) 231 | 232 | 233 | plt.legend() 234 | plt.xlabel(f'Input # {idim+1}') 235 | plt.ylabel(f'Output # {iout+1}') 236 | plt.savefig('fit_d' + str(idim) + '_o' + str(iout) + '.png') 237 | plt.clf() 238 | 239 | 240 | -------------------------------------------------------------------------------- /quinn/nns/nnfit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for the mother-of-all fit function.""" 3 | 4 | import sys 5 | import copy 6 | import torch 7 | import matplotlib.pyplot as plt 8 | 9 | from .tchutils import tch 10 | from .losses import NegLogPost, GradLoss 11 | 12 | 13 | def nnfit(nnmodel, xtrn, ytrn, val=None, 14 | loss_fn='mse', loss_xy=None, 15 | datanoise=None, wd=0.0, 16 | priorparams=None, 17 | optimizer='adam', 18 | lrate=0.1, lmbd=None, scheduler_lr=None, 19 | nepochs=5000, batch_size=None, 20 | gradcheck=False, 21 | cooldown=100, factor=0.95, 22 | freq_out=100, freq_plot=1000, lhist_suffix='' 23 | ): 24 | """Generic PyTorch NN fit function that is utilized in appropriate NN classes. 25 | 26 | Args: 27 | nnmodel (torch.nn.Module): The PyTorch NN module of interest. 28 | xtrn (np.ndarray): Training input array `x` of size `(N,d)`. 29 | ytrn (np.ndarray): Training output array `y` of size `(N,o)`. 30 | val (tuple, optional): `x,y` tuple of validation points. Default uses the training set for validation. 31 | loss_fn (str, optional): Loss function string identifier. Currently only 'mse' is implemented and it is the default. Only used if the next argument, loss_xy=None. 32 | loss_xy (None, optional): Optionally, a more flexible loss function (e.g. used in variational inference) with signature :math:`\textrm{loss}(x_{pred}, y_{target})`. The default is None, which triggers the use of previous argument, loss_fn. 33 | datanoise (None, optional): Datanoise for certain loss types. 34 | wd (float, optional): Optional weight decay (L2 regularization) parameter. 35 | priorparams (dict, optional): Dictionary of prior parameters. 36 | optimizer (str, optional): Optimizer string. Currently implemented 'adam' (default) and 'sgd'. 37 | lrate (float, optional): Learning rate or learning rate schedule factor. Default is 0.1. 38 | lmbd (callable, optional): Optional learning rate schedule. The actual learning rate is `lrate * lmbd(epoch)`. 39 | scheduler_lr (str, optional): Learning rate is adjusted during training according to a generic pytorch scheduler. Currently only ReduceLROnPlateau is implemented. Conflicts with user-defined lmbd scheduler: need to pick one. Defaults to None, which uses no-schedule lrate. 40 | nepochs (int, optional): Number of epochs. 41 | batch_size (int, optional): Batch size. Default is None, i.e. single batch. 42 | gradcheck (bool, optional): For code verification, whether we want to check the auto-computed gradients against numerically computed ones. Makes the code slow. Experimental - this is not tested enough. 43 | cooldown (int, optional): cooldown in ReduceLROnPlateau 44 | factor (float, optional): factor in ReduceLROnPlateau 45 | freq_out (int, optional): Frequency, in epochs, of screen output. Defaults to 100. 46 | freq_plot (int, optional): Frequency, in epochs, of plotting loss convergence graph. Defaults to 1000. 47 | lhist_suffix (str, optional): Optional suffix of loss history figure filename. 48 | 49 | Returns: 50 | dict: Dictionary of the results. Keys 'best_fepoch', 'best_epoch', 'best_loss', 'best_nnmodel', 'history'. 51 | """ 52 | 53 | ntrn = xtrn.shape[0] 54 | 55 | # Loss function 56 | if loss_xy is None: 57 | if loss_fn == 'mse': 58 | loss = torch.nn.MSELoss(reduction='mean') 59 | def loss_xy(x, y): 60 | return loss(nnmodel(x), y) 61 | elif loss_fn == 'logpost': 62 | loss_xy = NegLogPost(nnmodel, ntrn, datanoise, priorparams) 63 | else: 64 | print(f"Loss function {loss_fn} is unknown. Exiting.") 65 | sys.exit() 66 | 67 | 68 | # Optimizer selection 69 | if optimizer == 'adam': 70 | opt = torch.optim.Adam(nnmodel.parameters(), lr=lrate, weight_decay=wd) 71 | elif optimizer == 'sgd': 72 | opt = torch.optim.SGD(nnmodel.parameters(), lr=lrate, weight_decay=wd) 73 | else: 74 | print(f"Optimizer {optimizer} is unknown. Exiting.") 75 | sys.exit() 76 | 77 | # Learning rate schedule 78 | if scheduler_lr == "ReduceLROnPlateau" and not lmbd is None: 79 | print(f"Trying to use two schedulers. Exiting.") 80 | sys.exit() 81 | 82 | if lmbd is None: 83 | def lmbd(epoch): return 1.0 84 | scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lmbd) 85 | 86 | if scheduler_lr == "ReduceLROnPlateau": 87 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', cooldown=cooldown, factor=factor, verbose=False) 88 | 89 | 90 | 91 | if batch_size is None or batch_size > ntrn: 92 | batch_size = ntrn 93 | 94 | try: 95 | device = nnmodel.device 96 | except AttributeError: 97 | device = 'cpu' 98 | 99 | xtrn_ = tch(xtrn , device=device) 100 | ytrn_ = tch(ytrn , device=device) 101 | 102 | # Validation data 103 | if val is None: 104 | xval, yval = xtrn.copy(), ytrn.copy() 105 | else: 106 | xval, yval = val 107 | 108 | xval_ = tch(xval , device=device) 109 | yval_ = tch(yval , device=device) 110 | 111 | # print("device: ", device) 112 | # print("xval_ is device: ", xval_.get_device(), xval_.device) 113 | # Training process 114 | fit_info = {'best_fepoch': 0, 'best_epoch': 0, 115 | 'best_loss': 1.e+100, 'best_nnmodel': nnmodel, 116 | 'history': []} 117 | 118 | 119 | fepoch = 0 120 | for t in range(nepochs): 121 | permutation = torch.randperm(ntrn) 122 | # for parameter in model.parameters(): 123 | # print(parameter) 124 | nsubepochs = len(range(0, ntrn, batch_size)) 125 | for i in range(0, ntrn, batch_size): 126 | indices = permutation[i:i + batch_size] 127 | 128 | loss_trn = loss_xy(xtrn_[indices, :], ytrn_[indices, :]) 129 | #loss_val = loss_trn 130 | with torch.no_grad(): 131 | loss_val = loss_xy(xval_, yval_) 132 | #loss_trn_full = loss_trn 133 | if i == 0: # otherwise too expensive 134 | with torch.no_grad(): 135 | loss_trn_full = loss_xy(xtrn_, ytrn_) 136 | 137 | fepoch += 1. / nsubepochs 138 | 139 | curr_state = [fepoch + 0.0, loss_trn.item(), loss_trn_full.item(), loss_val.item()] 140 | crit = loss_val.item() 141 | 142 | fit_info['history'].append(curr_state) 143 | 144 | if crit < fit_info['best_loss']: 145 | fit_info['best_loss'] = crit 146 | fit_info['best_nnmodel'] = copy.copy(nnmodel) 147 | 148 | fit_info['best_fepoch'] = fepoch 149 | fit_info['best_epoch'] = t 150 | 151 | 152 | if gradcheck: 153 | gc = torch.autograd.gradcheck(nnmodel, (xtrn_,), 154 | eps=1e-2, atol=1e-2) 155 | 156 | opt.zero_grad() 157 | loss_trn.backward() 158 | 159 | opt.step() 160 | if scheduler_lr == "ReduceLROnPlateau": 161 | ## using ValLoss as metric 162 | scheduler.step(curr_state[3]) 163 | else: 164 | scheduler.step() 165 | 166 | ## Printout to screen 167 | if t == 0: 168 | print('{:>10} {:>10} {:>12} {:>12} {:>12} {:>18} {:>10}'.\ 169 | format("NEpochs", "NUpdates", 170 | "BatchLoss", "TrnLoss", "ValLoss", 171 | "BestLoss (Epoch)", "LrnRate"), flush=True) 172 | 173 | if (t + 1) % freq_out == 0 or t == 0 or t == nepochs - 1: 174 | tlr = opt.param_groups[0]['lr'] 175 | printout = f"{t+1:>10}" \ 176 | f"{len(fit_info['history']):>10}" \ 177 | f"{fit_info['history'][-1][1]:>14.6f}" \ 178 | f"{fit_info['history'][-1][2]:>13.6f}" \ 179 | f"{fit_info['history'][-1][3]:>13.6f}" \ 180 | f"{fit_info['best_loss']:>14.6f} ({fit_info['best_epoch']})" \ 181 | f"{tlr:>10}" 182 | print(printout, flush=True) 183 | 184 | ## Plotting 185 | if t % freq_plot == 0 or t == nepochs - 1: 186 | fepochs = [state[0] for state in fit_info['history']] 187 | losses_trn = [state[1] for state in fit_info['history']] 188 | losses_trn_full = [state[2] for state in fit_info['history']] 189 | losses_val = [state[3] for state in fit_info['history']] 190 | 191 | _ = plt.figure(figsize=(12, 8)) 192 | 193 | plt.plot(fepochs, losses_trn, label='Batch loss') 194 | plt.plot(fepochs, losses_trn_full, label='Training loss') 195 | plt.plot(fit_info['best_fepoch'], fit_info['best_loss'], 196 | 'ro', markersize=11) 197 | plt.vlines(fit_info['best_fepoch'], 0.0, 2.0, 198 | colors=None, linestyles='--') 199 | plt.plot(fepochs, losses_val, label='Validation loss') 200 | 201 | plt.legend() 202 | 203 | plt.savefig(f'loss_history{lhist_suffix}.png') 204 | plt.yscale('log') 205 | plt.savefig(f'loss_history{lhist_suffix}_log.png') 206 | plt.clf() 207 | 208 | return fit_info 209 | -------------------------------------------------------------------------------- /quinn/nns/nns.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module containing various simple PyTorch NN modules.""" 3 | 4 | 5 | import math 6 | import torch 7 | 8 | class Gaussian(torch.nn.Module): 9 | r"""Gaussian function. :math:`\textrm{Gaussian}(x) = e^{-x^2}` 10 | """ 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__() 14 | 15 | def forward(self, x): 16 | """Forward function. 17 | 18 | Args: 19 | x (torch.Tensor): Input tensor `x`. 20 | 21 | Returns: 22 | torch.Tensor: Output tensor of same size as input `x`. 23 | """ 24 | return torch.exp(-x**2) 25 | 26 | 27 | class Sine(torch.nn.Module): 28 | r"""Sine function. :math:`\textrm{Sin}(x) = A\sin\left(2\pi x/T\right)`""" 29 | def __init__(self, A=1.0, T=1.0): 30 | """Initialization. 31 | 32 | Args: 33 | A (float, optional): Amplitude `A`. Defaults to 1. 34 | T (float, optional): Period `T`. Defaults to 1. 35 | """ 36 | super().__init__() 37 | self.A = A 38 | self.T = T 39 | 40 | def forward(self, x): 41 | """Forward function. 42 | 43 | Args: 44 | x (torch.Tensor): Input tensor `x`. 45 | 46 | Returns: 47 | torch.Tensor: Output tensor of same size as input `x`. 48 | """ 49 | 50 | return torch.sin(self.A*torch.Tensor(math.pi)*x/self.T) 51 | 52 | 53 | class Polynomial(torch.nn.Module): 54 | r"""Polynomial function :math:`\textrm{Polynomial}(x)=\sum_{i=0}^p c_i x^i`. 55 | 56 | Attributes: 57 | order (int): Order of the polynomial. 58 | coefs (torch.nn.Parameter): Coefficient array of size `p+1`. 59 | """ 60 | 61 | def __init__(self, order): 62 | """Initialization. 63 | 64 | Args: 65 | order (int): Order of the polynomial. 66 | """ 67 | super().__init__() 68 | self.order = order 69 | 70 | self.coefs= torch.nn.Parameter(torch.randn((self.order+1,))) 71 | 72 | # Parameter List does not work with quinn.vi.vi 73 | # self.coefs= torch.nn.ParameterList([torch.nn.Parameter(torch.randn(())) for i in range(self.order+1)]) 74 | 75 | def forward(self, x): 76 | """Forward function. 77 | 78 | Args: 79 | x (torch.Tensor): Input tensor `x`. 80 | 81 | Returns: 82 | torch.Tensor: Output tensor of same size as input `x`. 83 | """ 84 | 85 | 86 | val = torch.zeros_like(x) 87 | for i, cf in enumerate(self.coefs): 88 | val += cf*x**i 89 | 90 | return val 91 | 92 | 93 | class Polynomial3(torch.nn.Module): 94 | r"""Example 3-rd order polynomial function :math:`\textrm{Polynomial3}(x)=a+bx+cx^2+dx^3`. 95 | 96 | Attributes: 97 | a (torch.nn.Parameter): Constant coefficient. 98 | b (torch.nn.Parameter): First-order coefficient. 99 | c (torch.nn.Parameter): Second-order coefficient. 100 | d (torch.nn.Parameter): Third-order coefficient. 101 | """ 102 | 103 | def __init__(self): 104 | """Instantiate four parameters. 105 | """ 106 | super().__init__() 107 | self.a = torch.nn.Parameter(torch.randn(())) 108 | self.b = torch.nn.Parameter(torch.randn(())) 109 | self.c = torch.nn.Parameter(torch.randn(())) 110 | self.d = torch.nn.Parameter(torch.randn(())) 111 | 112 | def forward(self, x): 113 | """Forward function. 114 | 115 | Args: 116 | x (torch.Tensor): Input tensor `x`. 117 | 118 | Returns: 119 | torch.Tensor: Output tensor of same size as input `x`. 120 | """ 121 | return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 122 | 123 | class Constant(torch.nn.Module): 124 | r"""Constant function :math:`\textrm{Constant}(x)=C`. 125 | 126 | Attributes: 127 | constant (torch.nn.Parameter): Constant `C`. 128 | """ 129 | 130 | def __init__(self): 131 | """Instantiate the constant.""" 132 | super().__init__() 133 | self.constant = torch.nn.Parameter(torch.randn(())) 134 | 135 | def forward(self, x): 136 | """Forward function. 137 | 138 | Args: 139 | x (torch.Tensor): Input tensor `x`. 140 | 141 | Returns: 142 | torch.Tensor: Output tensor of same size as input `x`. 143 | """ 144 | return self.constant * torch.ones_like(x) 145 | 146 | 147 | class SiLU(torch.nn.Module): 148 | r"""Sigmoid Linear Unit (SiLU) function :math:`\textrm{SiLU}(x) = x \sigma(x) = \frac{x}{1+e^{-x}}` 149 | """ 150 | def __init__(self): 151 | """Initialization. """ 152 | super().__init__() 153 | 154 | def forward(self, x): 155 | """Forward function. 156 | 157 | Args: 158 | x (torch.Tensor): Input tensor `x`. 159 | 160 | Returns: 161 | torch.Tensor: Output tensor of same size as input `x`. 162 | """ 163 | return x * torch.sigmoid(x) 164 | 165 | 166 | class Expon(torch.nn.Module): 167 | r"""Exponential function :math:`\textrm{Expon}(x) = e^{x}` 168 | """ 169 | def __init__(self): 170 | """Initialization. """ 171 | super().__init__() 172 | 173 | def forward(self, x): 174 | """Forward function. 175 | 176 | Args: 177 | x (torch.Tensor): Input tensor `x`. 178 | 179 | Returns: 180 | torch.Tensor: Output tensor of same size as input `x`. 181 | """ 182 | return torch.exp(x) 183 | 184 | class TwoLayerNet(torch.nn.Module): 185 | """Example two-layer function, with a cubic polynomical between layers. 186 | 187 | Attributes: 188 | linear1 (torch.nn.Linear): First linear layer. 189 | linear2 (torch.nn.Linear): Second linear layer. 190 | cubic (torch.nn.Module): Cubic layer in-between the linear ones. 191 | """ 192 | 193 | def __init__(self, D_in, H, D_out): 194 | r"""Initializes give the input, output dimensions and the hidden width. 195 | 196 | Args: 197 | D_in (int): Input dimension :math:`d_{in}`. 198 | H (int): Hidden layer width. 199 | D_out (int): Output dimension :math:`d_{out}`. 200 | """ 201 | super(TwoLayerNet, self).__init__() 202 | self.linear1 = torch.nn.Linear(D_in, H) 203 | self.linear2 = torch.nn.Linear(H, D_out) 204 | self.cubic = Polynomial3() 205 | 206 | def forward(self, x): 207 | r"""Forward function. 208 | 209 | Args: 210 | x (torch.Tensor): Input tensor `x` of size :math:`(N,d_{in})`. 211 | 212 | Returns: 213 | torch.Tensor: Output tensor of size :math:`(N,d_{out})`. 214 | """ 215 | h_relu = self.linear1(x).clamp(min=0) 216 | y_pred = self.cubic(h_relu) 217 | y_pred = self.linear2(y_pred) 218 | 219 | return y_pred 220 | 221 | 222 | class MLP_simple(torch.nn.Module): 223 | r"""Simple MLP example. 224 | 225 | Attributes: 226 | biasorno (bool): Whether to use bias or not. 227 | hls (tuple[int]): List of layer widths. 228 | indim (int): Input dimensionality :math:`d_{in}`. 229 | outdim (int): Output dimensionality :math:`d_{out}`. 230 | model (torch.nn.Sequential): The PyTorch Sequential model behind the forward function. 231 | 232 | Note: 233 | Uses :math:`\tanh(x)` as activation function between layers. 234 | """ 235 | 236 | def __init__(self, hls, biasorno=True): 237 | """Initialization. 238 | 239 | Args: 240 | hls (tuple[int]): Tuple of number of units per layer, length of list if number of layers 241 | biasorno (bool, optional): Whether to use bias or not. Defaults to True. 242 | """ 243 | super().__init__() 244 | assert(len(hls)>1) 245 | self.hls = hls[1:-1] 246 | self.indim = hls[0] 247 | self.outdim = hls[-1] 248 | self.biasorno = biasorno 249 | 250 | modules = [] 251 | for j in range(len(hls)-2): 252 | modules.append(torch.nn.Linear(hls[j], hls[j+1], self.biasorno)) 253 | modules.append(torch.nn.Tanh()) 254 | modules.append(torch.nn.Linear(hls[-2], hls[-1], bias=self.biasorno)) 255 | 256 | self.model = torch.nn.Sequential(*modules) 257 | 258 | def forward(self, x): 259 | r"""Forward function. 260 | 261 | Args: 262 | x (torch.Tensor): Input tensor `x` of size :math:`(N,d_{in})`. 263 | 264 | Returns: 265 | torch.Tensor: Output tensor of size :math:`(N,d_{out})`. 266 | """ 267 | 268 | return self.model(x) 269 | 270 | 271 | -------------------------------------------------------------------------------- /quinn/nns/nnwrap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for various useful wrappers to NN functions.""" 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from .nnbase import MLPBase 8 | from .tchutils import tch, npy 9 | 10 | class NNWrap(): 11 | """Wrapper class to any PyTorch NN module to make it work as a numpy function. Basic usage is therefore :math:`f=NNWrap(); y=f(x)` 12 | 13 | Attributes: 14 | indices (list): List containing [start index, end index) for each model parameter. Useful for flattening/unflattening of parameter arrays. 15 | nnmodel (torch.nn.Module): The original PyTorch NN module. 16 | """ 17 | 18 | def __init__(self, nnmodel): 19 | """Instantiate a NN Wrapper object. 20 | 21 | Args: 22 | nnmodel (torch.nn.Module): The original PyTorch NN module. 23 | """ 24 | self.nnmodel = nnmodel 25 | self.indices = None 26 | _ = self.p_flatten() 27 | 28 | def reinitialize_instance(self): 29 | """Reinitialize the underlying NN module.""" 30 | self.nnmodel.reinitialize_instance() 31 | 32 | 33 | def __call__(self, x): 34 | """Calling the wrapper function. 35 | 36 | Args: 37 | x (np.ndarray): A numpy input array of size `(N,d)`. 38 | 39 | Returns: 40 | np.ndarray: A numpy output array of size `(N,o)`. 41 | """ 42 | try: 43 | device = self.nnmodel.device 44 | except AttributeError: 45 | device = 'cpu' 46 | 47 | return npy(self.nnmodel.forward(tch(x, device=device))) 48 | 49 | def predict(self, x_in, weights): 50 | """Model prediction given new weights. 51 | 52 | Args: 53 | x_in (np.ndarray): A numpy input array of size `(N,d)`. 54 | weights (np.ndarray): flattened parameter vector. 55 | 56 | Returns: 57 | np.ndarray: A numpy output array of size `(N,o)`. 58 | """ 59 | x_in = tch(x_in) 60 | self.p_unflatten(weights) 61 | y_out = self.nnmodel(x_in).detach().numpy() 62 | return y_out 63 | 64 | def p_flatten(self): 65 | """Flattens all parameters of the underlying NN module into an array. 66 | 67 | Returns: 68 | torch.Tensor: A flattened (1d) torch tensor. 69 | """ 70 | l = [torch.flatten(p) for p in self.nnmodel.parameters()] 71 | self.indices = [] 72 | s = 0 73 | for p in l: 74 | size = p.shape[0] 75 | self.indices.append((s, s+size)) 76 | s += size 77 | flat_parameter = torch.cat(l).view(-1, 1) 78 | 79 | return flat_parameter 80 | 81 | def p_unflatten(self, flat_parameter): 82 | """Fills the values of corresponding parameters given the flattened numpy form. 83 | 84 | Args: 85 | flat_parameter (np.ndarray): A flattened form of parameters. 86 | 87 | Returns: 88 | list[torch.Tensor]: List of recovered parameters, reshaped and ordered to match the model. 89 | 90 | Note: 91 | Returning the list is secondary. The most important result is that this function internally fills the values of corresponding parameters. 92 | """ 93 | # FIXME: we should only allocate tensors in initialization. 94 | try: 95 | device = self.nnmodel.device 96 | except AttributeError: 97 | device = 'cpu' 98 | 99 | ll = [tch(flat_parameter[s:e],device=device) for (s, e) in self.indices] 100 | for i, p in enumerate(self.nnmodel.parameters()): 101 | if len(p.shape)>0: 102 | ll[i] = ll[i].view(*p.shape) 103 | 104 | p.data = ll[i] 105 | 106 | return ll 107 | 108 | 109 | def calc_loss(self, weights, loss_fn, inputs, targets): 110 | """Calculates a given loss function with respect to model parameters. 111 | 112 | Args: 113 | weights (np.ndarray): weights of the model. 114 | loss_fn (torch.nn.Module): pytorch loss module of signature loss(inputs, targets) 115 | inputs (np.ndarray): inputs to the model. 116 | targets (np.ndarray): target outputs that get compared to model outputs. 117 | 118 | Returns: 119 | loss (float): loss of the model given the data. 120 | """ 121 | inputs = tch(inputs, rgrad=False) 122 | targets = tch(targets, rgrad=False) 123 | self.p_unflatten(weights) # TODO: this is not always necessary if loss_fn already incorporates the weights? 124 | 125 | loss = loss_fn(inputs, targets) 126 | return loss.item() 127 | 128 | def calc_lossgrad(self, weights, loss_fn, inputs, targets): 129 | """Calculates the gradients of a given loss function with respect to model parameters. 130 | 131 | Args: 132 | weights (np.ndarray): weights of the model. 133 | loss_fn (torch.nn.Module): pytorch loss module of signature loss(inputs, targets) 134 | inputs (np.ndarray): inputs to the model. 135 | targets (np.ndarray): target outputs that get compared to model outputs. 136 | 137 | Returns: 138 | np.ndarray: A numpy array of the loss gradient w.r.t. to the model parameters at inputs. 139 | """ 140 | inputs = tch(inputs, rgrad=False) 141 | targets = tch(targets, rgrad=False) 142 | self.p_unflatten(weights) # TODO: this is not always necessary if loss_fn already incorporates the weights? 143 | 144 | loss = loss_fn(inputs, targets) 145 | loss.backward() 146 | gradients = [] 147 | for p in self.nnmodel.parameters(): 148 | gradients.append(npy(p.grad).flatten()) 149 | p.grad = None 150 | return np.concatenate(gradients, axis=0) 151 | 152 | 153 | def calc_hess_full(self, weigths, loss_fn, inputs, targets): 154 | """Calculates the hessian of a given loss function with respect to model parameters. 155 | 156 | Args: 157 | weights (np.ndarray): weights of the model. 158 | loss_fn (torch.nn.Module): pytorch loss module of signature loss(inputs, targets) 159 | inputs (np.ndarray): inputs to the model. 160 | targets (np.ndarray): target outputs that get compared to model outputs. 161 | 162 | Returns: 163 | np.ndarray: Hessian matrix of the loss with respect to the model parameters at inputs. 164 | """ 165 | inputs = tch(inputs, rgrad=False) 166 | targets = tch(targets, rgrad=False) 167 | self.p_unflatten(weigths) # TODO: this is not always necessary if loss_fn already incorporates the weights? 168 | 169 | # Calculate the gradient 170 | loss = loss_fn(inputs, targets) 171 | 172 | ## One method... 173 | # loss.backward() 174 | # gradients = [] 175 | # for p in self.nnmodel.parameters(): 176 | # gradients.append(npy(p.grad).flatten()) 177 | # p.grad = None 178 | # gradients = np.concatenate(gradients, axis=0) 179 | 180 | ## ... or its alternative 181 | gradients = torch.autograd.grad( 182 | loss, self.nnmodel.parameters(), create_graph=True, retain_graph=True 183 | ) 184 | gradients = [gradient.flatten() for gradient in gradients] 185 | 186 | hessian_rows = [] 187 | # Calculate the gradient of the elements of the gradient 188 | for gradient in gradients: 189 | for j in range(gradient.size(0)): 190 | hessian_rows.append( 191 | torch.autograd.grad(gradient[j], self.nnmodel.parameters(), retain_graph=True) 192 | ) 193 | hessian_mat = [] 194 | # Shape the Hessian to a 2D tensor 195 | for i in range(len(hessian_rows)): 196 | row_hessian = [] 197 | for gradient in hessian_rows[i]: 198 | row_hessian.append(gradient.flatten().unsqueeze(0)) 199 | hessian_mat.append(torch.cat(row_hessian, dim=1)) 200 | hessian_mat = torch.cat(hessian_mat, dim=0) 201 | return hessian_mat.detach().numpy() 202 | 203 | 204 | def calc_hess_diag(self, weigths, loss_fn, inputs, targets): 205 | """Calculates the diagonal hessian approximation of a given loss function with respect to model parameters. 206 | 207 | Args: 208 | weights (np.ndarray): weights of the model. 209 | loss_fn (torch.nn.Module): pytorch loss module of signature loss(inputs, targets) 210 | inputs (np.ndarray): inputs to the model. 211 | targets (np.ndarray): target outputs that get compared to model outputs. 212 | 213 | Returns: 214 | np.ndarray: A diagonal Hessian matrix of the loss with respect to the model parameters at inputs. 215 | """ 216 | inputs = tch(inputs, rgrad=False) 217 | targets = tch(targets, rgrad=False) 218 | self.p_unflatten(weigths) # TODO: this is not always necessary if loss_fn already incorporates the weights? 219 | 220 | # Calculate the gradient 221 | gradient_list = [] 222 | for input_, target_ in zip(inputs, targets): 223 | loss = loss_fn(input_, target_) 224 | 225 | gradients = torch.autograd.grad(loss, self.nnmodel.parameters(), create_graph=True, retain_graph=True) 226 | gradient_list.append(torch.cat([gradient.flatten() for gradient in gradients]).unsqueeze(0)) 227 | diag_fim = torch.cat(gradient_list, dim=0).pow(2).mean(0) 228 | 229 | return torch.diag(diag_fim).detach().numpy() 230 | 231 | ############################################################ 232 | ############################################################ 233 | ############################################################ 234 | 235 | class SNet(MLPBase): 236 | """A single NN wrapper of a given torch NN module. This is useful as it will inherit all the methods of MLPBase. Written in the spirit of UQ wrapper/solvers. 237 | 238 | Attributes: 239 | nnmodel (torch.nn.Module): The underlying torch NN module. 240 | """ 241 | 242 | def __init__(self, nnmodel, indim, outdim, device='cpu'): 243 | """Initialization. 244 | 245 | Args: 246 | nnmodel (torch.nn.Module): The underlying torch NN module. 247 | indim (int): Input dimensionality. 248 | outdim (int): Output dimensionality. 249 | device (str, optional): Device where the computations will be done. Defaults to 'cpu'. 250 | """ 251 | super().__init__(indim, outdim, device=device) 252 | self.nnmodel = nnmodel 253 | 254 | def forward(self, x): 255 | """Forward function. 256 | 257 | Args: 258 | x (torch.Tensor): Input tensor. 259 | 260 | Returns: 261 | torch.Tensor: Output tensor. 262 | """ 263 | return self.nnmodel(x) 264 | 265 | ############################################################### 266 | ############################################################### 267 | ############################################################### 268 | 269 | def nnwrapper(x, nnmodel): 270 | """A simple numpy-ifying wrapper function to any PyTorch NN module :math:`f(x)=\textrm{NN}(x)`. 271 | 272 | Args: 273 | x (np.ndarray): An input numpy array `x` of size `(N,d)`. 274 | nnmodel (torch.nn.Module): The underlying PyTorch NN module. 275 | 276 | Returns: 277 | np.ndarray: An output numpy array of size `(N,o)`. 278 | """ 279 | try: 280 | device = nnmodel.device 281 | except AttributeError: 282 | device = 'cpu' 283 | return npy(nnmodel.forward(tch(x,device=device, rgrad=False))) 284 | 285 | 286 | def nn_surrogate(x, *otherpars): 287 | r"""A simple wrapper function as a surrogate to a PyTorch NN module :math:`f(x)=\textrm{NN}(x)`. 288 | 289 | Args: 290 | x (np.ndarray): An input numpy array `x` of size `(N,d)`. 291 | otherpars (list): List containing one element, the PyTorch NN module of interest. 292 | 293 | Returns: 294 | np.ndarray: An output numpy array of size `(N,o)`. 295 | 296 | Note: 297 | This is effectively the same as nnwrapper. It is kept for backward compatibility. 298 | """ 299 | nnmodule = otherpars[0] 300 | 301 | return nnwrapper(x, nnmodule) 302 | 303 | ############################################################### 304 | ############################################################### 305 | ############################################################### 306 | 307 | def nn_surrogate_multi(par, *otherpars): 308 | r"""A simple wrapper function as a surrogate to a PyTorch NN module :math:`f_i(x)=\textrm{NN}_i(x)` for `i=1,...,o`. 309 | 310 | Args: 311 | x (np.ndarray): An input numpy array `x` of size `(N,d)`. 312 | otherpars (list[list]): List containing one element, a list of PyTorch NN modules of interest (a total of `o` modules). 313 | 314 | Returns: 315 | np.ndarray: An output numpy array of size `(N,o)`. 316 | """ 317 | nnmodules = otherpars[0] 318 | 319 | nout = len(nnmodules) 320 | yy = np.empty((par.shape[0], nout)) 321 | for iout in range(nout): 322 | yy[:, iout] = nnwrapper(par, nnmodules[iout]).reshape(-1,) 323 | 324 | return yy 325 | 326 | ############################################################### 327 | ############################################################### 328 | ############################################################### 329 | 330 | def nn_p(p, x, *otherpars): 331 | r"""A NN wrapper that evaluates a given PyTorch NN module given input `x` and flattened parameter vector `p`. In other words, :math:`f(p,x)=\textrm{NN}_p(x).` 332 | 333 | Args: 334 | p (np.ndarray): Flattened parameter (weights) vector. 335 | x (np.ndarray): An input numpy array `x` of size `(N,d)`. 336 | otherpars (list): List containing one element, the PyTorch NN module of interest. 337 | 338 | Returns: 339 | np.ndarray: A numpy output array of size `(N,o)`. 340 | 341 | Note: 342 | The size checks on `p` are missing: wherever this is used in QUiNN, the size checks are implied and correct. Use with care outside QUiNN. 343 | """ 344 | nnmodule = otherpars[0] 345 | nnw = NNWrap(nnmodule) 346 | nnw.p_unflatten(p) 347 | return nnw(x) 348 | -------------------------------------------------------------------------------- /quinn/nns/tchutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Various useful PyTorch related utilities.""" 3 | 4 | import copy 5 | import torch 6 | 7 | 8 | torch.set_default_dtype(torch.double) 9 | 10 | def tch(arr, device='cpu', rgrad=True): 11 | """Convert a numpy array to torch Tensor. 12 | 13 | Args: 14 | arr (np.ndarray): A numpy array of any size. 15 | device (str, optional): It represents where tensors are allocated. Default to cpu. 16 | rgrad (bool, optional): Whether to require gradient tracking or not. 17 | 18 | Returns: 19 | torch.Tensor: Torch tensor of the same size as the input numpy array. 20 | """ 21 | 22 | # return torch.from_numpy(arr.astype(np.double)).to(device) 23 | # return torch.from_numpy(arr).double() 24 | return torch.tensor(arr, requires_grad=rgrad, device=device) 25 | 26 | 27 | def npy(arr): 28 | """Convert a torch tensor to numpy array. 29 | 30 | Args: 31 | arr (torch.Tensor): Torch tensor of any size. 32 | 33 | Returns: 34 | np.ndarray: Numpy array of the same size as the input torch tensor. 35 | """ 36 | # return data.detach().numpy() 37 | return arr.cpu().data.numpy() 38 | 39 | def print_nnparams(nnmodel, names_only=False): 40 | """Print parameter names of a PyTorch NN module and optionally, values. 41 | 42 | Args: 43 | nnmodel (torch.nn.Module): The torch NN module. 44 | names_only (bool, optional): Print names only. Default is False. 45 | """ 46 | assert(isinstance(nnmodel, torch.nn.Module)) 47 | 48 | for name, param in nnmodel.named_parameters(): 49 | if names_only: 50 | print(f"{name}, shape {npy(param.data).shape}") 51 | else: 52 | print(name, param.data) 53 | 54 | def flatten_params(parameters): 55 | """Flattens all parameters into an array. 56 | 57 | Args: 58 | parameters (torch.nn.Parameters): Description 59 | 60 | Returns: 61 | (torch.Tensor, list[tuple]): A tuple of the flattened (1d) torch tensor and a list of pairs that correspond to start/end indices of the flattened parameters. 62 | """ 63 | l = [torch.flatten(p) for p in parameters] 64 | 65 | indices = [] 66 | s = 0 67 | for p in l: 68 | size = p.shape[0] 69 | indices.append((s, s+size)) 70 | s += size 71 | flat = torch.cat(l).view(-1, 1) 72 | return flat, indices 73 | 74 | 75 | def recover_flattened(flat_params, indices, model): 76 | """Fills the values of corresponding parameters given the flattened form. 77 | 78 | Args: 79 | flat_params (np.ndarray): A flattened form of parameters. 80 | indices (list[tuple]): A list of pairs that correspond to start/end indices of the flattened parameters. 81 | model (torch.nn.Module): The underlying PyTorch NN module. 82 | 83 | Returns: 84 | list[torch.Tensor]: List of recovered parameters, reshaped and ordered to match the model. 85 | """ 86 | l = [flat_params[s:e] for (s, e) in indices] 87 | for i, p in enumerate(model.parameters()): 88 | l[i] = l[i].view(*p.shape) 89 | return l 90 | 91 | 92 | -------------------------------------------------------------------------------- /quinn/rvar/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from . import rvs 4 | -------------------------------------------------------------------------------- /quinn/rvar/rvs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for random variable classes.""" 3 | 4 | import math 5 | import torch 6 | 7 | 8 | class RV(torch.nn.Module): 9 | """Parent class for random variables.""" 10 | 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__() 14 | 15 | 16 | def sample(self, num_samples=1): 17 | """Sampling function. 18 | 19 | Raises: 20 | NotImplementedError: Expected to be implemented in children classes 21 | """ 22 | raise NotImplementedError 23 | 24 | def log_prob(self, x): 25 | """Evaluate log-probability. 26 | 27 | Raises: 28 | NotImplementedError: Expected to be implemented in children classes 29 | """ 30 | raise NotImplementedError 31 | 32 | ######################################## 33 | ######################################## 34 | ######################################## 35 | 36 | 37 | class MVN(RV): 38 | def __init__(self, mean, cov): 39 | super().__init__() 40 | self.mean = mean 41 | self.cov = cov 42 | self.distribution = torch.distributions.MultivariateNormal(self.mean, self.cov) 43 | 44 | def sample(self, num_samples): 45 | return self.distribution.sample((num_samples,)) 46 | 47 | def log_prob(self, x): 48 | return self.distribution.log_prob(x) 49 | 50 | 51 | ######################################## 52 | ######################################## 53 | ######################################## 54 | 55 | class Gaussian_1d(RV): 56 | r"""One dimensional gaussian random variable. 57 | 58 | Attributes: 59 | mu (torch.Tensor): Mean tensor. 60 | rho (torch.Tensor): :math:`\rho` tensor, where :math:`\rho=\log{(e^\sigma-1)}` or, equivalently, :math:`\sigma=\log{(1+e^\rho)}`. This is the parameterization used in :cite:t:`blundell:2015`. 61 | logsigma (torch.Tensor): A more typical parameterization of the gaussian standard deviation :math:`\sigma` via its natural logarithm :math:`\log{\sigma}`. 62 | normal (torch.distributions.Normal): The underlying torch-based normal random variable. 63 | """ 64 | 65 | def __init__(self, mu, rho=None, logsigma=None): 66 | r"""Instantiate the random variable. 67 | 68 | Args: 69 | mu (torch.Tensor): Mean tensor. 70 | rho (torch.Tensor, optional): Parameterization that relates to standard deviation as :math:`\sigma=\log{(1+e^\rho)}`. 71 | logsigma (torch.Tensor, optional): Parameterization that relates to standard deviation as :math:`\log{\sigma}`. 72 | 73 | Note: 74 | Exactly one of rho or logsigma should be not None. 75 | Note: 76 | rho and logsigma, if not None, should have same shape as mu. 77 | """ 78 | super().__init__() 79 | 80 | self.mu = mu 81 | self.rho = None 82 | self.logsigma = None 83 | 84 | if rho is not None: 85 | assert(logsigma is None) 86 | assert(rho.shape==self.mu.shape) 87 | self.rho = rho 88 | else: 89 | assert(logsigma is not None) 90 | assert(logsigma.shape==self.mu.shape) 91 | self.logsigma = logsigma 92 | 93 | self.normal = torch.distributions.Normal(0,1) 94 | 95 | 96 | def sample(self): 97 | r"""Sampling function. 98 | 99 | Returns: 100 | torch.Tensor: A torch tensor of the same shape as :math:`\mu` and :math:`\rho` (or `\log{\sigma}`). 101 | """ 102 | if self.rho is not None: 103 | sigma = torch.log1p(torch.exp(self.rho)) 104 | else: 105 | sigma = torch.exp(self.logsigma) 106 | # FIXME: compute epsilon with pyTorch to avoid transfer data from host to device 107 | epsilon = self.normal.sample(sigma.size()).to(self.mu.device) 108 | return self.mu + sigma * epsilon 109 | 110 | def log_prob(self, x): 111 | """Evaluate the natural logarithm of the probability density function. 112 | 113 | Args: 114 | x (torch.Tensor): An input tensor of same shape (or broadcastable to) as mu and rho (logsigma). 115 | 116 | Returns: 117 | float: scalar torch.Tensor. 118 | """ 119 | if self.rho is not None: 120 | sigma = torch.log1p(torch.exp(self.rho)) 121 | else: 122 | sigma = torch.exp(self.logsigma) 123 | 124 | logprob = (-math.log(math.sqrt(2 * math.pi)) 125 | - torch.log(sigma) 126 | - ((x - self.mu) ** 2) / (2 * sigma ** 2)).sum() 127 | return logprob 128 | 129 | ######################################## 130 | ######################################## 131 | ######################################## 132 | 133 | class GMM2_1d(RV): 134 | """One dimensional gaussian mixture random variable with two gaussians that have zero mean and user-defined standard deviations. 135 | 136 | Attributes: 137 | pi (float): Weight of the first gaussian. The second weight is 1-pi. 138 | sigma1 (float): Standard deviation of the first gaussian. Can also be a scalar torch.Tensor. 139 | sigma2 (float): Standard deviation of the second gaussian. Can also be a scalar torch.Tensor. 140 | normal1 (torch.distributions.Normal): The underlying torch-based normal random variable for the first gaussian. 141 | normal2 (torch.distributions.Normal): The underlying torch-based normal random variable for the second gaussian. 142 | """ 143 | 144 | def __init__(self, pi, sigma1, sigma2): 145 | """Instantiation of the GMM2 object. 146 | 147 | Args: 148 | pi (float): Weight of the first gaussian. The second weight is 1-pi. 149 | sigma1 (float): Standard deviation of the first gaussian. Can also be a scalar torch.Tensor. 150 | sigma2 (float): Standard deviation of the second gaussian. Can also be a scalar torch.Tensor. 151 | """ 152 | super().__init__() 153 | self.pi = pi 154 | self.sigma1 = sigma1 155 | self.sigma2 = sigma2 156 | self.normal1 = torch.distributions.Normal(0,sigma1) 157 | self.normal2 = torch.distributions.Normal(0,sigma2) 158 | 159 | def log_prob(self, x): 160 | """Evaluate the natural logarithm of the probability density function. 161 | 162 | Args: 163 | x (torch.Tensor): An input tensor. 164 | 165 | Returns: 166 | float: scalar torch.Tensor. 167 | """ 168 | 169 | prob1 = torch.exp(self.normal1.log_prob(x)) 170 | prob2 = torch.exp(self.normal2.log_prob(x)) 171 | logprob = (torch.log(self.pi * prob1 + (1-self.pi) * prob2)).sum() 172 | 173 | return logprob 174 | -------------------------------------------------------------------------------- /quinn/solvers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from . import quinn 4 | 5 | from . import nn_mcmc 6 | from . import nn_vi 7 | 8 | from . import nn_ens 9 | from . import nn_laplace 10 | 11 | from . import nn_swag 12 | from . import nn_rms 13 | -------------------------------------------------------------------------------- /quinn/solvers/nn_ens.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for Ensemble NN wrapper.""" 3 | 4 | import numpy as np 5 | 6 | from .quinn import QUiNNBase 7 | from ..ens.learner import Learner 8 | 9 | class NN_Ens(QUiNNBase): 10 | """Deep Ensemble NN Wrapper. 11 | 12 | Attributes: 13 | dfrac (float): Fraction of data each learner sees. 14 | learners (list[Learner]): List of learners. 15 | nens (int): Number of ensemble members. 16 | verbose (bool): Verbose or not. 17 | """ 18 | 19 | def __init__(self, nnmodel, nens=1, dfrac=1.0, verbose=False): 20 | """Initialization. 21 | 22 | Args: 23 | nnmodel (torch.nn.Module): PyTorch NN model. 24 | nens (int, optional): Number of ensemble members. Defaults to 1. 25 | dfrac (float, optional): Fraction of data for each learner. Defaults to 1.0. 26 | verbose (bool, optional): Verbose or not. 27 | """ 28 | super().__init__(nnmodel) 29 | self.verbose = verbose 30 | self.nens = nens 31 | self.dfrac = dfrac 32 | self.learners = [] 33 | for i in range(nens): 34 | self.learners.append(Learner(nnmodel)) 35 | 36 | if self.verbose: 37 | self.print_params(names_only=True) 38 | 39 | 40 | def print_params(self, names_only=False): 41 | """Print model parameter names and optionally, values. 42 | 43 | Args: 44 | names_only (bool, optional): Print names only. Default is False. 45 | """ 46 | for i, learner in enumerate(self.learners): 47 | print(f"========== Learner {i+1}/{self.nens} ============") 48 | learner.print_params(names_only=names_only) 49 | 50 | 51 | def fit(self, xtrn, ytrn, **kwargs): 52 | """Fitting function for each ensemble member. 53 | 54 | Args: 55 | xtrn (np.ndarray): Input array of size `(N,d)`. 56 | ytrn (np.ndarray): Output array of size `(N,o)`. 57 | **kwargs (dict): Any keyword argument that :meth:`..nns.nnfit.nnfit` takes. 58 | """ 59 | for jens in range(self.nens): 60 | print(f"======== Fitting Learner {jens+1}/{self.nens} =======") 61 | 62 | ntrn = ytrn.shape[0] 63 | permutation = np.random.permutation(ntrn) 64 | ind_this = permutation[:int(ntrn*self.dfrac)] 65 | 66 | this_learner = self.learners[jens] 67 | 68 | kwargs['lhist_suffix'] = f'_e{jens}' 69 | this_learner.fit(xtrn[ind_this], ytrn[ind_this], **kwargs) 70 | 71 | 72 | def predict_sample(self, x): 73 | """Predict a single, randomly selected sample. 74 | 75 | Args: 76 | x (np.ndarray): Input array of size `(N,d)`. 77 | 78 | Returns: 79 | np.ndarray: Output array of size `(N,o)`. 80 | """ 81 | jens = np.random.randint(0, self.nens) 82 | return self.learners[jens].predict(x) 83 | 84 | 85 | def predict_ens(self, x, nens=None): 86 | """Predict from all ensemble members. 87 | 88 | Args: 89 | x (np.ndarray): `(N,d)` input array. 90 | 91 | Returns: 92 | list[np.ndarray]: List of `M` arrays of size `(N, o)`, i.e. `M` random samples of `(N,o)` outputs. 93 | 94 | Note: 95 | This overloads QUiNN's base predict_ens function. 96 | """ 97 | if nens is None: 98 | nens = self.nens 99 | if nens>self.nens: 100 | print(f"Warning: Requested {nens} but only {self.nens} ensemble members available.") 101 | nens = self.nens 102 | 103 | permuted_inds=np.random.permutation(nens) 104 | 105 | y_all = [] 106 | for jens in range(nens): 107 | y = self.learners[permuted_inds[jens]].predict(x) 108 | y_all.append(y) 109 | 110 | return np.array(y_all) 111 | 112 | def predict_ens_fromsamples(self, x, nens=1): 113 | """Predict ensemble in a loop using individual predict_sample() calls. 114 | 115 | Args: 116 | x (np.ndarray): `(N,d)` input array. 117 | nens (int, optional): Number of samples requested. 118 | 119 | Returns: 120 | list[np.ndarray]: List of `M` arrays of size `(N, o)`, i.e. `M` random samples of `(N,o)` outputs. 121 | """ 122 | y_all = [] 123 | for _ in range(nens): 124 | y = self.predict_sample(x) 125 | y_all.append(y) 126 | 127 | return np.array(y_all) 128 | -------------------------------------------------------------------------------- /quinn/solvers/nn_laplace.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for Laplace NN wrapper.""" 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from .nn_ens import NN_Ens 8 | from ..nns.nnwrap import NNWrap 9 | from ..nns.losses import NegLogPost 10 | 11 | 12 | class NN_Laplace(NN_Ens): 13 | """Wrapper class for the Laplace method. 14 | 15 | Attributes: 16 | cov_mats (list): List of covariance matrices. 17 | cov_scale (TYPE): Covariance scaling factor for prediction. 18 | datanoise (float): Data noise standard deviation. 19 | la_type (str): Laplace approximation type ('full' or 'diag'). 20 | means (list): List of MAP centers. 21 | nparams (int): Number of parameters in the model. 22 | priorsigma (float): Gaussian prior standard deviation. 23 | """ 24 | 25 | def __init__(self, nnmodel, la_type='full', cov_scale=1.0, datanoise=0.1, priorsigma=1.0, **kwargs): 26 | """Initialization. 27 | 28 | Args: 29 | nnmodel (torch.nn.Module): NNWrapper class. 30 | la_type (str, optional): Laplace approximation type ('full' or 'diag'). Dedaults to 'full'. 31 | cov_scale (float, optional): Covariance scaling factor for prediction. Defaults to 1.0. 32 | datanoise (float, optional): Data noise standard deviation. Defaults to 0.1. 33 | priorsigma (float, optional): Gaussian prior standard deviation. Defaults to 1.0. 34 | **kwargs: Any keyword argument that :meth:`..nns.nnfit.nnfit` takes. 35 | """ 36 | super().__init__(nnmodel, **kwargs) 37 | self.la_type = la_type 38 | self.cov_scale = cov_scale 39 | print( 40 | "NOTE: the hessian has not been averaged,", 41 | " i.e., it has not been divided by the number of training data points.", 42 | "Hence, the hyperparameter hessian scale can be tuned to calibrate uncertainty.", 43 | ) 44 | self.datanoise = datanoise 45 | self.priorsigma = priorsigma 46 | self.nparams = sum(p.numel() for p in self.nnmodel.parameters()) 47 | 48 | self.means = [] 49 | self.cov_mats = [] 50 | 51 | def fit(self, xtrn, ytrn, **kwargs): 52 | """Fitting function for each ensemble member. 53 | Args: 54 | xtrn (np.ndarray): Input array of size `(N,d)`. 55 | ytrn (np.ndarray): Output array of size `(N,o)`. 56 | **kwargs (dict): Any keyword argument that :meth:`..nns.nnfit.nnfit` takes. 57 | """ 58 | for jens in range(self.nens): 59 | print(f"======== Fitting Learner {jens+1}/{self.nens} =======") 60 | 61 | ntrn = ytrn.shape[0] 62 | permutation = np.random.permutation(ntrn) 63 | ind_this = permutation[: int(ntrn * self.dfrac)] 64 | 65 | this_learner = self.learners[jens] 66 | 67 | kwargs["lhist_suffix"] = f"_e{jens}" 68 | kwargs["loss_fn"] = "logpost" 69 | kwargs["datanoise"] = self.datanoise 70 | kwargs["priorparams"] = {'sigma': self.priorsigma, 'anchor': torch.randn(size=(self.nparams,)) * self.priorsigma} 71 | 72 | this_learner.fit(xtrn[ind_this], ytrn[ind_this], **kwargs) 73 | self.la_calc(this_learner, xtrn[ind_this], ytrn[ind_this]) 74 | 75 | 76 | def la_calc(self, learner, xtrn, ytrn, batch_size=None): 77 | """Given alearner, this method stores in the corresponding lists 78 | the vectors and matrices defining the posterior according to the 79 | laplace approximation. 80 | Args: 81 | learner (Learner): Instance of the Learner class including the model 82 | torch.nn.Module being used. 83 | xtrn (np.ndarray): input part of the training data. 84 | ytrn (np.ndarray): target part of the training data. 85 | batch_size (int): batch size used in the hessian estimation. 86 | Defaults to None, i.e. single batch. 87 | """ 88 | model = NNWrap(learner.nnmodel) 89 | 90 | weights_map = model.p_flatten().detach().squeeze().numpy() 91 | 92 | if self.la_type == "full": 93 | hessian_func = model.calc_hess_full 94 | elif self.la_type == "diag": 95 | hessian_func = model.calc_hess_diag 96 | else: 97 | assert ( 98 | NotImplementedError 99 | ), "Wrong approximation type given. Only full and diag are accepted." 100 | 101 | ntrn = len(xtrn) 102 | loss = NegLogPost(learner.nnmodel, ntrn, 0.1, None) # TODO: hardwired datanoise 103 | if not batch_size: 104 | hessian_mat = hessian_func(weights_map, loss, xtrn, ytrn) 105 | if batch_size: 106 | hessian_mat = None 107 | for i in range(ntrn // batch_size + 1): 108 | j = min(batch_size, ntrn - i * batch_size) 109 | if j > 0: 110 | x_batch, y_batch = ( 111 | xtrn[i * batch_size : i * batch_size + j], 112 | ytrn[i * batch_size : i * batch_size + j], 113 | ) 114 | hessian_cur = hessian_func(weights_map, loss, x_batch, y_batch) 115 | if i == 0: 116 | hessian_mat = hessian_cur 117 | else: 118 | hessian_mat = hessian_mat + hessian_cur 119 | 120 | cov_mat = np.linalg.inv(hessian_mat * self.cov_scale) 121 | self.means.append(weights_map) 122 | self.cov_mats.append(cov_mat) 123 | 124 | 125 | def predict_sample(self, x): 126 | """Predict a single sample. 127 | 128 | Args: 129 | x (np.ndarray): Input array `x` of size `(N,d)`. 130 | 131 | Returns: 132 | np.ndarray: Output array `x` of size `(N,o)`. 133 | """ 134 | jens = np.random.randint(0, self.nens) 135 | theta = np.random.multivariate_normal(self.means[jens], cov=self.cov_mats[jens]) 136 | 137 | model = NNWrap(self.learners[jens].nnmodel) 138 | 139 | return model.predict(x, theta) 140 | 141 | def predict_ens(self, x, nens=1): 142 | """Predict an ensemble of results. 143 | 144 | Args: 145 | x (np.ndarray): `(N,d)` input array. 146 | 147 | Returns: 148 | list[np.ndarray]: List of `M` arrays of size `(N, o)`, i.e. `M` random samples of `(N,o)` outputs. 149 | 150 | Note: 151 | This overloads NN_Ens's and QUiNN's base predict_ens function. 152 | """ 153 | 154 | return self.predict_ens_fromsamples(x, nens=nens) 155 | 156 | 157 | 158 | # def kfac_hessian(model, weigths_map, loss_func, x_train, y_train): 159 | # """ 160 | # Calculates the Kronecker-factor Approximate Curvature hessian of the loss. 161 | # To be implemented taking the following papers into account: 162 | # Aleksandaer Botev, Hippolyt Ritter, David Barber. Practrical Gauss-Newton 163 | # Optimisation for Deep Learining. Proceedings of the 34th International 164 | # Conference on Machine Learning. 165 | # James Martens and Roger Grosse. Optimizing Neural Networks with Kronecker- 166 | # factored Approximate Curvature. 167 | # -------- 168 | # Inputs: 169 | # - model (NNWrapp_Torch class instance): model over whose parameters we are 170 | # calculating the posterior. 171 | # - weights_map (torch.Tensor): weights of the MAP of the log_posterior given 172 | # by the image. 173 | # - loss_func: function that calculates the negative of the log posterior 174 | # given the model, the training data (x and y) and requires_grad to indicate that 175 | # gradients will be calculated. 176 | # - x_train (numpy.ndarray or torch.Tensor): input part of the training data. 177 | # - y_train (numpy.ndarray or torch.Tensor): target part of the training data. 178 | # -------- 179 | # Outputs: 180 | # - (torch.Tensor) KFAC Hessian of the loss with respect to the model parameters. 181 | 182 | # """ 183 | # return NotImplementedError 184 | -------------------------------------------------------------------------------- /quinn/solvers/nn_mcmc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for MCMC NN wrapper.""" 3 | 4 | import sys 5 | import copy 6 | import numpy as np 7 | from scipy.optimize import minimize 8 | 9 | from ..mcmc.admcmc import AMCMC 10 | from ..mcmc.hmc import HMC 11 | from .quinn import QUiNNBase 12 | from ..nns.nnwrap import nn_p, NNWrap 13 | from ..nns.losses import NegLogPost 14 | 15 | class NN_MCMC(QUiNNBase): 16 | """MCMC NN wrapper class. 17 | 18 | Attributes: 19 | cmode (np.ndarray): MAP values of all parameters, size `M`. 20 | lpinfo (dict): Dictionary that holds likelihood computation necessary information. 21 | pdim (int): Dimensonality `d` of chain. 22 | samples (np.ndarray): MCMC samples of all parameters, size `(M,d)`. 23 | verbose (bool): Whether to be verbose or not. 24 | """ 25 | 26 | def __init__(self, nnmodel, verbose=True): 27 | """Initialization. 28 | 29 | Args: 30 | nnmodel (torch.nn.Module): PyTorch NN model. 31 | verbose (bool, optional): Verbose or not. 32 | """ 33 | super().__init__(nnmodel) 34 | self.verbose = verbose 35 | self.pdim = sum(p.numel() for p in self.nnmodel.parameters()) 36 | print("Number of parameters:", self.pdim) 37 | 38 | if self.verbose: 39 | self.print_params(names_only=True) 40 | 41 | self.samples = None 42 | self.cmode = None 43 | self.lpinfo = {} 44 | 45 | def logpost(self, modelpars, lpinfo): 46 | """Function that computes log-posterior given model parameters. 47 | 48 | Args: 49 | modelpars (np.ndarray): Log-posterior input parameters. 50 | lpinfo (dict): Dictionary of arguments needed for likelihood computation. 51 | 52 | Returns: 53 | float: log-posterior value. 54 | """ 55 | model = NNWrap(self.nnmodel) 56 | model.p_unflatten(modelpars) 57 | 58 | 59 | # Data 60 | ydata = lpinfo['yd'] 61 | nd = len(ydata) 62 | 63 | if lpinfo['ltype'] == 'classical': 64 | loss = NegLogPost(self.nnmodel, nd, lpinfo['lparams']['sigma'], None) 65 | 66 | lpostm = - model.calc_loss(modelpars, loss, lpinfo['xd'], ydata) 67 | else: 68 | print('Likelihood type is not recognized. Exiting.') 69 | sys.exit() 70 | 71 | return lpostm 72 | 73 | def logpostgrad(self, modelpars, lpinfo): 74 | """Function that computes log-posterior given model parameters. 75 | 76 | Args: 77 | modelpars (np.ndarray): Log-posterior input parameters. 78 | lpinfo (dict): Dictionary of arguments needed for likelihood computation. 79 | 80 | Returns: 81 | np.ndarray: log-posterior gradient array. 82 | """ 83 | model = NNWrap(self.nnmodel) 84 | model.p_unflatten(modelpars) 85 | 86 | 87 | # Data 88 | ydata = lpinfo['yd'] 89 | nd = len(ydata) 90 | if lpinfo['ltype'] == 'classical': 91 | loss = NegLogPost(self.nnmodel, nd, lpinfo['lparams']['sigma'], None) 92 | #lpostm = - npy(loss(tch(lpinfo['xd'], rgrad=False), tch(ydata, rgrad=False), requires_grad=False)) 93 | lpostm = - model.calc_lossgrad(modelpars, loss, lpinfo['xd'], ydata) 94 | else: 95 | print('Likelihood type is not recognized. Exiting') 96 | sys.exit() 97 | 98 | return lpostm 99 | 100 | def fit(self, xtrn, ytrn, zflag=True, datanoise=0.05, nmcmc=6000, param_ini=None, sampler='amcmc', sampler_params=None): 101 | """Fit function that perfoms MCMC on NN parameters. 102 | 103 | Args: 104 | xtrn (np.ndarray): Input data array `x` of size `(N,d)`. 105 | ytrn (np.ndarray): Output data array `y` of size `(N,o)`. 106 | zflag (bool, optional): Whether to precede MCMC with a LBFGS optimization. Default is True. 107 | datanoise (float, optional): Datanoise size. Defaults to 0.05. 108 | nmcmc (int, optional): Number of MCMC steps, `M`. 109 | param_ini (None, optional): Initial parameter array of size `p`. Default samples randomly. 110 | sampler (str, optional): Sampler method ('amcmc', 'hmc', 'mala'). Defaults to 'amcmc'. 111 | sampler_params (dict, optional): Sampler parameter dictionary. 112 | """ 113 | shape_xtrn = xtrn.shape 114 | ntrn = shape_xtrn[0] 115 | ntrn_, outdim = ytrn.shape 116 | 117 | # Set dictionary info for posterior computation 118 | self.lpinfo = {'model': nn_p, 119 | 'xd': xtrn, 'yd': [y for y in ytrn], 120 | 'ltype': 'classical', 121 | 'lparams': {'sigma': datanoise}} 122 | 123 | if param_ini is None: 124 | param_ini = np.random.rand(self.pdim) # initial parameter values 125 | if zflag: 126 | res = minimize((lambda x, fcn, lpinfo: -fcn(x, lpinfo)), param_ini, args=(self.logpost,self.lpinfo), method='BFGS',options={'gtol': 1e-13}) 127 | param_ini = res.x 128 | 129 | 130 | if sampler == 'amcmc': 131 | mymcmc = AMCMC(**sampler_params) 132 | mymcmc.setLogPost(self.logpost, None, lpinfo=self.lpinfo) 133 | elif sampler == 'hmc': 134 | mymcmc = HMC(**sampler_params) 135 | mymcmc.setLogPost(self.logpost, self.logpostgrad, lpinfo=self.lpinfo) 136 | 137 | 138 | mcmc_results = mymcmc.run(param_ini=param_ini, nmcmc=nmcmc) 139 | self.samples, self.cmode, pmode, acc_rate = mcmc_results['chain'], mcmc_results['mapparams'], mcmc_results['maxpost'], mcmc_results['accrate'] 140 | 141 | 142 | def get_best_model(self, param): 143 | """Creates a PyTorch NN module with parameters set with a given flattened parameter array. 144 | 145 | Args: 146 | param (np.ndarray): A flattened weight parameter vector. 147 | 148 | Returns: 149 | torch.nn.Module: PyTorch NN module with the given parameters. 150 | """ 151 | nnw = NNWrap(self.nnmodel) 152 | nnw.p_unflatten(param) 153 | 154 | return copy.deepcopy(nnw.nnmodel) 155 | 156 | 157 | def predict_MAP(self, x): 158 | """Predict with the max a posteriori (MAP) parameter setting. 159 | 160 | Args: 161 | x (np.ndarray): Input array of size `(N,d)`. 162 | 163 | Returns: 164 | np.ndarray: Outpur array of size `(N,o)`. 165 | """ 166 | return nn_p(self.cmode, x, self.nnmodel) 167 | 168 | def predict_sample(self, x, param): 169 | """Predict with a given parameter array. 170 | 171 | Args: 172 | x (np.ndarray): Input array of size `(N,d)`. 173 | param (np.ndarray): Flattened weight parameter array. 174 | 175 | Returns: 176 | np.ndarray: Outpur array of size `(N,o)`. 177 | """ 178 | return nn_p(param, x, self.nnmodel) 179 | 180 | def predict_ens(self, x, nens=10, nburn=1000): 181 | """Predict an ensemble of results. 182 | 183 | Args: 184 | x (np.ndarray): `(N,d)` input array. 185 | nens (int, optional): Number of ensemble members requested, `M`. Defaults to 10. 186 | nburn (int, optional): Burn-in for the MCMC chain. Defaults to 1000. 187 | 188 | Returns: 189 | np.ndarray: Array of size `(M, N, o)`, i.e. `M` random samples of `(N,o)` outputs. 190 | 191 | Note: 192 | This overloads QUiNN's base predict_ens functions 193 | """ 194 | nevery = int((self.samples.shape[0]-nburn)/nens) 195 | for j in range(nens): 196 | yy = self.predict_sample(x, self.samples[nburn+j*nevery,:]) 197 | if j == 0: 198 | y = np.empty((nens, yy.shape[0], yy.shape[1])) 199 | y[j, :, :] = yy 200 | return y 201 | -------------------------------------------------------------------------------- /quinn/solvers/nn_rms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for RMS NN wrapper.""" 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from .nn_ens import NN_Ens 8 | 9 | 10 | class NN_RMS(NN_Ens): 11 | """RMS Ensemble NN Wrapper. For details of the method, see :cite:t:`pearce:2018`. 12 | 13 | Attributes: 14 | datanoise (float): Data noise standard deviation. 15 | nparams (int): Number of model parameters. 16 | priorsigma (float): Prior standard deviation. 17 | """ 18 | 19 | def __init__(self, nnmodel, datanoise=0.1, priorsigma=1.0, **kwargs): 20 | """Initialization. 21 | 22 | Args: 23 | nnmodel (torch.nn.Module): NNWrapper class. 24 | datanoise (float, optional): Data noise standard deviation. Defaults to 0.1. 25 | priorsigma (float, optional): Gaussian prior standard deviation. Defaults to 1.0. 26 | **kwargs: Any keyword argument that :meth:`..nns.nnfit.nnfit` takes. 27 | """ 28 | super().__init__(nnmodel, **kwargs) 29 | self.datanoise = datanoise 30 | self.priorsigma = priorsigma 31 | self.nparams = sum(p.numel() for p in self.nnmodel.parameters()) 32 | 33 | def fit(self, xtrn, ytrn, **kwargs): 34 | """Fitting function for each ensemble member. 35 | 36 | Args: 37 | xtrn (np.ndarray): Input array of size `(N,d)`. 38 | ytrn (np.ndarray): Output array of size `(N,o)`. 39 | **kwargs (dict): Any keyword argument that :meth:`..nns.nnfit.nnfit` takes. 40 | """ 41 | for jens in range(self.nens): 42 | print(f"======== Fitting Learner {jens+1}/{self.nens} =======") 43 | 44 | ntrn = ytrn.shape[0] 45 | permutation = np.random.permutation(ntrn) 46 | ind_this = permutation[: int(ntrn * self.dfrac)] 47 | 48 | this_learner = self.learners[jens] 49 | 50 | kwargs["lhist_suffix"] = f"_e{jens}" 51 | #kwargs["loss"] = torch.nn.MSELoss(reduction='mean') #Loss_Gaussian(self.nnmodel, 1.1) 52 | kwargs["loss_fn"] = "logpost" 53 | kwargs["datanoise"] = self.datanoise 54 | kwargs["priorparams"] = {'sigma': self.priorsigma, 'anchor': torch.randn(size=(self.nparams,)) * self.priorsigma} 55 | 56 | this_learner.fit(xtrn[ind_this], ytrn[ind_this], **kwargs) 57 | -------------------------------------------------------------------------------- /quinn/solvers/nn_swag.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for SWAG NN wrapper.""" 3 | 4 | import numpy as np 5 | 6 | from .nn_ens import NN_Ens 7 | from ..nns.tchutils import npy 8 | from ..nns.nnwrap import NNWrap 9 | 10 | 11 | class NN_SWAG(NN_Ens): 12 | """SWAG NN Wrapper class. 13 | 14 | Attributes: 15 | c (int): Frequency of the moment update. 16 | cov_diags (list): List of diagonal covariances. 17 | cov_type (str): Covariance type. 18 | d_mats (list): List of D-matrices. 19 | datanoise (float): Data noise standard deviation. 20 | k (int): k-parameter of the method 21 | lr_swag (float): Learning rate. 22 | means (list): List of mean values of the history. 23 | n_steps (int): Number of steps in SWAG algorithm. 24 | nparams (int): Number of underlying NN module parameters. 25 | priorsigma (float): Standard deviation of the prior. 26 | """ 27 | 28 | def __init__(self, nnmodel, k=10, 29 | n_steps=12, c=1, cov_type="lowrank", lr_swag=0.1, 30 | datanoise=0.1, priorsigma=1.0, **kwargs): 31 | """Initialization. 32 | 33 | Args: 34 | nnmodel (torch.nn.Module): NNWrapper class. 35 | k (int, optional): k-parameter of the method. Defaults to 10. 36 | n_steps (int, optional): Number of steps. Defaults to 12. 37 | c (int, optional): Frequency of moment update. Defaults to 1. 38 | cov_type (str, optional): Covariance type. Defaults to 'lowrank', anything else ignores low-rank approximation. 39 | lr_swag (float, optional): Learning rate. Defaults to 0.1. 40 | datanoise (float, optional): Data noise standard deviation. Defaults to 0.1. 41 | priorsigma (float, optional): Standard deviation of the prior. Defaults to 1.0. 42 | **kwargs: Any other keyword argument that :meth:`..nns.nnfit.nnfit` takes. 43 | """ 44 | super().__init__(nnmodel, **kwargs) 45 | self.k = k 46 | assert(self.k>1) 47 | self.c = c 48 | self.n_steps = n_steps 49 | self.cov_type = cov_type 50 | if self.cov_type == "lowrank": 51 | assert(self.n_steps >= self.k) 52 | self.lr_swag = lr_swag 53 | self.datanoise = datanoise 54 | self.priorsigma = priorsigma 55 | self.nparams = sum(p.numel() for p in self.nnmodel.parameters()) 56 | 57 | self.means = [] 58 | self.cov_diags = [] 59 | self.d_mats = [] 60 | 61 | def fit(self, xtrn, ytrn, **kwargs): 62 | """Fitting function for each ensemble member. 63 | 64 | Args: 65 | xtrn (np.ndarray): Input array of size `(N,d)`. 66 | ytrn (np.ndarray): Output array of size `(N,o)`. 67 | **kwargs (dict): Any keyword argument that :meth:`..nns.nnfit.nnfit` takes. 68 | """ 69 | for jens in range(self.nens): 70 | print(f"======== Fitting Learner {jens+1}/{self.nens} =======") 71 | 72 | ntrn = ytrn.shape[0] 73 | permutation = np.random.permutation(ntrn) 74 | ind_this = permutation[: int(ntrn * self.dfrac)] 75 | 76 | this_learner = self.learners[jens] 77 | 78 | kwargs["lhist_suffix"] = f"_e{jens}" 79 | kwargs["loss_fn"] = "logpost" 80 | kwargs["datanoise"] = self.datanoise 81 | #kwargs["priorparams"] = {'sigma': self.priorsigma} 82 | 83 | this_learner.fit(xtrn[ind_this], ytrn[ind_this], **kwargs) 84 | self.swag_calc(this_learner, xtrn[ind_this], ytrn[ind_this]) 85 | 86 | 87 | def swag_calc(self, learner, xtrn, ytrn): 88 | """Given a learner, this method stores in the corresponding lists 89 | the vectors and matrices defining the posterior according to the 90 | laplace approximation. 91 | 92 | Args: 93 | learner (Learner): Instance of the Learner class including the model 94 | torch.nn.Module being used. 95 | xtrn (np.ndarray): input part of the training data. 96 | ytrn (np.ndarray): target part of the training data. 97 | """ 98 | model = NNWrap(learner.nnmodel) 99 | 100 | moment1 = npy(model.p_flatten()) 101 | moment2 = np.power(npy(model.p_flatten()), 2) 102 | 103 | 104 | d_mat = [] 105 | for i in range(1, self.n_steps + 1): 106 | learner.fit(xtrn, ytrn, nepochs=1, optimizer='sgd', lrate=self.lr_swag) # TODO: does this need the main loss function, or the default is ok? 107 | 108 | if i % self.c == 0: 109 | n = i // self.c 110 | model = NNWrap(learner.nnmodel) 111 | moment1 = (n * moment1 + npy(model.p_flatten())) / (n + 1) 112 | moment2 = (n * moment2 + np.power(npy(model.p_flatten()), 2)) / (n + 1) 113 | if self.cov_type == "lowrank": 114 | d_mat.append(npy(model.p_flatten()) - moment1) 115 | if len(d_mat)>=self.k: 116 | d_mat = d_mat[-self.k :] 117 | 118 | self.means.append(np.squeeze(moment1)) 119 | self.cov_diags.append(np.squeeze(moment2 - np.power(moment1, 2))) 120 | if self.cov_type == "lowrank": 121 | self.d_mats.append(np.squeeze(np.array(d_mat).T)) 122 | 123 | def predict_sample(self, x): 124 | """Predict a single sample. 125 | 126 | Args: 127 | x (np.ndarray): Input array `x` of size `(N,d)`. 128 | 129 | Returns: 130 | np.ndarray: Output array `x` of size `(N,o)`. 131 | """ 132 | 133 | jens = np.random.randint(0, self.nens) 134 | 135 | z_1 = np.random.randn(self.nparams) 136 | z_2 = np.random.randn(self.k) 137 | theta = self.means[jens] 138 | theta_corr = np.multiply(np.sqrt(self.cov_diags[jens]), z_1) 139 | if self.cov_type == "lowrank": 140 | theta_corr = np.sqrt(0.5)*theta_corr + np.sqrt(0.5)*np.dot(self.d_mats[jens], z_2)/np.sqrt(self.k-1) 141 | 142 | theta += theta_corr 143 | model = NNWrap(self.learners[jens].nnmodel) 144 | 145 | return model.predict(x, theta) 146 | 147 | def predict_ens(self, x, nens=1): 148 | """Predict an ensemble of results. 149 | 150 | Args: 151 | x (np.ndarray): `(N,d)` input array. 152 | 153 | Returns: 154 | list[np.ndarray]: List of `M` arrays of size `(N, o)`, i.e. `M` random samples of `(N,o)` outputs. 155 | 156 | Note: 157 | This overloads NN_Ens's and QUiNN's base predict_ens function. 158 | """ 159 | 160 | return self.predict_ens_fromsamples(x, nens=nens) 161 | -------------------------------------------------------------------------------- /quinn/solvers/nn_vi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for the Variational Inference (VI) NN wrapper.""" 3 | 4 | import copy 5 | import math 6 | import torch 7 | 8 | from ..vi.bnet import BNet 9 | 10 | from ..solvers.quinn import QUiNNBase 11 | from ..nns.tchutils import npy, tch, print_nnparams 12 | from ..nns.nnfit import nnfit 13 | 14 | class NN_VI(QUiNNBase): 15 | """VI wrapper class. This implements the Bayes-by-backprop method. For details of the method, see :cite:t:`blundell:2015`. 16 | 17 | Attributes: 18 | best_model (torch.nn.Module): The best PyTorch NN model found during training. 19 | bmodel (BNet): The underlying Bayesian model. 20 | device (str): Device on which the computations are done. 21 | trained (bool): Whether the model is trained or not. 22 | verbose (bool): Whether to be verbose or not. 23 | """ 24 | 25 | def __init__(self, nnmodel, verbose=False, pi=0.5, sigma1=1.0, sigma2=1.0, 26 | mu_init_lower=-0.2, mu_init_upper=0.2, 27 | rho_init_lower=-5.0, rho_init_upper=-4.0 ): 28 | """Instantiate a VI wrapper object. 29 | 30 | Args: 31 | nnmodel (torch.nn.Module): The underlying PyTorch NN model. 32 | verbose (bool, optional): Whether to print out model details or not. 33 | pi (float): Weight of the first gaussian. The second weight is 1-pi. 34 | sigma1 (float): Standard deviation of the first gaussian. Can also be a scalar torch.Tensor. 35 | sigma2 (float): Standard deviation of the second gaussian. Can also be a scalar torch.Tensor. 36 | mu_init_lower (float): Initialization of mu lower value 37 | mu_init_upper (float): Initialization of mu upper value 38 | rho_init_lower (float): Initialization of rho lower value 39 | rho_init_upper (float): Initialization of rho upper value 40 | 41 | """ 42 | super().__init__(nnmodel) 43 | 44 | self.bmodel = BNet(nnmodel,pi=pi,sigma1=sigma1,sigma2=sigma2, 45 | mu_init_lower=mu_init_lower, mu_init_upper=mu_init_upper, 46 | rho_init_lower=rho_init_lower, rho_init_upper=rho_init_upper ) 47 | try: 48 | self.device = nnmodel.device 49 | except AttributeError: 50 | self.device = 'cpu' 51 | 52 | self.bmodel.to(self.device) 53 | self.verbose = verbose 54 | self.trained = False 55 | self.best_model = None 56 | 57 | if self.verbose: 58 | print("=========== Deterministic model parameters ================") 59 | self.print_params(names_only=True) 60 | print("=========== Variational model parameters ==================") 61 | print_nnparams(self.bmodel, names_only=True) 62 | print("===========================================================") 63 | 64 | def fit(self, xtrn, ytrn, val=None, 65 | nepochs=600, lrate=0.01, batch_size=None, freq_out=100, 66 | freq_plot=1000, wd=0, 67 | cooldown=100, 68 | factor=0.95, 69 | nsam=1,scheduler_lr=None, datanoise=0.05): 70 | """Fit function to train the network. 71 | 72 | Args: 73 | xtrn (np.ndarray): Training input array of size `(N,d)`. 74 | ytrn (np.ndarray): Training output array of size `(N,o)`. 75 | val (tuple, optional): `x,y` tuple of validation points. Default uses the training set for validation. 76 | nepochs (int, optional): Number of epochs. 77 | lrate (float, optional): Learning rate or learning rate schedule factor. Default is 0.01. 78 | batch_size (int, optional): Batch size. Default is None, i.e. single batch. 79 | freq_out (int, optional): Frequency, in epochs, of screen output. Defaults to 100. 80 | freq_plot (int, optional): Frequency, in epoch, of plotting the loss. 81 | wd (float, optional): Optional weight decay (L2 regularization) parameter. 82 | cooldown (int, optional): cooldown in ReduceLROnPlateau 83 | factor (float, optional): factor in ReduceLROnPlateau 84 | nsam (int, optional): Number of samples for ELBO computation. Defaults to 1. 85 | scheduler_lr (None, optional): Scheduler of learning rate. See the corresponding argument in :func:`..nns.nnfit.nnfit()`. 86 | datanoise (float, optional): Datanoise for ELBO computation. Defaults to 0.05. 87 | """ 88 | 89 | shape_xtrn = xtrn.shape 90 | ntrn = shape_xtrn[0] 91 | ntrn_, outdim = ytrn.shape 92 | assert(ntrn==ntrn_) 93 | 94 | if batch_size is None or batch_size > ntrn: 95 | batch_size = ntrn 96 | 97 | if batch_size == 1: 98 | num_batches = ntrn 99 | else: 100 | num_batches = (ntrn + 1) // batch_size 101 | 102 | self.bmodel.loss_params = [datanoise, nsam, num_batches] 103 | 104 | fit_info = nnfit(self.bmodel, xtrn, ytrn, val=val, 105 | loss_xy=self.bmodel.viloss, 106 | lrate=lrate, batch_size=batch_size, 107 | nepochs=nepochs, 108 | wd=wd, 109 | cooldown=cooldown, 110 | factor=factor, 111 | freq_plot=freq_plot, 112 | scheduler_lr=scheduler_lr, freq_out=freq_out) 113 | self.best_model = fit_info['best_nnmodel'] 114 | self.trained = True 115 | 116 | def predict_sample(self, x): 117 | """Predict a single sample. 118 | 119 | Args: 120 | x (np.ndarray): Input array `x` of size `(N,d)`. 121 | 122 | Returns: 123 | np.ndarray: Output array `x` of size `(N,o)`. 124 | 125 | Note: 126 | predict_ens() from the parent class will use this to sample an ensemble. 127 | """ 128 | assert(self.trained) 129 | device = self.best_model.device 130 | y = npy(self.best_model(tch(x, rgrad=False,device=device), sample=True)) 131 | 132 | return y 133 | 134 | ###################################################################### 135 | ###################################################################### 136 | ###################################################################### 137 | 138 | -------------------------------------------------------------------------------- /quinn/solvers/quinn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for base QUiNN class.""" 3 | 4 | import copy 5 | import functools 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from ..utils.plotting import plot_dm, lighten_color 10 | from ..utils.maps import scale01ToDom 11 | from ..utils.stats import get_stats, get_domain 12 | from ..nns.tchutils import print_nnparams 13 | 14 | 15 | class QUiNNBase(): 16 | """Base QUiNN class. 17 | 18 | Attributes: 19 | nens (int): Number of samples requested, `M`. 20 | nnmodel (torch.nn.Module): Underlying PyTorch NN model. 21 | """ 22 | 23 | def __init__(self, nnmodel): 24 | """Initialization. 25 | 26 | Args: 27 | nnmodel (torch.nn.Module): Underlying PyTorch NN model. 28 | """ 29 | self.nnmodel = copy.deepcopy(nnmodel) 30 | self.nens = None 31 | 32 | def print_params(self, names_only=False): 33 | """Print model parameter names and optionally, values. 34 | 35 | Args: 36 | names_only (bool, optional): Print names only. Default is False. 37 | """ 38 | print_nnparams(self.nnmodel, names_only=names_only) 39 | 40 | def predict_sample(self, x): 41 | """Produce a single sample prediction. 42 | 43 | Args: 44 | x (np.ndarray): `(N,d)` input array. 45 | 46 | Raises: 47 | NotImplementedError: Not implemented in the base class. 48 | """ 49 | raise NotImplementedError 50 | 51 | def predict_ens(self, x, nens=None): 52 | """Produce an ensemble of predictions. 53 | 54 | Args: 55 | x (np.ndarray): `(N,d)` input array. 56 | nens (int, optional): Number of samples requested, `M`. 57 | 58 | Returns: 59 | np.ndarray: Array of size `(M, N, o)`, i.e. `M` random samples of `(N,o)` outputs 60 | """ 61 | if nens is None: 62 | nens = self.nens 63 | y_list = [] 64 | for _ in range(nens): 65 | yy = self.predict_sample(x) 66 | y_list.append(yy) 67 | 68 | y = np.array(y_list) # y.shape is nens, nsam(x.shape[0]), nout 69 | 70 | return y 71 | 72 | def predict_mom_sample(self, x, msc=0, nsam=1000): 73 | r"""Predict function, given input :math:`x`. 74 | 75 | Args: 76 | x (np.ndarray): A 2d array of inputs of size :math:`(N,d)` at which bases are evaluated. 77 | msc (int, optional): Prediction mode: 0 (mean-only), 1 (mean and variance), or 2 (mean, variance and covariance). Defaults to 0. 78 | 79 | Returns: 80 | tuple(np.ndarray, np.ndarray, np.ndarray): triple of Mean (array of size `(N, o)`), Variance (array of size `(N, o)` or None), Covariance (array of size `(N, N, o)` or None). 81 | """ 82 | y = self.predict_ens(x, nens=nsam) 83 | nsam_, nx, nout = y.shape 84 | ymean = np.mean(y, axis=0) 85 | ycov = np.empty((nx, nx, nout)) 86 | yvar = np.empty((nx, nout)) 87 | if msc==2: 88 | for iout in range(nout): 89 | ycov[:,:,iout] = np.cov(y[:,:,iout], rowvar=False, ddof=1) 90 | yvar[:, iout] = np.diag(ycov[:,:,iout]) 91 | elif msc==1: 92 | ycov = None 93 | for iout in range(nout): 94 | ycov[:,:,iout] = np.cov(y[:,:,iout], rowvar=False, ddof=1) 95 | yvar = np.var(y, axis=0, ddof=1) 96 | elif msc==0: 97 | ycov = None 98 | yvar = None 99 | else: 100 | print(f"msc={msc}, but needs to be 0,1, or 2. Exiting.") 101 | sys.exit() 102 | 103 | return ymean, yvar, ycov 104 | 105 | def predict_plot(self, xx_list, yy_list, nmc=100, 106 | plot_qt=False, labels=None, 107 | colors=None, iouts=None, msize=14, 108 | figname=None): 109 | """Plots the diagonal comparison figures. 110 | 111 | Args: 112 | xx_list (list[np.ndarray]): List of `(N,d)` inputs (e.g., training, validation, testing). 113 | yy_list (list[np.ndarray]): List of `(N,o)` outputs. 114 | nmc (int, optional): Requested number of samples for computing statistics, `M`. 115 | plot_qt (bool, optional): Whether to plot quantiles or mean/st.dev. 116 | labels (list[str], optional): List of labels. If None, set label internally. 117 | colors (list[str], optional): List of colors. If None, sets colors internally. 118 | iouts (list[int], optional): List of outputs to plot. If None, plot all. 119 | msize (int, optional): Markersize. Defaults to 14. 120 | figname (str, optional): Name of the figure to be saved. 121 | 122 | Note: 123 | There is a similar function for deterministic NN in :class:``..nns.nnbase``. 124 | """ 125 | nlist = len(xx_list) 126 | assert(nlist==len(yy_list)) 127 | yy_pred_mb_list = [] 128 | yy_pred_lb_list = [] 129 | yy_pred_ub_list = [] 130 | 131 | for xx in xx_list: 132 | yy_pred = self.predict_ens(xx, nens=nmc) 133 | yy_pred_mb, yy_pred_lb, yy_pred_ub = get_stats(yy_pred, plot_qt) 134 | #print(yy_pred.shape) 135 | yy_pred_mb_list.append(yy_pred_mb) 136 | yy_pred_lb_list.append(yy_pred_lb) 137 | yy_pred_ub_list.append(yy_pred_ub) 138 | 139 | nout = yy_pred_mb.shape[1] 140 | if iouts is None: 141 | iouts = range(nout) 142 | 143 | if labels is None: 144 | labels = [f'Set {i+1}' for i in range(nlist)] 145 | assert(len(labels)==nlist) 146 | 147 | if colors is None: 148 | colors = ['b', 'g', 'r', 'c', 'm', 'y']*nlist 149 | colors = colors[:nlist] 150 | assert(len(colors)==nlist) 151 | 152 | for iout in iouts: 153 | x1 = [yy[:, iout] for yy in yy_list] 154 | x2 = [yy[:, iout] for yy in yy_pred_mb_list] 155 | eel = [yy[:, iout] for yy in yy_pred_lb_list] 156 | eeu = [yy[:, iout] for yy in yy_pred_ub_list] 157 | ee = list(zip(eel, eeu)) 158 | 159 | if figname is None: 160 | figname_ = 'fitdiag_o'+str(iout)+'.png' 161 | else: 162 | figname = figname_.copy() 163 | 164 | plot_dm(x1, x2, errorbars=ee, labels=labels, colors=colors, 165 | axes_labels=[f'Model output # {iout+1}', f'Fit output # {iout+1}'], 166 | figname=figname_, 167 | legendpos='in', msize=msize) 168 | 169 | 170 | def plot_1d_fits(self, xx_list, yy_list, domain=None, ngr=111, plot_qt=False, nmc=100, true_model=None, labels=None, colors=None, name_postfix=''): 171 | """Plotting one-dimensional slices, with the other dimensions at the nominal, of the fit. 172 | 173 | Args: 174 | xx_list (list[np.ndarray]): List of `(N,d)` inputs (e.g., training, validation, testing). 175 | yy_list (list[np.ndarray]): List of `(N,o)` outputs. 176 | domain (np.ndarray, optional): Domain of the function, `(d,2)` array. If None, sets it automatically based on data. 177 | ngr (int, optional): Number of grid points in the 1d plot. 178 | plot_qt (bool, optional): Whether to plot quantiles or mean/st.dev. 179 | nmc (int, optional): Requested number of samples for computing statistics, `M`. 180 | true_model (callable, optional): Optionally, plot a function 181 | labels (list[str], optional): List of labels. If None, set label internally. 182 | colors (list[str], optional): List of colors. If None, sets colors internally. 183 | name_postfix (str, optional): Postfix of the filename of the saved fig. 184 | 185 | Note: 186 | There is a similar function for deterministic NN in :class:``..nns.nnbase``. 187 | """ 188 | nlist = len(xx_list) 189 | assert(nlist==len(yy_list)) 190 | 191 | if labels is None: 192 | labels = [f'Set {i+1}' for i in range(nlist)] 193 | assert(len(labels)==nlist) 194 | 195 | if colors is None: 196 | colors = ['b', 'g', 'r', 'c', 'm', 'y']*nlist 197 | colors = colors[:nlist] 198 | assert(len(colors)==nlist) 199 | 200 | if domain is None: 201 | xall = functools.reduce(lambda x,y: np.vstack((x,y)), xx_list) 202 | domain = get_domain(xall) 203 | 204 | _ = plt.figure(figsize=(12, 8)) 205 | 206 | if plot_qt: 207 | mlabel = 'Median Pred.' 208 | slabel = 'Qtile' 209 | else: 210 | mlabel = 'Mean Pred.' 211 | slabel = 'St.Dev.' 212 | 213 | ndim = xx_list[0].shape[1] 214 | nout = yy_list[0].shape[1] 215 | for idim in range(ndim): 216 | xgrid_ = 0.5 * np.ones((ngr, ndim)) 217 | xgrid_[:, idim] = np.linspace(0., 1., ngr) 218 | 219 | xgrid = scale01ToDom(xgrid_, domain) 220 | ygrid_pred = self.predict_ens(xgrid, nens=nmc) 221 | ygrid_pred_mb, ygrid_pred_lb, ygrid_pred_ub = get_stats(ygrid_pred, plot_qt) 222 | 223 | for iout in range(nout): 224 | 225 | for j in range(nlist): 226 | xx = xx_list[j] 227 | yy = yy_list[j] 228 | 229 | plt.plot(xx[:, idim], yy[:, iout], colors[j]+'o', markersize=13, markeredgecolor='w', label=labels[j], zorder=1000) 230 | 231 | if true_model is not None: 232 | true = true_model(xgrid, 0.0) 233 | plt.plot(xgrid[:, idim], true[:, iout], 'k-', label='Truth', alpha=0.5) 234 | 235 | 236 | p, = plt.plot(xgrid[:, idim], ygrid_pred_mb[:, iout], 'm-', linewidth=5, label=mlabel) 237 | for ygrid_pred_sample in ygrid_pred: 238 | p, = plt.plot(xgrid[:, idim], ygrid_pred_sample[:, iout], 'm--', linewidth=1, zorder=-10000) 239 | lc = lighten_color(p.get_color(), 0.5) 240 | plt.fill_between(xgrid[:, idim], 241 | ygrid_pred_mb[:, iout] - ygrid_pred_lb[:, iout], 242 | ygrid_pred_mb[:, iout] + ygrid_pred_ub[:, iout], 243 | color=lc, zorder=-1000, alpha=0.9, 244 | label=slabel) 245 | 246 | plt.legend() 247 | plt.xlabel(f'Input # {idim+1}') 248 | plt.ylabel(f'Output # {iout+1}') 249 | plt.savefig('fit_d' + str(idim) + '_o' + str(iout) + '_' + name_postfix+'.png') 250 | plt.clf() 251 | -------------------------------------------------------------------------------- /quinn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from . import maps 4 | from . import plotting 5 | from . import stats 6 | from . import xutils 7 | -------------------------------------------------------------------------------- /quinn/utils/maps.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for various mapping functions.""" 3 | 4 | import numpy as np 5 | 6 | 7 | def scale01ToDom(xx, dom): 8 | """Scaling an array to a given domain, assuming \ 9 | the inputs are in [0,1]^d. 10 | 11 | Args: 12 | xx (np.ndarray): Nxd input array. 13 | dom (np.ndarray): dx2 domain. 14 | Returns: 15 | np.ndarray: Nxd scaled array. 16 | Note: 17 | If input is outside [0,1]^d, a warning is given, but the scaling will happen nevertheless. 18 | """ 19 | if np.any(xx<0.0) or np.any(xx>1.0): 20 | print("Warning: some elements are outside the [0,1] range.") 21 | 22 | return xx*np.abs(dom[:,1]-dom[:,0])+np.min(dom, axis=1) 23 | 24 | def scaleDomTo01(xx, dom): 25 | """Scaling an array from a given domain to [0,1]^d. 26 | 27 | Args: 28 | xx (np.ndarray): Nxd input array. 29 | dom (np.ndarray): dx2 domain. 30 | Returns: 31 | np.ndarray: Nxd scaled array. 32 | Note: 33 | If input is outside domain, a warning is given, but the scaling will happen nevertheless. 34 | """ 35 | xxsc = (xx-np.min(dom, axis=1)) / np.abs(dom[:,1]-dom[:,0]) 36 | if np.any(xxsc<0.0) or np.any(xxsc>1.0): 37 | print("Warning: some elements are outside the [0,1] range.") 38 | 39 | return xxsc 40 | 41 | def scaleTo01(xx): 42 | """Scale an array to [0,1], using dimension-wise min and max. 43 | 44 | Args: 45 | xx (np.ndarray): Initial 2d array 46 | 47 | Returns: 48 | np.ndarray: Scaled array. 49 | """ 50 | return (xx - np.min(xx, axis=0)) / (np.max(xx, axis=0) - np.min(xx, axis=0)) 51 | 52 | def standardize(xx): 53 | """Normalize an array, i.e. map it to zero mean and unit variance. 54 | 55 | Args: 56 | xx (np.ndarray): Initial 2d array 57 | 58 | Returns: 59 | np.ndarray: Normalized array. 60 | """ 61 | return (xx - np.mean(xx)) / np.std(xx) 62 | 63 | class XMap(): 64 | """Base class for a map.""" 65 | 66 | def __init__(self): 67 | """Initialization.""" 68 | ... 69 | 70 | def __call__(self, x): 71 | raise NotImplementedError("Base XMap call is not implemented") 72 | 73 | def forw(self, x): 74 | """Forward map. 75 | 76 | Args: 77 | x (np.ndarray): 2d numpy input array. 78 | 79 | Returns: 80 | np.ndarray: 2d numpy output array. 81 | """ 82 | return self.__call__(x) 83 | 84 | def inv(self, xs): 85 | """Inverse of the map. 86 | 87 | Args: 88 | xs (np.ndarray): 2d numpy array. 89 | Returns: 90 | np.ndarray: if implemented, 2d numpy array. 91 | """ 92 | raise NotImplementedError("Base XMap inverse is not implemented") 93 | 94 | class Expon(XMap): 95 | """Exponential map.""" 96 | 97 | def __init__(self): 98 | super().__init__() 99 | 100 | def __call__(self, x): 101 | return np.exp(x) 102 | 103 | def inv(self, xs): 104 | return np.log(xs) 105 | 106 | class Logar(XMap): 107 | """Logarithmic map.""" 108 | 109 | def __init__(self): 110 | super().__init__() 111 | 112 | def __call__(self, x): 113 | return np.log(x) 114 | 115 | def inv(self, xs): 116 | return np.exp(xs) 117 | 118 | class ComposeMap(XMap): 119 | """Composition of two maps.""" 120 | 121 | def __init__(self, map1, map2): 122 | """Initialize with the two maps to be composed. 123 | 124 | Args: 125 | map1 (XMap): Inner map 126 | map2 (XMap): Outer map 127 | """ 128 | super().__init__() 129 | self.map1 = map1 130 | self.map2 = map2 131 | 132 | def __repr__(self): 133 | return f"ComposeMap({self.map1=}, {self.map2=}" 134 | 135 | def __call__(self, x): 136 | return self.map2(self.map1(x)) 137 | 138 | def inv(self, xs): 139 | return self.map1.inv(self.map2.inv(xs)) 140 | 141 | 142 | class LinearScaler(XMap): 143 | """Linear scaler map.""" 144 | 145 | def __init__(self, shift=None, scale=None): 146 | """Initialize with shift and scale. 147 | 148 | Args: 149 | shift (np.ndarray, optional): Shift array, broadcast-friendly 150 | scale (np.ndarray, optional): Scale array, broadcast-friendly 151 | """ 152 | super().__init__() 153 | self.shift = shift 154 | self.scale = scale 155 | return 156 | 157 | def __repr__(self): 158 | return f"Scaler({self.shift=}, {self.scale=}" 159 | 160 | def __call__(self, x): 161 | if self.shift is None: 162 | xs = x - 0.0 163 | else: 164 | xs = x - self.shift 165 | 166 | if self.scale is None: 167 | xs /= 1.0 168 | else: 169 | xs /= self.scale 170 | 171 | return xs 172 | 173 | def inv(self, xs): 174 | if self.scale is None: 175 | x = xs * 1.0 176 | else: 177 | x = xs * self.scale#.reshape(1,-1) 178 | 179 | if self.shift is None: 180 | x += 0.0 181 | else: 182 | x += self.shift 183 | 184 | return x 185 | 186 | class Standardizer(LinearScaler): 187 | """Standardizer map, linearly scaling data to zero mean and unit variance.""" 188 | 189 | def __init__(self, x): 190 | """Initialize with a given 2d array. 191 | 192 | Args: 193 | x (np.ndarray): Data according to which the standardization happens. 194 | Note: 195 | This also can be accomplished by function `normalize` 196 | """ 197 | super().__init__(shift=np.mean(x, axis=0), scale=np.std(x, axis=0)) 198 | return 199 | 200 | class Normalizer(LinearScaler): 201 | """Normalizer map, linearly scaling data to [0,1].""" 202 | 203 | def __init__(self, x, nugget=0.0): 204 | """Initialize with a given 2d array and a nugget to keep slightly above zero. 205 | 206 | Args: 207 | x (np.ndarray): Data according to which the normalization happens. 208 | nugget (float, optional): Small value to keep data above zero if needed. 209 | Note: 210 | When nugget is 0, this also can be accomplished by function `scaleTo01` 211 | """ 212 | super().__init__(shift=np.min(x, axis=0)-nugget, 213 | scale=np.max(x, axis=0)-np.min(x, axis=0)) 214 | return 215 | 216 | class Domainizer(LinearScaler): 217 | """Domainizer map, linearly scaling data (assumed to be in [0,1]) to a given domain. 218 | 219 | Note: 220 | This also can be accomplished by functions `scaleDomTo01` and its inverse `scale01ToDom`. 221 | """ 222 | 223 | def __init__(self, dom): 224 | """Initialize with a given domain. 225 | 226 | Args: 227 | dom (np.ndarray): Domain of size `(d,2)` according to which the normalization happens. 228 | """ 229 | super().__init__(shift=dom[:,0], scale=dom[:,1]-dom[:,0]) 230 | return 231 | 232 | class Affine(XMap): 233 | """Affine map.""" 234 | 235 | def __init__(self, weight=None, bias=None): 236 | """Initializes with weight and bias arrays. 237 | 238 | Args: 239 | weight (np.ndarray, optional): 2d array 240 | bias (np.ndarray, optional): 1d array 241 | """ 242 | super().__init__() 243 | self.weight = weight 244 | self.bias = bias 245 | return 246 | 247 | def __repr__(self): 248 | return f"Scaler({self.weight=}, {self.bias=}" 249 | 250 | 251 | def __call__(self, x): 252 | if self.weight is None: 253 | xs = x * 1.0 254 | else: 255 | xs = x @ self.W.T 256 | 257 | if self.bias is None: 258 | xs += 0.0 259 | else: 260 | xs += self.bias 261 | 262 | return xs 263 | 264 | def inv(self, xs): 265 | if self.bias is None: 266 | x = xs - 0.0 267 | else: 268 | x = xs - self.bias 269 | 270 | if self.weight is None: 271 | x *= 1.0 272 | else: 273 | x = x @ np.linalg.inv(self.W.T) 274 | 275 | return x 276 | 277 | 278 | -------------------------------------------------------------------------------- /quinn/utils/stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Summary 3 | """ 4 | 5 | import numpy as np 6 | 7 | 8 | def get_stats(yy, qt): 9 | """Gets stats of a given dataset to help with plotting. 10 | 11 | Args: 12 | yy (np.ndarray): array of predicted values 13 | qt (bool): whether to compute quantiles or not 14 | 15 | Returns: 16 | tuple: tuple of np.ndarray, (mean, std, std) or 17 | (median, q50-q25, q75-q50) 18 | """ 19 | yy_mean = np.mean(yy, axis=0) 20 | yy_std = np.std(yy, axis=0) 21 | yy_qt = np.quantile(yy, [0.25, 0.5, 0.75], axis=0) 22 | 23 | if qt: 24 | yy_lb = yy_qt[1] - yy_qt[0] 25 | yy_ub = yy_qt[2] - yy_qt[1] 26 | yy_mb = yy_qt[1] 27 | else: 28 | yy_lb = yy_std 29 | yy_ub = yy_std 30 | yy_mb = yy_mean 31 | 32 | return yy_mb, yy_lb, yy_ub 33 | 34 | 35 | def get_domain(xx): 36 | """Get the domain of a given data array. 37 | 38 | Args: 39 | xx (np.ndarray): A data array of size `(N,d)`. 40 | 41 | Returns: 42 | np.ndarray: `(d,2)` domain array. 43 | """ 44 | _, ndim = xx.shape 45 | domain = np.empty((ndim, 2)) 46 | domain[:, 0] = np.min(xx, axis=0) 47 | domain[:, 1] = np.max(xx, axis=0) 48 | 49 | return domain 50 | 51 | def intersect_domain(dom1, dom2): 52 | """Create an intersection domain/hypercube. 53 | 54 | Args: 55 | dom1 (np.ndarray): `(d,2)` first domain array. 56 | dom2 (np.ndarray): `(d,2)` second domain array. 57 | 58 | Returns: 59 | np.ndarray: `(d,2)` intersection domain or None if no intersection. 60 | """ 61 | assert(dom1.shape[0]==dom2.shape[0]) 62 | domain = np.empty_like(dom1) 63 | domain[:, 0]=np.max((dom1[:,0], dom2[:,0]), axis=0) 64 | domain[:, 1]=np.min((dom1[:,1], dom2[:,1]), axis=0) 65 | 66 | if (domain[:,1]-domain[:,0]<0).any(): 67 | return None 68 | 69 | return domain 70 | 71 | 72 | 73 | def diam(xx): 74 | """Get the diameter of a given data array. 75 | 76 | Args: 77 | xx (np.ndarray): A data array of size `(N,d)`. 78 | 79 | Returns: 80 | float: diameter, i.e. max pairwise distance. 81 | """ 82 | pdist = np.linalg.norm(xx[:, None, :] - xx[None, :, :], axis=-1) 83 | diameter = np.max(pdist) 84 | 85 | return diameter 86 | -------------------------------------------------------------------------------- /quinn/utils/xutils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Collection of various useful utilities.""" 3 | 4 | import sys 5 | import os 6 | import itertools 7 | import numpy as np 8 | try: 9 | import dill as pk 10 | except ModuleNotFoundError: 11 | import pickle as pk 12 | import matplotlib as mpl 13 | 14 | from scipy import stats 15 | from scipy.stats.mstats import mquantiles 16 | from scipy.interpolate import interp1d 17 | 18 | def idt(x): 19 | """Identity function. 20 | 21 | Args: 22 | x (any type): input 23 | 24 | Returns: 25 | any type: output 26 | """ 27 | return x 28 | 29 | #################################################################### 30 | #################################################################### 31 | 32 | def savepk(sobj, nameprefix='savestate'): 33 | """Pickle a python object. 34 | 35 | Args: 36 | sobj (any type): Object to be pickled. 37 | nameprefix (str, optional): Name prefix. 38 | """ 39 | pk.dump(sobj, open(nameprefix + '.pk', 'wb'), -1) 40 | 41 | 42 | def loadpk(nameprefix='savestate'): 43 | """Unpickle a python object from a pickle file. 44 | 45 | Args: 46 | nameprefix (str, optional): Filename prefix 47 | 48 | Returns: 49 | any type: Unpickled object 50 | """ 51 | return pk.load(open(nameprefix + '.pk', 'rb')) 52 | 53 | #################################################################### 54 | #################################################################### 55 | 56 | def cartes_list(somelists): 57 | """Generate a list of all combination of elements in given lists. 58 | 59 | Args: 60 | somelists (list): List of lists 61 | Returns: 62 | list[tuple]: List of all combinations of elements in lists that make up somelists 63 | Example: 64 | >>> cartes_list([['a', 'b'], [3, 4, 2]]) 65 | [('a', 3), ('a', 4), ('a', 2), ('b', 3), ('b', 4), ('b', 2)] 66 | 67 | """ 68 | 69 | # final_list = [] 70 | # for element in itertools.product(*somelists): 71 | # final_list.append(element) 72 | 73 | final_list = list(itertools.product(*somelists)) 74 | 75 | return final_list 76 | 77 | #################################################################### 78 | #################################################################### 79 | 80 | def read_textlist(filename, nsize, names_prefix=''): 81 | """Read a textfile into a list containing the rows. 82 | 83 | Args: 84 | filename (str): File name 85 | nsize (int): Number of rows in the file 86 | names_prefix (str, optional): Prefix of a dummy list entry names if the file is not present. 87 | 88 | Returns: 89 | list[str]: List of elements that are rows of the file 90 | """ 91 | if os.path.exists(filename): 92 | with open(filename) as f: 93 | names = f.read().splitlines() 94 | assert(len(names) == nsize) 95 | else: 96 | names = [names_prefix + '_' + str(i) for i in range(1, nsize + 1)] 97 | 98 | return names 99 | 100 | #################################################################### 101 | #################################################################### 102 | 103 | def sample_sphere(center=None, rad=1.0, nsam=100): 104 | """Sample on a hypersphere of a given radius. 105 | 106 | Args: 107 | center (np.ndarray, optional): Center of the sphere. Defaults to origin. 108 | rad (float, optional): Radius of the sphere. Defaults to 1.0. 109 | nsam (int, optional): Number of samples requested. Defaults to 100. 110 | 111 | Returns: 112 | np.ndarray: Array of size `(N,d)` 113 | """ 114 | if center is None: 115 | center = np.zeros((3,)) 116 | dim = center.shape[0] 117 | 118 | samples = np.random.randn(nsam, dim) 119 | samples /= np.linalg.norm(samples, axis=1).reshape(-1,1) 120 | samples *= rad 121 | samples += center 122 | 123 | return samples 124 | 125 | #################################################################### 126 | #################################################################### 127 | 128 | #################################################################### 129 | #################################################################### 130 | 131 | def get_opt_bw(xsam, bwf=1.0): 132 | """Get the rule-of-thumb optimal bandwidth for kernel density estimation. 133 | 134 | Args: 135 | xsam (np.ndarray): Data array, `(N,d)` 136 | bwf (float): Factor behind the scaling optimal rule 137 | Returns: 138 | np.ndarray: Array of length `d`, the optimal per-dimension bandwidth 139 | """ 140 | nsam, ndim = xsam.shape 141 | xstd = np.std(xsam, axis=0) 142 | bw=xstd 143 | bw *= np.power(4./(ndim+2),1./(ndim+4.)) 144 | bw *= np.power(nsam,-1./(ndim+4.)) 145 | 146 | bw *= bwf 147 | 148 | #xmin, xmax = np.min(xsam, axis=0), np.max(xsam, axis=0) 149 | 150 | # in case standard deviation is 0 151 | bw[bw<1.e-16] = 0.5 152 | return bw 153 | 154 | #################################################################### 155 | #################################################################### 156 | 157 | def get_pdf(data, target): 158 | """Compute PDF given data at target points. 159 | 160 | Args: 161 | data (np.ndarray): an `(N,d)` array of `N` samples in `d` dimensions 162 | target np.ndarray): an `(M,d)` array of target points 163 | 164 | Returns: 165 | np.ndarray: PDF values at target 166 | """ 167 | assert(np.prod(np.var(data, axis=0))>0.0) 168 | 169 | # Python Scipy built-in method of KDE 170 | kde_py=stats.kde.gaussian_kde(data.T) 171 | dens=kde_py(target.T) 172 | 173 | # Return the target points and the probability density 174 | return dens 175 | 176 | #################################################################### 177 | #################################################################### 178 | 179 | def strarr(array): 180 | """Turn an array into a neatly formatted one for annotating figures. 181 | 182 | Args: 183 | array (np.ndarray): 1d array 184 | 185 | Returns: 186 | list: list of floats with two decimal digits 187 | """ 188 | return [float("{:0.2f}".format(i)) for i in array] 189 | 190 | 191 | #################################################################### 192 | #################################################################### 193 | 194 | def project(a, b): 195 | """Project a vector onto another vector in high-d space. 196 | 197 | Args: 198 | a (np.ndarray): The 1d array to be projected. 199 | b (np.ndarray): The array to project onto. 200 | 201 | Returns: 202 | tuple(np.ndarray, np.ndarray): tuple (projection, residual) where projection+residual=a, and projection is orthogonal to residual, and colinear with b. 203 | """ 204 | assert(a.shape[0]==b.shape[0]) 205 | proj = (np.dot(a, b)/ np.dot(b, b))*b 206 | resid = a - proj 207 | return proj, resid 208 | 209 | #################################################################### 210 | #################################################################### 211 | 212 | def pick_basis(x1, x2, x3, x0=None, random_direction_in_plane=None): 213 | """Given three points in a high-d space, picks a basis in a plane that goes through these points. 214 | 215 | Args: 216 | x1 (np.ndarray): 1d array, the first point 217 | x2 (np.ndarray): 1d array, the second point 218 | x3 (np.ndarray): 1d array, the third point 219 | x0 (np.ndarray, optional): 1d array, the central point of basis. Defaults to None, in which case the center-of-mass is selected. 220 | random_direction_in_plane (np.ndarray, optional): Direction aligned with the first basis. Has to be in the plane already. Defaults to None, in which case a random direction is selected. 221 | 222 | Returns: 223 | tuple(np.ndarray, np.ndarray, np.ndarray): tuple(origin, e1, e2) of the origin and two basis directions. 224 | """ 225 | assert(x1.shape==x2.shape and x1.shape==x3.shape) 226 | if x0 is None: 227 | x0 = (x1+x2+x3)/3. 228 | 229 | assert(x0.shape==x1.shape) 230 | 231 | 232 | # random direction in that plane 233 | x1230 = np.vstack((x1-x0, x2-x0, x3-x0)) 234 | assert(np.linalg.matrix_rank(x1230)==2) 235 | if random_direction_in_plane is None: 236 | random_direction_in_plane = np.dot(np.random.rand(1, 3), x1230)[0] 237 | random_direction_in_plane /= np.linalg.norm(random_direction_in_plane) 238 | # TODO: this assertion occasionally fails, e.g. when running all examples in bulk 239 | #assert(np.linalg.matrix_rank(np.vstack((x1230, random_direction_in_plane)))==2) 240 | 241 | proj_norms = np.empty(3,) 242 | resid_norms = np.empty(3,) 243 | for i in range(3): 244 | proj, resid = project(x1230[i], random_direction_in_plane) 245 | proj_norms[i] = np.linalg.norm(proj) 246 | resid_norms[i] = np.linalg.norm(resid) 247 | 248 | pm = np.argmax(proj_norms) 249 | rm = np.argmax(resid_norms) 250 | 251 | origin = x0 252 | e1, _ = project(x1230[pm], random_direction_in_plane) 253 | _, e2 = project(x1230[rm], random_direction_in_plane) 254 | 255 | return origin, e1, e2 256 | 257 | #################################################################### 258 | #################################################################### 259 | 260 | def safe_cholesky(cov): 261 | r"""Cholesky decomposition with some error handlers, and using SVD+QR trick in case the covariance is degenerate. 262 | 263 | Args: 264 | cov (np.ndarray): Positive-definite or zero-determinant symmetric matrix `C`. 265 | 266 | Returns: 267 | np.ndarray: Lower-triangular factor `L` such that `C=L L^T`. 268 | """ 269 | 270 | dim, dim_ = cov.shape 271 | assert(dim_==dim) 272 | assert(np.linalg.norm(cov-cov.T)<1.e-14) 273 | 274 | if np.min(np.linalg.eigvals(cov))<0: 275 | print("The matrix is not a covariance matrix (negative eigenvalues). Exiting.") 276 | sys.exit() 277 | elif np.min(np.linalg.eigvals(cov))<1e-14: 278 | print("Small/near-zero eigenvalue: replacing Cholesky with SVD+QR.") 279 | u, s, vd = np.linalg.svd(cov, hermitian=True) 280 | lower = np.linalg.qr(np.dot(np.diag(np.sqrt(s)),vd))[1].T 281 | signs = np.sign(np.diag(lower)) 282 | lower = np.dot(lower, np.diag(signs)) 283 | else: 284 | lower = np.linalg.cholesky(cov) 285 | 286 | assert(np.linalg.norm(cov - np.dot(lower, lower.T)) < 1.e-12) 287 | 288 | return lower 289 | -------------------------------------------------------------------------------- /quinn/vi/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from . import bnet 4 | -------------------------------------------------------------------------------- /quinn/vi/bnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module for the Bayesian network.""" 3 | 4 | import copy 5 | import math 6 | import torch 7 | 8 | from ..rvar.rvs import Gaussian_1d, GMM2_1d 9 | 10 | 11 | class BNet(torch.nn.Module): 12 | """Bayesian NN class. 13 | 14 | Attributes: 15 | device (str): Device where the computations are done. 16 | log_prior (float): Value of log-prior. 17 | log_variational_posterior (float): Value of logarithm of variational posterior. 18 | nnmodel (torch.nn.Module): The underlying PyTorch NN module. 19 | nparams (int): Number of deterministic parameters. 20 | param_names (list[str]): List of parameter names. 21 | param_priors (list[quinn.vi.rvs.RV]): List of parameter priors. 22 | params (torch.nn.ParameterList): Variational parameters. 23 | rparams (list[quinn.rvar.rvs.RV]): List of variational PDFs. 24 | """ 25 | 26 | def __init__(self, nnmodel, pi=0.5, sigma1=1.0, sigma2=1.0, 27 | mu_init_lower=-0.2, mu_init_upper=0.2, 28 | rho_init_lower=-5.0, rho_init_upper=-4.0): 29 | """Instantiate a Bayesian NN object given an underlying PyTorch NN module. 30 | 31 | Args: 32 | nnmodel (torch.nn.Module): The original PyTorch NN module. 33 | pi (float): Weight of the first gaussian. The second weight is 1-pi. 34 | sigma1 (float): Standard deviation of the first gaussian. Can also be a scalar torch.Tensor. 35 | sigma2 (float): Standard deviation of the second gaussian. Can also be a scalar torch.Tensor. 36 | mu_init_lower (float): Initialization of mu lower value 37 | mu_init_upper (float): Initialization of mu upper value 38 | rho_init_lower (float): Initialization of rho lower value 39 | rho_init_upper (float): Initialization of rho upper value 40 | """ 41 | super().__init__() 42 | assert(isinstance(nnmodel, torch.nn.Module)) 43 | 44 | self.nnmodel = copy.deepcopy(nnmodel) 45 | 46 | try: 47 | self.device = nnmodel.device 48 | except AttributeError: 49 | self.device = 'cpu' 50 | # for name, param in self.nnmodel.named_parameters(): 51 | # print(name) 52 | # if param.requires_grad: 53 | # name_ = (name+'').replace('.', '_') 54 | # self.del_attr(self.nnmodel, [name]) 55 | # self.nnmodel.register_parameter(name, torch.nn.Parameter(param.data)) 56 | # print("######") 57 | # for name, param in self.nnmodel.named_parameters(): 58 | # print(name) 59 | # sys.exit() 60 | 61 | self.param_names = [] 62 | self.rparams = [] 63 | self.param_priors = [] 64 | i=0 65 | for name, param in self.nnmodel.named_parameters(): 66 | if param.requires_grad: 67 | 68 | #param.requires_grad = False 69 | mu = torch.nn.Parameter(torch.Tensor(param.shape).uniform_(mu_init_lower, mu_init_upper)) 70 | self.register_parameter(name.replace('.', '_')+"_mu", mu) 71 | 72 | rho = torch.nn.Parameter(torch.Tensor(param.shape).uniform_(rho_init_lower, rho_init_upper)) 73 | self.register_parameter(name.replace('.', '_')+"_rho", rho) 74 | 75 | if i==0: 76 | self.params = torch.nn.ParameterList([mu, rho]) 77 | else: 78 | self.params.append(mu) 79 | self.params.append(rho) 80 | self.rparams.append(Gaussian_1d(mu, logsigma=rho)) 81 | 82 | ## PRIOR 83 | self.param_priors.append(GMM2_1d(pi, sigma1, sigma2)) 84 | self.param_names.append(name) 85 | 86 | # for i, param_name in enumerate(self.param_names): 87 | # #print("AAAA ", i, param_name) 88 | # self.set_attr(self.nnmodel,param_name.split("."), par_samples[i]) 89 | 90 | i+=1 91 | 92 | self.log_prior = 0.0 93 | self.log_variational_posterior = 0.0 94 | self.nparams = len(self.rparams) 95 | 96 | #print("AAAAA ", self.param_names) 97 | for param_name in self.param_names: 98 | self.del_attr(self.nnmodel,param_name.split(".")) 99 | 100 | 101 | # Inspired by this https://discuss.pytorch.org/t/how-does-one-have-the-parameters-of-a-model-not-be-leafs/70076/10 102 | # this could be better https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-subobjects-chained-properties 103 | def del_attr(self, obj, names): 104 | """Deletes attributes from a given object. 105 | 106 | Args: 107 | obj (any): The object of interest. 108 | names (list): List that corresponds to the attribute to be deleted. If list is ['A', 'B', 'C'], the attribute A.B.C is deleted recursively. 109 | """ 110 | # print("Del ", names) 111 | if len(names) == 1: 112 | delattr(obj, names[0]) 113 | else: 114 | self.del_attr(getattr(obj, names[0]), names[1:]) 115 | 116 | def set_attr(self, obj, names, val): 117 | """Sets attributes of a given object. 118 | 119 | Args: 120 | obj (any): The object of interest. 121 | names (list): List that corresponds to the attribute of interest. If list is ['A', 'B', 'C'], the attribute A.B.C is filled with value val. 122 | val (torch.Tensor): Value to be set. 123 | """ 124 | # print("Set ", names, val) 125 | if len(names) == 1: 126 | setattr(obj, names[0], val) 127 | else: 128 | self.set_attr(getattr(obj, names[0]), names[1:], val) 129 | 130 | 131 | def forward(self, x, sample=False, par_samples=None): 132 | """Forward function of Bayesian NN object. 133 | 134 | Args: 135 | x (torch.Tensor): Input array of size `(N,d)`. 136 | sample (bool, optional): Whether this is used in a sampling mode or not. 137 | par_samples (None, optional): Parameter samples. Default is None, in which cases the mean values of variational PDFs are used. 138 | 139 | Returns: 140 | torch.Tensor: Output array of size `(N,o)`. 141 | """ 142 | if self.training or sample: 143 | assert par_samples is None 144 | par_samples = [] 145 | for rpar in self.rparams: 146 | par_samples.append(rpar.sample()) 147 | else: 148 | if par_samples is None: 149 | par_samples = [] 150 | for rpar in self.rparams: 151 | par_samples.append(rpar.mu) 152 | 153 | assert(len(par_samples)==self.nparams) 154 | 155 | 156 | if self.training: 157 | self.log_prior = 0.0 158 | for par_sample, param_prior in zip(par_samples, self.param_priors): 159 | self.log_prior += param_prior.log_prob(par_sample) 160 | 161 | self.log_variational_posterior = 0.0 162 | for par_sample, rpar in zip(par_samples, self.rparams): 163 | self.log_variational_posterior += rpar.log_prob(par_sample) 164 | else: 165 | self.log_prior, self.log_variational_posterior = 0, 0 166 | 167 | 168 | for i, param_name in enumerate(self.param_names): 169 | #print("AAAA ", i, param_name, param_name.split("."), par_samples[i]) 170 | self.set_attr(self.nnmodel,param_name.split("."), par_samples[i]) 171 | #print([i for i in self.nnmodel.coefs]) 172 | #self.nnmodel.register_parameter(param_name.replace(".","_"), torch.nn.Parameter(par_samples[i])) 173 | #print("BBB ", list(self.nnmodel.parameters())) 174 | #print(dir(self)) 175 | #print(par_samples) 176 | 177 | 178 | return self.nnmodel(x) 179 | 180 | 181 | def sample_elbo(self, x, target, nsam, likparams=None): 182 | """Sample from ELBO. 183 | 184 | Args: 185 | x (torch.Tensor): A 2d input tensor. 186 | target (torch.Tensor): A 2d output tensor. 187 | nsam (int): Number of samples 188 | likparams (tuple, optional): Other parameters of the likelihood, e.g. data noise. 189 | 190 | Returns: 191 | tuple: (log_prior, log_variational_posterior, negative_log_likelihood) 192 | """ 193 | shape_x = x.shape 194 | batch_size = shape_x[0] 195 | batch_size_, outdim = target.shape 196 | assert(batch_size == batch_size_) 197 | # FIXME: 198 | device = x.device 199 | outputs = torch.zeros(nsam, batch_size, outdim, device=device) 200 | log_priors = torch.zeros(nsam, device=device) 201 | log_variational_posteriors = torch.zeros(nsam, device=device) 202 | for i in range(nsam): 203 | outputs[i] = self(x, sample=True) 204 | log_priors[i] = self.log_prior 205 | log_variational_posteriors[i] = self.log_variational_posterior 206 | #print("AA ", outputs) 207 | log_prior = log_priors.mean() 208 | log_variational_posterior = log_variational_posteriors.mean() 209 | #print(F.mse_loss(outputs, target, reduction='mean').shape) 210 | # outputs is MxBxd, target is Bxd, below broadcasting works, and we average over MxBxd (usually d=1) 211 | #negative_log_likelihood = F.mse_loss(outputs, target, reduction='none').mean() 212 | #print(outputs.shape, target.shape) 213 | ## FIXME transfer data to device is expensive. 214 | datasigma = torch.Tensor([likparams[0]]).to(device) 215 | negative_log_likelihood = batch_size * torch.log(datasigma) + 0.5*batch_size*torch.log(2.0*torch.tensor(math.pi))+ 0.5 * batch_size * ((outputs - target)**2).mean() / datasigma**2 216 | 217 | return log_prior, log_variational_posterior, negative_log_likelihood 218 | 219 | def viloss(self, data, target): 220 | """Variational loss function `L(x,y)`. 221 | 222 | Args: 223 | data (torch.Tensor): A 2d input tensor `x`. 224 | target (torch.Tensor): A 2d output tensor `y`. 225 | 226 | Returns: 227 | float: The value of loss function. 228 | """ 229 | datanoise, nsam, num_batches = self.loss_params 230 | log_prior, log_variational_posterior, negative_log_likelihood = self.sample_elbo(data, target, nsam, likparams=[datanoise]) 231 | 232 | return (log_variational_posterior - log_prior)/num_batches + negative_log_likelihood 233 | 234 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import io 4 | import os 5 | from setuptools import setup, find_packages 6 | 7 | here = os.path.abspath(os.path.dirname(__file__)) 8 | README = io.open(os.path.join(here, 'README.md'), encoding="latin-1").read() 9 | 10 | setup( 11 | name='QUINN', 12 | version="1.0", 13 | url='https://github.com/XYZ', 14 | description="Library for augmenting NN with UQ", 15 | long_description=README, 16 | author="K. Sargsyan and team", 17 | author_email="ksargsy@sandia.gov", 18 | license='MIT', 19 | platforms="BSD 3-clause", 20 | packages=find_packages(), 21 | package_dir={'quinn': 'quinn'}, 22 | # package_data={"":["*.pdf"]}, 23 | include_package_data=True, 24 | py_modules=['quinn.__init__'], 25 | test_suite='tests', 26 | install_requires=[ 27 | "numpy", "scipy", "matplotlib", "torch", 28 | ], 29 | # setup_requires=['setuptools'], 30 | classifiers=[ 31 | 'Programming Language :: Python :: 3', 32 | 'Intended Audience :: Science/Research', 33 | 'Topic :: Scientific/Engineering :: Mathematics', 34 | 'Natural Language :: English', 35 | ], 36 | ) 37 | 38 | --------------------------------------------------------------------------------