├── .github └── workflows │ ├── CI.template │ ├── CI.yml │ └── publish.yml.disabled ├── .gitignore ├── .idea ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── vcs.xml └── workspace.xml ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── cloud-setup.sh ├── customize.py ├── probml_utils ├── ._version.py ├── LogisticRegression.py ├── __init__.py ├── active_learn_utils.py ├── ae_mnist_conv.py ├── blackjax_utils.py ├── conv_vae_flax_utils.py ├── d2l.py ├── download_celeba.py ├── dp_mixgauss_truncatated_utils.py ├── dp_mixgauss_utils.py ├── fisher_lda_fit.py ├── gauss_inv_wishart_utils.py ├── gauss_utils.py ├── gibbs_finite_mixgauss_utils.py ├── gmm_lib.py ├── logreg_flax.py ├── logreg_flax_test.py ├── lvm_plots_utils.py ├── mcmc_utils.py ├── mfa_celeba_helpers.py ├── mix_bernoulli_em_mnist.py ├── mix_bernoulli_lib.py ├── mixture_lib.py ├── mlp_flax.py ├── mlp_flax_demo.py ├── mlp_logreg_flax_demo.py ├── mnist_helper_tf.py ├── multivariate_t_utils.py ├── nb_utils.py ├── pgmpy_utils.py ├── plotting.py ├── prefit_voting_classifier.py ├── pyprobml_utils.py ├── rvm_classifier.py ├── rvm_regressor.py ├── svi_gmm_model_tfp.py ├── url_utils.py ├── vae_celeba_lightning.py ├── vae_conv_mnist.py ├── vae_lightning_data.py └── variational_mixture_gaussians.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements-extra.txt ├── requirements.txt ├── scratchpad.ipynb ├── setup.cfg ├── setup.py └── tests ├── LogisticRegression_test.py ├── __init__.py ├── test_extension.py ├── test_import.py ├── test_latexify_status.py ├── test_save.py └── test_url_utils.py /.github/workflows/CI.template: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: [3.6, 3.7, 3.8] 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | pip install --no-cache-dir -U -r requirements.txt | cat 22 | pip install pytest pytest-cov coveralls 23 | pip install --upgrade numpy 24 | - name: Test 25 | env: 26 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 27 | COVERALLS_FLAG_NAME: ${{ matrix.python-version }} 28 | COVERALLS_PARALLEL: true 29 | run: | 30 | pytest -v --cov= --cov-report term-missing 31 | coveralls --service=github 32 | coveralls: 33 | name: Finish coverage 34 | needs: test 35 | runs-on: ubuntu-latest 36 | container: python:3-slim 37 | steps: 38 | - name: Finished 39 | run: | 40 | pip3 install --upgrade coveralls 41 | coveralls --finish 42 | env: 43 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 44 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | container: texlive/texlive:TL2020-historic 11 | steps: 12 | 13 | - name: Clone the reference repository 14 | uses: actions/checkout@v2 15 | 16 | - name: Setup pip 17 | run: | 18 | apt update 19 | apt install wget python3-distutils -y 20 | wget https://bootstrap.pypa.io/get-pip.py 21 | python3 get-pip.py 22 | 23 | - name: Install dependencies 24 | run: | 25 | pip install --no-cache-dir -U -r requirements-dev.txt | cat 26 | pip install --no-cache-dir -U -r requirements-extra.txt | cat 27 | pip install -e . 28 | pip install pytest pytest-cov coveralls 29 | 30 | - name: Test 31 | env: 32 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 33 | COVERALLS_FLAG_NAME: ${{ matrix.python-version }} 34 | COVERALLS_PARALLEL: true 35 | run: | 36 | pytest -v --cov=probml_utils --cov-report term-missing 37 | coveralls --service=github 38 | 39 | coveralls: 40 | name: Finish coverage 41 | needs: test 42 | runs-on: ubuntu-latest 43 | container: python:3-slim 44 | steps: 45 | - name: Finished 46 | run: | 47 | pip3 install --upgrade coveralls 48 | coveralls --finish 49 | env: 50 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 51 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml.disabled: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python package using Twine when a release is 2 | # created. For more information see the following link: 3 | # https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 4 | 5 | name: Publish to PyPI 6 | 7 | on: 8 | release: 9 | types: [published] 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | # Make sure tags are fetched so we can get a version. 19 | - run: | 20 | git fetch --prune --unshallow --tags 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: '3.x' 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -U setuptools 'setuptools_scm[toml]' setuptools_scm_git_archive wheel twine 30 | - name: Build and publish 31 | env: 32 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 33 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 34 | 35 | run: | 36 | python setup.py sdist 37 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.vscode 3 | *.pyc 4 | *.pdf 5 | *.png 6 | *.egg-info/ 7 | probml_utils/_version.py 8 | .coverage 9 | data.py 10 | *.idea 11 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 20 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 9 | 10 | 11 | 16 | 17 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 1654255942601 48 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.1.0 4 | hooks: 5 | - id: black-jupyter 6 | args: ["--verbose", "--line-length=120"] 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Probabilistic machine learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # probml-utils 2 | 3 | [![CI](https://github.com/probml/probml-utils/workflows/CI/badge.svg?branch=main)](https://github.com/probml/probml-utils/actions?query=workflow%3ACI) 4 | [![Coverage Status](https://coveralls.io/repos/github/probml/probml-utils/badge.svg?branch=main)](https://coveralls.io/github/probml/probml-utils?branch=main) 5 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 | 7 | This is a pip-installable repo of common utilities for probabilistic machine learning. 8 | These are used by various demos in [pyprobml](https://github.com/probml/pyprobml). 9 | 10 | ## Installation 11 | 12 | ``` 13 | pip install git+https://github.com/probml/probml-utils.git 14 | ``` 15 | or 16 | ``` 17 | git clone https://github.com/probml/probml-utils.git 18 | cd probml-utils 19 | pip install -e . 20 | 21 | python 22 | import probml_utils 23 | ``` 24 | 25 | -------------------------------------------------------------------------------- /cloud-setup.sh: -------------------------------------------------------------------------------- 1 | # Cloud VM setup. Modified from 2 | # https://github.com/Paperspace/ml-in-a-box/blob/main/ml_in_a_box.sh 3 | # Also works for lambdalabs and TPU VM 4 | 5 | #!/usr/bin/env bash 6 | 7 | 8 | # Upgrade Pytorch and TF and JAX 9 | 10 | pip3 install torch torchvision torchaudio 11 | 12 | # for pytorch on TPU V4 13 | export TPU_NUM_DEVICES=4 14 | 15 | #pip3 install --upgrade --user tensorflow tensorflow_probability 16 | 17 | pip3 install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 18 | 19 | # jax Libraries 20 | pip3 install distrax optax flax chex einops jaxtyping jax-tqdm 21 | pip3 install -Uq tfp-nightly[jax] 22 | # https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX 23 | 24 | # other libraries 25 | pip3 install seaborn tdqm scikit-learn 26 | 27 | # github 28 | git config --global user.email "murphyk@gmail.com" 29 | git config --global user.name "Kevin Murphy" 30 | 31 | 32 | # avoid having to paste PAT more than once per day 33 | # https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage 34 | git config --global credential.helper 'cache --timeout 90000' 35 | 36 | #sudo snap install gh 37 | #gh auth login # paste personal access token 38 | 39 | 40 | # Install mamba 41 | curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-$(uname)-$(uname -m).sh" 42 | bash Mambaforge-$(uname)-$(uname -m).sh 43 | # Restart bash instance 44 | #exec bash 45 | 46 | mamba install -y jupyterlab 47 | mamba install -y nodejs 48 | mamba install -y seaborn 49 | mamba install -y jupyter 50 | pip install --upgrade jupyterlab-git 51 | sudo snap install tree 52 | mamba install dask 53 | #jupyter labextension install base16-mexico-light -------------------------------------------------------------------------------- /customize.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ######## Manually modify following parameters to customize the structure of your project 4 | path = os.path.abspath(os.path.dirname(__file__)).split("/") 5 | # print(path) 6 | REPO_HOME_PATH = "/".join(path[:-1]) 7 | REPO_NAME = path[-1] 8 | PACKAGE_NAME = REPO_NAME 9 | AUTHOR = "Kevin P Murphy" 10 | AUTHOR_EMAIL = "murphyk@gmail.com" 11 | description = "Utilities for probabilistic ML" 12 | URL = "https://github.com/probml/" + REPO_NAME 13 | LICENSE = "MIT" 14 | LICENSE_FILE = "LICENSE" 15 | LONG_DESCRIPTION = "file: README.md" 16 | LONG_DESCRIPTION_CONTENT_TYPE = "text/markdown" 17 | 18 | full_path = os.path.join(REPO_HOME_PATH, REPO_NAME) 19 | 20 | ######################################################################################## 21 | 22 | 23 | ############### This part of the code is automatically updating the relevant files. 24 | # Write setup.cfg 25 | 26 | with open(os.path.join(full_path, "setup.cfg"), "w") as f: 27 | f.write("[metadata]\n") 28 | f.write("name = " + PACKAGE_NAME + "\n") 29 | f.write("author = " + AUTHOR + "\n") 30 | f.write("author-email = " + AUTHOR_EMAIL + "\n") 31 | f.write("description = " + description + "\n") 32 | f.write("url = " + URL + "\n") 33 | f.write("license = " + LICENSE + "\n") 34 | f.write("long_description_content_type = " + LONG_DESCRIPTION_CONTENT_TYPE + "\n") 35 | f.write("long_description = " + LONG_DESCRIPTION + "\n") 36 | 37 | # Write CI 38 | 39 | with open(os.path.join(full_path, ".github/workflows/CI.template"), "r") as f: 40 | content = f.read() 41 | 42 | with open(os.path.join(full_path, ".github/workflows/CI.yml"), "w") as f: 43 | content = content.replace("", REPO_NAME) 44 | f.write(content) 45 | 46 | # Write .gitignore 47 | with open(os.path.join(full_path, ".gitignore"), "w") as f: 48 | f.write("__pycache__/\n") 49 | f.write("*.vscode\n") 50 | f.write("*.pyc\n") 51 | f.write("*.egg-info/\n") 52 | f.write(f"{PACKAGE_NAME}/_version.py\n") 53 | 54 | 55 | # Write pyproject.toml 56 | with open(os.path.join(full_path, "pyproject.toml"), "w") as f: 57 | f.write("[build-system]\n") 58 | f.write("requires = [\n") 59 | f.write('\t"setuptools>=50.0",\n') 60 | f.write('\t"setuptools_scm[toml]>=6.0",\n') 61 | f.write('\t"setuptools_scm_git_archive",\n') 62 | f.write('\t"wheel>=0.33",\n') 63 | f.write('\t"numpy>=1.16",\n') 64 | f.write('\t"cython>=0.29",\n') 65 | f.write("\t]\n") 66 | f.write("\n") 67 | f.write("[tool.setuptools_scm]\n") 68 | f.write(f'write_to = "{PACKAGE_NAME}/_version.py"') 69 | 70 | # Write requirements.txt 71 | with open(os.path.join(full_path, "requirements.txt"), "r") as f: 72 | data = f.read() 73 | 74 | with open(os.path.join(full_path, "requirements.txt"), "w") as f: 75 | f.write(data) 76 | if "setuptools_scm[toml]" not in data: 77 | f.write("\nsetuptools_scm[toml]\n") 78 | if "setuptools_scm_git_archive" not in data: 79 | f.write("\nsetuptools_scm_git_archive\n") 80 | 81 | # Initialize project folder 82 | os.makedirs(os.path.join(full_path, PACKAGE_NAME)) 83 | 84 | with open(os.path.join(full_path, PACKAGE_NAME, "__init__.py"), "w") as f: 85 | f.write("from ._version import version as __version__ # noqa") 86 | 87 | print("Successful") 88 | -------------------------------------------------------------------------------- /probml_utils/._version.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /probml_utils/LogisticRegression.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jaxopt import LBFGS 4 | 5 | @jax.jit 6 | def binary_loss_function(weights, auxs): 7 | """ 8 | Arguments: 9 | weights : parameters, shape=(no. of features, ) 10 | auxs: list contains X, y, lambd 11 | X : datasets, shape (no. of examples, no. of features) 12 | y : targets, shape (no. of examples, ) 13 | lambd = regularization rate 14 | 15 | return: 16 | loss - binary cross entropy loss 17 | """ 18 | X, y, lambd = auxs 19 | m = X.shape[0] # no. of examples 20 | z = jnp.dot(X, weights) 21 | hypothesis_x = jax.nn.sigmoid(z) 22 | cost0 = jnp.dot(y.T, jnp.log(hypothesis_x + 1e-7)) 23 | cost1 = jnp.dot((1 - y).T, jnp.log(1 - hypothesis_x + 1e-7)) 24 | regularization_cost = (lambd * jnp.sum(weights[1:] ** 2)) / (2 * m) 25 | 26 | return -((cost0 + cost1) / m) + regularization_cost, auxs 27 | 28 | @jax.jit 29 | def multi_loss_function(weights, auxs): 30 | """ 31 | Arguments: 32 | weights : parameters, shape=(no. of features, no. of classes) 33 | auxs: list contains X, y, lambd 34 | X : datasets, shape (no. of examples, no. of features) 35 | y = targets, shape (no. of examples, ) 36 | lambd = regularization rate 37 | return: 38 | loss - CrossEntropy loss 39 | """ 40 | X, y, lambd = auxs 41 | m = X.shape[0] # no. of examples 42 | z = jnp.dot(X, weights) 43 | hypothesis_x = jax.nn.softmax(z, axis=1) 44 | regularization_cost = (lambd * jnp.sum(jnp.sum(weights[1:, :] ** 2, axis=0))) / ( 45 | 2 * m 46 | ) 47 | 48 | return (-jnp.sum(y * jnp.log(hypothesis_x + 1e-7)) / m + regularization_cost), auxs 49 | 50 | 51 | def fit(X, y, max_iter=None, lambd=1, random_key=0, tol=1e-8): 52 | """ 53 | Arguments: 54 | X : training dataset, shape = (no. of examples, no. of features) 55 | y : targets, shape = (no. of example, ) 56 | max_iter : maximum no. of iteration algorithms run 57 | learning_rate : stepsize to update the parameters 58 | lambda : regularization rate 59 | random_key : unique key to generate pseudo random numbers 60 | tol : gradient tolerance factor for LBFGS 61 | returns: 62 | weights : parameters, shape = (no. of features, no. of classes) 63 | bias : intercept, shape = (no. of classes, ) 64 | weights : coefficient, shape = (no. of features, no. of classes) 65 | 66 | """ 67 | classes = jnp.unique(y) 68 | n_classes = len(classes) 69 | key = jax.random.PRNGKey(random_key) 70 | 71 | # adding one more feature of ones for the bias term 72 | X = jnp.concatenate([jnp.ones([X.shape[0], 1]), X], axis=1) 73 | 74 | n_f = X.shape[1] # no. of features 75 | m = X.shape[0] # no. of examples 76 | 77 | if max_iter is None: 78 | max_iter = n_f * 200 79 | 80 | if n_classes > 2: 81 | weights = jax.random.normal(key=key, shape=[n_f, n_classes]) 82 | y = jax.nn.one_hot(y, n_classes) 83 | opt = LBFGS(multi_loss_function, has_aux=True, maxiter=max_iter, tol=tol) 84 | weights = opt.run(weights, auxs=[X, y, lambd]).params 85 | return weights, weights[0, :], weights[1:, :] 86 | 87 | elif n_classes == 2: 88 | weights = jax.random.normal( 89 | key=key, 90 | shape=[ 91 | n_f, 92 | ], 93 | ) 94 | opt = LBFGS(binary_loss_function, has_aux=True, maxiter=max_iter, tol=tol) 95 | weights = opt.run(weights, auxs=[X, y, lambd]).params 96 | return weights, weights[0], weights[1:] 97 | 98 | 99 | def predict_proba(weights, x): 100 | """ 101 | Arguments: 102 | weights : Trained Parameter, shape = (no. of features, no. of classes) 103 | x : int or array->shape(no. of examples, no. of features) 104 | Return: 105 | probs_y : probability of class 106 | """ 107 | x = jnp.concatenate([jnp.ones([x.shape[0], 1]), x], axis=1) 108 | z = jnp.dot(x, weights) 109 | if len(z.shape) > 1: 110 | probs_y = jax.nn.softmax(z, axis=1) 111 | else: 112 | probs_y = jax.nn.sigmoid(z) 113 | return probs_y 114 | 115 | def predict(weights, x, threshold = 0.5): 116 | """ 117 | Arguments: 118 | weights : Trained Parameter, shape = (no. of features, no. of classes) 119 | x : int or array->shape(no. of examples, no. of features) 120 | threshold : default 0.5, threshold value for binary classification 121 | Return: 122 | pred_y : predicted class 123 | """ 124 | probs_y = predict_proba(weights, x) 125 | if len(probs_y.shape) > 1: 126 | pred_y = jnp.argmax(probs_y, axis=1) 127 | else: 128 | pred_y = (probs_y > threshold).astype(int) 129 | return pred_y 130 | 131 | 132 | def score(weights, x, y, threshold = 0.5): 133 | """ 134 | Arguments: 135 | weights : Trained Parameter, shape = (no. of features, no. of classes) 136 | x : int or array->shape(no. of examples, no. of features) 137 | y : int or array->shape(no. of examples,) 138 | threshold : default 0.5, threshold value for binary classification 139 | """ 140 | y_pred = predict(weights, x, threshold) 141 | return jnp.sum(y_pred == y) / len(x) 142 | -------------------------------------------------------------------------------- /probml_utils/__init__.py: -------------------------------------------------------------------------------- 1 | #from ._version import version as __version__ 2 | from .plotting import savefig, latexify, _get_fig_name, is_latexify_enabled 3 | from .pyprobml_utils import ( 4 | hinton_diagram, 5 | plot_ellipse, 6 | convergence_test, 7 | kdeg, 8 | scale_3d, 9 | style3d, 10 | ) 11 | -------------------------------------------------------------------------------- /probml_utils/ae_mnist_conv.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch 3 | except ModuleNotFoundError: 4 | os.system("pip install -qq torch") 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | 9 | try: 10 | from torchvision.datasets import MNIST 11 | except: 12 | os.system("pip install -qq torchvision") 13 | from torchvision.datasets import MNIST 14 | import torch.nn.functional as F 15 | import torchvision.transforms as transforms 16 | from torch.utils.data import DataLoader 17 | 18 | try: 19 | from pytorch_lightning import LightningModule, Trainer 20 | except: 21 | os.system("pip install -qq pytorch-lightning") 22 | from pytorch_lightning import LightningModule, Trainer 23 | try: 24 | from einops import rearrange 25 | except: 26 | os.system("pip install -qq einops") 27 | from einops import rearrange 28 | from argparse import ArgumentParser 29 | 30 | 31 | class ConvAEModule(nn.Module): 32 | def __init__( 33 | self, 34 | input_shape, 35 | encoder_conv_filters, 36 | decoder_conv_t_filters, 37 | latent_dim, 38 | deterministic=False, 39 | ): 40 | super(ConvAEModule, self).__init__() 41 | self.input_shape = input_shape 42 | 43 | self.latent_dim = latent_dim 44 | self.deterministic = deterministic 45 | 46 | all_channels = [self.input_shape[0]] + encoder_conv_filters 47 | 48 | self.enc_convs = nn.ModuleList([]) 49 | 50 | # encoder_conv_layers 51 | for i in range(len(encoder_conv_filters)): 52 | self.enc_convs.append( 53 | nn.Conv2d( 54 | all_channels[i], 55 | all_channels[i + 1], 56 | kernel_size=3, 57 | stride=2, 58 | padding=1, 59 | ) 60 | ) 61 | if not self.latent_dim == 2: 62 | self.enc_convs.append(nn.BatchNorm2d(all_channels[i + 1])) 63 | self.enc_convs.append(nn.LeakyReLU()) 64 | 65 | self.flatten_out_size = self.flatten_enc_out_shape(input_shape) 66 | 67 | if self.latent_dim == 2: 68 | self.mu_linear = nn.Linear(self.flatten_out_size, self.latent_dim) 69 | else: 70 | self.mu_linear = nn.Sequential( 71 | nn.Linear(self.flatten_out_size, self.latent_dim), 72 | nn.LeakyReLU(), 73 | nn.Dropout(0.2), 74 | ) 75 | 76 | if self.latent_dim == 2: 77 | self.log_var_linear = nn.Linear(self.flatten_out_size, self.latent_dim) 78 | else: 79 | self.log_var_linear = nn.Sequential( 80 | nn.Linear(self.flatten_out_size, self.latent_dim), 81 | nn.LeakyReLU(), 82 | nn.Dropout(0.2), 83 | ) 84 | 85 | if self.latent_dim == 2: 86 | self.decoder_linear = nn.Linear(self.latent_dim, self.flatten_out_size) 87 | else: 88 | self.decoder_linear = nn.Sequential( 89 | nn.Linear(self.latent_dim, self.flatten_out_size), 90 | nn.LeakyReLU(), 91 | nn.Dropout(0.2), 92 | ) 93 | 94 | all_t_channels = [encoder_conv_filters[-1]] + decoder_conv_t_filters 95 | 96 | self.dec_t_convs = nn.ModuleList([]) 97 | 98 | num = len(decoder_conv_t_filters) 99 | 100 | # decoder_trans_conv_layers 101 | for i in range(num - 1): 102 | self.dec_t_convs.append(nn.UpsamplingNearest2d(scale_factor=2)) 103 | self.dec_t_convs.append( 104 | nn.ConvTranspose2d( 105 | all_t_channels[i], all_t_channels[i + 1], 3, stride=1, padding=1 106 | ) 107 | ) 108 | if not self.latent_dim == 2: 109 | self.dec_t_convs.append(nn.BatchNorm2d(all_t_channels[i + 1])) 110 | self.dec_t_convs.append(nn.LeakyReLU()) 111 | 112 | self.dec_t_convs.append(nn.UpsamplingNearest2d(scale_factor=2)) 113 | self.dec_t_convs.append( 114 | nn.ConvTranspose2d( 115 | all_t_channels[num - 1], all_t_channels[num], 3, stride=1, padding=1 116 | ) 117 | ) 118 | self.dec_t_convs.append(nn.Sigmoid()) 119 | 120 | def reparameterize(self, mu, log_var): 121 | std = torch.exp(0.5 * log_var) # standard deviation 122 | eps = torch.randn_like(std) # `randn_like` as we need the same size 123 | sample = mu + (eps * std) # sampling 124 | return sample 125 | 126 | def _run_step(self, x): 127 | mu, log_var = self.encode(x) 128 | std = torch.exp(0.5 * log_var) 129 | p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) 130 | q = torch.distributions.Normal(mu, std) 131 | z = self.reparameterize(mu, log_var) 132 | recon = self.decode(z) 133 | return z, recon, p, q 134 | 135 | def flatten_enc_out_shape(self, input_shape): 136 | x = torch.zeros(1, *input_shape) 137 | for l in self.enc_convs: 138 | x = l(x) 139 | self.shape_before_flattening = x.shape 140 | return int(np.prod(self.shape_before_flattening)) 141 | 142 | def encode(self, x): 143 | for l in self.enc_convs: 144 | x = l(x) 145 | x = x.view(x.size()[0], -1) # flatten 146 | mu = self.mu_linear(x) 147 | log_var = self.log_var_linear(x) 148 | return mu, log_var 149 | 150 | def decode(self, z): 151 | z = self.decoder_linear(z) 152 | recon = z.view(z.size()[0], *self.shape_before_flattening[1:]) 153 | for l in self.dec_t_convs: 154 | recon = l(recon) 155 | return recon 156 | 157 | def forward(self, x): 158 | mu, log_var = self.encode(x) 159 | if self.deterministic: 160 | return self.decode(mu), mu, None 161 | else: 162 | z = self.reparameterize(mu, log_var) 163 | recon = self.decode(z) 164 | return recon, mu, log_var 165 | 166 | 167 | class ConvAE(LightningModule): 168 | def __init__( 169 | self, 170 | input_shape, 171 | encoder_conv_filters, 172 | decoder_conv_t_filters, 173 | latent_dim, 174 | kl_coeff=0.1, 175 | lr=0.001, 176 | ): 177 | super(ConvAE, self).__init__() 178 | self.kl_coeff = kl_coeff 179 | self.lr = lr 180 | self.vae = ConvAEModule( 181 | input_shape, encoder_conv_filters, decoder_conv_t_filters, latent_dim 182 | ) 183 | 184 | def step(self, batch, batch_idx): 185 | x, y = batch 186 | z, x_hat, p, q = self.vae._run_step(x) 187 | 188 | loss = F.binary_cross_entropy(x_hat, x, reduction="sum") 189 | 190 | logs = { 191 | "loss": loss, 192 | } 193 | return loss, logs 194 | 195 | def training_step(self, batch, batch_idx): 196 | loss, logs = self.step(batch, batch_idx) 197 | self.log_dict( 198 | {f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False 199 | ) 200 | return loss 201 | 202 | def validation_step(self, batch, batch_idx): 203 | loss, logs = self.step(batch, batch_idx) 204 | self.log_dict({f"val_{k}": v for k, v in logs.items()}) 205 | return loss 206 | 207 | def configure_optimizers(self): 208 | return torch.optim.Adam(self.parameters(), lr=self.lr) 209 | -------------------------------------------------------------------------------- /probml_utils/blackjax_utils.py: -------------------------------------------------------------------------------- 1 | import arviz as az 2 | import jax.numpy as jnp 3 | import jax 4 | 5 | def arviz_trace_from_states(states, info, burn_in=0): 6 | """ 7 | args: 8 | ........... 9 | states: contains samples returned by blackjax model (i.e HMCState) 10 | info: conatins the meta info returned by blackjax model (i.e HMCinfo) 11 | 12 | returns: 13 | ........... 14 | trace: arviz trace object 15 | """ 16 | if isinstance(states.position, jnp.DeviceArray): #if states.position is array of samples 17 | ndims = jnp.ndim(states.position) 18 | if ndims > 1: 19 | samples = {"samples":jnp.swapaxes(states.position,0,1)} 20 | divergence = jnp.swapaxes(info.is_divergent, 0, 1) 21 | else: 22 | samples = states.position 23 | divergence = info.is_divergent 24 | 25 | else: # if states.position is dict 26 | samples = {} 27 | for param in states.position.keys(): 28 | ndims = len(states.position[param].shape) 29 | if ndims >= 2: 30 | samples[param] = jnp.swapaxes(states.position[param], 0, 1)[:, burn_in:] # swap n_samples and n_chains 31 | elif ndims == 1: 32 | samples[param] = states.position[param] 33 | 34 | divergence = info.is_divergent 35 | ndims_div = len(divergence.shape) 36 | if ndims_div >= 2: 37 | divergence = jnp.swapaxes(divergence, 0, 1)[:, burn_in:] 38 | elif ndims_div == 1: 39 | divergence = info.is_divergent 40 | 41 | trace_posterior = az.convert_to_inference_data(samples) 42 | trace_sample_stats = az.convert_to_inference_data({"diverging": divergence}, group="sample_stats") 43 | trace = az.concat(trace_posterior, trace_sample_stats) 44 | return trace 45 | 46 | def inference_loop_multiple_chains(rng_key, kernel, initial_states, num_samples, num_chains): 47 | ''' 48 | returns (states, info) 49 | Visit this page for more info: https://blackjax-devs.github.io/blackjax/examples/Introduction.html 50 | ''' 51 | @jax.jit 52 | def one_step(states, rng_key): 53 | keys = jax.random.split(rng_key, num_chains) 54 | states, infos = jax.vmap(kernel)(keys, states) 55 | return states, (states, infos) 56 | 57 | keys = jax.random.split(rng_key, num_samples) 58 | _, (states, infos) = jax.lax.scan(one_step, initial_states, keys) 59 | 60 | return (states, infos) 61 | 62 | def inference_loop(rng_key, kernel, initial_state, num_samples): 63 | ''' 64 | returns (states, info) 65 | Visit this page for more info: https://blackjax-devs.github.io/blackjax/examples/Introduction.html 66 | ''' 67 | @jax.jit 68 | def one_step(state, rng_key): 69 | state, info = kernel(rng_key, state) 70 | return state, (state, info) 71 | 72 | keys = jax.random.split(rng_key, num_samples) 73 | _, (states, infos) = jax.lax.scan(one_step, initial_state, keys) 74 | 75 | return (states, infos) 76 | -------------------------------------------------------------------------------- /probml_utils/conv_vae_flax_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Sequence 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | 8 | import flax 9 | import flax.linen as nn 10 | from flax.training import train_state 11 | 12 | import optax 13 | 14 | 15 | class Encoder(nn.Module): 16 | latent_dim: int 17 | hidden_channels: Sequence[int] 18 | 19 | @nn.compact 20 | def __call__(self, X, training): 21 | for channel in self.hidden_channels: 22 | X = nn.Conv(channel, (3, 3), strides=2, padding=1)(X) 23 | X = nn.BatchNorm(use_running_average=not training)(X) 24 | X = jax.nn.relu(X) 25 | 26 | X = X.reshape((-1, np.prod(X.shape[-3:]))) 27 | mu = nn.Dense(self.latent_dim)(X) 28 | logvar = nn.Dense(self.latent_dim)(X) 29 | 30 | return mu, logvar 31 | 32 | 33 | class Decoder(nn.Module): 34 | output_dim: Tuple[int, int, int] 35 | hidden_channels: Sequence[int] 36 | 37 | @nn.compact 38 | def __call__(self, X, training): 39 | H, W, C = self.output_dim 40 | 41 | # TODO: relax this restriction 42 | factor = 2 ** len(self.hidden_channels) 43 | assert ( 44 | H % factor == W % factor == 0 45 | ), f"output_dim must be a multiple of {factor}" 46 | H, W = H // factor, W // factor 47 | 48 | X = nn.Dense(H * W * self.hidden_channels[-1])(X) 49 | X = jax.nn.relu(X) 50 | X = X.reshape((-1, H, W, self.hidden_channels[-1])) 51 | 52 | for hidden_channel in reversed(self.hidden_channels[:-1]): 53 | X = nn.ConvTranspose( 54 | hidden_channel, (3, 3), strides=(2, 2), padding=((1, 2), (1, 2)) 55 | )(X) 56 | X = nn.BatchNorm(use_running_average=not training)(X) 57 | X = jax.nn.relu(X) 58 | 59 | X = nn.ConvTranspose(C, (3, 3), strides=(2, 2), padding=((1, 2), (1, 2)))(X) 60 | X = jax.nn.sigmoid(X) 61 | 62 | return X 63 | 64 | 65 | def reparameterize(key, mean, logvar): 66 | std = jnp.exp(0.5 * logvar) 67 | eps = jax.random.normal(key, logvar.shape) 68 | return mean + eps * std 69 | 70 | 71 | class VAE(nn.Module): 72 | variational: bool 73 | latent_dim: int 74 | output_dim: Tuple[int, int, int] 75 | hidden_channels: Sequence[int] 76 | 77 | def setup(self): 78 | self.encoder = Encoder(self.latent_dim, self.hidden_channels) 79 | self.decoder = Decoder(self.output_dim, self.hidden_channels) 80 | 81 | def __call__(self, key, X, training): 82 | mean, logvar = self.encoder(X, training) 83 | if self.variational: 84 | Z = reparameterize(key, mean, logvar) 85 | else: 86 | Z = mean 87 | 88 | recon = self.decoder(Z, training) 89 | return recon, mean, logvar 90 | 91 | def decode(self, Z, training): 92 | return self.decoder(Z, training) 93 | 94 | 95 | class TrainState(train_state.TrainState): 96 | batch_stats: flax.core.FrozenDict[str, jnp.ndarray] 97 | beta: float 98 | 99 | 100 | def create_train_state( 101 | key, variational, beta, latent_dim, hidden_channels, learning_rate, specimen 102 | ): 103 | vae = VAE(variational, latent_dim, specimen.shape, hidden_channels) 104 | key_dummy = jax.random.PRNGKey(42) 105 | (recon, _, _), variables = vae.init_with_output(key, key_dummy, specimen, True) 106 | assert ( 107 | recon.shape[-3:] == specimen.shape 108 | ), f"{recon.shape} = recon.shape != specimen.shape = {specimen.shape}" 109 | tx = optax.adam(learning_rate) 110 | state = TrainState.create( 111 | apply_fn=vae.apply, 112 | params=variables["params"], 113 | tx=tx, 114 | batch_stats=variables["batch_stats"], 115 | beta=beta, 116 | ) 117 | 118 | return state 119 | 120 | 121 | @jax.vmap 122 | def kl_divergence(mean, logvar): 123 | return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar)) 124 | 125 | 126 | @jax.jit 127 | def train_step(state, key, image): 128 | @partial(jax.value_and_grad, has_aux=True) 129 | def loss_fn(params): 130 | variables = {"params": params, "batch_stats": state.batch_stats} 131 | (recon, mean, logvar), new_model_state = state.apply_fn( 132 | variables, key, image, True, mutable=["batch_stats"] 133 | ) 134 | loss = jnp.sum((recon - image) ** 2) + state.beta * jnp.sum( 135 | kl_divergence(mean, logvar) 136 | ) 137 | return loss.sum(), new_model_state 138 | 139 | (loss, new_model_state), grads = loss_fn(state.params) 140 | 141 | state = state.apply_gradients( 142 | grads=grads, batch_stats=new_model_state["batch_stats"] 143 | ) 144 | 145 | return state, loss 146 | 147 | 148 | @jax.jit 149 | def test_step(state, key, image): 150 | variables = {"params": state.params, "batch_stats": state.batch_stats} 151 | recon, mean, logvar = state.apply_fn(variables, key, image, False) 152 | 153 | return recon, mean, logvar 154 | 155 | 156 | @jax.jit 157 | def decode(state, Z): 158 | variables = {"params": state.params, "batch_stats": state.batch_stats} 159 | decoded = state.apply_fn(variables, Z, False, method=VAE.decode) 160 | 161 | return decoded 162 | 163 | 164 | def mnist_demo(): 165 | from torch import Generator 166 | from torch.utils.data import DataLoader 167 | import torchvision.transforms as T 168 | from torchvision.datasets import MNIST 169 | 170 | 171 | batch_size = 256 172 | latent_dim = 20 173 | hidden_channels = (32, 64, 128, 256, 512) 174 | lr = 1e-3 175 | specimen = jnp.empty((32, 32, 1)) 176 | variational = True 177 | beta = 1 178 | target_epoch = 2 179 | 180 | transform = T.Compose([T.Resize((32, 32)), T.ToTensor()]) 181 | mnist_train = MNIST("/tmp/torchvision", train=True, download=True, transform=transform) 182 | generator = Generator().manual_seed(42) 183 | loader = DataLoader(mnist_train, batch_size, shuffle=True, generator=generator) 184 | 185 | key = jax.random.PRNGKey(42) 186 | state = create_train_state(key, variational, beta, latent_dim, hidden_channels, lr, specimen) 187 | 188 | for epoch in range(target_epoch): 189 | loss_train = 0 190 | for X, _ in loader: 191 | image = jnp.array(X).reshape((-1, *specimen.shape)) 192 | key, key_Z = jax.random.split(key) 193 | state, loss = train_step(state, key_Z, image) 194 | loss_train += loss 195 | 196 | print(f"Epoch {epoch + 1}: train loss {loss_train}") 197 | 198 | 199 | if __name__ == "__main__": 200 | mnist_demo() 201 | -------------------------------------------------------------------------------- /probml_utils/download_celeba.py: -------------------------------------------------------------------------------- 1 | # The CelebA dataloader is from 2 | # https://github.com/sayantanauddy/vae_lightning/blob/main/data.py 3 | # and extracts the data from kaggle. 4 | #First make sure you have kaggle.json, 5 | #as explained at https://github.com/Kaggle/kaggle-api#api-credentials. 6 | 7 | # We create a dataloader with the required image size, 8 | # and thus force the code to first download the data locally. 9 | 10 | import os 11 | from absl import app 12 | from absl import flags 13 | 14 | import torchvision.transforms as transforms 15 | 16 | from functools import partial 17 | import pandas as pd 18 | import os 19 | import PIL 20 | import glob 21 | 22 | try: 23 | import torch 24 | except ModuleNotFoundError: 25 | os.system("pip install torch") 26 | import torch 27 | 28 | from torch.utils.data import Dataset, DataLoader, random_split 29 | try: 30 | import torchvision 31 | except: 32 | os.system("pip install torchvision") 33 | import torchvision 34 | 35 | from torchvision import utils, io 36 | 37 | from torchvision.datasets.utils import verify_str_arg 38 | 39 | import pytorch_lightning as pl 40 | 41 | #from celeba_data import CelebADataModule 42 | 43 | 44 | class CelebADataset(Dataset): 45 | """CelebA Dataset class""" 46 | 47 | def __init__(self, 48 | root, 49 | split="train", 50 | target_type="attr", 51 | transform=None, 52 | target_transform=None, 53 | download=False 54 | ): 55 | """ 56 | """ 57 | 58 | self.root = root 59 | self.split = split 60 | self.target_type = target_type 61 | self.transform = transform 62 | self.target_transform = target_transform 63 | 64 | if isinstance(target_type, list): 65 | self.target_type = target_type 66 | else: 67 | self.target_type = [target_type] 68 | 69 | if not self.target_type and self.target_transform is not None: 70 | raise RuntimeError('target_transform is specified but target_type is empty') 71 | 72 | if download: 73 | self.download_from_kaggle() 74 | 75 | split_map = { 76 | "train": 0, 77 | "valid": 1, 78 | "test": 2, 79 | "all": None, 80 | } 81 | 82 | split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))] 83 | 84 | fn = partial(os.path.join, self.root) 85 | splits = pd.read_csv(fn("list_eval_partition.csv"), delim_whitespace=False, header=0, index_col=0) 86 | # This file is not available in Kaggle 87 | # identity = pd.read_csv(fn("identity_CelebA.csv"), delim_whitespace=True, header=None, index_col=0) 88 | bbox = pd.read_csv(fn("list_bbox_celeba.csv"), delim_whitespace=False, header=0, index_col=0) 89 | landmarks_align = pd.read_csv(fn("list_landmarks_align_celeba.csv"), delim_whitespace=False, header=0, 90 | index_col=0) 91 | attr = pd.read_csv(fn("list_attr_celeba.csv"), delim_whitespace=False, header=0, index_col=0) 92 | 93 | mask = slice(None) if split_ is None else (splits['partition'] == split_) 94 | 95 | self.filename = splits[mask].index.values 96 | # self.identity = torch.as_tensor(identity[mask].values) 97 | self.bbox = torch.as_tensor(bbox[mask].values) 98 | self.landmarks_align = torch.as_tensor(landmarks_align[mask].values) 99 | self.attr = torch.as_tensor(attr[mask].values) 100 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 101 | self.attr_names = list(attr.columns) 102 | 103 | def download_from_kaggle(self): 104 | 105 | # Annotation files will be downloaded at the end 106 | label_files = ['list_attr_celeba.csv', 'list_bbox_celeba.csv', 'list_eval_partition.csv', 107 | 'list_landmarks_align_celeba.csv'] 108 | 109 | # Check if files have been downloaded already 110 | files_exist = False 111 | for label_file in label_files: 112 | if os.path.isfile(os.path.join(self.root, label_file)): 113 | files_exist = True 114 | else: 115 | files_exist = False 116 | 117 | if files_exist: 118 | print("Files exist already") 119 | else: 120 | print("Downloading dataset. Please while while the download and extraction processes complete") 121 | # Download files from Kaggle using its API as per 122 | # https://stackoverflow.com/questions/55934733/documentation-for-kaggle-api-within-python 123 | 124 | # Kaggle authentication 125 | # Remember to place the API token from Kaggle in $HOME/.kaggle 126 | from kaggle.api.kaggle_api_extended import KaggleApi 127 | api = KaggleApi() 128 | api.authenticate() 129 | 130 | # Download all files of a dataset 131 | # Signature: dataset_download_files(dataset, path=None, force=False, quiet=True, unzip=False) 132 | api.dataset_download_files(dataset='jessicali9530/celeba-dataset', 133 | path=self.root, 134 | unzip=True, 135 | force=False, 136 | quiet=False) 137 | 138 | # Downoad the label files 139 | # Signature: dataset_download_file(dataset, file_name, path=None, force=False, quiet=True) 140 | for label_file in label_files: 141 | api.dataset_download_file(dataset='jessicali9530/celeba-dataset', 142 | file_name=label_file, 143 | path=self.root, 144 | force=False, 145 | quiet=False) 146 | 147 | # Clear any remaining *.csv.zip files 148 | files_to_delete = glob.glob(os.path.join(self.root, "*.csv.zip")) 149 | for f in files_to_delete: 150 | os.remove(f) 151 | 152 | print("Done!") 153 | 154 | def __getitem__(self, index: int): 155 | X = PIL.Image.open(os.path.join(self.root, 156 | "img_align_celeba", 157 | "img_align_celeba", 158 | self.filename[index])) 159 | 160 | target = [] 161 | for t in self.target_type: 162 | if t == "attr": 163 | target.append(self.attr[index, :]) 164 | # elif t == "identity": 165 | # target.append(self.identity[index, 0]) 166 | elif t == "bbox": 167 | target.append(self.bbox[index, :]) 168 | elif t == "landmarks": 169 | target.append(self.landmarks_align[index, :]) 170 | else: 171 | raise ValueError(f"Target type {t} is not recognized") 172 | 173 | if self.transform is not None: 174 | X = self.transform(X) 175 | 176 | if target: 177 | target = tuple(target) if len(target) > 1 else target[0] 178 | 179 | if self.target_transform is not None: 180 | target = self.target_transform(target) 181 | else: 182 | target = None 183 | 184 | return X, target 185 | 186 | def __len__(self) -> int: 187 | return len(self.attr) 188 | 189 | 190 | class CelebADataModule(pl.LightningDataModule): 191 | 192 | def __init__(self, 193 | data_dir, 194 | target_type="attr", 195 | train_transform=None, 196 | val_transform=None, 197 | target_transform=None, 198 | download=False, 199 | batch_size=32, 200 | num_workers=8): 201 | super().__init__() 202 | 203 | self.data_dir = data_dir 204 | self.target_type = target_type 205 | self.train_transform = train_transform 206 | self.val_transform = val_transform 207 | self.target_transform = target_transform 208 | self.download = download 209 | 210 | self.batch_size = batch_size 211 | self.num_workers = num_workers 212 | 213 | def setup(self, stage=None): 214 | # Training dataset 215 | self.celebA_trainset = CelebADataset(root=self.data_dir, 216 | split='train', 217 | target_type=self.target_type, 218 | download=self.download, 219 | transform=self.train_transform, 220 | target_transform=self.target_transform) 221 | 222 | # Validation dataset 223 | self.celebA_valset = CelebADataset(root=self.data_dir, 224 | split='valid', 225 | target_type=self.target_type, 226 | download=False, 227 | transform=self.val_transform, 228 | target_transform=self.target_transform) 229 | 230 | # Test dataset 231 | self.celebA_testset = CelebADataset(root=self.data_dir, 232 | split='test', 233 | target_type=self.target_type, 234 | download=False, 235 | transform=self.val_transform, 236 | target_transform=self.target_transform) 237 | 238 | def train_dataloader(self): 239 | return DataLoader(self.celebA_trainset, batch_size=self.batch_size, shuffle=True, drop_last=True, 240 | num_workers=self.num_workers) 241 | 242 | def val_dataloader(self): 243 | return DataLoader(self.celebA_valset, batch_size=self.batch_size, shuffle=False, drop_last=False, 244 | num_workers=self.num_workers) 245 | 246 | def test_dataloader(self): 247 | return DataLoader(self.celebA_testset, batch_size=self.batch_size, shuffle=False, drop_last=False, 248 | num_workers=self.num_workers) 249 | 250 | 251 | 252 | 253 | FLAGS = flags.FLAGS 254 | 255 | flags.DEFINE_float( 256 | 'crop_size', default=128, 257 | help=('The dataset we are interested to train out vae on') 258 | ) 259 | 260 | flags.DEFINE_integer( 261 | 'batch_size', default=256, 262 | help=('Batch size for training.') 263 | ) 264 | 265 | flags.DEFINE_integer( 266 | 'image_size', default=64, 267 | help=('Image size for training.') 268 | ) 269 | 270 | flags.DEFINE_string( 271 | 'data_dir', default="kaggle", 272 | help=('Data directory for training.') 273 | ) 274 | 275 | def celeba_dataloader(bs, IMAGE_SIZE, CROP, DATA_PATH): 276 | trans = [] 277 | trans.append(transforms.RandomHorizontalFlip()) 278 | if CROP > 0: 279 | trans.append(transforms.CenterCrop(CROP)) 280 | trans.append(transforms.Resize(IMAGE_SIZE)) 281 | trans.append(transforms.ToTensor()) 282 | transform = transforms.Compose(trans) 283 | 284 | dm = CelebADataModule(data_dir=DATA_PATH, 285 | target_type='attr', 286 | train_transform=transform, 287 | val_transform=transform, 288 | download=True, 289 | batch_size=bs) 290 | return dm 291 | 292 | 293 | def main(argv): 294 | del argv 295 | 296 | bs = FLAGS.batch_size 297 | IMAGE_SIZE = FLAGS.image_size 298 | CROP = FLAGS.crop_size 299 | DATA_PATH = FLAGS.data_dir 300 | 301 | dm = celeba_dataloader(bs, IMAGE_SIZE, CROP, DATA_PATH) 302 | 303 | dm.prepare_data() # force download now 304 | dm.setup() # force make data loaders 305 | 306 | if __name__ == '__main__': 307 | app.run(main) 308 | -------------------------------------------------------------------------------- /probml_utils/dp_mixgauss_truncatated_utils.py: -------------------------------------------------------------------------------- 1 | # Transformed from 2 | # https://github.com/ericmjl/dl-workshop/blob/master/src/dl_workshop/gaussian_mixture.py 3 | # change the code to work with the current jax 4 | 5 | import jax.numpy as jnp 6 | from jax.scipy import stats 7 | 8 | 9 | def loglike_one_component(component_weight, component_mu, log_component_scale, datum): 10 | """Log likelihood of datum under one component of the mixture. 11 | Defined as the log likelihood of observing that datum from the component 12 | (i.e. log of component probability) 13 | added to the log likelihood of observing that datum 14 | under the Gaussian that belongs to that component. 15 | :param component_weight: Component weight, a scalar value between 0 and 1. 16 | :param component_mu: A scalar value. 17 | :param log_component_scale: A scalar value. 18 | Gets exponentiated before being passed into norm.logpdf. 19 | :returns: A scalar. 20 | """ 21 | component_scale = jnp.exp(log_component_scale) 22 | return jnp.log(component_weight) + stats.norm.logpdf(datum, loc=component_mu, scale=component_scale) 23 | 24 | 25 | def normalize_weights(weights): 26 | """Normalize a weights vector to sum to 1.""" 27 | return weights / jnp.sum(weights) 28 | 29 | 30 | from functools import partial 31 | from jax.scipy.special import logsumexp 32 | from jax import vmap 33 | 34 | 35 | def loglike_across_components(log_component_weights, component_mus, 36 | log_component_scales, datum): 37 | """Log likelihood of datum under all components of the mixture.""" 38 | component_weights = normalize_weights(jnp.exp(log_component_weights)) 39 | loglike_components = vmap(partial(loglike_one_component, datum=datum))( 40 | component_weights, component_mus, log_component_scales) 41 | return logsumexp(loglike_components) 42 | 43 | 44 | def mixture_loglike(log_component_weights, component_mus, 45 | log_component_scales, data): 46 | """Log likelihood of data (not datum!) under all components of the mixture.""" 47 | ll_per_data = vmap(partial(loglike_across_components, log_component_weights, 48 | component_mus, log_component_scales,))(data) 49 | return jnp.sum(ll_per_data) 50 | 51 | 52 | from jax.scipy.stats import norm 53 | 54 | 55 | def plot_component_norm_pdfs(log_component_weights, component_mus, 56 | log_component_scales, xmin, xmax, ax, title): 57 | component_weights = normalize_weights(jnp.exp(log_component_weights)) 58 | component_scales = jnp.exp(log_component_scales) 59 | x = jnp.linspace(xmin, xmax, 1000).reshape(-1, 1) 60 | pdfs = component_weights * norm.pdf(x, loc=component_mus, scale=component_scales) 61 | for component in range(pdfs.shape[1]): 62 | ax.plot(x, pdfs[:, component]) 63 | ax.set_title(title) 64 | 65 | 66 | def get_loss(state, get_params_func, loss_func, data): 67 | params = get_params_func(state) 68 | loss_score = loss_func(params, data) 69 | return loss_score 70 | 71 | 72 | import matplotlib.pyplot as plt 73 | from celluloid import Camera 74 | 75 | 76 | def animate_training(params_for_plotting, interval, data_mixture): 77 | """Animation function for mixture likelihood.""" 78 | log_component_weights_history = params_for_plotting['log_component_weight'] 79 | component_mus_history = params_for_plotting['component_mus'] 80 | log_component_scales_history = params_for_plotting['log_component_scale'] 81 | fig, ax = plt.subplots() 82 | cam = Camera(fig) 83 | for w, m, s in zip(log_component_weights_history[::interval], 84 | component_mus_history[::interval], 85 | log_component_scales_history[::interval]): 86 | ax.hist(data_mixture, bins=40, density=True, color="blue") 87 | plot_component_norm_pdfs(w, m, s, xmin=-20, xmax=20, ax=ax, title=None) 88 | cam.snap() 89 | animation = cam.animate() 90 | return animation 91 | 92 | 93 | from jax import lax 94 | 95 | 96 | def stick_breaking_weights(beta_draws): 97 | """Return weights from a stick breaking process. 98 | :param beta_draws: i.i.d draws from a Beta distribution. 99 | This should be a row vector. 100 | """ 101 | def weighting(occupied_probability, beta_i): 102 | """ 103 | :param occupied_probability: The cumulative occupied probability taken up. 104 | :param beta_i: Current value of beta to consider. 105 | """ 106 | weight = (1 - occupied_probability) * beta_i 107 | return occupied_probability + weight, weight 108 | occupied_probability, weights = lax.scan(weighting, jnp.array(0.0), beta_draws) 109 | weights = weights / jnp.sum(weights) 110 | return occupied_probability, weights 111 | 112 | 113 | from jax import random 114 | 115 | 116 | def beta_draw_from_weights(weights): 117 | def beta_from_w(accounted_probability, weights_i): 118 | """ 119 | :param accounted_probability: The cumulative probability acounted for. 120 | :param weights_i: Current value of weights to consider. 121 | """ 122 | denominator = 1 - accounted_probability 123 | log_denominator = jnp.log(denominator) 124 | log_beta_i = jnp.log(weights_i) - log_denominator 125 | newly_accounted_probability = accounted_probability + weights_i 126 | return newly_accounted_probability, jnp.exp(log_beta_i) 127 | final, betas = lax.scan(beta_from_w, jnp.array(0.0), weights) 128 | return final, betas 129 | 130 | 131 | def component_probs_loglike(log_component_probs, log_concentration, num_components): 132 | """Evaluate log likelihood of probability vector under Dirichlet process. 133 | :param log_component_probs: A vector. 134 | :param log_concentration: Real-valued scalar. 135 | :param num_compnents: Scalar integer. 136 | """ 137 | concentration = jnp.exp(log_concentration) 138 | component_probs = normalize_weights(jnp.exp(log_component_probs)) 139 | _, beta_draws = beta_draw_from_weights(component_probs) 140 | eval_draws = beta_draws[:num_components] 141 | return jnp.sum(stats.beta.logpdf(x=eval_draws, a=1, b=concentration)) 142 | -------------------------------------------------------------------------------- /probml_utils/fisher_lda_fit.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file implements fisher projection of classification data onto K(< no.of.classes) dimensions 3 | to further fit it with LDA. 4 | Referenced from 5 | https://github.com/probml/pmtk3/blob/master/toolbox/SupervisedModels/fisherLda/fisherLdaFit.m 6 | Author: Srikar-Reddy-Jilugu(@always-newbie161) 7 | """ 8 | 9 | import numpy as np 10 | 11 | 12 | def fisher_lda_fit(X, y, kdims): 13 | """ 14 | :param X: shape(nsamples, ndims) 15 | :param y: shape(nsamples, 1) 16 | :param kdims: int 17 | :return W: Linear Projection Matrix; ndarray of shape(ndims, kdims) 18 | """ 19 | nclasses = np.max(y) # no .of.classes 20 | nsamples, ndims = X.shape 21 | 22 | if nclasses == 2: 23 | # assuming y is from {1,2} 24 | ndx1 = np.where(y == 1)[0] 25 | ndx2 = np.where(y == 2)[0] 26 | mu1, mu2 = np.mean(X[ndx1, :]), np.mean(X[ndx2, :]) 27 | S1, S2 = np.cov(X[ndx1, :]), np.cov(X[ndx2, :]) 28 | Sw = S1 + S2 29 | W = np.linalg.inv(Sw) @ (mu2 - mu1) 30 | else: 31 | # assuming y is from {1,2,..nclasses} 32 | muC = np.zeros((nclasses, ndims)) 33 | for c in range(0, nclasses): 34 | ndx = np.where(y == (c + 1))[0] 35 | muC[c, :] = np.mean((X[ndx, :]), axis=0) 36 | 37 | mu_matrix = np.squeeze(muC[y - 1, :], axis=1) 38 | Sw = (X - mu_matrix).T @ (X - mu_matrix) 39 | muX = np.mean(X, axis=0) 40 | Sb = (np.ones((nclasses, 1)) * muX - muC).T @ (np.ones((nclasses, 1)) * muX - muC) 41 | _, eigvecs = np.linalg.eig(np.linalg.pinv(Sw) @ Sb) 42 | W = eigvecs[:, :kdims] 43 | 44 | return W 45 | -------------------------------------------------------------------------------- /probml_utils/gauss_inv_wishart_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation of Normal Inverse Wishart distribution is directly copied from the note book of Scott Linderman: 3 | 'Implementing a Normal Inverse Wishart Distribution in Tensorflow Probability' 4 | https://github.com/lindermanlab/hackathons/blob/master/notebooks/TFP_Normal_Inverse_Wishart.ipynb 5 | and 6 | https://github.com/lindermanlab/hackathons/blob/master/notebooks/TFP_Normal_Inverse_Wishart_(Part_2).ipynb 7 | """ 8 | import jax.numpy as np 9 | import jax.random as jr 10 | from jax import vmap 11 | from jax.tree_util import tree_map 12 | import tensorflow_probability.substrates.jax as tfp 13 | import matplotlib.pyplot as plt 14 | from functools import partial 15 | 16 | tfd = tfp.distributions 17 | tfb = tfp.bijectors 18 | 19 | 20 | class NormalInverseWishart(tfd.JointDistributionNamed): 21 | def __init__(self, loc, mean_precision, df, scale, **kwargs): 22 | """ 23 | A normal inverse Wishart (NIW) distribution with 24 | 25 | Args: 26 | loc: \mu_0 in math above 27 | mean_precision: \kappa_0 28 | df: \nu 29 | scale: \Psi 30 | 31 | Returns: 32 | A tfp.JointDistribution object. 33 | """ 34 | # Store hyperparameters. 35 | self._loc = loc 36 | self._mean_precision = mean_precision 37 | self._df = df 38 | self._scale = scale 39 | 40 | # Convert the inverse Wishart scale to the scale_tril of a Wishart. 41 | # Note: this could be done more efficiently. 42 | self.wishart_scale_tril = np.linalg.cholesky(np.linalg.inv(scale)) 43 | 44 | super(NormalInverseWishart, self).__init__(dict( 45 | Sigma=lambda: tfd.TransformedDistribution( 46 | tfd.WishartTriL(df, scale_tril=self.wishart_scale_tril), 47 | tfb.Chain([tfb.CholeskyOuterProduct(), 48 | tfb.CholeskyToInvCholesky(), 49 | tfb.Invert(tfb.CholeskyOuterProduct()) 50 | ])), 51 | mu=lambda Sigma: tfd.MultivariateNormalFullCovariance( 52 | loc, Sigma / mean_precision) 53 | )) 54 | 55 | # Replace the default JointDistributionNamed parameters with the NIW ones 56 | # because the JointDistributionNamed parameters contain lambda functions, 57 | # which are not jittable. 58 | self._parameters = dict( 59 | loc=loc, 60 | mean_precision=mean_precision, 61 | df=df, 62 | scale=scale 63 | ) 64 | 65 | # These functions compute the pseudo-observations implied by the NIW prior 66 | # and convert sufficient statistics to a NIW posterior. We'll describe them 67 | # in more detail below. 68 | @property 69 | def natural_parameters(self): 70 | """Compute pseudo-observations from standard NIW parameters.""" 71 | dim = self._loc.shape[-1] 72 | chi_1 = self._df + dim + 2 73 | chi_2 = np.einsum('...,...i->...i', self._mean_precision, self._loc) 74 | chi_3 = self._scale + self._mean_precision * \ 75 | np.einsum("...i,...j->...ij", self._loc, self._loc) 76 | chi_4 = self._mean_precision 77 | return chi_1, chi_2, chi_3, chi_4 78 | 79 | @classmethod 80 | def from_natural_parameters(cls, natural_params): 81 | """Convert natural parameters into standard parameters and construct.""" 82 | chi_1, chi_2, chi_3, chi_4 = natural_params 83 | dim = chi_2.shape[-1] 84 | df = chi_1 - dim - 2 85 | mean_precision = chi_4 86 | loc = np.einsum('..., ...i->...i', 1 / mean_precision, chi_2) 87 | scale = chi_3 - mean_precision * np.einsum('...i,...j->...ij', loc, loc) 88 | return cls(loc, mean_precision, df, scale) 89 | 90 | def _mode(self): 91 | r"""Solve for the mode. Recall, 92 | .. math:: 93 | p(\mu, \Sigma) \propto 94 | \mathrm{N}(\mu | \mu_0, \Sigma / \kappa_0) \times 95 | \mathrm{IW}(\Sigma | \nu_0, \Psi_0) 96 | The optimal mean is :math:`\mu^* = \mu_0`. Substituting this in, 97 | .. math:: 98 | p(\mu^*, \Sigma) \propto IW(\Sigma | \nu_0 + 1, \Psi_0) 99 | and the mode of this inverse Wishart distribution is at 100 | .. math:: 101 | \Sigma^* = \Psi_0 / (\nu_0 + d + 2) 102 | """ 103 | dim = self._loc.shape[-1] 104 | covariance = np.einsum("...,...ij->...ij", 105 | 1 / (self._df + dim + 2), self._scale) 106 | return self._loc, covariance 107 | 108 | 109 | class MultivariateNormalFullCovariance(tfd.MultivariateNormalFullCovariance): 110 | """ 111 | This wrapper adds simple functions to get sufficient statistics and 112 | construct a MultivariateNormalFullCovariance from parameters drawn 113 | from the normal inverse Wishart distribution. 114 | """ 115 | @classmethod 116 | def from_parameters(cls, params, **kwargs): 117 | return cls(*params, **kwargs) 118 | 119 | @staticmethod 120 | def sufficient_statistics(datapoint): 121 | return (1.0, datapoint, np.outer(datapoint, datapoint), 1.0) 122 | 123 | -------------------------------------------------------------------------------- /probml_utils/gauss_utils.py: -------------------------------------------------------------------------------- 1 | # utility functions for multivariate Gaussians 2 | # author: murphyk@ 3 | 4 | import numpy as np 5 | 6 | 7 | def is_pos_def(x): 8 | j = np.linalg.eigvals(x) 9 | return np.all(j > 0) 10 | 11 | 12 | def gauss_sample(mu, sigma, n): 13 | a = np.linalg.cholesky(sigma) 14 | z = np.random.randn(len(mu), n) 15 | k = np.dot(a, z) 16 | return np.transpose(mu + k) 17 | 18 | 19 | def gauss_condition(mu, sigma, visible_nodes, visible_values): 20 | d = len(mu) 21 | j = np.array(range(d)) 22 | v = visible_nodes.reshape(len(visible_nodes)) 23 | h = np.setdiff1d(j, v) 24 | if len(h) == 0: 25 | mugivh = np.array([]) 26 | sigivh = np.array([]) 27 | elif len(v) == 0: 28 | mugivh = mu 29 | sigivh = sigma 30 | else: 31 | ndx_hh = np.ix_(h, h) 32 | sigma_hh = sigma[ndx_hh] 33 | ndx_hv = np.ix_(h, v) 34 | sigma_hv = sigma[ndx_hv] 35 | ndx_vv = np.ix_(v, v) 36 | sigma_vv = sigma[ndx_vv] 37 | sigma_vv_inv = np.linalg.inv(sigma_vv) 38 | visible_values_len = len(visible_values) 39 | mugivh = mu[h] + np.dot( 40 | sigma_hv, 41 | np.dot( 42 | sigma_vv_inv, (visible_values.reshape((visible_values_len, 1)) - mu[v].reshape((visible_values_len, 1))) 43 | ), 44 | ) 45 | sigivh = sigma_hh - np.dot(sigma_hv, np.dot(sigma_vv_inv, np.transpose(sigma_hv))) 46 | return mugivh, sigivh 47 | 48 | 49 | def gauss_impute(mu, sigma, x): 50 | n_data, data_dim = x.shape 51 | x_imputed = np.copy(x) 52 | for i in range(n_data): 53 | hidden_nodes = np.argwhere(np.isnan(x[i, :])) 54 | visible_nodes = np.argwhere(~np.isnan(x[i, :])) 55 | visible_values = np.zeros(len(visible_nodes)) 56 | for tc, h in enumerate(visible_nodes): 57 | visible_values[tc] = x[i, h] 58 | mu_hgv, sigma_hgv = gauss_condition(mu, sigma, visible_nodes, visible_values) 59 | for rr, h in enumerate(hidden_nodes): 60 | x_imputed[i, h] = mu_hgv[rr] 61 | return x_imputed 62 | 63 | 64 | def gauss_fit_em(X, max_iter=50, eps=1e-04): 65 | """ 66 | Compute MLE of multivariate Gaussian given missing data using EM. 67 | """ 68 | nr, nc = X.shape 69 | C = np.isnan(X) == False # Identifying nan locations 70 | e = 0.0000001 71 | one_to_nc = np.arange(1, nc + 1, step=1) 72 | M = one_to_nc * (C == False) - 1 # Missing locations (-1 at locations where Nan is present in X) 73 | O = one_to_nc * C - 1 # Observed locations (-1 at locations where Nan is not present in X) 74 | 75 | # Generate initial Mu and Sigma 76 | Mu = np.nanmean(X, axis=0).reshape(-1, 1) 77 | Mu_new = Mu.copy() 78 | observed_rows = np.where(np.isnan(sum(X.T)) == False)[0] 79 | S = np.cov( 80 | X[ 81 | observed_rows, 82 | ].T 83 | ) 84 | if np.isnan(S).any(): 85 | S = np.diag(np.nanvar(X, axis=0)) 86 | S_new = S.copy() 87 | 88 | # Start updating 89 | X_tilde = X.copy() 90 | no_conv = True 91 | iteration = 0 92 | 93 | while no_conv and iteration < max_iter: 94 | # E-step: 95 | EX = np.zeros((nc, 1)) 96 | EXX = np.zeros((nc, nc)) 97 | EXsum = np.zeros((nc, 1)) 98 | EXXsum = np.zeros((nc, nc)) 99 | Mu = Mu_new 100 | S = S_new 101 | for i in range(nr): 102 | if set(O[i,]) != set( 103 | one_to_nc - 1 104 | ): # Missing component exists 105 | 106 | m_indx = ( 107 | M[ 108 | i, 109 | ] 110 | != -1 111 | ) 112 | o_indx = ( 113 | O[ 114 | i, 115 | ] 116 | != -1 117 | ) 118 | M_i = M[i,][ 119 | m_indx 120 | ] # Missing entries (u) 121 | O_i = O[i,][ 122 | o_indx 123 | ] # Observed entries (o) 124 | 125 | Mui = Mu[np.ix_(M_i)] + ( 126 | S[np.ix_(M_i, O_i)] 127 | @ np.linalg.pinv(S[np.ix_(O_i, O_i)] + e) 128 | @ (X_tilde[i, np.ix_(O_i)].T - Mu[np.ix_(O_i)]) 129 | ) # Expected stats for mean 130 | Vi = ( 131 | S[np.ix_(M_i, M_i)] 132 | - S[np.ix_(M_i, O_i)] @ np.linalg.inv(S[np.ix_(O_i, O_i)] + e) @ S[np.ix_(M_i, O_i)].T 133 | ) # Expected stats for sigma 134 | Mui = Mui.reshape(-1, 1) 135 | 136 | EX[np.ix_(O_i)] = X_tilde[i, np.ix_(O_i)].T 137 | EX[np.ix_(M_i)] = Mui 138 | 139 | EXX[np.ix_(M_i, M_i)] = EX[np.ix_(M_i)] * EX[np.ix_(M_i)].T + Vi 140 | EXX[np.ix_(O_i, O_i)] = EX[np.ix_(O_i)] * EX[np.ix_(O_i)].T 141 | EXX[np.ix_(O_i, M_i)] = EX[np.ix_(O_i)] * EX[np.ix_(M_i)].T 142 | EXX[np.ix_(M_i, O_i)] = EX[np.ix_(M_i)] * EX[np.ix_(O_i)].T 143 | 144 | EXsum = EXsum + EX 145 | EXXsum = EXXsum + EXX 146 | 147 | # M-step: 148 | Mu_new = EXsum / nr 149 | S_new = EXXsum / nr - Mu_new * Mu_new.T 150 | 151 | # Convergence condition: 152 | no_conv = np.linalg.norm(Mu - Mu_new) >= eps or np.linalg.norm(S - S_new, ord=2) >= eps 153 | iteration += 1 154 | 155 | return {"mu": Mu, "Sigma": S, "niter": iteration} 156 | -------------------------------------------------------------------------------- /probml_utils/gibbs_finite_mixgauss_utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import random 3 | from multivariate_t_utils import log_predic_t 4 | 5 | 6 | def gibbs_gmm(T, X, alpha, K, hyper_params, key): 7 | """ 8 | Implementation of the cluster analysis with a K component mixture distribution, 9 | using Gaussian likelihood and normalized inverse Wishart (NIW) prior 10 | -------------------------------------------------------------------- 11 | T: int 12 | Number of iterations of the Gibbs sampling 13 | X: array(size_of_data, dimension) 14 | The array of observations 15 | alpha: float 16 | Precision of a symmetric Dirichlet distribution, which is the pior of the mixture weights 17 | hyper_params: object of NormalInverseWishart 18 | Base measure of the Dirichlet process 19 | K: int 20 | Number of component of the mixture distribution 21 | key: jax.random.PRNGKey 22 | Seed of initial random cluster 23 | ---------------------------------- 24 | * array(T, size_of_data): 25 | Simulation of cluster assignment 26 | """ 27 | Zs = [] 28 | n, dim = X.shape 29 | CR = [[] for k in range(K)] 30 | Z = jnp.full(n, 0) 31 | CR[0] = list(range(n)) 32 | logits = jnp.ones(K) 33 | for t in range(T): 34 | print(t) 35 | for i in range(n): 36 | k_i = Z[i] 37 | CR[k_i].remove(i) 38 | for k in range(K): 39 | l_k = len(CR[k]) 40 | X_k = jnp.atleast_2d(X[CR[k][:],]) if l_k>0 else jnp.empty((0, dim)) 41 | logits = logits.at[k].set(jnp.log(l_k + alpha/K) + log_predic_t(X[i,], X_k, hyper_params)) 42 | key, subkey = random.split(key) 43 | j = random.categorical(subkey, logits=logits) 44 | Z = Z.at[i].set(j) 45 | CR[j].append(i) 46 | Zs.append(Z) 47 | return jnp.array(Zs) -------------------------------------------------------------------------------- /probml_utils/gmm_lib.py: -------------------------------------------------------------------------------- 1 | # Library of Gaussian Mixture Models 2 | # To-do: convert library into class 3 | # Author: Gerardo Durán-Martín 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from scipy.stats import multivariate_normal 8 | 9 | 10 | def plot_mixtures(X, mu, pi, Sigma, r, step=0.01, cmap="viridis", ax=None): 11 | ax = ax if ax is not None else plt.subplots()[1] 12 | colors = ["tab:red", "tab:blue"] 13 | x0, y0 = X.min(axis=0) 14 | x1, y1 = X.max(axis=0) 15 | xx, yy = np.mgrid[x0:x1:step, y0:y1:step] 16 | zdom = np.c_[xx.ravel(), yy.ravel()] 17 | 18 | Norms = [multivariate_normal(mean=mui, cov=Sigmai) for mui, Sigmai in zip(mu, Sigma)] 19 | 20 | for Norm, color in zip(Norms, colors): 21 | density = Norm.pdf(zdom).reshape(xx.shape) 22 | ax.contour(xx, yy, density, levels=1, colors=color, linewidths=5) 23 | 24 | ax.scatter(*X.T, alpha=0.7, c=r, cmap=cmap, s=10) 25 | ax.set_xlim(x0, x1) 26 | ax.set_ylim(y0, y1) 27 | 28 | 29 | def compute_responsibilities(k, pi, mu, sigma): 30 | Ns = [multivariate_normal(mean=mu_i, cov=Sigma_i) for mu_i, Sigma_i in zip(mu, sigma)] 31 | 32 | def respons(x): 33 | elements = [pi_i * Ni.pdf(x) for pi_i, Ni in zip(pi, Ns)] 34 | return elements[k] / np.sum(elements, axis=0) 35 | 36 | return respons 37 | 38 | 39 | def e_step(pi, mu, Sigma): 40 | responsibilities = [] 41 | for i, _ in enumerate(mu): 42 | resp_k = compute_responsibilities(i, pi, mu, Sigma) 43 | responsibilities.append(resp_k) 44 | return responsibilities 45 | 46 | 47 | def m_step(X, responsibilities, S=None, eta=None): 48 | N, M = X.shape 49 | pi, mu, Sigma = [], [], [] 50 | has_priors = eta is not None 51 | for resp_k in responsibilities: 52 | resp_k = resp_k(X) 53 | Nk = resp_k.sum() 54 | # mu_k 55 | mu_k = (resp_k[:, np.newaxis] * X).sum(axis=0) / Nk 56 | # Sigma_k 57 | dk = X - mu_k 58 | Sigma_k = resp_k[:, np.newaxis, np.newaxis] * np.einsum("ij, ik->ikj", dk, dk) 59 | Sigma_k = Sigma_k.sum(axis=0) 60 | if not has_priors: 61 | Sigma_k = Sigma_k / Nk 62 | else: 63 | Sigma_k = (S + Sigma_k) / (eta + Nk + M + 2) 64 | 65 | # pi_k 66 | pi_k = Nk / N 67 | 68 | pi.append(pi_k) 69 | mu.append(mu_k) 70 | Sigma.append(Sigma_k) 71 | return pi, mu, Sigma 72 | 73 | 74 | def gmm_log_likelihood(X, pi, mu, Sigma): 75 | likelihood = 0 76 | for pi_k, mu_k, Sigma_k in zip(pi, mu, Sigma): 77 | norm_k = multivariate_normal(mean=mu_k, cov=Sigma_k) 78 | likelihood += pi_k * norm_k.pdf(X) 79 | return np.log(likelihood).sum() 80 | 81 | 82 | def apply_em(X, pi, mu, Sigma, threshold=1e-5, S=None, eta=None): 83 | r = compute_responsibilities(0, pi, mu, Sigma)(X) 84 | log_likelihood = gmm_log_likelihood(X, pi, mu, Sigma) 85 | hist_log_likelihood = [log_likelihood] 86 | hist_coeffs = [(pi, mu, Sigma)] 87 | hist_responsibilities = [r] 88 | 89 | while True: 90 | responsibilities = e_step(pi, mu, Sigma) 91 | pi, mu, Sigma = m_step(X, responsibilities, S, eta) 92 | log_likelihood = gmm_log_likelihood(X, pi, mu, Sigma) 93 | 94 | hist_coeffs.append((pi, mu, Sigma)) 95 | hist_responsibilities.append(responsibilities[0](X)) 96 | hist_log_likelihood.append(log_likelihood) 97 | 98 | if np.abs(hist_log_likelihood[-1] / hist_log_likelihood[-2] - 1) < threshold: 99 | break 100 | results = {"coeffs": hist_coeffs, "rvals": hist_responsibilities, "logl": hist_log_likelihood} 101 | return results 102 | -------------------------------------------------------------------------------- /probml_utils/logreg_flax.py: -------------------------------------------------------------------------------- 1 | # Logistic regression using flax, optax and jaxopt 2 | # Since the objective is convex, we should be able to get good results using 3 | # BFGS or first order methods with automatic (Armijo) step size tuning. 4 | 5 | from functools import partial 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import scipy.stats 9 | import einops 10 | import matplotlib 11 | from functools import partial 12 | from collections import namedtuple 13 | import jax 14 | import jax.random as jr 15 | import jax.numpy as jnp 16 | from jax import vmap, grad, jit 17 | #import jax.debug 18 | import itertools 19 | from itertools import repeat 20 | from time import time 21 | import chex 22 | import typing 23 | 24 | import jax 25 | from typing import Any, Callable, Sequence 26 | from jax import lax, random, numpy as jnp 27 | from flax.core import freeze, unfreeze 28 | from flax import linen as nn 29 | import flax 30 | 31 | import jaxopt 32 | import optax 33 | import distrax 34 | from jaxopt import OptaxSolver 35 | import tensorflow as tf 36 | 37 | from sklearn.base import ClassifierMixin 38 | 39 | logistic_loss = jax.vmap(jaxopt.loss.multiclass_logistic_loss) 40 | 41 | def regularizer(params, l2reg): 42 | sqnorm = jaxopt.tree_util.tree_l2_norm(params, squared=True) 43 | return 0.5 * l2reg * sqnorm 44 | 45 | def loss_from_logits(params, l2reg, logits, labels): 46 | mean_loss = jnp.mean(logistic_loss(labels, logits)) 47 | return mean_loss + regularizer(params, l2reg) 48 | 49 | 50 | def loglikelihood_fn(params, model, X, y): 51 | # 1/N sum_n log p(yn | xn, params) 52 | logits = model.apply(params, X) 53 | return jnp.mean(distrax.Categorical(logits).log_prob(y)) 54 | 55 | def logprior_fn(params, sigma): 56 | # log p(params) 57 | leaves, _ = jax.tree_util.tree_flatten(params) 58 | flat_params = jnp.concatenate([jnp.ravel(a) for a in leaves]) 59 | return jnp.sum(distrax.Normal(0, sigma).log_prob(flat_params)) 60 | 61 | @partial(jax.jit, static_argnames=["network"]) 62 | def objective(params, data, network, prior_sigma, ntrain): 63 | # objective = -1/N [ (sum_n log p(yn|xn, theta)) + log p(theta) ] 64 | X, y = data["X"], data["y"] 65 | logjoint = loglikelihood_fn(params, network, X, y) + (1/ntrain)*logprior_fn(params, prior_sigma) 66 | return -logjoint 67 | 68 | 69 | class LogRegNetwork(nn.Module): 70 | nclasses: int 71 | W_init_fn: Any 72 | b_init_fn: Any 73 | 74 | @nn.compact 75 | def __call__(self, x): 76 | if self.W_init_fn is not None: 77 | logits = nn.Dense(self.nclasses, kernel_init=self.W_init_fn, bias_init=self.b_init_fn)(x) 78 | else: 79 | logits = nn.Dense(self.nclasses)(x) 80 | return logits 81 | 82 | class LogReg(ClassifierMixin): 83 | def __init__(self, key, nclasses, *, l2reg=1e-5, 84 | optimizer = 'lbfgs', batch_size=0, max_iter=500, 85 | W_init=None, b_init=None): 86 | # optimizer is {'lbfgs', 'polyak', 'armijo'} or an optax object 87 | self.nclasses = nclasses 88 | if W_init is not None: # specify initial parameters by hand 89 | W_init_fn = lambda key, shape, dtype: W_init # (D,C) 90 | b_init_fn = lambda key, shape, dtype: b_init # (C) 91 | ninputs = W_init.shape[0] 92 | self.network = LogRegNetwork(nclasses, W_init_fn, b_init_fn) 93 | x = jr.normal(key, (ninputs,)) # single random input 94 | self.params = self.network.init(key, x) 95 | else: 96 | self.network = LogRegNetwork(nclasses, None, None) 97 | self.params = None # must call fit to infer size of input 98 | self.optimization_results = None 99 | self.max_iter = max_iter 100 | self.optimizer = optimizer 101 | self.batch_size = batch_size 102 | self.l2reg = l2reg 103 | self.key = key 104 | 105 | 106 | def predict(self, inputs): 107 | return jax.nn.softmax(self.network.apply(self.params, inputs)) 108 | 109 | def fit(self, X, y): 110 | self.params = self.network.init(self.key, X[0]) 111 | N = X.shape[0] 112 | if (self.batch_size == 0) or (self.batch_size == N): 113 | return self.fit_batch(self.key, X, y) 114 | else: 115 | return self.fit_minibatch(self.key, X, y) 116 | 117 | def fit_batch(self, key, X, y): 118 | del key 119 | # This version is fully deterministic 120 | sigma = np.sqrt(1/self.l2reg) 121 | N = X.shape[0] 122 | data = {"X": X, "y": y} 123 | def loss_fn(params): 124 | return objective(params=params, data=data, network=self.network, prior_sigma=sigma, ntrain=N) 125 | 126 | if isinstance(self.optimizer, str) and (self.optimizer.lower() == "lbfgs"): 127 | solver = jaxopt.LBFGS(fun=loss_fn, maxiter=self.max_iter) 128 | elif isinstance(self.optimizer, str) and (self.optimizer.lower() == "polyak"): 129 | solver = jaxopt.PolyakSGD(fun=loss_fn, maxiter=self.max_iter) 130 | elif isinstance(self.optimizer, str) and (self.optimizer.lower() == "armijo"): 131 | solver = jaxopt.ArmijoSGD(fun=loss_fn, maxiter=self.max_iter) 132 | else: 133 | solver = OptaxSolver(opt=self.optimizer, fun=loss_fn, maxiter=self.max_iter) 134 | 135 | res = solver.run(self.params) 136 | self.params = res.params 137 | 138 | def fit_minibatch(self, key, X, y): 139 | del key 140 | # https://jaxopt.github.io/stable/auto_examples/deep_learning/flax_resnet.html 141 | # https://github.com/blackjax-devs/blackjax/discussions/360#discussioncomment-3756412 142 | sigma = np.sqrt(1/self.l2reg) 143 | N, B = X.shape[0], self.batch_size 144 | def loss_fn(params, data): 145 | return objective(params=params, data=data, network=self.network, prior_sigma=sigma, ntrain=N) 146 | 147 | # Convert dataset into a stream of minibatches (for stochasitc optimizers) 148 | # https://www.tensorflow.org/api_docs/python/tf/data/Dataset?version=nightly#from_tensor_slices 149 | ds = tf.data.Dataset.from_tensor_slices({"X": X, "y": y}) 150 | # https://jaxopt.github.io/stable/auto_examples/deep_learning/haiku_image_classif.htm 151 | ds = ds.cache().repeat() 152 | ds = ds.shuffle(10 * self.batch_size, seed=0) # how use jax key? 153 | ds = ds.batch(self.batch_size) 154 | iterator = ds.as_numpy_iterator() 155 | 156 | if isinstance(self.optimizer, str) and (self.optimizer.lower() == "lbfgs"): 157 | solver = jaxopt.LBFGS(fun=loss_fn, maxiter=self.max_iter) 158 | elif isinstance(self.optimizer, str) and (self.optimizer.lower() == "polyak"): 159 | solver = jaxopt.PolyakSGD(fun=loss_fn, maxiter=self.max_iter) 160 | elif isinstance(self.optimizer, str) and (self.optimizer.lower() == "armijo"): 161 | solver = jaxopt.ArmijoSGD(fun=loss_fn, maxiter=self.max_iter) 162 | else: 163 | solver = OptaxSolver(opt=self.optimizer, fun=loss_fn, maxiter=self.max_iter) 164 | 165 | res = solver.run_iterator(self.params, iterator=iterator) 166 | self.params = res.params 167 | 168 | 169 | -------------------------------------------------------------------------------- /probml_utils/logreg_flax_test.py: -------------------------------------------------------------------------------- 1 | # to show output from the 'tests', run with 2 | # pytest logreg_flax_test.py -rP 3 | 4 | from functools import partial 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | np.set_printoptions(precision=3) 8 | import scipy.stats 9 | import einops 10 | import matplotlib 11 | from functools import partial 12 | from collections import namedtuple 13 | import jax 14 | import jax.random as jr 15 | import jax.numpy as jnp 16 | from jax import vmap, grad, jit 17 | #import jax.debug 18 | import itertools 19 | from itertools import repeat 20 | from time import time 21 | import chex 22 | import typing 23 | 24 | import jax 25 | from typing import Any, Callable, Sequence 26 | from jax import lax, random, numpy as jnp 27 | from flax.core import freeze, unfreeze 28 | from flax import linen as nn 29 | import flax 30 | 31 | import jaxopt 32 | import optax 33 | 34 | import sklearn.datasets 35 | from sklearn.model_selection import train_test_split 36 | import sklearn 37 | from sklearn.preprocessing import PolynomialFeatures, StandardScaler 38 | from sklearn.pipeline import make_pipeline, Pipeline 39 | from sklearn.linear_model import LogisticRegression 40 | 41 | from logreg_flax import * 42 | #jax.config.update("jax_enable_x64", True) # jaxopt.lbfgs uses float32 43 | 44 | def print_probs(probs): 45 | str = ['{:0.3f}'.format(p) for p in probs] 46 | print(str) 47 | 48 | def make_iris_data(): 49 | iris = sklearn.datasets.load_iris() 50 | X = iris["data"] 51 | #y = (iris["target"] == 2).astype(np.int) # 1 if Iris-Virginica, else 0' 52 | y = iris["target"] 53 | nclasses = len(np.unique(y)) # 3 54 | ndata, ndim = X.shape # 150, 4 55 | key = jr.PRNGKey(0) 56 | noise = jr.normal(key, (ndata, ndim)) * 2.0 57 | X = X + noise # add noise to make the classes less separable 58 | #X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) 59 | return X, y 60 | 61 | def make_data(seed, n_samples, class_sep, n_features): 62 | X, y = sklearn.datasets.make_classification(n_samples=n_samples, n_features=n_features, n_informative=5, 63 | n_redundant=5, n_repeated=0, n_classes=10, n_clusters_per_class=1, weights=None, flip_y=0.01, 64 | class_sep=class_sep, hypercube=True, shift=0.0, scale=1.0, shuffle=True, random_state=seed) 65 | return X, y 66 | 67 | def compute_mle(X, y): 68 | # We set C to a large number to turn off regularization. 69 | # We don't fit the bias term to simplify the comparison below. 70 | log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=True) 71 | log_reg.fit(X, y) 72 | W_mle = log_reg.coef_ # (nclasses, ndim) 73 | b_mle = log_reg.intercept_ # (nclasses,) 74 | true_probs = log_reg.predict_proba(X) 75 | return true_probs, W_mle, b_mle 76 | 77 | ############ 78 | 79 | def test_inference(): 80 | # test inference at the MLE params 81 | X, y = make_iris_data() 82 | true_probs, W_mle, b_mle = compute_mle(X, y) 83 | nclasses, ndim = W_mle.shape 84 | key = jr.PRNGKey(0) 85 | model = LogReg(key, nclasses, W_init=W_mle.T, b_init=b_mle) 86 | probs = np.array(model.predict(X)) 87 | assert np.allclose(probs, true_probs, atol=1e-2) 88 | 89 | def compare_training(X, y, name, tol=0.02): 90 | true_probs, W_mle, b_mle = compute_mle(X, y) 91 | nclasses, ndim = W_mle.shape 92 | key = jr.PRNGKey(0) 93 | model = LogReg(key, nclasses, max_iter=500, l2reg=1e-5) 94 | model.fit(X, y) 95 | probs = np.array(model.predict(X)) 96 | delta = np.max(true_probs - probs) 97 | print('dataset ', name) 98 | print('max difference in predicted probabilities', delta) 99 | print('truth'); print_probs(true_probs[0]) 100 | print('pred'); print_probs(probs[0]) 101 | assert (delta < tol) 102 | 103 | def test_training_iris(): 104 | X, y = make_iris_data() 105 | compare_training(X, y, 'iris', 0.02) 106 | 107 | def test_training_blobs(): 108 | X, y = make_data(0, n_samples=1000, class_sep=1, n_features=10) 109 | compare_training(X, y, 'blobs', 0.02) 110 | 111 | 112 | 113 | def skip_test_objectives(): 114 | # compare logprior to the l2 regularizer 115 | X, y = make_iris_data() 116 | ndata, ndim = X.shape 117 | nclasses = 3 118 | key = jr.PRNGKey(0) 119 | model = LogReg(key, nclasses, max_iter=2, l2reg=1e-5) 120 | model.fit(X, y) # we need to fit the model to populate the params field 121 | params = model.params 122 | leaves, _ = jax.tree_util.tree_flatten(params) 123 | flat_params = jnp.concatenate([jnp.ravel(a) for a in leaves]) 124 | nparams = len(flat_params) 125 | logits = model.network.apply(params, X) 126 | l2reg = 0.1 127 | sigma = np.sqrt(1/l2reg) 128 | l1 = loss_from_logits(params, l2reg, logits, y) 129 | l2 = loglikelihood_fn(params, model.network, X, y) 130 | l3 = logprior_fn(params, sigma) 131 | Z = sigma*np.sqrt(2*np.pi) 132 | # log p(w) = sum_i log N(wi | 0, sigma) = sum_i [-log Z_i - 0.5*l2reg*w_i^2] 133 | assert np.allclose(-l1 -nparams*np.log(Z), l2+l3) 134 | 135 | ########## 136 | 137 | def fit_pipeline_sklearn(key, X, Y): 138 | classifier = Pipeline([ 139 | ('standardscaler', StandardScaler()), 140 | ('poly', PolynomialFeatures(degree=2)), 141 | ('logreg', LogisticRegression(random_state=0, max_iter=500, C=1e5))]) 142 | classifier.fit(np.array(X), np.array(Y)) 143 | return classifier 144 | 145 | def fit_pipeline_logreg(key, X, Y): 146 | nclasses = len(np.unique(Y)) 147 | classifier = Pipeline([ 148 | ('standardscaler', StandardScaler()), 149 | ('poly', PolynomialFeatures(degree=2)), 150 | ('logreg', LogReg(key, nclasses, max_iter=500, l2reg=1e-5))]) 151 | classifier.fit(np.array(X), np.array(Y)) 152 | return classifier 153 | 154 | 155 | def compare_pipeline(X, y, name, tol=0.02): 156 | key = jr.PRNGKey(0) 157 | clf = fit_pipeline_sklearn(key, X, y) 158 | true_probs = clf.predict_proba(X) 159 | model = fit_pipeline_logreg(key, X, y) 160 | probs = np.array(model.predict(X)) 161 | delta = np.max(true_probs - probs) 162 | print('data ', name) 163 | print('max difference in predicted probs {:.3f}'.format(delta)) 164 | print('truth: ', true_probs[0]) 165 | print('pred: ', probs[0]) 166 | assert delta < tol 167 | 168 | def test_pipeline_iris(): 169 | X, y = make_iris_data() 170 | compare_pipeline(X, y, 'iris', 0.02) 171 | 172 | def test_pipeline_blobs(): 173 | X, y = make_data(0, n_samples=1000, class_sep=1, n_features=10) 174 | compare_pipeline(X, y, 'blobs', 0.1) # much less accurate! 175 | 176 | 177 | 178 | ######### 179 | 180 | def compare_optimizer(optimizer, name=None, batch_size=None, max_iter=5000, tol=0.02): 181 | X, y = make_iris_data() 182 | true_probs, W_mle, b_mle = compute_mle(X, y) 183 | nclasses, ndim = W_mle.shape 184 | key = jr.PRNGKey(0) 185 | l2reg = 1e-5 186 | model = LogReg(key, nclasses, max_iter=max_iter, l2reg=l2reg, optimizer=optimizer, batch_size=batch_size) 187 | model.fit(X, y) 188 | probs = np.array(model.predict(X)) 189 | error = np.max(true_probs - probs) 190 | print('method {:s}, max deviation from true probs {:.3f}'.format(name, error)) 191 | print('truth: ', true_probs[0]) 192 | print('pred: ', probs[0]) 193 | assert (error < tol) 194 | 195 | 196 | def test_bfgs(): 197 | compare_optimizer("lbfgs", name= "lbfgs, bs=N", batch_size=0) 198 | 199 | def test_armijo_full_batch(): 200 | compare_optimizer("armijo", name="armijo, bs=N", batch_size=0) 201 | 202 | 203 | def test_adam_full_batch_lr2(): 204 | compare_optimizer(optax.adam(1e-2), name="adam 1e-2, bs=N", batch_size=0) 205 | 206 | # These tests fail at reasonable tolerance 207 | 208 | def test_armijo_minibatch(): 209 | compare_optimizer("armijo", name="armijo, bs=32", batch_size=32, tol=0.25) 210 | 211 | def test_adam_mini_batch_lr2(): 212 | compare_optimizer(optax.adam(1e-2), name="adam 1e-2, bs=32", batch_size=32, tol=0.1) 213 | -------------------------------------------------------------------------------- /probml_utils/lvm_plots_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import umap 4 | from typing import Callable, Tuple 5 | 6 | import torch 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | 11 | try: 12 | from einops import rearrange 13 | except ModuleNotFoundError: 14 | os.system("pip install -qq einops") 15 | from einops import rearrange 16 | 17 | try: 18 | from torchvision.utils import make_grid 19 | except ModuleNotFoundError: 20 | os.system("pip install -qq torchvision") 21 | from torchvision.utils import make_grid 22 | 23 | from scipy.stats import truncnorm 24 | from scipy.stats import norm 25 | from sklearn.manifold import TSNE 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | 29 | 30 | def get_interpolation(interpolation): 31 | """ 32 | interpolation: can accept either string or function 33 | """ 34 | if interpolation == "spherical": 35 | return slerp 36 | elif interpolation == "linear": 37 | return lerp 38 | elif callable(interpolation): 39 | return interpolation 40 | 41 | 42 | def get_embedder(encoder, X_data, y_data=None, use_embedder="TSNE"): 43 | X_data_2D = encoder(X_data) 44 | if X_data_2D.shape[-1] == 2: 45 | return X_data_2D 46 | if use_embedder == "UMAP": 47 | umap_fn = umap.UMAP() 48 | X_data_2D = umap_fn.fit_transform(X_data_2D.cpu().detach().numpy(), y_data) 49 | elif use_embedder == "TSNE": 50 | tsne = TSNE() 51 | X_data_2D = tsne.fit_transform(X_data_2D.cpu().detach().numpy()) 52 | return X_data_2D 53 | 54 | 55 | def lerp(val, low, high): 56 | """Linear interpolation""" 57 | return low + (high - low) * val 58 | 59 | 60 | def slerp(val, low, high): 61 | """Spherical interpolation. val has a range of 0 to 1.""" 62 | if val <= 0: 63 | return low 64 | elif val >= 1: 65 | return high 66 | elif torch.allclose(low, high): 67 | return low 68 | omega = torch.arccos(torch.dot(low / torch.norm(low), high / torch.norm(high))) 69 | so = torch.sin(omega) 70 | return ( 71 | torch.sin((1.0 - val) * omega) / so * low + torch.sin(val * omega) / so * high 72 | ) 73 | 74 | 75 | def make_imrange(arr: list): 76 | interpolation = torch.stack(arr) 77 | imgs = rearrange(make_grid(interpolation, 11), "c h w -> h w c") 78 | imgs = ( 79 | imgs.cpu().detach().numpy() 80 | if torch.cuda.is_available() 81 | else imgs.detach().numpy() 82 | ) 83 | return imgs 84 | 85 | 86 | def get_imrange( 87 | G: Callable[[torch.tensor], torch.tensor], 88 | start: torch.tensor, 89 | end: torch.tensor, 90 | nums: int = 8, 91 | interpolation="spherical", 92 | ) -> torch.tensor: 93 | """ 94 | Decoder must produce a 3d vector to be appened togther to form a new grid 95 | """ 96 | val = 0 97 | arr2 = [] 98 | inter = get_interpolation(interpolation) 99 | for val in torch.linspace(0, 1, nums): 100 | new_z = torch.unsqueeze(inter(val, start, end), 0) 101 | arr2.append(G(new_z)) 102 | return make_imrange(arr2) 103 | 104 | 105 | def get_random_samples( 106 | decoder: Callable[[torch.tensor], torch.tensor], 107 | truncation_threshold=1, 108 | latent_dim=20, 109 | num_images=64, 110 | num_images_per_row=8, 111 | ) -> torch.tensor: 112 | """ 113 | Decoder must produce a 4d vector to be feed into make_grid 114 | """ 115 | values = truncnorm.rvs( 116 | -truncation_threshold, truncation_threshold, size=(num_images, latent_dim) 117 | ) 118 | z = torch.from_numpy(values).float() 119 | z = z.to(device) 120 | imgs = ( 121 | rearrange(make_grid(decoder(z), num_images_per_row), "c h w -> h w c") 122 | .cpu() 123 | .detach() 124 | .numpy() 125 | ) 126 | return imgs 127 | 128 | 129 | def get_grid_samples( 130 | decoder: Callable[[torch.tensor], torch.tensor], 131 | latent_size: int = 2, 132 | size: int = 10, 133 | max_z: float = 3.1, 134 | ) -> torch.tensor: 135 | """ 136 | Decoder must produce a 3d vector to be appened togther to form a new grid 137 | """ 138 | arr = [] 139 | for i in range(0, size): 140 | z1 = (((i / (size - 1)) * max_z) * 2) - max_z 141 | for j in range(0, size): 142 | z2 = (((j / (size - 1)) * max_z) * 2) - max_z 143 | z_ = torch.tensor([[z1, z2] + (latent_size - 2) * [0]], device=device) 144 | decoded = decoder(z_) 145 | arr.append(decoded) 146 | return torch.stack(arr) 147 | 148 | 149 | def plot_scatter_plot(batch, encoder, use_embedder="TSNE", min_distance=0.03): 150 | """ 151 | Plots scatter plot of embeddings 152 | """ 153 | X_data, y_data = batch 154 | X_data = X_data.to(device) 155 | np.random.seed(42) 156 | X_data_2D = get_embedder(encoder, X_data, y_data, use_embedder) 157 | X_data_2D = (X_data_2D - X_data_2D.min()) / (X_data_2D.max() - X_data_2D.min()) 158 | 159 | # adapted from https://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html 160 | fig = plt.figure(figsize=(10, 8)) 161 | cmap = plt.cm.tab10 162 | plt.scatter(X_data_2D[:, 0], X_data_2D[:, 1], c=y_data, s=10, cmap=cmap) 163 | image_positions = np.array([[1.0, 1.0]]) 164 | for index, position in enumerate(X_data_2D): 165 | dist = np.sum((position - image_positions) ** 2, axis=1) 166 | if np.min(dist) > 0.04: # if far enough from other images 167 | image_positions = np.r_[image_positions, [position]] 168 | if X_data[index].shape[0] == 3: 169 | imagebox = matplotlib.offsetbox.AnnotationBbox( 170 | matplotlib.offsetbox.OffsetImage( 171 | rearrange(X_data[index].cpu(), "c h w -> h w c"), cmap="binary" 172 | ), 173 | position, 174 | bboxprops={"edgecolor": tuple(cmap([y_data[index]])[0]), "lw": 2}, 175 | ) 176 | elif X_data[index].shape[0] == 1: 177 | imagebox = matplotlib.offsetbox.AnnotationBbox( 178 | matplotlib.offsetbox.OffsetImage( 179 | rearrange(X_data[index].cpu(), "c h w -> (c h) w"), 180 | cmap="binary", 181 | ), 182 | position, 183 | bboxprops={"edgecolor": tuple(cmap([y_data[index]])[0]), "lw": 2}, 184 | ) 185 | plt.gca().add_artist(imagebox) 186 | plt.axis("off") 187 | return fig 188 | 189 | 190 | def plot_grid_plot( 191 | batch, encoder, use_cdf=False, use_embedder="TSNE", model_name="VAE mnist" 192 | ): 193 | """ 194 | This takes in images in batch, so G should produce a 3D tensor output example 195 | for a model that outputs images with a channel dim along with a batch dim we need 196 | to rearrange the tensor as such to produce the correct shape 197 | def decoder(z): 198 | return rearrange(m.decode(z), "b c h w -> b (c h) w") 199 | """ 200 | figsize = 8 201 | example_images, example_labels = batch 202 | example_images = example_images.to(device=device) 203 | 204 | z_points = get_embedder(encoder, example_images, use_embedder=use_embedder) 205 | p_points = norm.cdf(z_points) 206 | 207 | fig = plt.figure(figsize=(figsize, figsize)) 208 | if use_cdf: 209 | plt.scatter( 210 | p_points[:, 0], 211 | p_points[:, 1], 212 | cmap="rainbow", 213 | c=example_labels, 214 | alpha=0.5, 215 | s=5, 216 | ) 217 | else: 218 | plt.scatter( 219 | z_points[:, 0], 220 | z_points[:, 1], 221 | cmap="rainbow", 222 | c=example_labels, 223 | alpha=0.5, 224 | s=2, 225 | ) 226 | plt.colorbar() 227 | plt.title(f"{model_name} embedding") 228 | return fig 229 | 230 | 231 | def plot_grid_plot_with_sample( 232 | batch, encoder, decoder, use_embedder="TSNE", model_name="VAE mnist" 233 | ): 234 | """ 235 | This takes in images in batch, so G should produce a 3D tensor output example 236 | for a model that outputs images with a channel dim along with a batch dim we need 237 | to rearrange the tensor as such to produce the correct shape 238 | def decoder(z): 239 | return rearrange(m.decode(z), "b c h w -> b (c h) w") 240 | """ 241 | figsize = 8 242 | example_images, example_labels = batch 243 | example_images = example_images.to(device=device) 244 | 245 | z_points = get_embedder(encoder, example_images, use_embedder=use_embedder) 246 | plt.figure(figsize=(figsize, figsize)) 247 | # plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2) 248 | plt.scatter( 249 | z_points[:, 0], z_points[:, 1], cmap="rainbow", c=example_labels, alpha=0.5, s=2 250 | ) 251 | plt.colorbar() 252 | 253 | grid_size = 15 254 | grid_depth = 2 255 | np.random.seed(42) 256 | x_min = np.min(z_points[:, 0]) 257 | x_max = np.max(z_points[:, 0]) 258 | y_min = np.min(z_points[:, 1]) 259 | y_max = np.max(z_points[:, 1]) 260 | x = np.random.uniform(low=x_min, high=x_max, size=grid_size * grid_depth) 261 | y = np.random.uniform(low=y_min, high=y_max, size=grid_size * grid_depth) 262 | 263 | z_grid = np.array(list(zip(x, y))) 264 | t_z_grid = torch.FloatTensor(z_grid).to(device) 265 | reconst = decoder(t_z_grid) 266 | reconst = reconst.cpu().detach() if torch.cuda.is_available() else reconst.detach() 267 | plt.scatter(z_grid[:, 0], z_grid[:, 1], c="red", alpha=1, s=20) 268 | n = np.shape(z_grid)[0] 269 | for i in range(n): 270 | x = z_grid[i, 0] 271 | y = z_grid[i, 1] 272 | plt.text(x, y, i) 273 | plt.title(f"{model_name} embedding with samples") 274 | 275 | fig = plt.figure(figsize=(figsize, grid_depth)) 276 | fig.subplots_adjust(hspace=0.4, wspace=0.4) 277 | for i in range(grid_size * grid_depth): 278 | ax = fig.add_subplot(grid_depth, grid_size, i + 1) 279 | ax.axis("off") 280 | # ax.text(0.5, -0.35, str(np.round(z_grid[i],1)), fontsize=8, ha='center', transform=ax.transAxes) 281 | ax.text(0.5, -0.35, str(i)) 282 | ax.imshow(reconst[i, :], cmap="Greys") 283 | -------------------------------------------------------------------------------- /probml_utils/mcmc_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author : Ang Ming Liang 3 | ''' 4 | 5 | import numpy as np 6 | #from tqdm.notebook import tqdm 7 | from tqdm import tqdm 8 | 9 | def slice_sample(init, dist, iters, sigma, burnin, step_out=True, rng=None): 10 | """ 11 | based on http://homepages.inf.ed.ac.uk/imurray2/teaching/09mlss/ 12 | """ 13 | 14 | # set up empty sample holder 15 | D = len(init) 16 | samples = np.zeros((D, iters)) 17 | sigma = 5*np.ones(init.shape[-1]) 18 | 19 | # initialize 20 | xx = init.copy() 21 | 22 | for i in tqdm(range(iters)): 23 | perm = list(range(D)) 24 | rng.shuffle(perm) 25 | last_llh = dist(xx) 26 | 27 | for d in perm: 28 | llh0 = last_llh + np.log(rng.random()) 29 | rr = rng.random(1) 30 | x_l = xx.copy() 31 | x_l[d] = x_l[d] - rr * sigma[d] 32 | x_r = xx.copy() 33 | x_r[d] = x_r[d] + (1 - rr) * sigma[d] 34 | 35 | if step_out: 36 | llh_l = dist(x_l) 37 | while llh_l > llh0: 38 | x_l[d] = x_l[d] - sigma[d] 39 | llh_l = dist(x_l) 40 | llh_r = dist(x_r) 41 | while llh_r > llh0: 42 | x_r[d] = x_r[d] + sigma[d] 43 | llh_r = dist(x_r) 44 | 45 | x_cur = xx.copy() 46 | while True: 47 | xd = rng.random() * (x_r[d] - x_l[d]) + x_l[d] 48 | x_cur[d] = xd.copy() 49 | last_llh = dist(x_cur) 50 | if last_llh > llh0: 51 | xx[d] = xd.copy() 52 | break 53 | elif xd > xx[d]: 54 | x_r[d] = xd 55 | elif xd < xx[d]: 56 | x_l[d] = xd 57 | else: 58 | raise RuntimeError('Slice sampler shrank too far.') 59 | 60 | samples[:, i] = xx.copy().ravel() 61 | 62 | return samples[:, burnin:] 63 | -------------------------------------------------------------------------------- /probml_utils/mix_bernoulli_em_mnist.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Fits Bernoulli mixture model for mnist digits using em algorithm 3 | Author: Meduri Venkata Shivaditya, Aleyna Kara(@karalleyna) 4 | ''' 5 | import os 6 | try: 7 | import tensorflow as tf 8 | except ModuleNotFoundError: 9 | os.system("pip install tensorflow") 10 | import tensorflow as tf 11 | 12 | from jax.random import PRNGKey, randint 13 | from probml_utils.mix_bernoulli_lib import BMM 14 | 15 | def mnist_data(n_obs, rng_key=None): 16 | ''' 17 | Downloads data from tensorflow datasets 18 | Parameters 19 | ---------- 20 | n_obs : int 21 | Number of digits randomly chosen from mnist 22 | rng_key : array 23 | Random key of shape (2,) and dtype uint32 24 | Returns 25 | ------- 26 | * array((n_obs, 784)) 27 | Dataset 28 | ''' 29 | rng_key = PRNGKey(0) if rng_key is None else rng_key 30 | 31 | (x_train, y_train), _ = tf.keras.datasets.mnist.load_data() 32 | x = (x_train > 0).astype('int') # Converting to binary 33 | dataset_size = x.shape[0] 34 | 35 | perm = randint(rng_key, minval=0, maxval=dataset_size, shape=((n_obs,))) 36 | x_train = x[perm] 37 | x_train = x_train.reshape((n_obs, 784)) 38 | 39 | return x_train -------------------------------------------------------------------------------- /probml_utils/mix_bernoulli_lib.py: -------------------------------------------------------------------------------- 1 | # Implementation of Bernoulli Mixture Model 2 | # Author : Aleyna Kara(@karalleyna) 3 | 4 | import jax.numpy as jnp 5 | from jax import vmap, jit, value_and_grad 6 | from jax.random import PRNGKey, uniform, split, permutation 7 | from jax.lax import scan 8 | from jax.scipy.special import expit, logit 9 | from jax.nn import softmax 10 | from jax.example_libraries import optimizers 11 | 12 | import distrax 13 | from distrax._src.utils import jittable 14 | 15 | from .mixture_lib import MixtureSameFamily 16 | import probml_utils as pml 17 | 18 | import matplotlib.pyplot as plt 19 | import itertools 20 | 21 | opt_init, opt_update, get_params = optimizers.adam(1e-1) 22 | 23 | 24 | class BMM(jittable.Jittable): 25 | def __init__(self, K, n_vars, rng_key=None): 26 | """ 27 | Initializes Bernoulli Mixture Model 28 | 29 | Parameters 30 | ---------- 31 | K : int 32 | Number of latent variables 33 | 34 | n_vars : int 35 | Dimension of binary random variable 36 | 37 | rng_key : array 38 | Random key of shape (2,) and dtype uint32 39 | """ 40 | rng_key = PRNGKey(0) if rng_key is None else rng_key 41 | 42 | mixing_coeffs = uniform(rng_key, (K,), minval=100, maxval=200) 43 | mixing_coeffs = mixing_coeffs / mixing_coeffs.sum() 44 | initial_probs = jnp.full((K, n_vars), 1.0 / K) 45 | 46 | self._probs = initial_probs 47 | self.model = (mixing_coeffs, initial_probs) 48 | 49 | @property 50 | def mixing_coeffs(self): 51 | return self._model.mixture_distribution.probs 52 | 53 | @property 54 | def probs(self): 55 | return self._probs 56 | 57 | @property 58 | def model(self): 59 | return self._model 60 | 61 | @model.setter 62 | def model(self, value): 63 | mixing_coeffs, probs = value 64 | self._model = MixtureSameFamily( 65 | mixture_distribution=distrax.Categorical(probs=mixing_coeffs), 66 | components_distribution=distrax.Independent(distrax.Bernoulli(probs=probs)), 67 | ) 68 | 69 | def responsibilities(self, observations): 70 | """ 71 | Finds responsibilities 72 | 73 | Parameters 74 | ---------- 75 | observations : array(N, seq_len) 76 | Dataset 77 | 78 | Returns 79 | ------- 80 | * array 81 | Responsibilities 82 | """ 83 | return jnp.nan_to_num(self._model.posterior_marginal(observations).probs) 84 | 85 | def expected_log_likelihood(self, observations): 86 | """ 87 | Calculates expected log likelihood 88 | 89 | Parameters 90 | ---------- 91 | observations : array(N, seq_len) 92 | Dataset 93 | 94 | Returns 95 | ------- 96 | * int 97 | Log likelihood 98 | """ 99 | return jnp.sum(jnp.nan_to_num(self._model.log_prob(observations))) 100 | 101 | def _m_step(self, observations): 102 | """ 103 | Maximization step 104 | 105 | Parameters 106 | ---------- 107 | observations : array(N, seq_len) 108 | Dataset 109 | 110 | Returns 111 | ------- 112 | * array 113 | Mixing coefficients 114 | 115 | * array 116 | Probabilities 117 | """ 118 | n_obs, _ = observations.shape 119 | 120 | # Computes responsibilities, or posterior probability p(z|x) 121 | def m_step_per_bernoulli(responsibility): 122 | norm_const = responsibility.sum() 123 | mu = jnp.sum(responsibility[:, None] * observations, axis=0) / norm_const 124 | return mu, norm_const 125 | 126 | mus, ns = vmap(m_step_per_bernoulli, in_axes=(1))( 127 | self.responsibilities(observations) 128 | ) 129 | return ns / n_obs, mus 130 | 131 | def fit_em(self, observations, num_of_iters=10): 132 | """ 133 | Fits the model using em algorithm. 134 | 135 | Parameters 136 | ---------- 137 | observations : array(N, seq_len) 138 | Dataset 139 | 140 | num_of_iters : int 141 | The number of iterations the training process takes place 142 | 143 | Returns 144 | ------- 145 | * array 146 | Log likelihoods found per iteration 147 | 148 | * array 149 | Responsibilities 150 | """ 151 | iterations = jnp.arange(num_of_iters) 152 | 153 | def train_step(params, i): 154 | self.model = params 155 | 156 | log_likelihood = self.expected_log_likelihood(observations) 157 | responsibilities = self.responsibilities(observations) 158 | 159 | mixing_coeffs, probs = self._m_step(observations) 160 | 161 | return (mixing_coeffs, probs), (log_likelihood, responsibilities) 162 | 163 | initial_params = (self.mixing_coeffs, self.probs) 164 | 165 | final_params, history = scan(train_step, initial_params, iterations) 166 | 167 | self.model = final_params 168 | _, probs = final_params 169 | self._probs = probs 170 | 171 | ll_hist, responsibility_hist = history 172 | 173 | ll_hist = jnp.append(ll_hist, self.expected_log_likelihood(observations)) 174 | responsibility_hist = jnp.vstack( 175 | [responsibility_hist, jnp.array([self.responsibilities(observations)])] 176 | ) 177 | 178 | return ll_hist, responsibility_hist 179 | 180 | def _make_minibatches(self, observations, batch_size, rng_key): 181 | """ 182 | Creates minibatches consists of the random permutations of the 183 | given observation sequences 184 | 185 | Parameters 186 | ---------- 187 | observations : array(N, seq_len) 188 | Dataset 189 | 190 | batch_size : int 191 | The number of observation sequences that will be included in 192 | each minibatch 193 | 194 | rng_key : array 195 | Random key of shape (2,) and dtype uint32 196 | 197 | Returns 198 | ------- 199 | * array(num_batches, batch_size, max_len) 200 | Minibatches 201 | """ 202 | num_train = len(observations) 203 | perm = permutation(rng_key, num_train) 204 | 205 | def create_mini_batch(batch_idx): 206 | return observations[batch_idx] 207 | 208 | num_batches = num_train // batch_size 209 | batch_indices = perm.reshape((num_batches, -1)) 210 | minibatches = vmap(create_mini_batch)(batch_indices) 211 | 212 | return minibatches 213 | 214 | @jit 215 | def loss_fn(self, params, batch): 216 | """ 217 | Calculates expected mean negative loglikelihood. 218 | 219 | Parameters 220 | ---------- 221 | params : tuple 222 | Consists of mixing coefficients and probabilities of the Bernoulli distribution respectively. 223 | 224 | batch : array 225 | The subset of observations 226 | 227 | Returns 228 | ------- 229 | * int 230 | Negative log likelihood 231 | """ 232 | mixing_coeffs, probs = params 233 | self.model = (softmax(mixing_coeffs), expit(probs)) 234 | return -self.expected_log_likelihood(batch) / len(batch) 235 | 236 | @jit 237 | def update(self, i, opt_state, batch): 238 | """ 239 | Updates the optimizer state after taking derivative 240 | i : int 241 | The current iteration 242 | 243 | opt_state : jax.experimental.optimizers.OptimizerState 244 | The current state of the parameters 245 | 246 | batch : array 247 | The subset of observations 248 | 249 | Returns 250 | ------- 251 | * jax.experimental.optimizers.OptimizerState 252 | The updated state 253 | 254 | * int 255 | Loss value calculated on the current batch 256 | """ 257 | params = get_params(opt_state) 258 | loss, grads = value_and_grad(self.loss_fn)(params, batch) 259 | return opt_update(i, grads, opt_state), loss 260 | 261 | def fit_sgd( 262 | self, observations, batch_size, rng_key=None, optimizer=None, num_epochs=1 263 | ): 264 | """ 265 | Fits the model using gradient descent algorithm with the given hyperparameters. 266 | 267 | Parameters 268 | ---------- 269 | observations : array 270 | The observation sequences which Bernoulli Mixture Model is trained on 271 | 272 | batch_size : int 273 | The size of the batch 274 | 275 | rng_key : array 276 | Random key of shape (2,) and dtype uint32 277 | 278 | optimizer : jax.experimental.optimizers.Optimizer 279 | Optimizer to be used 280 | 281 | num_epochs : int 282 | The number of epoch the training process takes place 283 | 284 | Returns 285 | ------- 286 | * array 287 | Mean loss values found per epoch 288 | 289 | * array 290 | Mixing coefficients found per epoch 291 | 292 | * array 293 | Probabilities of Bernoulli distribution found per epoch 294 | 295 | * array 296 | Responsibilites found per epoch 297 | """ 298 | global opt_init, opt_update, get_params 299 | 300 | if rng_key is None: 301 | rng_key = PRNGKey(0) 302 | 303 | if optimizer is not None: 304 | opt_init, opt_update, get_params = optimizer 305 | 306 | opt_state = opt_init((softmax(self.mixing_coeffs), logit(self.probs))) 307 | itercount = itertools.count() 308 | 309 | def epoch_step(opt_state, key): 310 | def train_step(opt_state, batch): 311 | opt_state, loss = self.update(next(itercount), opt_state, batch) 312 | return opt_state, loss 313 | 314 | batches = self._make_minibatches(observations, batch_size, key) 315 | opt_state, losses = scan(train_step, opt_state, batches) 316 | 317 | params = get_params(opt_state) 318 | mixing_coeffs, probs_logits = params 319 | probs = expit(probs_logits) 320 | self.model = (softmax(mixing_coeffs), probs) 321 | self._probs = probs 322 | 323 | return opt_state, ( 324 | losses.mean(), 325 | *params, 326 | self.responsibilities(observations), 327 | ) 328 | 329 | epochs = split(rng_key, num_epochs) 330 | opt_state, history = scan(epoch_step, opt_state, epochs) 331 | params = get_params(opt_state) 332 | mixing_coeffs, probs_logits = params 333 | probs = expit(probs_logits) 334 | self.model = (softmax(mixing_coeffs), probs) 335 | self._probs = probs 336 | return history 337 | 338 | def plot(self, n_row, n_col, file_name): 339 | """ 340 | Plots the mean of each Bernoulli distribution as an image. 341 | 342 | Parameters 343 | ---------- 344 | n_row : int 345 | The number of rows of the figure 346 | n_col : int 347 | The number of columns of the figure 348 | file_name : str 349 | The path where the figure will be stored 350 | """ 351 | if n_row * n_col != len(self.mixing_coeffs): 352 | raise TypeError( 353 | "The number of rows and columns does not match with the number of component distribution." 354 | ) 355 | fig, axes = plt.subplots(n_row, n_col) 356 | 357 | for (coeff, mean), ax in zip( 358 | zip(self.mixing_coeffs, self.probs), axes.flatten() 359 | ): 360 | ax.imshow(mean.reshape(28, 28), cmap=plt.cm.gray) 361 | ax.set_title("%1.2f" % coeff) 362 | ax.axis("off") 363 | 364 | fig.tight_layout(pad=1.0) 365 | pml.savefig(f"{file_name}.pdf") 366 | plt.show() 367 | -------------------------------------------------------------------------------- /probml_utils/mixture_lib.py: -------------------------------------------------------------------------------- 1 | # Mixture distributions 2 | # Author Aleyna Kara(@karalleyna) 3 | 4 | import distrax 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | 10 | class MixtureSameFamily(distrax.MixtureSameFamily): 11 | def _per_mixture_component_log_prob(self, value): 12 | """Per mixture component log probability. 13 | https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/distributions/mixture_same_family.py 14 | 15 | Parameters 16 | ---------- 17 | value: array 18 | Represents observations from the mixture. Must 19 | be broadcastable with the mixture's batch shape. 20 | 21 | Returns 22 | ------- 23 | *array 24 | Represents, for each observation and for each mixture 25 | component, the log joint probability of that mixture component and 26 | the observation. The shape will be equal to the concatenation of (1) the 27 | broadcast shape of the observations and the batch shape, and (2) the 28 | number of mixture components. 29 | """ 30 | # Add component axis to make input broadcast with components distribution. 31 | expanded = jnp.expand_dims(value, axis=-1 - len(self.event_shape)) 32 | # Compute `log_prob` in every component. 33 | lp = self.components_distribution.log_prob(expanded) 34 | # Last batch axis is number of components, i.e. last axis of `lp` below. 35 | # Last axis of mixture log probs are components. 36 | return lp + self._mixture_log_probs 37 | 38 | def log_prob(self, value): 39 | """See "distrax.Distribution.log_prob and distrax.MixtureSameFamily". 40 | https://github.com/deepmind/distrax/blob/master/distrax/_src/distributions/distribution.py 41 | https://github.com/deepmind/distrax/blob/master/distrax/_src/distributions/mixture_same_family.py 42 | """ 43 | # Reduce last axis of mixture log probs are components 44 | return jax.scipy.special.logsumexp(self._per_mixture_component_log_prob(value), axis=-1) 45 | 46 | def posterior_marginal(self, observations): 47 | """Compute the marginal posterior distribution for a batch of observations. 48 | https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MixtureSameFamily?version=nightly#posterior_marginal 49 | 50 | Parameters 51 | ---------- 52 | observations: 53 | An array representing observations from the mixture. Must 54 | be broadcastable with the mixture's batch shape. 55 | 56 | Returns 57 | ------- 58 | * array 59 | Posterior marginals that is a `Categorical` distribution object representing 60 | the marginal probability of the components of the mixture. The batch 61 | shape of the `Categorical` will be the broadcast shape of `observations` 62 | and the mixture batch shape; the number of classes will equal the 63 | number of mixture components. 64 | """ 65 | return distrax.Categorical(logits=self._per_mixture_component_log_prob(observations)) 66 | 67 | def posterior_mode(self, observations): 68 | """Compute the posterior mode for a batch of distributions. 69 | https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MixtureSameFamily?version=nightly#posterior_mode 70 | 71 | Parameters 72 | ---------- 73 | observations: 74 | Represents observations from the mixture. Must 75 | be broadcastable with the mixture's batch shape. 76 | 77 | Returns 78 | ------- 79 | * array 80 | Represents the mode (most likely component) for each 81 | observation. The shape will be equal to the broadcast shape of the 82 | observations and the batch shape. 83 | """ 84 | return jnp.argmax(self._per_mixture_component_log_prob(observations), axis=-1) 85 | -------------------------------------------------------------------------------- /probml_utils/mlp_flax.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import scipy.stats 7 | import einops 8 | 9 | from functools import partial 10 | import jax 11 | import jax.random as jr 12 | import jax.numpy as jnp 13 | from itertools import repeat 14 | 15 | import jax 16 | from typing import Any, Callable, Sequence 17 | from jax import lax, random, numpy as jnp 18 | from flax import linen as nn 19 | import flax 20 | from flax.training import train_state 21 | 22 | import optax 23 | import jaxopt 24 | import distrax 25 | 26 | 27 | 28 | 29 | class LogRegNetwork(nn.Module): 30 | nclasses: int 31 | 32 | @nn.compact 33 | def __call__(self, x): 34 | logits = nn.Dense(self.nclasses)(x) 35 | return logits 36 | 37 | class MLPNetwork(nn.Module): 38 | nfeatures_per_layer: Sequence[int] 39 | 40 | @nn.compact 41 | def __call__(self, inputs): 42 | x = inputs 43 | nlayers = len(self.nfeatures_per_layer) 44 | for i, feat in enumerate(self.nfeatures_per_layer): 45 | x = nn.Dense(feat, name=f'layers_{i}')(x) 46 | if i != (nlayers - 1): 47 | #x = nn.relu(x) 48 | x = nn.gelu(x) 49 | return x 50 | 51 | 52 | def logprior_fn(params, sigma): 53 | # log p(params) 54 | leaves, _ = jax.tree_util.tree_flatten(params) 55 | flat_params = jnp.concatenate([jnp.ravel(a) for a in leaves]) 56 | return jnp.sum(distrax.Normal(0, sigma).log_prob(flat_params)) 57 | 58 | def l2_regularizer(params, l2reg): 59 | sqnorm = jaxopt.tree_util.tree_l2_norm(params, squared=True) 60 | return 0.5 * l2reg * sqnorm 61 | 62 | @partial(jax.jit, static_argnums=(1,2)) 63 | def get_batch_train_ixs(key, num_train, batch_size): 64 | # return indices of training set in a random order 65 | # Based on https://github.com/google/flax/blob/main/examples/mnist/train.py#L74 66 | steps_per_epoch = num_train // batch_size 67 | batch_ixs = jax.random.permutation(key, num_train) 68 | batch_ixs = batch_ixs[:steps_per_epoch * batch_size] 69 | batch_ixs = batch_ixs.reshape(steps_per_epoch, batch_size) 70 | return batch_ixs 71 | 72 | 73 | class NeuralNetClassifier: 74 | def __init__(self, network, key, nclasses, *, l2reg=1e-5, standardize = True, 75 | optimizer = 'adam+warmup', batch_size=128, max_iter=100, num_epochs=10, print_every=0): 76 | # optimizer is one of {'adam+warmup'} or an optax object 77 | self.nclasses = nclasses 78 | self.network = network 79 | self.standardize = standardize 80 | self.max_iter = max_iter 81 | self.num_epochs = num_epochs 82 | self.optimizer = optimizer 83 | self.batch_size = batch_size 84 | self.l2reg = l2reg 85 | self.print_every = print_every 86 | self.params = None # must first call fit 87 | self.key = key 88 | 89 | def predict(self, X): 90 | if self.params is None: 91 | raise ValueError('need to call fit before predict') 92 | if self.standardize: 93 | X = X - self.mean 94 | X = X / self.std 95 | return jax.nn.softmax(self.network.apply(self.params, X)) 96 | 97 | def fit(self, X, y): 98 | """Fit model. We assume y is (N) integer labels, not one-hot.""" 99 | if self.standardize: 100 | self.mean = jnp.mean(X, axis=0) 101 | self.std = jnp.std(X, axis=0) + 1e-5 102 | X = X - self.mean 103 | X = X / self.std 104 | if self.params is None: # initialize model parameters 105 | nfeatures = X.shape[1] 106 | x = jr.normal(self.key, (nfeatures,)) # single random input 107 | self.params = self.network.init(self.key, x) 108 | ntrain = X.shape[0] 109 | if isinstance(self.optimizer, str) and (self.optimizer.lower() == "adam+warmup"): 110 | total_steps = self.num_epochs*(ntrain//self.batch_size) 111 | warmup_cosine_decay_scheduler = optax.warmup_cosine_decay_schedule( 112 | init_value=1e-3, peak_value=1e-1, warmup_steps=int(total_steps*0.1), 113 | decay_steps=total_steps, end_value=1e-3) 114 | self.optimizer = optax.adam(learning_rate=warmup_cosine_decay_scheduler) 115 | return self.fit_optax(self.key, X, y) 116 | else: 117 | return self.fit_optax(self.key, X, y) 118 | 119 | 120 | def fit_optax(self, key, X, y): 121 | # based on https://github.com/google/flax/blob/main/examples/mnist/train.py 122 | ntrain = X.shape[0] # full dataset 123 | @jax.jit 124 | def train_step(state, Xb, yb): 125 | # loss = -1/N [ (sum_n log p(yn|xn, theta)) + log p(theta) ] 126 | # We estimate this from a minibatch. 127 | # We assume yb is integer, not one-hot. 128 | def loss_fn(params): 129 | logits = state.apply_fn({'params': params}, Xb) 130 | #loglik = jnp.mean(distrax.Categorical(logits).log_prob(yb)) 131 | #sigma = np.sqrt(1/self.l2reg) 132 | #logjoint = loglik + (1/ntrain)*logprior_fn(params, sigma) 133 | #loss = -logjoint 134 | loss = optax.softmax_cross_entropy_with_integer_labels(logits, yb) 135 | loss = jnp.mean(loss) + l2_regularizer(params, self.l2reg) 136 | return loss, logits 137 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 138 | (loss, logits), grads = grad_fn(state.params) 139 | accuracy = jnp.mean(jnp.argmax(logits, -1) == yb) 140 | return grads, loss, accuracy 141 | 142 | def train_epoch(key, state): 143 | key, sub_key = jr.split(key) 144 | batch_ixs = get_batch_train_ixs(sub_key, ntrain, self.batch_size) # shuffles 145 | epoch_loss = [] 146 | epoch_accuracy = [] 147 | for batch_ix in batch_ixs: 148 | X_batch, y_batch = X[batch_ix], y[batch_ix] 149 | grads, loss, accuracy = train_step(state, X_batch, y_batch) 150 | state = state.apply_gradients(grads=grads) 151 | epoch_loss.append(loss) 152 | epoch_accuracy.append(accuracy) 153 | train_loss = np.mean(epoch_loss) 154 | train_accuracy = np.mean(epoch_accuracy) 155 | return state, train_loss, train_accuracy 156 | 157 | # main loop 158 | state = train_state.TrainState.create( 159 | apply_fn=self.network.apply, params=self.params['params'], tx=self.optimizer) 160 | for epoch in range(self.num_epochs): 161 | key, sub_key = jr.split(key) 162 | state, train_loss, train_accuracy = train_epoch(sub_key, state) 163 | if (self.print_every > 0) and (epoch % self.print_every == 0): 164 | print('epoch {:d}, train loss {:0.3f}, train accuracy {:0.3f}'.format( 165 | epoch, train_loss, train_accuracy)) 166 | 167 | self.params = {'params': state.params} 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /probml_utils/mlp_flax_demo.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import partial 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | np.set_printoptions(precision=3) 6 | import scipy.stats 7 | import einops 8 | import matplotlib 9 | from functools import partial 10 | from collections import namedtuple 11 | import jax 12 | import jax.random as jr 13 | import jax.numpy as jnp 14 | from jax import vmap, grad, jit 15 | import jax.scipy as jsp 16 | import itertools 17 | from itertools import repeat 18 | from time import time 19 | import chex 20 | import typing 21 | 22 | import jax 23 | from typing import Any, Callable, Sequence 24 | from jax import lax, random, numpy as jnp 25 | from flax import linen as nn 26 | import flax 27 | 28 | import jaxopt 29 | import optax 30 | 31 | 32 | from probml_utils.mlp_flax import MLPNetwork, NeuralNetClassifier 33 | 34 | def print_vec(probs): 35 | str = ['{:0.3f}'.format(p) for p in probs] 36 | print(str) 37 | 38 | ### Create synthetic dataset from GMM, compare MLP predictions (with different optimizers) to Bayes optimal 39 | 40 | @chex.dataclass 41 | class GenParams: 42 | nclasses: int 43 | nfeatures: int 44 | prior: chex.Array 45 | mus: chex.Array # (C,D) 46 | Sigmas: chex.Array #(C,D,D) 47 | 48 | def make_params(key, nclasses, nfeatures, scale_factor=1): 49 | mus = jr.normal(key, (nclasses, nfeatures)) # (C,D) 50 | # shared covariance -> linearly separable 51 | #Sigma = scale_factor * jnp.eye(nfeatures) 52 | #Sigmas = jnp.array([Sigma for _ in range(nclasses)]) # (C,D,D) 53 | # diagonal covariance -> nonlinear decision boundaries 54 | sigmas = jr.uniform(key, shape=(nclasses, nfeatures), minval=0.5, maxval=5) 55 | Sigmas = jnp.array([scale_factor*jnp.diag(sigmas[y]) for y in range(nclasses)]) 56 | prior = jnp.ones(nclasses)/nclasses 57 | return GenParams(nclasses=nclasses, nfeatures=nfeatures, prior=prior, mus=mus, Sigmas=Sigmas) 58 | 59 | def sample_data(key, params, nsamples): 60 | y = jr.categorical(key, logits=jnp.log(params.prior), shape=(nsamples,)) 61 | X = jr.multivariate_normal(key, params.mus[y], params.Sigmas[y]) 62 | return X, y 63 | 64 | def predict_bayes(X, params): 65 | def lik_fn(y): 66 | return jsp.stats.multivariate_normal.pdf(X, params.mus[y], params.Sigmas[y]) 67 | liks = vmap(lik_fn)(jnp.arange(params.nclasses)) # liks(k,n)=p(X(n,:) | y=k) 68 | joint = jnp.einsum('kn,k -> nk', liks, params.prior) # joint(n,k) = liks(k,n) * prior(k) 69 | norm = joint.sum(axis=1) # norm(n) = sum_k joint(n,k) = p(X(n,:) 70 | post = joint / jnp.expand_dims(norm, axis=1) # post(n,k) = p(y = k | xn) 71 | return post 72 | 73 | def compare_bayes(optimizer, name, nhidden, scale_factor): 74 | nclasses = 4 75 | key = jr.PRNGKey(0) 76 | key, subkey = jr.split(key) 77 | params = make_params(subkey, nclasses=nclasses, nfeatures=10, scale_factor=scale_factor) 78 | key, subkey = jr.split(key) 79 | Xtrain, ytrain = sample_data(subkey, params, nsamples=1000) 80 | key, subkey = jr.split(key) 81 | Xtest, ytest = sample_data(subkey, params, nsamples=1000) 82 | 83 | yprobs_train_bayes = predict_bayes(Xtrain, params) 84 | yprobs_test_bayes = predict_bayes(Xtest, params) 85 | 86 | ypred_train_bayes = jnp.argmax(yprobs_train_bayes, axis=1) 87 | error_rate_train_bayes = jnp.sum(ypred_train_bayes != ytrain) / len(ytrain) 88 | 89 | ypred_test_bayes = jnp.argmax(yprobs_test_bayes, axis=1) 90 | error_rate_test_bayes = jnp.sum(ypred_test_bayes != ytest) / len(ytest) 91 | 92 | nhidden = nhidden + (nclasses,) # set nhidden() to get logistic regression 93 | network = MLPNetwork(nhidden) 94 | mlp = NeuralNetClassifier(network, key, nclasses, l2reg=1e-5, optimizer = optimizer, 95 | batch_size=32, num_epochs=30, print_every=1) 96 | mlp.fit(Xtrain, ytrain) 97 | 98 | yprobs_train_mlp = np.array(mlp.predict(Xtrain)) 99 | yprobs_test_mlp = np.array(mlp.predict(Xtest)) 100 | 101 | ypred_train_mlp = jnp.argmax(yprobs_train_mlp, axis=1) 102 | error_rate_train_mlp = jnp.sum(ypred_train_mlp != ytrain) / len(ytrain) 103 | 104 | ypred_test_mlp = jnp.argmax(yprobs_test_mlp, axis=1) 105 | error_rate_test_mlp = jnp.sum(ypred_test_mlp != ytest) / len(ytest) 106 | 107 | delta_train = jnp.max(yprobs_train_bayes - yprobs_train_mlp) 108 | delta_test = jnp.max(yprobs_test_bayes - yprobs_test_mlp) 109 | 110 | print('Evaluating training method {:s} on model with {} hidden layers'.format(name, nhidden)) 111 | print('Train error rate {:.3f} (Bayes {:.3f}), Test error rate {:.3f} (Bayes {:.3f})'.format( 112 | error_rate_train_mlp, error_rate_train_bayes, error_rate_test_mlp, error_rate_test_bayes)) 113 | #print('Max diff in probs from Bayes: train {:.3f}, test {:.3f}'.format( 114 | # delta_train, delta_test)) 115 | print('\n') 116 | 117 | def eval_mlp_on_gmm(sf): 118 | # scale_factor = 5 means the class conditional densities have higher variance (more overlap) 119 | compare_bayes(optax.adam(1e-3), "adam(1e-3)", nhidden=(), scale_factor=sf) 120 | compare_bayes("adam+warmup", "adam+warmup", nhidden=(), scale_factor=sf) 121 | 122 | compare_bayes(optax.adam(1e-3), "adam(1e-3)", nhidden=(10,), scale_factor=sf) 123 | compare_bayes("adam+warmup", "adam+warmup", nhidden=(10,), scale_factor=sf) 124 | 125 | compare_bayes(optax.adam(1e-3), "adam(1e-3)", nhidden=(10,10), scale_factor=sf) 126 | compare_bayes("adam+warmup", "adam+warmup", nhidden=(10,10), scale_factor=sf) 127 | 128 | 129 | 130 | def main(): 131 | eval_mlp_on_gmm(sf=5) 132 | 133 | if __name__ == "__main__": 134 | main() -------------------------------------------------------------------------------- /probml_utils/mlp_logreg_flax_demo.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Compare logistic regression in sklearn to our MLP model with no hidden layers on some small synthetic data 4 | # We find comparable results in predictive probabilites provided we use learning rate warmup. 5 | 6 | from functools import partial 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | np.set_printoptions(precision=3) 10 | import scipy.stats 11 | import einops 12 | import matplotlib 13 | from functools import partial 14 | from collections import namedtuple 15 | import jax 16 | import jax.random as jr 17 | import jax.numpy as jnp 18 | from jax import vmap, grad, jit 19 | import jax.scipy as jsp 20 | import itertools 21 | from itertools import repeat 22 | from time import time 23 | import chex 24 | import typing 25 | 26 | import jax 27 | from typing import Any, Callable, Sequence 28 | from jax import lax, random, numpy as jnp 29 | from flax import linen as nn 30 | import flax 31 | 32 | import jaxopt 33 | import optax 34 | 35 | import sklearn.datasets 36 | import sklearn 37 | from sklearn.preprocessing import PolynomialFeatures, StandardScaler 38 | from sklearn.pipeline import Pipeline 39 | from sklearn.linear_model import LogisticRegression 40 | 41 | 42 | from probml_utils.mlp_flax import MLPNetwork, NeuralNetClassifier 43 | 44 | def make_data(seed, n_samples, class_sep, n_features): 45 | X, y = sklearn.datasets.make_classification(n_samples=n_samples, n_features=n_features, n_informative=5, 46 | n_redundant=5, n_repeated=0, n_classes=10, n_clusters_per_class=1, weights=None, flip_y=0.01, 47 | class_sep=class_sep, hypercube=True, shift=0.0, scale=1.0, shuffle=True, random_state=seed) 48 | return X, y 49 | 50 | def fit_predict_logreg(Xtrain, ytrain, Xtest, ytest, l2reg): 51 | # Use sklearn to fit logistic regression model 52 | classifier = Pipeline([ 53 | ('standardscaler', StandardScaler()), 54 | ('logreg', LogisticRegression(random_state=0, max_iter=100, C=1/l2reg))]) 55 | classifier.fit(Xtrain, ytrain) 56 | train_probs = classifier.predict_proba(Xtrain) 57 | test_probs = classifier.predict_proba(Xtest) 58 | return train_probs, test_probs 59 | 60 | def compare_probs(logreg_probs, probs, labels): 61 | delta = np.max(logreg_probs - probs) 62 | logreg_pred = np.argmax(logreg_probs, axis=1) 63 | pred = np.argmax(probs, axis=1) 64 | logreg_error_rate = np.mean(logreg_pred != labels) 65 | error_rate = np.mean(pred != labels) 66 | return delta, logreg_error_rate, error_rate 67 | 68 | def compare_logreg(optimizer, name=None, batch_size=None, num_epochs=30, 69 | n_samples=1000, class_sep=1, n_features=10): 70 | key = jr.PRNGKey(0) 71 | l2reg = 1e-5 72 | X_train, y_train = make_data(0, n_samples, class_sep, n_features) 73 | X_test, y_test = make_data(1, 1000, class_sep, n_features) 74 | nclasses = len(np.unique(y_train)) 75 | train_probs_logreg, test_probs_logreg = fit_predict_logreg(X_train, y_train, X_test, y_test, l2reg) 76 | 77 | #network = MLPNetwork((5, nclasses,)) 78 | network = MLPNetwork((nclasses,)) # no hidden layers == logistic regression 79 | model = NeuralNetClassifier(network, key, nclasses, l2reg=l2reg, optimizer = optimizer, 80 | batch_size=batch_size, num_epochs=num_epochs, print_every=1) 81 | model.fit(X_train, y_train) 82 | train_probs = np.array(model.predict(X_train)) 83 | test_probs = np.array(model.predict(X_test)) 84 | 85 | train_delta, train_logreg_error_rate, train_error_rate = compare_probs(train_probs_logreg, train_probs, y_train) 86 | test_delta, test_logreg_error_rate, test_error_rate = compare_probs(test_probs_logreg, test_probs, y_test) 87 | print('max difference in train probabilities from logreg to {:s} is {:.3f}'.format(name, train_delta)) 88 | print('misclassification rates: logreg train = {:.3f}, model train = {:.3f}'.format( 89 | train_logreg_error_rate, train_error_rate)) 90 | print('misclassification rates: logreg test = {:.3f}, model test = {:.3f}'.format( 91 | test_logreg_error_rate, test_error_rate)) 92 | 93 | 94 | compare_logreg(optax.adam(1e-3), name="adam 1e-3, bs=32", batch_size=32) 95 | 96 | compare_logreg("adam+warmup", name="adam+warmup, bs=32", batch_size=32) -------------------------------------------------------------------------------- /probml_utils/mnist_helper_tf.py: -------------------------------------------------------------------------------- 1 | # Helper functions for DNN demos related to mnist images 2 | import os 3 | try: 4 | import tensorflow as tf 5 | except ModuleNotFoundError: 6 | os.system("pip install tensorflow") 7 | import tensorflow as tf 8 | 9 | from tensorflow import keras 10 | assert tf.__version__ >= "2.0" 11 | 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | from IPython import display 15 | from time import time 16 | 17 | def get_dataset(FASHION=False): 18 | if FASHION: 19 | (train_images, train_labels), (test_images, test_labels) = keras.datasets.fashion_mnist.load_data() 20 | class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 21 | 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 22 | else: 23 | (train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data() 24 | class_names = [str(x) for x in range(10)] 25 | train_images = train_images / 255.0 26 | test_images = test_images / 255.0 27 | return train_images, train_labels, test_images, test_labels, class_names 28 | 29 | def plot_dataset(train_images, train_labels, class_names): 30 | plt.figure(figsize=(10,10)) 31 | for i in range(25): 32 | plt.subplot(5,5,i+1) 33 | plt.xticks([]) 34 | plt.yticks([]) 35 | plt.grid(False) 36 | plt.imshow(train_images[i], cmap=plt.cm.binary) 37 | plt.xlabel(class_names[train_labels[i]]) 38 | #save_fig("fashion-mnist-data.pdf") 39 | plt.show() 40 | 41 | def plot_image_and_label(predictions_array, true_label, img, class_names): 42 | plt.grid(False) 43 | plt.xticks([]) 44 | plt.yticks([]) 45 | img = np.reshape(img, (28, 28)) # drop any trailing dimension of size 1 46 | plt.imshow(img, cmap=plt.cm.binary) 47 | predicted_label = np.argmax(predictions_array) 48 | if predicted_label == true_label: 49 | color = 'blue' 50 | else: 51 | color = 'red' 52 | plt.xlabel("truth={}, pred={}, score={:2.0f}%".format( 53 | class_names[true_label], 54 | class_names[predicted_label], 55 | 100*np.max(predictions_array)), 56 | color=color) 57 | 58 | def plot_label_dist(predictions_array, true_label): 59 | plt.grid(False) 60 | plt.xticks([]) 61 | plt.yticks([]) 62 | thisplot = plt.bar(range(10), predictions_array, color="#777777") 63 | plt.ylim([0, 1]) 64 | predicted_label = np.argmax(predictions_array) 65 | thisplot[predicted_label].set_color('red') 66 | thisplot[true_label].set_color('blue') 67 | 68 | def find_interesting_test_images(predictions, test_labels): 69 | # We select the first 9 images plus 6 error images 70 | pred = np.argmax(predictions, axis=1) 71 | errors = np.where(pred != test_labels)[0] 72 | print(errors.shape) 73 | ndx1 = range(9) 74 | ndx2 = errors[:6] 75 | ndx = np.concatenate((ndx1, ndx2)) 76 | return ndx 77 | 78 | def plot_interesting_test_results(test_images, test_labels, predictions, 79 | class_names, ndx): 80 | # Plot some test images, their predicted label, and the true label 81 | # Color correct predictions in blue, incorrect predictions in red 82 | num_rows = 5 83 | num_cols = 3 84 | num_images = num_rows*num_cols 85 | plt.figure(figsize=(2*2*num_cols, 2*num_rows)) 86 | for i in range(num_images): 87 | n = ndx[i] 88 | plt.subplot(num_rows, 2*num_cols, 2*i+1) 89 | plot_image_and_label(predictions[n], test_labels[n], test_images[n], 90 | class_names) 91 | plt.subplot(num_rows, 2*num_cols, 2*i+2) 92 | plot_label_dist(predictions[n], test_labels[n]) 93 | plt.show() 94 | -------------------------------------------------------------------------------- /probml_utils/multivariate_t_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | import jax.numpy as jnp 4 | from jax.scipy.special import gammaln 5 | from jax.numpy.linalg import slogdet, solve 6 | 7 | 8 | @jax.jit 9 | def log_p_of_multi_t(x, nu, mu, Sigma): 10 | """ 11 | Computing the logarithm of probability density of the multivariate T distribution, 12 | https://en.wikipedia.org/wiki/Multivariate_t-distribution 13 | --------------------------------------------------------- 14 | x: array(dim) 15 | Data point that we want to evaluate log pdf at 16 | nu: int 17 | Degree of freedom of the multivariate T distribution 18 | mu: array(dim) 19 | Location parameter of the multivariate T distribution 20 | Sigma: array(dim, dim) 21 | Positive-definite real scale matrix of the multivariate T distribution 22 | -------------------------------------------------------------------------- 23 | * float 24 | Log probability of the multivariate T distribution at x 25 | """ 26 | dim = mu.shape[0] 27 | # Logarithm of the normalizing constant 28 | l0 = gammaln((nu+dim)/2.0) - (gammaln(nu/2.0) + dim/2.0*(jnp.log(nu)+jnp.log(np.pi)) + slogdet(Sigma)[1]) 29 | # Logarithm of the unnormalized pdf 30 | l1 = -(nu+dim)/2.0 * jnp.log(1 + 1/nu*(x-mu).dot(solve(Sigma, x-mu))) 31 | return l0 + l1 32 | 33 | 34 | def log_predic_t(x, obs, hyper_params): 35 | """ 36 | Evaluating the logarithm of probability of the posterior predictive multivariate T distribution. 37 | The likelihood of the observation given the parameter is Gaussian distribution. 38 | The prior distribution is Normal Inverse Wishart (NIW) with parameters given by hyper_params. 39 | --------------------------------------------------------------------------------------------- 40 | x: array(dim) 41 | Data point that we want to evalute the log probability 42 | obs: array(n, dim) 43 | Observations that the posterior distritbuion is conditioned on 44 | hyper_params: () 45 | The set of hyper parameters of the NIW prior 46 | ------------------------------------------------ 47 | * float 48 | Log probability of the multivariate T distribution at x 49 | """ 50 | mu0, kappa0, nu0, Sigma0 = hyper_params 51 | n, dim = obs.shape 52 | # Use the prior marginal distribution if no observation 53 | if n==0: 54 | nu_t = nu0 - dim + 1 55 | mu_t = mu0 56 | Sigma_t = Sigma0*(kappa0+1)/(kappa0*nu_t) 57 | return log_p_of_multi_t(x, nu_t, mu_t, Sigma_t) 58 | # Update the distribution using sufficient statistics 59 | obs_mean = jnp.mean(obs, axis=0) 60 | S = (obs-obs_mean).T @ (obs-obs_mean) 61 | nu_n = nu0 + n 62 | kappa_n = kappa0 + n 63 | mu_n = kappa0/kappa_n*mu0 + n/kappa_n*obs_mean 64 | Lambda_n = Sigma0 + S + kappa0*n/kappa_n*jnp.outer(obs_mean-mu0, obs_mean-mu0) 65 | nu_t = nu_n - dim + 1 66 | mu_t = mu_n 67 | Sigma_t = Lambda_n*(kappa_n+1)/(kappa_n*nu_t) 68 | return log_p_of_multi_t(x, nu_t, mu_t, Sigma_t) -------------------------------------------------------------------------------- /probml_utils/nb_utils.py: -------------------------------------------------------------------------------- 1 | import nbformat as nbf 2 | 3 | 4 | def get_ipynb_from_code(code): 5 | """ 6 | Get the ipynb from the code. 7 | """ 8 | nb = nbf.v4.new_notebook() 9 | nb["cells"] = [nbf.v4.new_code_cell(code)] 10 | return nb 11 | -------------------------------------------------------------------------------- /probml_utils/pgmpy_utils.py: -------------------------------------------------------------------------------- 1 | # utility functions for pgmpy library 2 | # authors: murphyk@, Drishttii@ 3 | 4 | #!pip install pgmpy 5 | #!pip install graphviz 6 | import os 7 | try: 8 | import pgmpy 9 | except: 10 | os.system("pip install pgmpy") 11 | import pgmpy 12 | 13 | import numpy as np 14 | import itertools 15 | from graphviz import Digraph 16 | 17 | def get_state_names(model, name): 18 | state_names = dict() 19 | cpd = model.get_cpds(name) 20 | state_names = cpd.state_names 21 | return state_names 22 | 23 | def get_all_state_names(model): 24 | state_names = dict() 25 | for cpd in model.get_cpds(): 26 | for k, v in cpd.state_names.items(): 27 | state_names[k] = v 28 | return state_names 29 | 30 | def get_lengths(state_names, name): 31 | row = [] 32 | for k, v in state_names.items(): 33 | if (k==name): 34 | col = len(v) 35 | else: 36 | row.append(len(v)) 37 | return row, col 38 | 39 | def get_all_perms(states, name): 40 | all_list = [] 41 | for k, v in states.items(): 42 | if (k==name): 43 | continue 44 | else: 45 | all_list.append(states[k]) 46 | res = list(itertools.product(*all_list)) 47 | resu = [] 48 | for j in res: 49 | j = str(j) 50 | j = j.replace('(', '') 51 | j = j.replace(',', '') 52 | j = j.replace(')', '') 53 | j = j.replace("'", '') 54 | j = j.replace(' ', ', ') 55 | resu.append(j) 56 | return resu 57 | 58 | def visualize_model(model): 59 | h = Digraph('model_') 60 | # Adding each node 61 | for cpd in model.get_cpds(): 62 | name = cpd.variable 63 | states = get_state_names(model, name) 64 | cpd = model.get_cpds(name) 65 | values = cpd.values 66 | if values.ndim > 2: 67 | values = values.reshape(values.shape[0], -1) 68 | values = values.T 69 | the_string = "" 70 | 71 | if len(states) == 1: 72 | rows = len(states[name]) 73 | cols = rows 74 | row_string = "" 75 | for row in range(rows): 76 | col_string = "" 77 | for col in range(cols): 78 | if (row==0): 79 | inp = states[name] 80 | col_string = col_string + "" + str(inp[col]) + "" 81 | else: 82 | two_dec = format(values[col], ".2f") 83 | col_string = col_string + "" + str(two_dec) + "" 84 | 85 | row_string = row_string + "" + str(col_string) + "" 86 | 87 | else: 88 | #lis = get_list(states, name) 89 | res = get_all_perms(states, name) 90 | r, c = get_lengths(states, name) 91 | rows = np.prod(r) + 1 92 | cols = c + 1 93 | row_string = "" 94 | for row in range(rows): 95 | col_string = "" 96 | if (row==0): 97 | for col in range(cols): 98 | if (col==0): 99 | col_string = col_string + "" + " " + "" 100 | else: 101 | inp = states[name] 102 | col_string = col_string + "" + str(inp[col-1]) + "" 103 | 104 | row_string = row_string + "" + str(col_string) + "" 105 | 106 | else: 107 | for col in range(cols): 108 | if (col==0): 109 | col_string = col_string + "" + str(res[row-1]) + "" 110 | else: 111 | two_dec = format(values[row-1][col-1], ".2f") 112 | col_string = col_string + "" + str(two_dec) + "" 113 | 114 | row_string = row_string + "" + str(col_string) + "" 115 | 116 | h.node(name, label = '''< 117 | 118 | 119 | {}
{}
>'''.format(rows ,name, row_string)) 120 | 121 | edges = (model.edges()) 122 | for item in edges: 123 | edge = list(item) 124 | h.edge(edge[0], edge[1]) 125 | 126 | return h 127 | 128 | def get_marginals(model, evidence={}, inference_engine=None): 129 | if inference_engine is None: 130 | inference_engine = pgmpy.inference.VariableElimination(model) # more efficient to precompute this 131 | nodes = model.nodes() 132 | num_nodes = len(nodes) 133 | state_names = get_all_state_names(model) 134 | marginals = dict() 135 | for n in nodes: 136 | if n in evidence: # observed nodes 137 | v = evidence[n] 138 | if type(v) == str: 139 | v_ndx = state_names[n].index(v) 140 | else: 141 | v_ndx = v 142 | nstates = model.get_cardinality(n) 143 | marginals[n] = np.zeros(nstates) 144 | marginals[n][v_ndx] = 1.0 # delta function on observed value 145 | else: 146 | probs = inference_engine.query([n], evidence=evidence).values 147 | marginals[n] = probs 148 | return marginals 149 | 150 | def visualize_marginals(model, evidence, marginals): 151 | h = Digraph('pgm') 152 | for node_name, probs in marginals.items(): 153 | states = get_state_names(model, node_name) 154 | rows = 2 #len(probs) 155 | cols = len(states[node_name]) 156 | row_string = "" 157 | for row in range(rows): 158 | col_string = "" 159 | for col in range(cols): 160 | if (row==0): 161 | inp = states[node_name] 162 | col_string = col_string + "" + str(inp[col]) + "" 163 | else: 164 | inp = round(probs[col], 2) 165 | col_string = col_string + "" + str(inp) + "" 166 | 167 | row_string = row_string + "" + str(col_string) + "" 168 | 169 | if node_name in evidence.keys(): 170 | h.node(node_name, label = '''< 171 | 172 | 173 | {}
{}
>'''.format(cols , node_name, row_string)) 174 | else: 175 | h.node(node_name, label = '''< 176 | 177 | 178 | {}
{}
>'''.format(cols , node_name, row_string)) 179 | 180 | edges = (model.edges()) 181 | for item in edges: 182 | edge = list(item) 183 | h.edge(edge[0], edge[1]) 184 | 185 | return h 186 | -------------------------------------------------------------------------------- /probml_utils/plotting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import warnings 4 | 5 | DEFAULT_WIDTH = 6.0 6 | DEFAULT_HEIGHT = 1.5 7 | SIZE_SMALL = 9 # Caption size in the pml book 8 | 9 | 10 | def latexify( 11 | width_scale_factor=1, 12 | height_scale_factor=1, 13 | fig_width=None, 14 | fig_height=None, 15 | font_size=SIZE_SMALL, 16 | ): 17 | f""" 18 | width_scale_factor: float, DEFAULT_WIDTH will be divided by this number, DEFAULT_WIDTH is page width: {DEFAULT_WIDTH} inches. 19 | height_scale_factor: float, DEFAULT_HEIGHT will be divided by this number, DEFAULT_HEIGHT is {DEFAULT_HEIGHT} inches. 20 | fig_width: float, width of the figure in inches (if this is specified, width_scale_factor is ignored) 21 | fig_height: float, height of the figure in inches (if this is specified, height_scale_factor is ignored) 22 | font_size: float, font size 23 | """ 24 | if "LATEXIFY" not in os.environ: 25 | warnings.warn("LATEXIFY environment variable not set, not latexifying") 26 | return 27 | if fig_width is None: 28 | fig_width = DEFAULT_WIDTH / width_scale_factor 29 | if fig_height is None: 30 | fig_height = DEFAULT_HEIGHT / height_scale_factor 31 | 32 | # use TrueType fonts so they are embedded 33 | # https://stackoverflow.com/questions/9054884/how-to-embed-fonts-in-pdfs-produced-by-matplotlib 34 | # https://jdhao.github.io/2018/01/18/mpl-plotting-notes-201801/ 35 | plt.rcParams["pdf.fonttype"] = 42 36 | 37 | # Font sizes 38 | # SIZE_MEDIUM = 14 39 | # SIZE_LARGE = 24 40 | # https://stackoverflow.com/a/39566040 41 | plt.rc("font", size=font_size) # controls default text sizes 42 | plt.rc("axes", titlesize=font_size) # fontsize of the axes title 43 | plt.rc("axes", labelsize=font_size) # fontsize of the x and y labels 44 | plt.rc("xtick", labelsize=font_size) # fontsize of the tick labels 45 | plt.rc("ytick", labelsize=font_size) # fontsize of the tick labels 46 | plt.rc("legend", fontsize=font_size) # legend fontsize 47 | plt.rc("figure", titlesize=font_size) # fontsize of the figure title 48 | 49 | # latexify: https://nipunbatra.github.io/blog/visualisation/2014/06/02/latexify.html 50 | plt.rcParams["backend"] = "ps" 51 | plt.rc("text", usetex=True) 52 | plt.rc("font", family="serif") 53 | plt.rc("figure", figsize=(fig_width, fig_height)) 54 | 55 | 56 | def is_latexify_enabled(): 57 | """ 58 | returns true if LATEXIFY environment variable is set 59 | """ 60 | return "LATEXIFY" in os.environ 61 | 62 | 63 | def _get_fig_name(fname_full): 64 | fname_full = fname_full.replace("_latexified", "") 65 | LATEXIFY = "LATEXIFY" in os.environ 66 | # extention = "_latexified.pdf" if LATEXIFY else ".png" 67 | extention = "_latexified.pdf" if LATEXIFY else ".pdf" 68 | if fname_full[-4:] in [".png", ".pdf", ".jpg"]: 69 | fname = fname_full[:-4] 70 | warnings.warn( 71 | f"renaming {fname_full} to {fname}{extention} because LATEXIFY is {LATEXIFY}", 72 | ) 73 | else: 74 | fname = fname_full 75 | return fname + extention 76 | 77 | 78 | def savefig( 79 | f_name, tight_layout=True, tight_bbox=False, pad_inches=0.0, *args, **kwargs 80 | ): 81 | if len(f_name) == 0: 82 | return 83 | if "FIG_DIR" not in os.environ: 84 | warnings.warn("set FIG_DIR environment variable to save figures") 85 | return 86 | 87 | fig_dir = os.environ["FIG_DIR"] 88 | # Auto create the directory if it doesn't exist 89 | if not os.path.exists(fig_dir): 90 | os.makedirs(fig_dir) 91 | 92 | fname_full = os.path.join(fig_dir, f_name) 93 | fname_full = _get_fig_name(fname_full) 94 | 95 | print("saving image to {}".format(fname_full)) 96 | if tight_layout: 97 | plt.tight_layout(pad=pad_inches) 98 | print("Figure size:", plt.gcf().get_size_inches()) 99 | 100 | if tight_bbox: 101 | # This changes the size of the figure 102 | plt.savefig( 103 | fname_full, pad_inches=pad_inches, bbox_inches="tight", *args, **kwargs 104 | ) 105 | else: 106 | plt.savefig(fname_full, pad_inches=pad_inches, *args, **kwargs) 107 | 108 | if "DUAL_SAVE" in os.environ: 109 | if fname_full.endswith(".pdf"): 110 | fname_full = fname_full[:-4] + ".png" 111 | else: 112 | fname_full = fname_full[:-4] + ".pdf" 113 | if tight_bbox: 114 | # This changes the size of the figure 115 | plt.savefig( 116 | fname_full, pad_inches=pad_inches, bbox_inches="tight", *args, **kwargs 117 | ) 118 | else: 119 | plt.savefig(fname_full, pad_inches=pad_inches, *args, **kwargs) 120 | -------------------------------------------------------------------------------- /probml_utils/prefit_voting_classifier.py: -------------------------------------------------------------------------------- 1 | # Make ensemble of pre-fit estimators 2 | # https://gist.github.com/tomquisel/a421235422fdf6b51ec2ccc5e3dee1b4 3 | # tomquisel 4 | 5 | import numpy as np 6 | from sklearn.utils.validation import check_is_fitted 7 | 8 | 9 | class PrefitVotingClassifier(object): 10 | """Stripped-down version of VotingClassifier that uses prefit estimators""" 11 | 12 | def __init__(self, estimators, voting="hard", weights=None): 13 | self.estimators = [e[1] for e in estimators] 14 | self.named_estimators = dict(estimators) 15 | self.voting = voting 16 | self.weights = weights 17 | 18 | def fit(self, X, y, sample_weight=None): 19 | raise NotImplementedError 20 | 21 | def predict(self, X): 22 | """Predict class labels for X. 23 | Parameters 24 | ---------- 25 | X : {array-like, sparse matrix}, shape = [n_samples, n_features] 26 | Training vectors, where n_samples is the number of samples and 27 | n_features is the number of features. 28 | Returns 29 | ---------- 30 | maj : array-like, shape = [n_samples] 31 | Predicted class labels. 32 | """ 33 | 34 | check_is_fitted(self, "estimators") 35 | if self.voting == "soft": 36 | maj = np.argmax(self.predict_proba(X), axis=1) 37 | 38 | else: # 'hard' voting 39 | predictions = self._predict(X) 40 | maj = np.apply_along_axis( 41 | lambda x: np.argmax(np.bincount(x, weights=self.weights)), axis=1, arr=predictions.astype("int") 42 | ) 43 | return maj 44 | 45 | def _collect_probas(self, X): 46 | """Collect results from clf.predict calls.""" 47 | return np.asarray([clf.predict_proba(X) for clf in self.estimators]) 48 | 49 | def _predict_proba(self, X): 50 | """Predict class probabilities for X in 'soft' voting""" 51 | if self.voting == "hard": 52 | raise AttributeError("predict_proba is not available when" " voting=%r" % self.voting) 53 | check_is_fitted(self, "estimators") 54 | avg = np.average(self._collect_probas(X), axis=0, weights=self.weights) 55 | return avg 56 | 57 | @property 58 | def predict_proba(self): 59 | """Compute probabilities of possible outcomes for samples in X. 60 | Parameters 61 | ---------- 62 | X : {array-like, sparse matrix}, shape = [n_samples, n_features] 63 | Training vectors, where n_samples is the number of samples and 64 | n_features is the number of features. 65 | Returns 66 | ---------- 67 | avg : array-like, shape = [n_samples, n_classes] 68 | Weighted average probability for each class per sample. 69 | """ 70 | return self._predict_proba 71 | 72 | def transform(self, X): 73 | """Return class labels or probabilities for X for each estimator. 74 | Parameters 75 | ---------- 76 | X : {array-like, sparse matrix}, shape = [n_samples, n_features] 77 | Training vectors, where n_samples is the number of samples and 78 | n_features is the number of features. 79 | Returns 80 | ------- 81 | If `voting='soft'`: 82 | array-like = [n_classifiers, n_samples, n_classes] 83 | Class probabilities calculated by each classifier. 84 | If `voting='hard'`: 85 | array-like = [n_samples, n_classifiers] 86 | Class labels predicted by each classifier. 87 | """ 88 | check_is_fitted(self, "estimators") 89 | if self.voting == "soft": 90 | return self._collect_probas(X) 91 | else: 92 | return self._predict(X) 93 | 94 | def _predict(self, X): 95 | """Collect results from clf.predict calls.""" 96 | return np.asarray([clf.predict(X) for clf in self.estimators]).T 97 | -------------------------------------------------------------------------------- /probml_utils/pyprobml_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from numpy import linalg 5 | from mpl_toolkits.mplot3d import Axes3D 6 | 7 | from inspect import getsourcefile 8 | from os.path import abspath 9 | 10 | 11 | # https://stackoverflow.com/questions/2632199/how-do-i-get-the-path-of-the-current-executed-file-in-python?lq=1 12 | def get_current_path(): 13 | current_path = abspath(getsourcefile(lambda: 0)) # fullname of current file 14 | # current_path = os.path.dirname(__file__) 15 | current_dir = os.path.dirname(current_path) 16 | return current_dir 17 | 18 | 19 | def test(): 20 | print("welcome to python probabilistic ML library") 21 | print(get_current_path()) 22 | 23 | 24 | # https://stackoverflow.com/questions/10685495/reducing-the-size-of-pdf-figure-file-in-matplotlib 25 | 26 | # save_fig is now depricated 27 | # def save_fig(fname, *args, **kwargs): 28 | # # figdir = '../figures' # default directory one above where code lives 29 | # current_dir = get_current_path() 30 | # figdir = os.path.join(current_dir, "..", "figures") 31 | 32 | # if not os.path.exists(figdir): 33 | # print("making directory {}".format(figdir)) 34 | # os.mkdir(figdir) 35 | 36 | # fname_full = os.path.join(figdir, fname) 37 | # print("saving image to {}".format(fname_full)) 38 | # # plt.tight_layout() 39 | 40 | # # use TrueType fonts so they are embedded 41 | # # https://stackoverflow.com/questions/9054884/how-to-embed-fonts-in-pdfs-produced-by-matplotlib 42 | # # https://jdhao.github.io/2018/01/18/mpl-plotting-notes-201801/ 43 | # plt.rcParams["pdf.fonttype"] = 42 44 | 45 | # # Font sizes 46 | # SIZE_SMALL = 12 47 | # SIZE_MEDIUM = 14 48 | # SIZE_LARGE = 24 49 | # # https://stackoverflow.com/a/39566040 50 | # plt.rc("font", size=SIZE_SMALL) # controls default text sizes 51 | # plt.rc("axes", titlesize=SIZE_SMALL) # fontsize of the axes title 52 | # plt.rc("axes", labelsize=SIZE_SMALL) # fontsize of the x and y labels 53 | # plt.rc("xtick", labelsize=SIZE_SMALL) # fontsize of the tick labels 54 | # plt.rc("ytick", labelsize=SIZE_SMALL) # fontsize of the tick labels 55 | # plt.rc("legend", fontsize=SIZE_SMALL) # legend fontsize 56 | # plt.rc("figure", titlesize=SIZE_LARGE) # fontsize of the figure title 57 | 58 | # plt.savefig(fname_full, *args, **kwargs) 59 | 60 | 61 | # def savefig(fname, *args, **kwargs): 62 | # save_fig(fname, *args, **kwargs) 63 | 64 | 65 | from matplotlib.patches import Ellipse, transforms 66 | 67 | # https://matplotlib.org/devdocs/gallery/statistics/confidence_ellipse.html 68 | def plot_ellipse(Sigma, mu, ax, n_std=3.0, facecolor="none", edgecolor="k", plot_center="true", **kwargs): 69 | cov = Sigma 70 | pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1]) 71 | 72 | ell_radius_x = np.sqrt(1 + pearson) 73 | ell_radius_y = np.sqrt(1 - pearson) 74 | ellipse = Ellipse( 75 | (0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2, facecolor=facecolor, edgecolor=edgecolor, **kwargs 76 | ) 77 | 78 | scale_x = np.sqrt(cov[0, 0]) * n_std 79 | mean_x = mu[0] 80 | 81 | scale_y = np.sqrt(cov[1, 1]) * n_std 82 | mean_y = mu[1] 83 | 84 | transf = transforms.Affine2D().rotate_deg(45).scale(scale_x, scale_y).translate(mean_x, mean_y) 85 | 86 | ellipse.set_transform(transf + ax.transData) 87 | 88 | if plot_center: 89 | ax.plot(mean_x, mean_y, ".") 90 | return ax.add_patch(ellipse) 91 | 92 | 93 | def plot_ellipse_test(): 94 | fig, ax = plt.subplots() 95 | Sigma = np.array([[5, 1], [1, 5]]) 96 | plot_ellipse(Sigma, np.zeros(2), ax, n_std=1) 97 | plt.axis("equal") 98 | plt.show() 99 | 100 | 101 | def convergence_test(fval, previous_fval, threshold=1e-4, warn=False): 102 | eps = 2e-10 103 | converged = 0 104 | delta_fval = np.abs(fval - previous_fval) 105 | avg_fval = (np.abs(fval) + abs(previous_fval) + eps) / 2.0 106 | if (delta_fval / avg_fval) < threshold: 107 | converged = 1 108 | 109 | if warn and (fval - previous_fval) < -2 * eps: 110 | print("convergenceTest:fvalDecrease", "objective decreased!") 111 | return converged 112 | 113 | 114 | def hinton_diagram(matrix, max_weight=None, ax=None): 115 | """Draw Hinton diagram for visualizing a weight matrix.""" 116 | if not max_weight: 117 | max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2)) 118 | 119 | ax.patch.set_facecolor("white") 120 | ax.set_aspect("equal", "box") 121 | 122 | for (x, y), w in np.ndenumerate(matrix): 123 | color = "lawngreen" if w > 0 else "royalblue" 124 | size = np.sqrt(np.abs(w) / max_weight) 125 | rect = plt.Rectangle([x - size / 2, y - size / 2], size, size, facecolor=color, edgecolor=color) 126 | ax.add_patch(rect) 127 | nr, nc = matrix.shape 128 | ax.set_xticks(np.arange(0, nr)) 129 | ax.set_yticks(np.arange(0, nc)) 130 | ax.grid(linestyle="--", linewidth=2) 131 | ax.autoscale_view() 132 | ax.invert_yaxis() 133 | 134 | 135 | def kdeg(x, X, h): 136 | """ 137 | KDE under a gaussian kernel 138 | 139 | Parameters 140 | ---------- 141 | x: array(eval, D) 142 | X: array(obs, D) 143 | h: float 144 | 145 | Returns 146 | ------- 147 | array(eval): 148 | KDE around the observed values 149 | """ 150 | N, D = X.shape 151 | nden, _ = x.shape 152 | 153 | Xhat = X.reshape(D, 1, N) 154 | xhat = x.reshape(D, nden, 1) 155 | u = xhat - Xhat 156 | u = linalg.norm(u, ord=2, axis=0) ** 2 / (2 * h**2) 157 | px = np.exp(-u).sum(axis=1) / (N * h * np.sqrt(2 * np.pi)) 158 | return px 159 | 160 | 161 | def scale_3d(ax, x_scale, y_scale, z_scale, factor): 162 | scale = np.diag([x_scale, y_scale, z_scale, 1.0]) 163 | scale = scale * (1.0 / scale.max()) 164 | scale[3, 3] = factor 165 | 166 | def short_proj(): 167 | return np.dot(Axes3D.get_proj(ax), scale) 168 | 169 | return short_proj 170 | 171 | 172 | def style3d(ax, x_scale, y_scale, z_scale, factor=0.62): 173 | plt.gca().patch.set_facecolor("white") 174 | ax.w_xaxis.set_pane_color((0, 0, 0, 0)) 175 | ax.w_yaxis.set_pane_color((0, 0, 0, 0)) 176 | ax.w_zaxis.set_pane_color((0, 0, 0, 0)) 177 | ax.get_proj = scale_3d(ax, x_scale, y_scale, z_scale, factor) 178 | 179 | 180 | if __name__ == "__main__": 181 | test() 182 | -------------------------------------------------------------------------------- /probml_utils/rvm_classifier.py: -------------------------------------------------------------------------------- 1 | # This file implements Relevance Vector Machine Classifier. 2 | # Author Srikar Reddy Jilugu(@always-newbie161) 3 | 4 | import numpy as np 5 | 6 | # This is a python implementation of Relevance Vector Machine Classifier, 7 | # it's based on github.com/ctgk/PRML/blob/master/prml/kernel/relevance_vector_classifier.py 8 | class RVC: 9 | def sigmoid(self, a): 10 | return np.tanh(a * 0.5) * 0.5 + 0.5 11 | 12 | # Kernel matrix using rbf kernel with gamma = 0.3. 13 | def kernel_mat(self, X, Y): 14 | (x, y) = (np.tile(X, (len(Y), 1, 1)).transpose(1, 0, 2), np.tile(Y, (len(X), 1, 1))) 15 | d = np.repeat(1 / (0.3 * 0.3), X.shape[-1]) * (x - y) ** 2 16 | return np.exp(-0.5 * np.sum(d, axis=-1)) 17 | 18 | def __init__(self, alpha=1.0): 19 | self.threshold_alpha = 1e8 20 | self.alpha = alpha 21 | self.iter_max = 100 22 | self.relevance_vectors_ = [] 23 | 24 | # estimates for singulat matrices. 25 | def ps_inv(self, m): 26 | # assuming it is a square matrix. 27 | a = m.shape[0] 28 | i = np.eye(a, a) 29 | return np.linalg.lstsq(m, i, rcond=None)[0] 30 | 31 | """ 32 | For the current fixed values of alpha, the most probable 33 | weights are found by maximizing w over p(w/t,alpha) 34 | using the Laplace approximation of finding an hessian. 35 | (E step) 36 | w = mean of p(w/t,alpha) 37 | cov = negative hessian of p(w/t,alpha) 38 | 39 | """ 40 | 41 | def _map_estimate(self, X, t, w, n_iter=10): 42 | for _ in range(n_iter): 43 | y = self.sigmoid(X @ w) 44 | g = X.T @ (y - t) + self.alpha * w 45 | H = (X.T * y * (1 - y)) @ X + np.diag(self.alpha) # negated Hessian of p(w/t,alpha) 46 | w -= np.linalg.lstsq(H, g, rcond=None)[0] # works even if for singular matrices. 47 | return w, self.ps_inv(H) # inverse of H is the covariance of the gaussian approximation. 48 | 49 | """ 50 | Fitting of input-target pairs works by 51 | iteratively finding the most probable weights(done by _map_estimate method) 52 | and optimizing the hyperparameters(alpha) until there is no 53 | siginificant change in alpha. 54 | 55 | (M step) 56 | Optimizing alpha: 57 | For the given targets and current variance(sigma^2) alpha is optimized over p(t/alpha,variance) 58 | It is done by Mackay approach(ARD). 59 | alpha(new) = gamma/mean^2 60 | where gamma = 1 - alpha(old)*covariance. 61 | 62 | After finding the hyperparameters(alpha), 63 | the samples which have alpha less than the threshold(hence weight >> 0) 64 | are choosen as relevant vectors. 65 | 66 | Now predicted y = sign(phi(X) @ mean) ( mean contains the optimal weights) 67 | """ 68 | 69 | def fit(self, X, y): 70 | Phi = self.kernel_mat(X, X) 71 | N = len(y) 72 | self.alpha = np.zeros(N) + self.alpha 73 | mean = np.zeros(N) 74 | for i in range(self.iter_max): 75 | param = np.copy(self.alpha) 76 | mean, cov = self._map_estimate(Phi, y, mean, 10) 77 | gamma = 1 - self.alpha * np.diag(cov) 78 | self.alpha = gamma / np.square(mean) 79 | np.clip(self.alpha, 0, 1e10, out=self.alpha) 80 | if np.allclose(param, self.alpha): 81 | break 82 | 83 | ret_alpha = self.alpha < self.threshold_alpha 84 | self.relevance_vectors_ = X[ret_alpha] 85 | self.y = y[ret_alpha] 86 | self.alpha = self.alpha[ret_alpha] 87 | Phi = self.kernel_mat(self.relevance_vectors_, self.relevance_vectors_) 88 | mean = mean[ret_alpha] 89 | self.mean, self.covariance = self._map_estimate(Phi, self.y, mean, 100) 90 | 91 | # gives probability for target to be class 0. 92 | def predict_proba(self, X): 93 | phi = self.kernel_mat(X, self.relevance_vectors_) 94 | mu_a = phi @ self.mean 95 | var_a = np.sum(phi @ self.covariance * phi, axis=1) 96 | return 1 - self.sigmoid(mu_a / np.sqrt(1 + np.pi * var_a / 8)) 97 | 98 | def predict(self, X): 99 | phi = self.kernel_mat(X, self.relevance_vectors_) 100 | return (phi @ self.mean > 0).astype(np.int) 101 | -------------------------------------------------------------------------------- /probml_utils/rvm_regressor.py: -------------------------------------------------------------------------------- 1 | """ 2 | code taken from 3 | https://github.com/ctgk/PRML/blob/master/prml/kernel/relevance_vector_regressor.py 4 | """ 5 | 6 | import numpy as np 7 | 8 | 9 | class RelevanceVectorRegressor(object): 10 | def __init__(self, kernel, alpha=1.0, beta=1.0): 11 | """ 12 | construct relevance vector regressor 13 | Parameters 14 | ---------- 15 | kernel : Kernel 16 | kernel function to compute components of feature vectors 17 | alpha : float 18 | initial precision of prior weight distribution 19 | beta : float 20 | precision of observation 21 | """ 22 | self.kernel = kernel 23 | self.alpha = alpha 24 | self.beta = beta 25 | 26 | def fit(self, X, t, iter_max=1000): 27 | """ 28 | maximize evidence with respect to hyperparameter 29 | Parameters 30 | ---------- 31 | X : (sample_size, n_features) ndarray 32 | input 33 | t : (sample_size,) ndarray 34 | corresponding target 35 | iter_max : int 36 | maximum number of iterations 37 | Attributes 38 | ------- 39 | X : (N, n_features) ndarray 40 | relevance vector 41 | t : (N,) ndarray 42 | corresponding target 43 | alpha : (N,) ndarray 44 | hyperparameter for each weight or training sample 45 | cov : (N, N) ndarray 46 | covariance matrix of weight 47 | mean : (N,) ndarray 48 | mean of each weight 49 | """ 50 | if X.ndim == 1: 51 | X = X[:, None] 52 | assert X.ndim == 2 53 | assert t.ndim == 1 54 | N = len(t) 55 | Phi = self.kernel(X, X) 56 | self.alpha = np.zeros(N) + self.alpha 57 | for _ in range(iter_max): 58 | params = np.hstack([self.alpha, self.beta]) 59 | precision = np.diag(self.alpha) + self.beta * Phi.T @ Phi 60 | covariance = np.linalg.inv(precision) 61 | mean = self.beta * covariance @ Phi.T @ t 62 | gamma = 1 - self.alpha * np.diag(covariance) 63 | self.alpha = gamma / np.square(mean) 64 | np.clip(self.alpha, 0, 1e10, out=self.alpha) 65 | self.beta = (N - np.sum(gamma)) / np.sum((t - Phi.dot(mean)) ** 2) 66 | if np.allclose(params, np.hstack([self.alpha, self.beta])): 67 | break 68 | mask = self.alpha < 1e9 69 | self.X = X[mask] 70 | self.t = t[mask] 71 | self.alpha = self.alpha[mask] 72 | Phi = self.kernel(self.X, self.X) 73 | precision = np.diag(self.alpha) + self.beta * Phi.T @ Phi 74 | self.covariance = np.linalg.inv(precision) 75 | self.mean = self.beta * self.covariance @ Phi.T @ self.t 76 | 77 | def predict(self, X, with_error=True): 78 | """ 79 | predict output with this model 80 | Parameters 81 | ---------- 82 | X : (sample_size, n_features) 83 | input 84 | with_error : bool 85 | if True, predict with standard deviation of the outputs 86 | Returns 87 | ------- 88 | mean : (sample_size,) ndarray 89 | mean of predictive distribution 90 | std : (sample_size,) ndarray 91 | standard deviation of predictive distribution 92 | """ 93 | if X.ndim == 1: 94 | X = X[:, None] 95 | assert X.ndim == 2 96 | phi = self.kernel(X, self.X) 97 | mean = phi @ self.mean 98 | if with_error: 99 | var = 1 / self.beta + np.sum(phi @ self.covariance * phi, axis=1) 100 | return mean, np.sqrt(var) 101 | return mean 102 | -------------------------------------------------------------------------------- /probml_utils/svi_gmm_model_tfp.py: -------------------------------------------------------------------------------- 1 | # SVI for a GMM 2 | # Modified from 3 | # https://github.com/brendanhasz/svi-gaussian-mixture-model/blob/master/BayesianGaussianMixtureModel.ipynb 4 | 5 | #pip install tf-nightly 6 | #pip install --upgrade tfp-nightly -q 7 | 8 | # Imports 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | import tensorflow as tf 13 | import tensorflow_probability as tfp 14 | tfd = tfp.distributions 15 | 16 | 17 | class GaussianMixtureModel(tf.keras.Model): 18 | """A Bayesian Gaussian mixture model. 19 | 20 | Assumes Gaussians' variances in each dimension are independent. 21 | 22 | Parameters 23 | ---------- 24 | Nc : int > 0 25 | Number of mixture components. 26 | Nd : int > 0 27 | Number of dimensions. 28 | """ 29 | 30 | def __init__(self, Nc, Nd): 31 | 32 | # Initialize 33 | super(GaussianMixtureModel, self).__init__() 34 | self.Nc = Nc 35 | self.Nd = Nd 36 | 37 | # Variational distribution variables for means 38 | self.locs = tf.Variable(tf.random.normal((Nc, Nd))) 39 | self.scales = tf.Variable(tf.pow(tf.random.gamma((Nc, Nd), 5, 5), -0.5)) 40 | 41 | # Variational distribution variables for standard deviations 42 | self.alpha = tf.Variable(tf.random.uniform((Nc, Nd), 4., 6.)) 43 | self.beta = tf.Variable(tf.random.uniform((Nc, Nd), 4., 6.)) 44 | 45 | # Variational distribution variables for component weights 46 | self.counts = tf.Variable(2*tf.ones((Nc,))) 47 | 48 | # Prior distributions for the means 49 | self.mu_prior = tfd.Normal(tf.zeros((Nc, Nd)), tf.ones((Nc, Nd))) 50 | 51 | # Prior distributions for the standard deviations 52 | self.sigma_prior = tfd.Gamma(5*tf.ones((Nc, Nd)), 5*tf.ones((Nc, Nd))) 53 | 54 | # Prior distributions for the component weights 55 | self.theta_prior = tfd.Dirichlet(2*tf.ones((Nc,))) 56 | 57 | 58 | 59 | def call(self, x, sampling=True, independent=True): 60 | """Compute losses given a batch of data. 61 | 62 | Parameters 63 | ---------- 64 | x : tf.Tensor 65 | A batch of data 66 | sampling : bool 67 | Whether to sample from the variational posterior 68 | distributions (if True, the default), or just use the 69 | mean of the variational distributions (if False). 70 | 71 | Returns 72 | ------- 73 | log_likelihoods : tf.Tensor 74 | Log likelihood for each sample 75 | kl_sum : tf.Tensor 76 | Sum of the KL divergences between the variational 77 | distributions and their priors 78 | """ 79 | 80 | # The variational distributions 81 | mu = tfd.Normal(self.locs, self.scales) 82 | sigma = tfd.Gamma(self.alpha, self.beta) 83 | theta = tfd.Dirichlet(self.counts) 84 | 85 | # Sample from the variational distributions 86 | if sampling: 87 | Nb = x.shape[0] #number of samples in the batch 88 | mu_sample = mu.sample(Nb) 89 | sigma_sample = tf.pow(sigma.sample(Nb), -0.5) 90 | theta_sample = theta.sample(Nb) 91 | else: 92 | mu_sample = tf.reshape(mu.mean(), (1, self.Nc, self.Nd)) 93 | sigma_sample = tf.pow(tf.reshape(sigma.mean(), (1, self.Nc, self.Nd)), -0.5) 94 | theta_sample = tf.reshape(theta.mean(), (1, self.Nc)) 95 | 96 | # The mixture density 97 | density = tfd.Mixture( 98 | cat=tfd.Categorical(probs=theta_sample), 99 | components=[ 100 | tfd.MultivariateNormalDiag(loc=mu_sample[:, i, :], 101 | scale_diag=sigma_sample[:, i, :]) 102 | for i in range(self.Nc)]) 103 | 104 | # Compute the mean log likelihood 105 | log_likelihoods = density.log_prob(x) 106 | 107 | # Compute the KL divergence sum 108 | mu_div = tf.reduce_sum(tfd.kl_divergence(mu, self.mu_prior)) 109 | sigma_div = tf.reduce_sum(tfd.kl_divergence(sigma, self.sigma_prior)) 110 | theta_div = tf.reduce_sum(tfd.kl_divergence(theta, self.theta_prior)) 111 | kl_sum = mu_div + sigma_div + theta_div 112 | 113 | # Return both losses 114 | return log_likelihoods, kl_sum 115 | 116 | def fit(self, dataset, N, nepochs): 117 | # How infer N from a tfds?? 118 | optimizer = tf.keras.optimizers.Adam(lr=1e-3) 119 | 120 | @tf.function 121 | def train_step(data): 122 | with tf.GradientTape() as tape: 123 | log_likelihoods, kl_sum = self(data) 124 | elbo_loss = kl_sum/N - tf.reduce_mean(log_likelihoods) 125 | gradients = tape.gradient(elbo_loss, self.trainable_variables) 126 | optimizer.apply_gradients(zip(gradients, self.trainable_variables)) 127 | 128 | for epoch in range(nepochs): 129 | for data in dataset: 130 | train_step(data) 131 | 132 | -------------------------------------------------------------------------------- /probml_utils/url_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import requests 3 | from typing import Any 4 | from TexSoup import TexSoup 5 | import re 6 | import os 7 | import pandas as pd 8 | import firebase_admin 9 | from firebase_admin import credentials, firestore, initialize_app 10 | 11 | 12 | def is_dead_url(link): 13 | """ 14 | check if given link is dead or not 15 | """ 16 | resp = requests.get(link) 17 | if resp.status_code != 200: 18 | return 1 19 | return 0 20 | 21 | 22 | def check_dead_urls(urls: Any, print_dead_url=True): 23 | """ 24 | returns if urls are dead or not 25 | """ 26 | cnt = 0 27 | mapping_values, mapping_treedef = jax.tree_flatten( 28 | urls 29 | ) # pick only values (leaf noded) 30 | status = [] 31 | for url in mapping_values: 32 | if is_dead_url(url): 33 | if print_dead_url: 34 | print(url) 35 | status.append(1) 36 | cnt += 1 37 | else: 38 | status.append(0) 39 | if print_dead_url: 40 | print(f"{cnt} dead urls detected!") 41 | return mapping_treedef.unflatten(status) # convert to original structure 42 | 43 | 44 | def github_url_to_colab_url(url): 45 | """ 46 | convert github .ipynb url to colab .ipynb url 47 | """ 48 | if not (url.startswith("https://github.com")): 49 | raise ValueError("INVALID URL: not a Github url") 50 | 51 | if not (url.endswith(".ipynb")): 52 | raise ValueError("INVALID URL: not a .ipynb file") 53 | 54 | base_url_colab = "https://colab.research.google.com/github/" 55 | base_url_github = "https://github.com/" 56 | 57 | return url.replace(base_url_github, base_url_colab) 58 | 59 | 60 | def colab_url_to_github_url(url): 61 | """ 62 | convert colab .ipynb url to github .ipynb url 63 | """ 64 | if not (url.startswith("https://colab.research.google.com/github")): 65 | raise ValueError("INVALID URL: not a colab github url") 66 | 67 | if not (url.endswith(".ipynb")): 68 | raise ValueError("INVALID URL: not a .ipynb file") 69 | 70 | base_url_colab = "https://colab.research.google.com/github/" 71 | base_url_github = "https://github.com/" 72 | return url.replace(base_url_colab, base_url_github) 73 | 74 | 75 | def colab_to_githubraw_url(url): 76 | """ 77 | convert colab .ipynb url to github raw .ipynb url 78 | """ 79 | if not (url.startswith("https://colab.research.google.com/github")): 80 | raise ValueError("INVALID URL: not a colab github url") 81 | 82 | if not (url.endswith(".ipynb")): 83 | raise ValueError("INVALID URL: not a .ipynb file") 84 | 85 | base_url_colab = "https://colab.research.google.com/github/" 86 | base_url_githubraw = "https://raw.githubusercontent.com/" 87 | return ( 88 | url.replace(base_url_colab, base_url_githubraw) 89 | .replace("blob/", "") 90 | .replace("tree/", "") 91 | ) 92 | 93 | 94 | def github_to_rawcontent_url(github_url): 95 | return github_url.replace("github.com", "raw.githubusercontent.com").replace( 96 | "blob/", "" 97 | ) 98 | 99 | 100 | def extract_scripts_name_from_caption(caption): 101 | """ 102 | extract foo.py from ...{https//:foo.py}{foo.py}... 103 | Input: caption 104 | Output: ['foo.py'] 105 | """ 106 | py_pattern = r"\{\S+?\.py\}" 107 | ipynb_pattern = r"\}{\S+?\.ipynb?\}" 108 | matches = re.findall(py_pattern, str(caption)) + re.findall( 109 | ipynb_pattern, str(caption) 110 | ) 111 | 112 | extracted_scripts = [] 113 | for each in matches: 114 | if "https" not in each: 115 | each = each.replace("{", "").replace("}", "").replace("\\_", "_") 116 | extracted_scripts.append(each) 117 | return extracted_scripts 118 | 119 | 120 | def make_url_from_fig_no_and_script_name( 121 | fig_no, 122 | script_name, 123 | base_url="https://github.com/probml/pyprobml/blob/master/notebooks", 124 | book_no=1, 125 | convert_to_which_url="github", 126 | ): 127 | """ 128 | create mapping between fig_no and actual_url path 129 | (fig_no=1.3,script_name=iris_plot.ipynb) converted to https://github.com/probml/pyprobml/blob/master/notebooks/book1/01/iris_plot.ipynb 130 | convert_to_which_url = Union["github","colab","gihub-raw"] 131 | """ 132 | chapter_no = int(fig_no.strip().split(".")[0]) 133 | base_url_ipynb = os.path.join(base_url, f"book{book_no}/{chapter_no:02d}") 134 | if ".py" in script_name: 135 | script_name = script_name[:-3] + ".ipynb" 136 | 137 | github_url = os.path.join(base_url_ipynb, script_name) 138 | if convert_to_which_url == "colab": 139 | return github_url_to_colab_url(github_url) 140 | elif convert_to_which_url == "gihub-raw": 141 | return github_to_rawcontent_url(github_url) 142 | return github_url 143 | 144 | 145 | def make_url_from_chapter_no_and_script_name( 146 | chapter_no, 147 | script_name, 148 | base_url="https://github.com/probml/pyprobml/blob/master/notebooks", 149 | book_no=1, 150 | convert_to_which_url="github", 151 | ): 152 | """ 153 | create mapping between chapter_no and actual_url path 154 | (chapter_no = 3,script_name=iris_plot.ipynb) converted to https://github.com/probml/pyprobml/blob/master/notebooks/book1/01/iris_plot.ipynb 155 | convert_to_which_url = Union["github","colab","gihub-raw"] 156 | """ 157 | base_url_ipynb = os.path.join(base_url, f"book{book_no}/{int(chapter_no):02d}") 158 | if script_name.strip().endswith(".py"): 159 | script_name = script_name[:-3] + ".ipynb" 160 | github_url = os.path.join(base_url_ipynb, script_name) 161 | 162 | if convert_to_which_url == "colab": 163 | return github_url_to_colab_url(github_url) 164 | elif convert_to_which_url == "github-raw": 165 | return github_to_rawcontent_url(github_url) 166 | return github_url 167 | 168 | 169 | def dict_to_csv(key_value_dict, csv_name, columns=["key", "url"]): 170 | df = pd.DataFrame(key_value_dict.items(), columns=columns) 171 | df.set_index(keys=columns[0], inplace=True, drop=True) 172 | df.to_csv(csv_name) 173 | 174 | 175 | def figure_url_mapping_from_lof( 176 | lof_file_path, 177 | csv_name, 178 | convert_to_which_url="colab", 179 | base_url="https://github.com/probml/pyprobml/blob/master/notebooks", 180 | book_no=1, 181 | ): 182 | f""" 183 | create mappng of fig_no to url by parsing lof_file and save mapping in {csv_name} 184 | convert_to_which_url = Union["github","colab","gihub-raw"] 185 | """ 186 | with open(lof_file_path) as fp: 187 | LoF_File_Contents = fp.read() 188 | soup = TexSoup(LoF_File_Contents) 189 | 190 | # create mapping of fig_no to list of script_name 191 | 192 | url_mapping = {} 193 | for caption in soup.find_all("numberline"): 194 | fig_no = str(caption.contents[0]) 195 | extracted_scripts = extract_scripts_name_from_caption(str(caption)) 196 | if len(extracted_scripts) == 1: 197 | url_mapping[fig_no] = make_url_from_fig_no_and_script_name( 198 | fig_no, 199 | extracted_scripts[0], 200 | convert_to_which_url=convert_to_which_url, 201 | base_url=base_url, 202 | book_no=book_no, 203 | ) 204 | elif len(extracted_scripts) > 1: 205 | url_mapping[fig_no] = make_url_from_fig_no_and_script_name( 206 | fig_no, 207 | "fig_" + fig_no.replace(".", "_") + ".ipynb", 208 | convert_to_which_url=convert_to_which_url, 209 | base_url=base_url, 210 | book_no=book_no, 211 | ) 212 | 213 | if csv_name: 214 | dict_to_csv(url_mapping, csv_name) 215 | print(f"Mapping of {len(url_mapping)} urls is saved in {csv_name}") 216 | return url_mapping 217 | 218 | 219 | def non_figure_notebook_url_mapping( 220 | notebooks_path, 221 | csv_name, 222 | convert_to_which_url="colab", 223 | base_url="https://github.com/probml/pyprobml/blob/master/notebooks", 224 | book_no=1, 225 | ): 226 | f""" 227 | create mapping of notebook_name to url using notebooks in given path - {notebooks_path} and save mapping in {csv_name} 228 | convert_to_which_url = Union["github","colab","gihub-raw"] 229 | """ 230 | url_mapping = {} 231 | for notebook_path in notebooks_path: 232 | parts = notebook_path.split("/") 233 | script_name = parts[-1] 234 | chapter_no = parts[-2] 235 | url = make_url_from_chapter_no_and_script_name( 236 | chapter_no, 237 | script_name, 238 | convert_to_which_url=convert_to_which_url, 239 | base_url=base_url, 240 | book_no=book_no, 241 | ) 242 | key = script_name.split(".")[0] # remove extension 243 | url_mapping[key] = url 244 | if csv_name: 245 | dict_to_csv(url_mapping, csv_name) 246 | print(f"Mapping of {len(url_mapping)} urls is saved in {csv_name}") 247 | return url_mapping 248 | 249 | 250 | def create_firestore_db(key_path): 251 | cred = credentials.Certificate(key_path) 252 | try: 253 | default_app = initialize_app(cred) # this should called only once 254 | except ValueError: 255 | firebase_admin.delete_app( 256 | firebase_admin.get_app() 257 | ) # delete current firebase app 258 | default_app = initialize_app(cred) 259 | db = firestore.client() 260 | return db 261 | 262 | 263 | def upload_urls_to_firestore( 264 | key_path, 265 | csv_path, 266 | level1_collection="figures", 267 | level2_document=None, 268 | level3_collection=None, 269 | ): 270 | 271 | f""" 272 | extract key-value pair from {csv_path} and upload in firestore database 273 | """ 274 | assert level2_document in [ 275 | "book1", 276 | "book2", 277 | ], "Incorrect level2_document value: possible values of level2_document should be ['book1', 'book2']" 278 | 279 | db = create_firestore_db(key_path) 280 | 281 | collection = ( 282 | db.collection(level1_collection) 283 | .document(level2_document) 284 | .collection(level3_collection) 285 | ) 286 | 287 | df = pd.read_csv( 288 | csv_path, dtype=str 289 | ) # put dtype=str otherwise fig_no 3.30 will converted to 3.3 290 | 291 | assert sorted(df.columns) == [ 292 | "key", 293 | "url", 294 | ], f"columns of {csv_path} should be only 'key' and 'url'" 295 | 296 | print("Uploading...") 297 | for (key, url) in list(zip(df["key"], df["url"])): 298 | collection.document(key).set({"link": url}) 299 | print(f"{len(df)} urls uploaded!") 300 | -------------------------------------------------------------------------------- /probml_utils/vae_celeba_lightning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Ang Ming Liang 3 | 4 | Please run the following command before running the script 5 | 6 | wget -q https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py 7 | or curl https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py > data.py 8 | 9 | Then, make sure to get your kaggle.json from kaggle.com then run 10 | 11 | mkdir /root/.kaggle 12 | cp kaggle.json /root/.kaggle/kaggle.json 13 | chmod 600 /root/.kaggle/kaggle.json 14 | rm kaggle.json 15 | 16 | to copy kaggle.json into a folder first 17 | """ 18 | import os 19 | 20 | try: 21 | import torch 22 | except ModuleNotFoundError: 23 | os.system("pip install torch") 24 | import torch 25 | try: 26 | import torchvision.transforms as transforms 27 | except: 28 | os.system("pip install torchvision") 29 | import torchvision.transforms as transforms 30 | 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | 34 | try: 35 | from pytorch_lightning import LightningModule, Trainer 36 | except: 37 | os.system("pip install pytorch-lightning") 38 | from pytorch_lightning import LightningModule, Trainer 39 | 40 | from probml_utils.vae_lightning_data import CelebADataModule 41 | 42 | from argparse import ArgumentParser 43 | 44 | 45 | IMAGE_SIZE = 64 46 | CROP = 128 47 | DATA_PATH = "kaggle" 48 | 49 | trans = [] 50 | trans.append(transforms.RandomHorizontalFlip()) 51 | if CROP > 0: 52 | trans.append(transforms.CenterCrop(CROP)) 53 | trans.append(transforms.Resize(IMAGE_SIZE)) 54 | trans.append(transforms.ToTensor()) 55 | transform = transforms.Compose(trans) 56 | 57 | 58 | class VAE(LightningModule): 59 | """ 60 | Standard VAE with Gaussian Prior and approx posterior. 61 | """ 62 | 63 | def __init__( 64 | self, 65 | input_height: int, 66 | hidden_dims = None, 67 | in_channels = 3, 68 | enc_out_dim: int = 512, 69 | kl_coeff: float = 0.1, 70 | latent_dim: int = 256, 71 | lr: float = 1e-4 72 | ): 73 | """ 74 | Args: 75 | input_height: height of the images 76 | enc_type: option between resnet18 or resnet50 77 | first_conv: use standard kernel_size 7, stride 2 at start or 78 | replace it with kernel_size 3, stride 1 conv 79 | maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2 80 | enc_out_dim: set according to the out_channel count of 81 | encoder used (512 for resnet18, 2048 for resnet50) 82 | kl_coeff: coefficient for kl term of the loss 83 | latent_dim: dim of latent space 84 | lr: learning rate for Adam 85 | """ 86 | 87 | super(VAE, self).__init__() 88 | 89 | self.save_hyperparameters() 90 | 91 | self.lr = lr 92 | self.kl_coeff = kl_coeff 93 | self.enc_out_dim = enc_out_dim 94 | self.latent_dim = latent_dim 95 | self.input_height = input_height 96 | 97 | modules = [] 98 | if hidden_dims is None: 99 | hidden_dims = [32, 64, 128, 256, 512] 100 | 101 | # Build Encoder 102 | for h_dim in hidden_dims: 103 | modules.append( 104 | nn.Sequential( 105 | nn.Conv2d(in_channels, out_channels=h_dim, 106 | kernel_size= 3, stride= 2, padding = 1), 107 | nn.BatchNorm2d(h_dim), 108 | nn.LeakyReLU()) 109 | ) 110 | in_channels = h_dim 111 | 112 | self.encoder = nn.Sequential(*modules) 113 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 114 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 115 | 116 | # Build Decoder 117 | modules = [] 118 | 119 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 120 | 121 | hidden_dims.reverse() 122 | 123 | for i in range(len(hidden_dims) - 1): 124 | modules.append( 125 | nn.Sequential( 126 | nn.ConvTranspose2d(hidden_dims[i], 127 | hidden_dims[i + 1], 128 | kernel_size=3, 129 | stride = 2, 130 | padding=1, 131 | output_padding=1), 132 | nn.BatchNorm2d(hidden_dims[i + 1]), 133 | nn.LeakyReLU()) 134 | ) 135 | 136 | self.decoder = nn.Sequential(*modules) 137 | 138 | self.final_layer = nn.Sequential( 139 | nn.ConvTranspose2d(hidden_dims[-1], 140 | hidden_dims[-1], 141 | kernel_size=3, 142 | stride=2, 143 | padding=1, 144 | output_padding=1), 145 | nn.BatchNorm2d(hidden_dims[-1]), 146 | nn.LeakyReLU(), 147 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 148 | kernel_size= 3, padding= 1), 149 | nn.Sigmoid()) 150 | 151 | @staticmethod 152 | def pretrained_weights_available(): 153 | return list(VAE.pretrained_urls.keys()) 154 | 155 | def from_pretrained(self, checkpoint_name): 156 | if checkpoint_name not in VAE.pretrained_urls: 157 | raise KeyError(str(checkpoint_name) + ' not present in pretrained weights.') 158 | 159 | return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False) 160 | 161 | def forward(self, x): 162 | mu, log_var = self.encode(x) 163 | p, q, z = self.sample(mu, log_var) 164 | 165 | return self.decode(z) 166 | 167 | def encode(self, x): 168 | x = self.encoder(x) 169 | x = torch.flatten(x, start_dim=1) 170 | mu = self.fc_mu(x) 171 | log_var = self.fc_var(x) 172 | return mu, log_var 173 | 174 | def _run_step(self, x): 175 | mu, log_var = self.encode(x) 176 | p, q, z = self.sample(mu, log_var) 177 | 178 | return z, self.decode(z), p, q 179 | 180 | def decode(self, z): 181 | result = self.decoder_input(z) 182 | result = result.view(-1, 512, 2, 2) 183 | result = self.decoder(result) 184 | result = self.final_layer(result) 185 | return result 186 | 187 | def sample(self, mu, log_var): 188 | std = torch.exp(log_var / 2) 189 | p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) 190 | q = torch.distributions.Normal(mu, std) 191 | z = q.rsample() 192 | return p, q, z 193 | 194 | def step(self, batch, batch_idx): 195 | x, y = batch 196 | z, x_hat, p, q = self._run_step(x) 197 | 198 | recon_loss = F.mse_loss(x_hat, x, reduction='mean') 199 | 200 | log_qz = q.log_prob(z) 201 | log_pz = p.log_prob(z) 202 | 203 | kl = log_qz - log_pz 204 | kl = kl.mean() 205 | kl *= self.kl_coeff 206 | 207 | loss = kl + recon_loss 208 | 209 | logs = { 210 | "recon_loss": recon_loss, 211 | "kl": kl, 212 | "loss": loss, 213 | } 214 | return loss, logs 215 | 216 | def training_step(self, batch, batch_idx): 217 | loss, logs = self.step(batch, batch_idx) 218 | self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False) 219 | return loss 220 | 221 | def validation_step(self, batch, batch_idx): 222 | loss, logs = self.step(batch, batch_idx) 223 | self.log_dict({f"val_{k}": v for k, v in logs.items()}) 224 | return loss 225 | 226 | def configure_optimizers(self): 227 | return torch.optim.Adam(self.parameters(), lr=self.lr) 228 | 229 | if __name__ == "__main__": 230 | parser = ArgumentParser(description='Hyperparameters for our experiments') 231 | parser.add_argument('--latent-dim', type=int, default=256, help="size of latent dim for our vae") 232 | parser.add_argument('--epochs', type=int, default=50, help="num epochs") 233 | parser.add_argument('--gpus', type=int, default=1, help="gpus, if no gpu set to 0, to run on all gpus set to -1") 234 | parser.add_argument('--bs', type=int, default=500, help="batch size") 235 | parser.add_argument('--kl-coeff', type=int, default=5, help="kl coeff aka beta term in the elbo loss function") 236 | parser.add_argument('--lr', type=int, default=0.01, help="learning rate") 237 | hparams = parser.parse_args() 238 | 239 | m = VAE(input_height=IMAGE_SIZE, latent_dim=hparams.latent_dim, kl_coeff=hparams.kl_coeff, lr=hparams.lr) 240 | runner = Trainer(gpus = hparams.gpus, 241 | max_epochs = hparams.epochs) 242 | dm = CelebADataModule(data_dir=DATA_PATH, 243 | target_type='attr', 244 | train_transform=transform, 245 | val_transform=transform, 246 | download=True, 247 | batch_size=hparams.bs) 248 | runner.fit(m, datamodule=dm) 249 | torch.save(m.state_dict(), "vae-celeba-conv.ckpt") 250 | -------------------------------------------------------------------------------- /probml_utils/vae_conv_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | try: 3 | import torch 4 | except ModuleNotFoundError: 5 | os.system("pip install torch") 6 | import torch 7 | try: 8 | import torchvision.transforms as transforms 9 | except: 10 | os.system("pip install torchvision'") 11 | import torchvision.transforms as transforms 12 | 13 | import torch.nn as nn 14 | import numpy as np 15 | from torch.nn import functional as F 16 | from torchvision.datasets import MNIST 17 | from torch.utils.data import DataLoader 18 | import torchvision.transforms as transforms 19 | try: 20 | from pytorch_lightning import LightningModule, Trainer 21 | except: 22 | os.system("pip install pytorch-lightning") 23 | from pytorch_lightning import LightningModule, Trainer 24 | 25 | from argparse import ArgumentParser 26 | 27 | class ConvVAEModule(nn.Module): 28 | def __init__(self, input_shape, 29 | encoder_conv_filters, 30 | decoder_conv_t_filters, 31 | latent_dim, 32 | deterministic=False): 33 | super(ConvVAEModule, self).__init__() 34 | self.input_shape = input_shape 35 | 36 | self.latent_dim = latent_dim 37 | self.deterministic = deterministic 38 | 39 | all_channels = [self.input_shape[0]] + encoder_conv_filters 40 | 41 | self.enc_convs = nn.ModuleList([]) 42 | 43 | # encoder_conv_layers 44 | for i in range(len(encoder_conv_filters)): 45 | self.enc_convs.append(nn.Conv2d(all_channels[i], all_channels[i + 1], 46 | kernel_size=3, stride=2, padding=1)) 47 | if not self.latent_dim == 2: 48 | self.enc_convs.append(nn.BatchNorm2d(all_channels[i + 1])) 49 | self.enc_convs.append(nn.LeakyReLU()) 50 | 51 | self.flatten_out_size = self.flatten_enc_out_shape(input_shape) 52 | 53 | if self.latent_dim == 2: 54 | self.mu_linear = nn.Linear(self.flatten_out_size, self.latent_dim) 55 | else: 56 | self.mu_linear = nn.Sequential( 57 | nn.Linear(self.flatten_out_size, self.latent_dim), 58 | nn.LeakyReLU(), 59 | nn.Dropout(0.2) 60 | ) 61 | 62 | if self.latent_dim == 2: 63 | self.log_var_linear = nn.Linear(self.flatten_out_size, self.latent_dim) 64 | else: 65 | self.log_var_linear = nn.Sequential( 66 | nn.Linear(self.flatten_out_size, self.latent_dim), 67 | nn.LeakyReLU(), 68 | nn.Dropout(0.2) 69 | ) 70 | 71 | if self.latent_dim == 2: 72 | self.decoder_linear = nn.Linear(self.latent_dim, self.flatten_out_size) 73 | else: 74 | self.decoder_linear = nn.Sequential( 75 | nn.Linear(self.latent_dim, self.flatten_out_size), 76 | nn.LeakyReLU(), 77 | nn.Dropout(0.2) 78 | ) 79 | 80 | all_t_channels = [encoder_conv_filters[-1]] + decoder_conv_t_filters 81 | 82 | self.dec_t_convs = nn.ModuleList([]) 83 | 84 | num = len(decoder_conv_t_filters) 85 | 86 | # decoder_trans_conv_layers 87 | for i in range(num - 1): 88 | self.dec_t_convs.append(nn.UpsamplingNearest2d(scale_factor=2)) 89 | self.dec_t_convs.append(nn.ConvTranspose2d(all_t_channels[i], all_t_channels[i + 1], 90 | 3, stride=1, padding=1)) 91 | if not self.latent_dim == 2: 92 | self.dec_t_convs.append(nn.BatchNorm2d(all_t_channels[i + 1])) 93 | self.dec_t_convs.append(nn.LeakyReLU()) 94 | 95 | self.dec_t_convs.append(nn.UpsamplingNearest2d(scale_factor=2)) 96 | self.dec_t_convs.append(nn.ConvTranspose2d(all_t_channels[num - 1], all_t_channels[num], 97 | 3, stride=1, padding=1)) 98 | self.dec_t_convs.append(nn.Sigmoid()) 99 | 100 | def reparameterize(self, mu, log_var): 101 | std = torch.exp(0.5 * log_var) # standard deviation 102 | eps = torch.randn_like(std) # `randn_like` as we need the same size 103 | sample = mu + (eps * std) # sampling 104 | return sample 105 | 106 | def _run_step(self, x): 107 | mu, log_var = self.encode(x) 108 | std = torch.exp(0.5*log_var) 109 | p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) 110 | q = torch.distributions.Normal(mu, std) 111 | z = self.reparameterize(mu,log_var) 112 | recon = self.decode(z) 113 | return z, recon, p, q 114 | 115 | def flatten_enc_out_shape(self, input_shape): 116 | x = torch.zeros(1, *input_shape) 117 | for l in self.enc_convs: 118 | x = l(x) 119 | self.shape_before_flattening = x.shape 120 | return int(np.prod(self.shape_before_flattening)) 121 | 122 | def encode(self, x): 123 | for l in self.enc_convs: 124 | x = l(x) 125 | x = x.view(x.size()[0], -1) # flatten 126 | mu = self.mu_linear(x) 127 | log_var = self.log_var_linear(x) 128 | return mu, log_var 129 | 130 | def decode(self, z): 131 | z = self.decoder_linear(z) 132 | recon = z.view(z.size()[0], *self.shape_before_flattening[1:]) 133 | for l in self.dec_t_convs: 134 | recon = l(recon) 135 | return recon 136 | 137 | def forward(self, x): 138 | mu, log_var = self.encode(x) 139 | if self.deterministic: 140 | return self.decode(mu), mu, None 141 | else: 142 | z = self.reparameterize(mu, log_var) 143 | recon = self.decode(z) 144 | return recon, mu, log_var 145 | 146 | class ConvVAE(LightningModule): 147 | def __init__(self,input_shape, 148 | encoder_conv_filters, 149 | decoder_conv_t_filters, 150 | latent_dim, 151 | kl_coeff=0.1, 152 | lr = 0.001): 153 | super(ConvVAE, self).__init__() 154 | self.kl_coeff = kl_coeff 155 | self.lr = lr 156 | self.vae = ConvVAEModule(input_shape, encoder_conv_filters, decoder_conv_t_filters, latent_dim) 157 | 158 | def step(self, batch, batch_idx): 159 | x, y = batch 160 | z, x_hat, p, q = self.vae._run_step(x) 161 | 162 | recon_loss = F.binary_cross_entropy(x_hat, x, reduction='sum') 163 | 164 | log_qz = q.log_prob(z) 165 | log_pz = p.log_prob(z) 166 | 167 | kl = log_qz - log_pz 168 | kl = kl.sum() # I tried sum, here 169 | kl *= self.kl_coeff 170 | 171 | loss = kl + recon_loss 172 | 173 | logs = { 174 | "recon_loss": recon_loss, 175 | "kl": kl, 176 | "loss": loss, 177 | } 178 | return loss, logs 179 | 180 | def training_step(self, batch, batch_idx): 181 | loss, logs = self.step(batch, batch_idx) 182 | self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False) 183 | return loss 184 | 185 | def validation_step(self, batch, batch_idx): 186 | loss, logs = self.step(batch, batch_idx) 187 | self.log_dict({f"val_{k}": v for k, v in logs.items()}) 188 | return loss 189 | 190 | def configure_optimizers(self): 191 | return torch.optim.Adam(self.parameters(), lr=self.lr) 192 | 193 | if __name__ == "__main__": 194 | parser = ArgumentParser(description='Hyperparameters for our experiments') 195 | parser.add_argument('--bs', type=int, default=500, help="batch size") 196 | parser.add_argument('--epochs', type=int, default=50, help="num epochs") 197 | parser.add_argument('--latent-dim', type=int, default=2, help="size of latent dim for our vae") 198 | parser.add_argument('--lr', type=float, default=0.001, help="learning rate") 199 | parser.add_argument('--kl-coeff', type=int, default=5, help="kl coeff aka beta term in the elbo loss function") 200 | hparams = parser.parse_args() 201 | 202 | m = ConvVAE((1, 28, 28), 203 | encoder_conv_filters=[28,64,64], 204 | decoder_conv_t_filters=[64,28,1], 205 | latent_dim=hparams.latent_dim, kl_coeff=hparams.kl_coeff, lr=hparams.lr) 206 | 207 | mnist_full = MNIST(".", train=True, download=True, 208 | transform=transforms.Compose([transforms.ToTensor(), 209 | transforms.Resize((32,32))])) 210 | dm = DataLoader(mnist_full, batch_size=hparams.bs) 211 | trainer = Trainer(gpus=1, weights_summary='full', max_epochs=hparams.epochs) 212 | trainer.fit(m, dm) 213 | torch.save(m.state_dict(), "vae-mnist-conv.ckpt") -------------------------------------------------------------------------------- /probml_utils/vae_lightning_data.py: -------------------------------------------------------------------------------- 1 | # Source: https://raw.githubusercontent.com/sayantanauddy/vae_lightning/main/data.py 2 | from functools import partial 3 | import pandas as pd 4 | import os 5 | import PIL 6 | import glob 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader, random_split 10 | from torchvision import transforms, utils, io 11 | from torchvision.datasets.utils import verify_str_arg 12 | 13 | import pytorch_lightning as pl 14 | 15 | 16 | class CelebADataset(Dataset): 17 | """CelebA Dataset class""" 18 | 19 | def __init__(self, 20 | root, 21 | split="train", 22 | target_type="attr", 23 | transform=None, 24 | target_transform=None, 25 | download=False 26 | ): 27 | """ 28 | """ 29 | 30 | self.root = root 31 | self.split = split 32 | self.target_type = target_type 33 | self.transform = transform 34 | self.target_transform = target_transform 35 | 36 | if isinstance(target_type, list): 37 | self.target_type = target_type 38 | else: 39 | self.target_type = [target_type] 40 | 41 | if not self.target_type and self.target_transform is not None: 42 | raise RuntimeError('target_transform is specified but target_type is empty') 43 | 44 | if download: 45 | self.download_from_kaggle() 46 | 47 | split_map = { 48 | "train": 0, 49 | "valid": 1, 50 | "test": 2, 51 | "all": None, 52 | } 53 | 54 | split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))] 55 | 56 | fn = partial(os.path.join, self.root) 57 | splits = pd.read_csv(fn("list_eval_partition.csv"), delim_whitespace=False, header=0, index_col=0) 58 | # This file is not available in Kaggle 59 | # identity = pd.read_csv(fn("identity_CelebA.csv"), delim_whitespace=True, header=None, index_col=0) 60 | bbox = pd.read_csv(fn("list_bbox_celeba.csv"), delim_whitespace=False, header=0, index_col=0) 61 | landmarks_align = pd.read_csv(fn("list_landmarks_align_celeba.csv"), delim_whitespace=False, header=0, index_col=0) 62 | attr = pd.read_csv(fn("list_attr_celeba.csv"), delim_whitespace=False, header=0, index_col=0) 63 | 64 | mask = slice(None) if split_ is None else (splits['partition'] == split_) 65 | 66 | self.filename = splits[mask].index.values 67 | # self.identity = torch.as_tensor(identity[mask].values) 68 | self.bbox = torch.as_tensor(bbox[mask].values) 69 | self.landmarks_align = torch.as_tensor(landmarks_align[mask].values) 70 | self.attr = torch.as_tensor(attr[mask].values) 71 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 72 | self.attr_names = list(attr.columns) 73 | 74 | def download_from_kaggle(self): 75 | 76 | # Annotation files will be downloaded at the end 77 | label_files = ['list_attr_celeba.csv', 'list_bbox_celeba.csv', 'list_eval_partition.csv', 'list_landmarks_align_celeba.csv'] 78 | 79 | # Check if files have been downloaded already 80 | files_exist = False 81 | for label_file in label_files: 82 | if os.path.isfile(os.path.join(self.root, label_file)): 83 | files_exist = True 84 | else: 85 | files_exist = False 86 | 87 | if files_exist: 88 | print("Files exist already") 89 | else: 90 | print("Downloading dataset. Please while while the download and extraction processes complete") 91 | # Download files from Kaggle using its API as per 92 | # https://stackoverflow.com/questions/55934733/documentation-for-kaggle-api-within-python 93 | 94 | # Kaggle authentication 95 | # Remember to place the API token from Kaggle in $HOME/.kaggle 96 | from kaggle.api.kaggle_api_extended import KaggleApi 97 | api = KaggleApi() 98 | api.authenticate() 99 | 100 | # Download all files of a dataset 101 | # Signature: dataset_download_files(dataset, path=None, force=False, quiet=True, unzip=False) 102 | api.dataset_download_files(dataset='jessicali9530/celeba-dataset', 103 | path=self.root, 104 | unzip=True, 105 | force=False, 106 | quiet=False) 107 | 108 | # Downoad the label files 109 | # Signature: dataset_download_file(dataset, file_name, path=None, force=False, quiet=True) 110 | for label_file in label_files: 111 | api.dataset_download_file(dataset='jessicali9530/celeba-dataset', 112 | file_name=label_file, 113 | path=self.root, 114 | force=False, 115 | quiet=False) 116 | 117 | # Clear any remaining *.csv.zip files 118 | files_to_delete = glob.glob(os.path.join(self.root,"*.csv.zip")) 119 | for f in files_to_delete: 120 | os.remove(f) 121 | 122 | print("Done!") 123 | 124 | 125 | def __getitem__(self, index: int): 126 | X = PIL.Image.open(os.path.join(self.root, 127 | "img_align_celeba", 128 | "img_align_celeba", 129 | self.filename[index])) 130 | 131 | target = [] 132 | for t in self.target_type: 133 | if t == "attr": 134 | target.append(self.attr[index, :]) 135 | # elif t == "identity": 136 | # target.append(self.identity[index, 0]) 137 | elif t == "bbox": 138 | target.append(self.bbox[index, :]) 139 | elif t == "landmarks": 140 | target.append(self.landmarks_align[index, :]) 141 | else: 142 | raise ValueError(f"Target type {t} is not recognized") 143 | 144 | if self.transform is not None: 145 | X = self.transform(X) 146 | 147 | if target: 148 | target = tuple(target) if len(target) > 1 else target[0] 149 | 150 | if self.target_transform is not None: 151 | target = self.target_transform(target) 152 | else: 153 | target = None 154 | 155 | return X, target 156 | 157 | def __len__(self) -> int: 158 | return len(self.attr) 159 | 160 | 161 | class CelebADataModule(pl.LightningDataModule): 162 | 163 | def __init__(self, 164 | data_dir, 165 | target_type="attr", 166 | train_transform=None, 167 | val_transform=None, 168 | target_transform=None, 169 | download=False, 170 | batch_size=32, 171 | num_workers=8): 172 | 173 | super().__init__() 174 | 175 | self.data_dir = data_dir 176 | self.target_type = target_type 177 | self.train_transform = train_transform 178 | self.val_transform = val_transform 179 | self.target_transform = target_transform 180 | self.download = download 181 | 182 | self.batch_size = batch_size 183 | self.num_workers = num_workers 184 | 185 | def setup(self, stage=None): 186 | 187 | # Training dataset 188 | self.celebA_trainset = CelebADataset(root=self.data_dir, 189 | split='train', 190 | target_type=self.target_type, 191 | download=self.download, 192 | transform=self.train_transform, 193 | target_transform=self.target_transform) 194 | 195 | # Validation dataset 196 | self.celebA_valset = CelebADataset(root=self.data_dir, 197 | split='valid', 198 | target_type=self.target_type, 199 | download=False, 200 | transform=self.val_transform, 201 | target_transform=self.target_transform) 202 | 203 | # Test dataset 204 | self.celebA_testset = CelebADataset(root=self.data_dir, 205 | split='test', 206 | target_type=self.target_type, 207 | download=False, 208 | transform=self.val_transform, 209 | target_transform=self.target_transform) 210 | 211 | def train_dataloader(self): 212 | return DataLoader(self.celebA_trainset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=self.num_workers) 213 | 214 | def val_dataloader(self): 215 | return DataLoader(self.celebA_valset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=self.num_workers) 216 | 217 | def test_dataloader(self): 218 | return DataLoader(self.celebA_testset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=self.num_workers) 219 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=50.0", 4 | "setuptools_scm[toml]>=6.0", 5 | "setuptools_scm_git_archive", 6 | "wheel>=0.33", 7 | "numpy>=1.16", 8 | "cython>=0.29", 9 | ] 10 | 11 | [tool.setuptools_scm] 12 | write_to = "probml_utils/_version.py" -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | setuptools_scm[toml] 2 | setuptools_scm_git_archive -------------------------------------------------------------------------------- /requirements-extra.txt: -------------------------------------------------------------------------------- 1 | pgmpy # torch will be installed by this 2 | imageio 3 | tensorflow 4 | einops 5 | torchvision 6 | umap-learn 7 | pytorch_lightning -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | arviz 2 | jaxlib 3 | jax 4 | jaxopt 5 | jupyter 6 | matplotlib 7 | scikit-learn 8 | scipy 9 | graphviz 10 | distrax==0.1.3 11 | umap-learn 12 | pandas 13 | TexSoup 14 | firebase_admin 15 | regex -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = probml_utils 3 | author = Kevin P Murphy 4 | author-email = murphyk@gmail.com 5 | description = Utilities for probabilistic ML 6 | url = https://github.com/probml/probml-utils 7 | license = MIT 8 | long_description_content_type = text/markdown 9 | long_description = file: README.md 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("requirements.txt") as f: 4 | requirements = f.read().strip().split("\n") 5 | 6 | setup( 7 | packages=find_packages(), 8 | python_requires=">=3.6", 9 | install_requires=requirements, 10 | include_package_data=True, 11 | ) 12 | -------------------------------------------------------------------------------- /tests/LogisticRegression_test.py: -------------------------------------------------------------------------------- 1 | from probml_utils.LogisticRegression import binary_loss_function, multi_loss_function, fit 2 | import jax.numpy as jnp 3 | import jax 4 | from sklearn import datasets 5 | from sklearn.linear_model import LogisticRegression 6 | import pytest 7 | 8 | def test_binary_loss_function(): 9 | x = jnp.array([[1, 1, 2]]) 10 | weights = jnp.array([-1, 0.4, -0.2]) 11 | y = jnp.array([0.26]) 12 | loss = 0.5832 13 | assert jnp.allclose(loss, binary_loss_function(weights, auxs=[x, y, 0.1])[0], rtol=1e-2) 14 | 15 | def test_multi_loss_function(): 16 | x = jnp.array([[1, 1, 2]]) 17 | weights = jnp.array([ 18 | [0, 1, 2],[0.3, 0.4, 0.5], [0.5, 0.2, -0.3] 19 | ]) 20 | y = jnp.array([[0.22, 0.37, 0.41]]) 21 | loss = 1.110 22 | assert jnp.allclose(loss, multi_loss_function(weights, auxs=[x, y, 0.1])[0], rtol=1e-2) 23 | 24 | def test_fit_function(): 25 | iris = datasets.load_iris() 26 | train_x = iris["data"][:, (2,3)] 27 | train_y = (iris["target"] == 2).astype(jnp.int32) 28 | 29 | #lr from scratch using jax 30 | weights, intercept_, coef_ = fit(train_x, train_y, 10000, lambd=0.01) 31 | 32 | #lr from sklearn 33 | lr = LogisticRegression(C=100) 34 | lr.fit(train_x, train_y) 35 | 36 | assert jnp.allclose(intercept_, lr.intercept_, rtol=1e-2) 37 | assert jnp.allclose(coef_.T, lr.coef_, rtol=1e-2) 38 | 39 | 40 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probml/probml-utils/638ebb99f2e1ecc32f97bc02e8b81a8798c38c77/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_extension.py: -------------------------------------------------------------------------------- 1 | import os 2 | from probml_utils import _get_fig_name 3 | 4 | 5 | def test_extension(): 6 | os.environ["LATEXIFY"] = "" 7 | assert _get_fig_name("test.pdf") == "test_latexified.pdf" 8 | assert _get_fig_name("test.png") == "test_latexified.pdf" 9 | assert _get_fig_name("test.jpg") == "test_latexified.pdf" 10 | assert _get_fig_name("test") == "test_latexified.pdf" 11 | os.environ.pop("LATEXIFY") 12 | assert _get_fig_name("test.pdf") == "test.pdf" 13 | assert _get_fig_name("test.png") == "test.pdf" 14 | assert _get_fig_name("test.jpg") == "test.pdf" 15 | assert _get_fig_name("test") == "test.pdf" 16 | -------------------------------------------------------------------------------- /tests/test_import.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import probml_utils as pml 3 | 4 | 5 | def test_import_scripts(): 6 | # Files 7 | pml.__version__ 8 | pml.savefig 9 | pml.latexify 10 | pml.hinton_diagram 11 | pml.plot_ellipse 12 | pml.convergence_test 13 | pml.kdeg 14 | pml.scale_3d 15 | pml.style3d 16 | 17 | 18 | def test_import_modules(): 19 | from probml_utils import fisher_lda_fit 20 | from probml_utils import gauss_utils 21 | from probml_utils import gmm_lib 22 | from probml_utils import mix_bernoulli_lib 23 | from probml_utils import mixture_lib 24 | from probml_utils import pgmpy_utils 25 | from probml_utils import plotting 26 | from probml_utils import prefit_voting_classifier 27 | from probml_utils import pyprobml_utils 28 | from probml_utils import rvm_classifier, rvm_regressor 29 | from probml_utils import mnist_helper_tf 30 | from probml_utils import vae_conv_mnist 31 | from probml_utils import lvm_plots_utils 32 | from probml_utils import url_utils 33 | from probml_utils import mix_bernoulli_lib 34 | from probml_utils import mix_bernoulli_em_mnist 35 | from probml_utils import vae_celeba_lightning 36 | from probml_utils import mfa_celeba_helpers 37 | from probml_utils import download_celeba 38 | from probml_utils import ae_mnist_conv 39 | -------------------------------------------------------------------------------- /tests/test_latexify_status.py: -------------------------------------------------------------------------------- 1 | import probml_utils as pml 2 | import os 3 | 4 | def test_latexify_disabled(): 5 | if "LATEXIFY" in os.environ: 6 | os.environ.pop("LATEXIFY") 7 | assert(pml.is_latexify_enabled() == False) 8 | 9 | def test_latexify_enabled(): 10 | os.environ["LATEXIFY"] = "" 11 | assert(pml.is_latexify_enabled() == True) -------------------------------------------------------------------------------- /tests/test_save.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import os 4 | import probml_utils as pml 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | @pytest.mark.parametrize("latexify", [True, False]) 9 | def test_save(latexify): 10 | if latexify: 11 | os.environ["LATEXIFY"] = "" 12 | suffix = "_latexified" 13 | else: 14 | if "LATEXIFY" in os.environ: 15 | os.environ.pop("LATEXIFY") 16 | suffix = "" 17 | 18 | os.environ["FIG_DIR"] = "figures" 19 | pml.latexify(width_scale_factor=2, fig_height=1.5) 20 | plt.plot([1.0, 2.0], [3.0, 4.0]) 21 | save_name = os.path.join(os.environ["FIG_DIR"], f"test{suffix}.pdf") 22 | if os.path.exists(save_name): 23 | os.remove(save_name) 24 | pml.savefig("test") 25 | assert os.path.exists(save_name) 26 | -------------------------------------------------------------------------------- /tests/test_url_utils.py: -------------------------------------------------------------------------------- 1 | from probml_utils.url_utils import check_dead_urls, github_url_to_colab_url, colab_url_to_github_url 2 | import pytest 3 | 4 | def test_dead_urls(): 5 | links = {'1.3': 'https://github.com/probml/pyprobml/blob/master/notebooks/book1/01/iris_plot.ipynb', 6 | '1.4': 'https://github.com/probml/pyprobml/blob/master/notebooks/book1/01/iris_dtree.ipynb', 7 | '1.5': 'https://github.com/probml/pyprobml/blob/master/notebooks/book1/01/linreg_residuals_plot_broken.ipynb', 8 | '1.6': 'https://github.com/probml/pyprobml/blob/master/notebooks/book1/01/linreg_2d_surface_demo_broken.ipynb'} 9 | 10 | status = check_dead_urls(links) 11 | status_true = {'1.3': 0, '1.4': 0, '1.5': 1, '1.6': 1} 12 | assert (status == status_true) 13 | 14 | def test_github_to_colab(): 15 | links = { 16 | "https://github.com/probml/pyprobml/blob/master/notebooks/book1/13/mlp_1d_regression_hetero_tfp.ipynb":"https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/13/mlp_1d_regression_hetero_tfp.ipynb"} 17 | 18 | for link in links: 19 | assert links[link] == github_url_to_colab_url(link) 20 | 21 | invalid_links = ["https://google.com/probml/pyprobml/blob/master/notebooks/book1/13/mlp_1d_regression_hetero_tfp.ipynb", 22 | "https://github.com/probml/pyprobml/blob/master/notebooks/book1/13/mlp_1d_regression_hetero_tfp.py"] 23 | 24 | for link in invalid_links: 25 | with pytest.raises(ValueError): 26 | github_url_to_colab_url(link) 27 | 28 | def test_colab_to_github(): 29 | links = { 30 | "https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/13/mlp_1d_regression_hetero_tfp.ipynb":"https://github.com/probml/pyprobml/blob/master/notebooks/book1/13/mlp_1d_regression_hetero_tfp.ipynb"} 31 | 32 | for link in links: 33 | assert links[link] == colab_url_to_github_url(link) 34 | 35 | invalid_links = ["https://github.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/13/mlp_1d_regression_hetero_tfp.ipynb", 36 | "https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/13/mlp_1d_regression_hetero_tfp.txt"] 37 | 38 | for link in invalid_links: 39 | with pytest.raises(ValueError): 40 | colab_url_to_github_url(link) --------------------------------------------------------------------------------