├── results ├── VAE │ ├── trainVAE-g.pt │ ├── trainVAE.mat │ ├── trainVAE.pt │ └── trainVAE.txt ├── OTFlow │ ├── trainOTFlow.mat │ ├── trainOTFlow.pt │ └── trainOTFlow ├── realNVP │ ├── trainRealNVP.mat │ ├── trainRealNVP.pt │ └── trainRealNVP.txt ├── DCGAN │ ├── trainWGAN-vaeInit.mat │ ├── trainDCGAN-vaeInit-d.pt │ ├── trainDCGAN-vaeInit-g.pt │ ├── trainDCGAN-vaeInit.mat │ ├── trainWGAN-vaeInit-d.pt │ ├── trainWGAN-vaeInit-g.pt │ ├── trainDCGAN-randomInit-d.pt │ ├── trainDCGAN-randomInit-g.pt │ ├── trainDCGAN-randomInit.mat │ ├── trainWGAN-randomInit-d.pt │ ├── trainWGAN-randomInit-g.pt │ ├── trainWGAN-randomInit.mat │ ├── trainWGAN-vaeInit-his.png │ ├── trainDCGAN-randomInit.txt │ └── trainDCGAN-vaeInit.txt └── WGAN │ ├── trainWGAN-randomInit.txt │ └── trainWGAN-vaeInit.txt ├── .idea ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml └── modules.xml ├── testAll.sh ├── LICENSE ├── epsTest.py ├── plotting.py ├── runAll.sh ├── .gitignore ├── vae.py ├── README.md ├── modelMNIST.py ├── trainVAEmnist.py ├── trainOTFlow.py ├── realNVP.py ├── trainRealNVP.py ├── trainDCGANmnist.py ├── trainWGANmnist.py ├── OTFlow.py └── Phi.py /results/VAE/trainVAE-g.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/VAE/trainVAE-g.pt -------------------------------------------------------------------------------- /results/VAE/trainVAE.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/VAE/trainVAE.mat -------------------------------------------------------------------------------- /results/VAE/trainVAE.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/VAE/trainVAE.pt -------------------------------------------------------------------------------- /results/OTFlow/trainOTFlow.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/OTFlow/trainOTFlow.mat -------------------------------------------------------------------------------- /results/OTFlow/trainOTFlow.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/OTFlow/trainOTFlow.pt -------------------------------------------------------------------------------- /results/realNVP/trainRealNVP.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/realNVP/trainRealNVP.mat -------------------------------------------------------------------------------- /results/realNVP/trainRealNVP.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/realNVP/trainRealNVP.pt -------------------------------------------------------------------------------- /results/DCGAN/trainWGAN-vaeInit.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-vaeInit.mat -------------------------------------------------------------------------------- /results/DCGAN/trainDCGAN-vaeInit-d.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-vaeInit-d.pt -------------------------------------------------------------------------------- /results/DCGAN/trainDCGAN-vaeInit-g.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-vaeInit-g.pt -------------------------------------------------------------------------------- /results/DCGAN/trainDCGAN-vaeInit.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-vaeInit.mat -------------------------------------------------------------------------------- /results/DCGAN/trainWGAN-vaeInit-d.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-vaeInit-d.pt -------------------------------------------------------------------------------- /results/DCGAN/trainWGAN-vaeInit-g.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-vaeInit-g.pt -------------------------------------------------------------------------------- /results/DCGAN/trainDCGAN-randomInit-d.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-randomInit-d.pt -------------------------------------------------------------------------------- /results/DCGAN/trainDCGAN-randomInit-g.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-randomInit-g.pt -------------------------------------------------------------------------------- /results/DCGAN/trainDCGAN-randomInit.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-randomInit.mat -------------------------------------------------------------------------------- /results/DCGAN/trainWGAN-randomInit-d.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-randomInit-d.pt -------------------------------------------------------------------------------- /results/DCGAN/trainWGAN-randomInit-g.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-randomInit-g.pt -------------------------------------------------------------------------------- /results/DCGAN/trainWGAN-randomInit.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-randomInit.mat -------------------------------------------------------------------------------- /results/DCGAN/trainWGAN-vaeInit-his.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-vaeInit-his.png -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /results/realNVP/trainRealNVP.txt: -------------------------------------------------------------------------------- 1 | ------device=cpu, K=6, width=128, batch_size=256, num_steps=20000------ 2 | ------out_file: results/realNVP/trainRealNVP------ 3 | step J_ML 4 | 001000 -4.7034e+00 5 | 002000 -5.5544e+00 6 | 003000 -5.5929e+00 7 | 004000 -5.6134e+00 8 | 005000 -5.6250e+00 9 | 006000 -5.6343e+00 10 | 007000 -5.6390e+00 11 | 008000 -5.6381e+00 12 | 009000 -5.6421e+00 13 | 010000 -5.6439e+00 14 | 011000 -5.6426e+00 15 | 012000 -5.6476e+00 16 | 013000 -5.6472e+00 17 | 014000 -5.6476e+00 18 | 015000 -5.6507e+00 19 | 016000 -5.6488e+00 20 | 017000 -5.6525e+00 21 | 018000 -5.6557e+00 22 | 019000 -5.6495e+00 23 | 020000 -5.6543e+00 24 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 19 | -------------------------------------------------------------------------------- /testAll.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python3 -u trainRealNVP.py --batch_size 64 --width 32 --K 4 --num_steps 50 --plot_interval 10 --out_file /tmp/realNVP/test 2> testLog.txt 4 | python3 -u trainOTFlow.py --batch_size 64 --width 32 --nt 4 --num_steps 50 --plot_interval 10 --out_file /tmp/OTFlow/test 2>> testLog.txt 5 | python3 -u trainVAEmnist.py --batch_size 256 --num_epochs 2 --width_enc 8 --width_dec 8 --out_file /tmp/VAE/test 2>> testLog.txt 6 | python3 -u trainDCGANmnist.py --batch_size 64 --num_steps 50 --width_disc 8 --width_dec 8 --plot_interval 10 --out_file /tmp/DCGAN/test 2>> testLog.txt 7 | python3 -u trainWGANmnist.py --batch_size 64 --num_steps 50 --width_disc 8 --width_dec 8 --plot_interval 10 --out_file /tmp/WGAN/test 2>> testLog.txt 8 | python3 -u trainDCGANmnist.py --batch_size 64 --num_steps 50 --width_disc 8 --width_dec 8 --plot_interval 10 --init_g /tmp/VAE/test-g.pt --out_file /tmp/DCGAN/test2 2>> testLog.txt 9 | python3 -u trainWGANmnist.py --batch_size 64 --num_steps 50 --width_disc 8 --width_dec 8 --plot_interval 10 --init_g /tmp/VAE/test-g.pt --out_file /tmp/WGAN/test2 2>> testLog.txt 10 | 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Lars Ruthotto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /epsTest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def epsTest(X, Y, eps=1e-1): 4 | """ 5 | Test for equal distributions suggested in Székely, G. J., InterStat, M. R., 2004. (n.d.). 6 | Testing for equal distributions in high dimension. Personal.Bgsu.Edu. 7 | 8 | :param X: Samples from first distribution 9 | :param Y: Samples from second distribution 10 | :param eps: conditioning paramter 11 | :return: 12 | """ 13 | nx = X.shape[0] 14 | ny = Y.shape[0] 15 | 16 | X = X.view(nx, -1) 17 | Y = Y.view(ny, -1) 18 | 19 | sX = torch.norm(X, dim=1) ** 2; 20 | sY = torch.norm(Y, dim=1) ** 2; 21 | 22 | CXX = sX.unsqueeze(1) + sX.unsqueeze(0) - 2 * X @ X.t() 23 | CXX = torch.sqrt(CXX + eps) 24 | 25 | CYY = sY.unsqueeze(1) + sY.unsqueeze(0) - 2 * Y @ Y.t() 26 | CYY = torch.sqrt(CYY + eps) 27 | 28 | CXY = sX.unsqueeze(1) + sY.unsqueeze(0) - 2 * X @ Y.t() 29 | CXY = torch.sqrt(CXY + eps) 30 | 31 | D = (nx * ny) / (nx + ny) * (2.0 / (nx * ny) * torch.sum(CXY) 32 | - 1.0 / nx ** 2 * (torch.sum(CXX)) - 1.0 / ny ** 2 * (torch.sum(CYY))); 33 | 34 | return D / (nx + ny) -------------------------------------------------------------------------------- /plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | 4 | 5 | def plot_x(xs,domain=(0,1,0,1)): 6 | plt.plot(xs[:, 0], xs[:, 1], "bs") 7 | plt.axis("square") 8 | plt.axis(domain) 9 | plt.xticks(domain[0:2]) 10 | plt.yticks(domain[2:]) 11 | plt.xlabel("$\mathbf{x}_1$", labelpad=-20) 12 | plt.ylabel("$\mathbf{x}_2$", labelpad=-30) 13 | 14 | 15 | def plot_z(zs): 16 | plt.plot(zs[:, 0], zs[:, 1], "or") 17 | plt.axis("square") 18 | plt.axis((-3.5, 3.5, -3.5, 3.5)) 19 | plt.xticks((-3.5, 3.5)) 20 | plt.yticks((-3.5, 3.5)) 21 | plt.xlabel("$\mathbf{z}_1$", labelpad=-20) 22 | plt.ylabel("$\mathbf{z}_2$", labelpad=-30) 23 | 24 | 25 | def plot_px(log_px,domain=(0,1,0,1)): 26 | px = torch.exp(log_px) 27 | img = px 28 | plt.imshow(img.t(), extent=domain,origin='lower') 29 | plt.axis("square") 30 | plt.axis(domain) 31 | plt.xticks(domain[0:2]) 32 | plt.yticks(domain[2:]) 33 | plt.xlabel("$\mathbf{x}_1$", labelpad=-20) 34 | plt.ylabel("$\mathbf{x}_2$", labelpad=-30) 35 | 36 | def plot_pz(zz,domain=(-3.5, 3.5, -3.5, 3.5)): 37 | plt.hist2d(zz[:,0], zz[:,1],bins=256,range=[[-domain[0], domain[1]], [domain[2], domain[3]]]) 38 | plt.axis("square") 39 | plt.axis(domain) 40 | plt.xticks(domain[0:2]) 41 | plt.yticks(domain[2:]) 42 | plt.xlabel("$\mathbf{z}_1$", labelpad=-20) 43 | plt.ylabel("$\mathbf{z}_2$", labelpad=-30) -------------------------------------------------------------------------------- /runAll.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python3 -u trainRealNVP.py --batch_size 256 --width 128 --K 6 --num_steps 20000 --plot_interval 1000 --out_file results/realNVP/trainRealNVP | tee results/realNVP/trainRealNVP.txt 4 | python3 -u trainOTFlow.py --batch_size 256 --width 32 --nt 4 --num_steps 20000 --plot_interval 1000 --out_file results/OTFlow/trainOTFlow |tee results/OTFlow/trainOTFlow 5 | python3 -u trainVAEmnist.py --batch_size 64 --num_epochs 50 --width_enc 32 --width_dec 32 --out_file results/VAE/trainVAE |tee results/VAE/trainVAE.txt 6 | python3 -u trainDCGANmnist.py --batch_size 64 --num_steps 50000 --width_disc 32 --width_dec 32 --plot_interval 1000 --out_file results/DCGAN/trainDCGAN-randomInit |tee results/DCGAN/trainDCGAN-randomInit.txt 7 | python3 -u trainWGANmnist.py --batch_size 64 --num_steps 50000 --width_disc 32 --width_dec 32 --plot_interval 1000 --out_file results/DCGAN/trainWGAN-randomInit |tee results/WGAN/trainWGAN-randomInit.txt 8 | python3 -u trainDCGANmnist.py --batch_size 64 --num_steps 50000 --width_disc 32 --width_dec 32 --plot_interval 1000 --init_g results/VAE/trainVAE-g.pt --out_file results/WGAN/trainDCGAN-vaeInit |tee results/DCGAN/trainDCGAN-vaeInit.txt 9 | python3 -u trainWGANmnist.py --batch_size 64 --num_steps 50000 --width_disc 32 --width_dec 32 --plot_interval 1000 --init_g results/VAE/trainVAE-g.pt --out_file results/WGAN/trainWGAN-vaeInit |tee results/WGAN/trainWGAN-vaeInit.txt 10 | 11 | 12 | -------------------------------------------------------------------------------- /results/OTFlow/trainOTFlow: -------------------------------------------------------------------------------- 1 | ------device=cpu, nTh=2, width=32, batch_size=256, num_steps=20000------ 2 | ------out_file: results/OTFlow/trainOTFlow------ 3 | step J J_L J_ML J_HJB 4 | 001000 1.6560e+01 3.1925e-01 1.5329e+00 1.8234e-01 5 | 002000 9.6450e+00 3.3269e-01 8.6591e-01 1.3064e-01 6 | 003000 7.5762e+00 3.5424e-01 6.5991e-01 1.2457e-01 7 | 004000 6.8949e+00 3.6180e-01 5.9144e-01 1.2373e-01 8 | 005000 6.4849e+00 3.6729e-01 5.4807e-01 1.2739e-01 9 | 006000 6.1083e+00 3.6985e-01 5.1064e-01 1.2642e-01 10 | 007000 5.8387e+00 3.6832e-01 4.8543e-01 1.2322e-01 11 | 008000 5.7509e+00 3.6623e-01 4.7970e-01 1.1755e-01 12 | 009000 5.6017e+00 3.6534e-01 4.6686e-01 1.1355e-01 13 | 010000 5.4999e+00 3.6431e-01 4.5894e-01 1.0924e-01 14 | 011000 5.3779e+00 3.6332e-01 4.4920e-01 1.0452e-01 15 | 012000 5.2646e+00 3.6242e-01 4.3959e-01 1.0126e-01 16 | 013000 5.1163e+00 3.6151e-01 4.2714e-01 9.6683e-02 17 | 014000 5.0453e+00 3.5966e-01 4.2190e-01 9.3342e-02 18 | 015000 4.9553e+00 3.5814e-01 4.1512e-01 8.9203e-02 19 | 016000 4.8560e+00 3.5744e-01 4.0607e-01 8.7583e-02 20 | 017000 4.7735e+00 3.5802e-01 3.9844e-01 8.6216e-02 21 | 018000 4.7460e+00 3.5733e-01 3.9660e-01 8.4522e-02 22 | 019000 4.6819e+00 3.5838e-01 3.9125e-01 8.2200e-02 23 | 020000 4.6584e+00 3.5816e-01 3.8998e-01 8.0097e-02 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /results/WGAN/trainWGAN-randomInit.txt: -------------------------------------------------------------------------------- 1 | ------device=cuda:0, q=2, batch_size=64, num_steps=50000, w_disc=32, w_dec=32------ 2 | ------out_file: results/DCGAN/trainWGAN-randomInit------ 3 | step J_GAN J_Gen ProbDist 4 | 001000 9.0434e-02 3.6219e-02 9.5863e+02 5 | 002000 3.0011e-02 1.7892e-02 2.3149e+02 6 | 003000 2.4966e-02 1.2125e-02 1.6959e+02 7 | 004000 2.0141e-02 7.8579e-03 1.4928e+02 8 | 005000 1.7243e-02 4.9258e-03 1.3729e+02 9 | 006000 1.4991e-02 4.5250e-03 1.2822e+02 10 | 007000 1.3449e-02 4.1725e-03 1.2226e+02 11 | 008000 1.2059e-02 3.4655e-03 1.1704e+02 12 | 009000 1.0941e-02 4.7856e-03 1.1407e+02 13 | 010000 9.9864e-03 3.7305e-03 1.1019e+02 14 | 011000 9.3497e-03 6.0258e-03 1.0907e+02 15 | 012000 8.7816e-03 7.7004e-03 1.0663e+02 16 | 013000 8.3173e-03 5.8963e-03 1.0627e+02 17 | 014000 7.9613e-03 4.1980e-03 1.0445e+02 18 | 015000 7.5758e-03 4.1143e-03 1.0471e+02 19 | 016000 7.2847e-03 5.2026e-03 1.0285e+02 20 | 017000 7.0688e-03 5.4935e-03 1.0152e+02 21 | 018000 6.7694e-03 6.2888e-03 1.0057e+02 22 | 019000 6.4661e-03 5.1770e-03 1.0075e+02 23 | 020000 6.2607e-03 5.3424e-03 9.9820e+01 24 | 021000 6.0518e-03 3.9571e-03 9.9159e+01 25 | 022000 5.8393e-03 3.5221e-03 9.8704e+01 26 | 023000 5.6738e-03 3.4962e-03 9.8189e+01 27 | 024000 5.5298e-03 2.3215e-03 9.7550e+01 28 | 025000 5.4000e-03 1.7744e-03 9.6966e+01 29 | 026000 5.2774e-03 1.0285e-03 9.6718e+01 30 | 027000 5.2354e-03 9.6602e-04 9.6132e+01 31 | 028000 5.1198e-03 8.4378e-04 9.6180e+01 32 | 029000 5.0476e-03 1.6136e-03 9.5849e+01 33 | 030000 4.9132e-03 1.7841e-03 9.5108e+01 34 | 031000 4.8423e-03 2.8072e-03 9.5625e+01 35 | 032000 4.7830e-03 2.4392e-03 9.5571e+01 36 | 033000 4.7676e-03 2.6003e-03 9.5113e+01 37 | 034000 4.6369e-03 3.0293e-03 9.4740e+01 38 | 035000 4.5631e-03 2.2041e-03 9.4086e+01 39 | 036000 4.4883e-03 3.2473e-03 9.3532e+01 40 | 037000 4.4246e-03 2.9545e-03 9.3902e+01 41 | 038000 4.3665e-03 3.2789e-03 9.4058e+01 42 | 039000 4.3227e-03 2.3956e-03 9.3894e+01 43 | 040000 4.2762e-03 1.6265e-03 9.2691e+01 44 | 041000 4.2137e-03 1.3642e-03 9.4176e+01 45 | 042000 4.1963e-03 1.1782e-03 9.3716e+01 46 | 043000 4.1571e-03 8.6539e-04 9.3948e+01 47 | 044000 4.1267e-03 1.3113e-03 9.3176e+01 48 | 045000 4.1107e-03 4.9682e-04 9.3393e+01 49 | 046000 4.0724e-03 6.0957e-04 9.3622e+01 50 | 047000 4.0431e-03 4.6356e-04 9.2811e+01 51 | 048000 4.0419e-03 2.5190e-05 9.3241e+01 52 | 049000 4.0043e-03 -5.5302e-06 9.3191e+01 53 | 050000 3.9637e-03 3.9199e-04 9.2859e+01 54 | -------------------------------------------------------------------------------- /results/WGAN/trainWGAN-vaeInit.txt: -------------------------------------------------------------------------------- 1 | initialize g with weights in results/VAE/trainVAE-g.pt 2 | ------device=cuda:0, q=2, batch_size=64, num_steps=50000, w_disc=32, w_dec=32------ 3 | ------out_file: results/DCGAN/trainWGAN-vaeInit------ 4 | step J_GAN J_Gen ProbDist 5 | 001000 9.8016e-02 3.7898e-02 2.1267e+02 6 | 002000 4.0626e-02 2.3535e-02 1.7130e+02 7 | 003000 2.4519e-02 5.8282e-03 1.2997e+02 8 | 004000 1.9377e-02 3.2689e-02 1.1579e+02 9 | 005000 1.4879e-02 3.2376e-02 1.0757e+02 10 | 006000 1.2726e-02 3.8867e-02 1.0229e+02 11 | 007000 1.1038e-02 3.1195e-02 1.0047e+02 12 | 008000 1.0462e-02 3.2218e-02 9.9698e+01 13 | 009000 9.8914e-03 3.0616e-02 9.7902e+01 14 | 010000 9.3823e-03 2.8720e-02 9.7652e+01 15 | 011000 9.0339e-03 2.7796e-02 9.6153e+01 16 | 012000 8.5916e-03 2.4935e-02 9.5867e+01 17 | 013000 8.2343e-03 2.5079e-02 9.4840e+01 18 | 014000 7.9187e-03 2.3666e-02 9.3587e+01 19 | 015000 7.6963e-03 1.9979e-02 9.2664e+01 20 | 016000 7.5541e-03 1.8075e-02 9.2843e+01 21 | 017000 7.3329e-03 1.6485e-02 9.1934e+01 22 | 018000 7.1320e-03 1.3612e-02 9.0954e+01 23 | 019000 6.9459e-03 1.4304e-02 9.0953e+01 24 | 020000 6.7611e-03 1.5827e-02 9.0958e+01 25 | 021000 6.6543e-03 1.6127e-02 9.0671e+01 26 | 022000 6.5316e-03 1.5347e-02 9.0589e+01 27 | 023000 6.4119e-03 1.5592e-02 9.0986e+01 28 | 024000 6.3097e-03 1.5181e-02 9.0269e+01 29 | 025000 6.2119e-03 1.5757e-02 8.9563e+01 30 | 026000 6.1119e-03 1.5680e-02 8.9119e+01 31 | 027000 6.0015e-03 1.5852e-02 8.8693e+01 32 | 028000 5.8711e-03 1.4648e-02 8.9046e+01 33 | 029000 5.7979e-03 1.3644e-02 8.7846e+01 34 | 030000 5.7710e-03 1.3030e-02 8.9200e+01 35 | 031000 5.7028e-03 1.4651e-02 8.8510e+01 36 | 032000 5.6183e-03 1.4748e-02 8.8049e+01 37 | 033000 5.5419e-03 1.4654e-02 8.8659e+01 38 | 034000 5.4964e-03 1.3939e-02 8.8095e+01 39 | 035000 5.4264e-03 1.3998e-02 8.8708e+01 40 | 036000 5.3579e-03 1.2901e-02 8.7748e+01 41 | 037000 5.3108e-03 1.3888e-02 8.7898e+01 42 | 038000 5.2779e-03 1.4576e-02 8.7882e+01 43 | 039000 5.1993e-03 1.3370e-02 8.8319e+01 44 | 040000 5.1644e-03 1.2683e-02 8.7140e+01 45 | 041000 5.1564e-03 1.3523e-02 8.8754e+01 46 | 042000 5.1318e-03 1.3073e-02 8.8003e+01 47 | 043000 5.1313e-03 1.1972e-02 8.7905e+01 48 | 044000 5.0500e-03 1.1824e-02 8.7697e+01 49 | 045000 5.0123e-03 1.1280e-02 8.7181e+01 50 | 046000 4.9632e-03 1.1529e-02 8.8270e+01 51 | 047000 4.9310e-03 1.2047e-02 8.7663e+01 52 | 048000 4.8943e-03 1.1841e-02 8.8133e+01 53 | 049000 4.8601e-03 1.1867e-02 8.7772e+01 54 | 050000 4.8166e-03 1.1109e-02 8.7176e+01 55 | -------------------------------------------------------------------------------- /results/DCGAN/trainDCGAN-randomInit.txt: -------------------------------------------------------------------------------- 1 | ------device=cuda:0, q=2, batch_size=64, num_steps=50000, w_disc=32, w_dec=32,------ 2 | ------out_file: results/DCGAN/trainDCGAN-randomInit------ 3 | step J_GAN J_Gen ProbDist 4 | 001000 -1.6240e-01 -7.0221e-02 8.2363e+02 5 | 002000 -4.4369e-02 -1.9166e-02 5.2543e+02 6 | 003000 -5.2540e-05 -1.7542e-05 7.4219e+02 7 | 004000 -7.8862e-02 -1.9111e-02 4.9744e+02 8 | 005000 -2.6664e-01 -1.0555e-01 2.5069e+02 9 | 006000 -2.9983e-01 -1.2022e-01 2.2268e+02 10 | 007000 -2.5933e-01 -1.0743e-01 2.0266e+02 11 | 008000 -2.7962e-01 -1.1531e-01 1.7890e+02 12 | 009000 -3.2447e-01 -1.3183e-01 1.7197e+02 13 | 010000 -3.0801e-01 -1.2520e-01 1.6948e+02 14 | 011000 -2.9975e-01 -1.2312e-01 1.7978e+02 15 | 012000 -3.0585e-01 -1.2521e-01 1.6989e+02 16 | 013000 -3.3236e-01 -1.3747e-01 1.6467e+02 17 | 014000 -3.8498e-01 -1.6704e-01 1.4661e+02 18 | 015000 -3.6886e-01 -1.5774e-01 1.4381e+02 19 | 016000 -3.6060e-01 -1.5909e-01 1.3357e+02 20 | 017000 -3.5183e-01 -1.5064e-01 1.3173e+02 21 | 018000 -3.4364e-01 -1.4763e-01 1.3900e+02 22 | 019000 -4.0491e-01 -1.7854e-01 1.3424e+02 23 | 020000 -3.9583e-01 -1.7362e-01 1.2394e+02 24 | 021000 -3.5344e-01 -1.5191e-01 1.2611e+02 25 | 022000 -3.6819e-01 -1.5867e-01 1.3833e+02 26 | 023000 -3.4235e-01 -1.4709e-01 1.2866e+02 27 | 024000 -3.5760e-01 -1.5520e-01 1.2599e+02 28 | 025000 -3.3611e-01 -1.4285e-01 1.2532e+02 29 | 026000 -3.5132e-01 -1.4984e-01 1.2587e+02 30 | 027000 -3.5437e-01 -1.5670e-01 1.2115e+02 31 | 028000 -3.8594e-01 -1.6715e-01 1.2182e+02 32 | 029000 -3.3948e-01 -1.4934e-01 1.1738e+02 33 | 030000 -3.3234e-01 -1.4675e-01 1.1683e+02 34 | 031000 -3.2204e-01 -1.4313e-01 1.1764e+02 35 | 032000 -3.2760e-01 -1.4531e-01 1.2158e+02 36 | 033000 -3.2468e-01 -1.3883e-01 1.1953e+02 37 | 034000 -3.1281e-01 -1.3564e-01 1.1946e+02 38 | 035000 -3.1644e-01 -1.3420e-01 1.1847e+02 39 | 036000 -3.3908e-01 -1.4482e-01 1.1678e+02 40 | 037000 -3.4147e-01 -1.5113e-01 1.1595e+02 41 | 038000 -2.9663e-01 -1.2712e-01 1.1481e+02 42 | 039000 -2.8986e-01 -1.2577e-01 1.1691e+02 43 | 040000 -2.8042e-01 -1.1591e-01 1.1890e+02 44 | 041000 -2.8173e-01 -1.2197e-01 1.1177e+02 45 | 042000 -2.7905e-01 -1.1820e-01 1.1523e+02 46 | 043000 -2.7027e-01 -1.1974e-01 1.1492e+02 47 | 044000 -2.7638e-01 -1.1535e-01 1.1516e+02 48 | 045000 -2.8521e-01 -1.2421e-01 1.1780e+02 49 | 046000 -2.7991e-01 -1.2134e-01 1.1599e+02 50 | 047000 -2.5829e-01 -1.1118e-01 1.2519e+02 51 | 048000 -3.1127e-01 -1.3053e-01 1.1762e+02 52 | 049000 -2.6953e-01 -1.1793e-01 1.1570e+02 53 | 050000 -2.9524e-01 -1.3025e-01 1.1301e+02 54 | -------------------------------------------------------------------------------- /results/DCGAN/trainDCGAN-vaeInit.txt: -------------------------------------------------------------------------------- 1 | initialize g with weights in results/VAE/trainVAE-g.pt 2 | ------device=cuda:0, q=2, batch_size=64, num_steps=50000, w_disc=32, w_dec=32,------ 3 | ------out_file: results/DCGAN/trainDCGAN-vaeInit------ 4 | step J_GAN J_Gen ProbDist 5 | 001000 -1.5065e-01 -5.7034e-02 2.3224e+02 6 | 002000 -2.4258e-01 -1.1107e-01 1.5204e+02 7 | 003000 -2.6185e-01 -1.1970e-01 1.2679e+02 8 | 004000 -2.5740e-01 -1.1888e-01 1.2323e+02 9 | 005000 -2.6669e-01 -1.2124e-01 1.1961e+02 10 | 006000 -2.3286e-01 -1.0311e-01 1.1147e+02 11 | 007000 -2.2996e-01 -1.0447e-01 1.0950e+02 12 | 008000 -2.1535e-01 -1.0029e-01 1.1002e+02 13 | 009000 -2.3136e-01 -1.0422e-01 1.1030e+02 14 | 010000 -2.1367e-01 -9.4286e-02 1.1156e+02 15 | 011000 -2.0159e-01 -8.8985e-02 1.1032e+02 16 | 012000 -2.1341e-01 -9.4397e-02 1.1280e+02 17 | 013000 -2.1593e-01 -9.6983e-02 1.1122e+02 18 | 014000 -1.8815e-01 -8.5642e-02 1.1019e+02 19 | 015000 -2.0207e-01 -9.2386e-02 1.0957e+02 20 | 016000 -2.0112e-01 -9.0253e-02 1.1057e+02 21 | 017000 -2.1544e-01 -9.3563e-02 1.0798e+02 22 | 018000 -2.1136e-01 -9.1400e-02 1.0599e+02 23 | 019000 -2.2456e-01 -9.6762e-02 1.0651e+02 24 | 020000 -1.8257e-01 -8.3164e-02 1.0817e+02 25 | 021000 -2.0020e-01 -8.9516e-02 1.0715e+02 26 | 022000 -1.9245e-01 -8.6282e-02 1.0793e+02 27 | 023000 -1.8340e-01 -7.9962e-02 1.0357e+02 28 | 024000 -1.6653e-01 -7.4786e-02 1.0495e+02 29 | 025000 -1.6308e-01 -7.4571e-02 1.0434e+02 30 | 026000 -1.5791e-01 -6.9270e-02 1.0587e+02 31 | 027000 -1.7612e-01 -7.6740e-02 1.0506e+02 32 | 028000 -1.8177e-01 -8.1500e-02 1.0438e+02 33 | 029000 -1.7027e-01 -7.6036e-02 1.0431e+02 34 | 030000 -1.5498e-01 -6.8826e-02 1.0276e+02 35 | 031000 -1.7734e-01 -7.9945e-02 1.0110e+02 36 | 032000 -1.5316e-01 -6.9440e-02 1.0415e+02 37 | 033000 -1.7264e-01 -7.7597e-02 1.0333e+02 38 | 034000 -1.6258e-01 -7.1064e-02 1.0077e+02 39 | 035000 -1.5907e-01 -7.0660e-02 9.8916e+01 40 | 036000 -1.5300e-01 -6.9610e-02 1.0103e+02 41 | 037000 -1.2759e-01 -5.4200e-02 1.0244e+02 42 | 038000 -1.3363e-01 -5.9072e-02 1.0061e+02 43 | 039000 -1.3983e-01 -6.1393e-02 1.0243e+02 44 | 040000 -1.2590e-01 -5.7687e-02 1.0177e+02 45 | 041000 -1.3524e-01 -6.0630e-02 1.0402e+02 46 | 042000 -1.3517e-01 -6.1084e-02 1.0345e+02 47 | 043000 -1.3010e-01 -5.7706e-02 1.0231e+02 48 | 044000 -1.3336e-01 -5.9891e-02 1.0317e+02 49 | 045000 -1.3037e-01 -5.9534e-02 1.0790e+02 50 | 046000 -1.3391e-01 -5.7453e-02 1.0566e+02 51 | 047000 -1.5771e-01 -6.9367e-02 1.0580e+02 52 | 048000 -1.3496e-01 -5.8078e-02 1.0980e+02 53 | 049000 -1.1879e-01 -5.2954e-02 1.0454e+02 54 | 050000 -1.3243e-01 -5.8981e-02 1.0384e+02 55 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | import torch.nn.functional as F 6 | 7 | class VAE(nn.Module): 8 | """ 9 | Variational Autoencoder (VAE) 10 | 11 | """ 12 | def __init__(self,e,g): 13 | """ 14 | Initialize VAE 15 | :param e: encoder network e(z|x), provides parameters of approximate posteriors 16 | :param g: generator g(z), provides samples in data space 17 | """ 18 | super(VAE, self).__init__() 19 | self.e = e 20 | self.g = g 21 | 22 | def ELBO(self,x): 23 | """ 24 | Empirical lower bound on p_{\theta}(x) 25 | :param x: sample from dataset 26 | :return: 27 | """ 28 | 29 | s = x.shape[0] # number of samples 30 | mu, logvar = self.e(x) # parameters of approximate posterior 31 | z, eps = self.sample_ezx(x,mu, logvar) 32 | gz = self.g(z) 33 | 34 | log_pzx = torch.sum(self.log_prob_pzx(z,x,gz)[0]) 35 | 36 | log_ezx = -0.5*torch.norm(eps)**2 - 0.5*torch.sum(logvar) - (z.shape[1]/2)*np.log(2*np.pi) 37 | 38 | return (-log_pzx+log_ezx)/s, (-log_pzx/s).item(), (log_ezx/s).item(), gz.detach(), mu.detach() 39 | 40 | def sample_ezx(self, x , mu=None, logvar=None, sample=None): 41 | """ 42 | draw sample from approximate posterior 43 | 44 | :param x: sample from dataset 45 | :param mu: mean of approximate posterior (optional; will be computed here if is None) 46 | :param logvar: log-variance of approximate posterior (optional; will be computed here if is None) 47 | :param sample: flag whether to sample or return the mean 48 | :return: 49 | """ 50 | if mu is None or logvar is None: 51 | mu, logvar = self.e(x) 52 | 53 | if sample is None: 54 | sample = self.training 55 | 56 | if sample: 57 | std = torch.exp(logvar) 58 | eps = torch.randn_like(std) 59 | return std * eps + mu, eps 60 | else: 61 | return mu, mu 62 | 63 | def log_prob_ezx(self,z,x): 64 | """ 65 | :param z: latent sample 66 | :param x: data sample 67 | :return: log(e(z|x)) 68 | """ 69 | q = z.shape[1] 70 | 71 | mu, logvar = self.e(x) 72 | ezx = -torch.sum((0.5 / torch.exp(logvar)) * (z - mu) ** 2, dim=1) - 0.5 * torch.sum(logvar,dim=1) - (q/2)*np.log(2*np.pi) 73 | return ezx 74 | 75 | def log_prob_pzx(self,z,x,gz=None): 76 | """ 77 | :param z: latent sample 78 | :param x: data sample 79 | :return: log(p(z,x)) = log(p(x|z)) + log(p(z)), log(p(x|z|), log(p(z)) 80 | """ 81 | if gz is None: 82 | gz = self.g(z) 83 | n = x.shape[1] 84 | px = -F.binary_cross_entropy(gz.view(-1, 784), x.view(-1, 784), reduction='none') 85 | px = torch.sum(px,dim=1) 86 | pz = - 0.5 * torch.norm(z, dim=1) ** 2 - (n/2)*np.log(2*np.pi) 87 | return px + pz, px, pz 88 | 89 | 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepGenerativeModelingIntro 2 | 3 | PyTorch Code used in the paper *Introduction to Deep Generative Modeling*: 4 | 5 | @article{RuthottoHaber2021, 6 | title = {An Introduction to Deep Generative Modeling}, 7 | year = {2021}, 8 | journal = {arXiv preprint arXiv:tbd}, 9 | author = {L. Ruthotto and E. Haber}, 10 | pages = {25 pages}, 11 | url={https://arxiv.org/abs/2103.05180} 12 | } 13 | 14 | ## Run Examples from the Terminal 15 | 16 | To reproduce the examples from the paper (up to randomization), we provide a shell script `runAll.sh`. 17 | 18 | ## Run Examples in Colab 19 | 20 | The `examples` directory contains interactive version of the examples from the paper. Those can be run locally or using 21 | Google Colab. For the latter option, you may click the badges below: 22 | 23 | 1. [Two-Dimensional Normalizing Flow Examples with Real NVP](https://github.com/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/RealNVP.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/RealNVP.ipynb) 24 | 1. [Two-Dimensional Continuous Normalizing Flow Example with OT-Flow](https://github.com/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/OTFlow.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/OTFlow.ipynb) 25 | 1. [Variational Autoencoder for MNIST Image Generation](https://github.com/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/VAE.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/VAE.ipynb) 26 | 1. [DCGAN for MNIST Image Generation](https://github.com/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/DCGAN.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/DCGAN.ipynb) 27 | 1. [WGAN for MNIST Image Generation](https://github.com/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/WGAN.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/WGAN.ipynb) 28 | 29 | ## Dependencies 30 | 31 | The code is based on pytorch and some other standard machine learning packages. In addition, training the continuous normalizing flow example requires [OT-Flow](https://github.com/EmoryMLIP/OT-Flow). 32 | 33 | ## Acknowledgements 34 | 35 | This material is in part based upon work supported by the National Science Foundation under Grant Number 1751636, the Air Force Office of Scientific Research under Grant Number 20RT0237, and 36 | the US DOE's Office of Advanced Scientific Computing Research Field Work Proposal 20-023231. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the funding agencies. 37 | -------------------------------------------------------------------------------- /modelMNIST.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | from torch import nn 4 | 5 | class Generator(nn.Module): 6 | def __init__(self,w,q): 7 | """ 8 | Initialize generator 9 | 10 | :param w: number of channels on the finest level 11 | :param q: latent space dimension 12 | """ 13 | super(Generator, self).__init__() 14 | self.w = w 15 | self.fc = nn.Linear(q, w * 2 * 7 * 7) 16 | self.conv2 = nn.ConvTranspose2d(w * 2, w, kernel_size=4, stride=2, padding=1) 17 | self.conv1 = nn.ConvTranspose2d(w, 1, kernel_size=4, stride=2, padding=1) 18 | 19 | self.bn1 = nn.BatchNorm2d(w) 20 | self.bn2 = nn.BatchNorm2d(2*w) 21 | 22 | def forward(self, z): 23 | """ 24 | :param z: latent space sample 25 | :return: g(z) 26 | """ 27 | gz = self.fc(z) 28 | gz = gz.view(gz.size(0), self.w * 2, 7, 7) 29 | gz = self.bn2(gz) 30 | gz = F.relu(gz) 31 | gz = self.conv2(gz) 32 | gz = self.bn1(gz) 33 | 34 | gz = F.relu(gz) 35 | gz = torch.sigmoid(self.conv1(gz)) 36 | return gz 37 | 38 | class Encoder(nn.Module): 39 | def __init__(self,w,q): 40 | """ 41 | Initialize the encoder for the VAE 42 | 43 | :param w: number of channels on finest level 44 | :param q: latent space dimension 45 | """ 46 | super(Encoder, self).__init__() 47 | self.conv1 = nn.Conv2d(1, w, kernel_size=4, stride=2, padding=1) 48 | self.conv2 = nn.Conv2d(w, w * 2, kernel_size=4, stride=2, padding=1) 49 | self.fc_mu = nn.Linear(w * 2 * 7 * 7, q) 50 | self.fc_logvar = nn.Linear(w * 2 * 7 * 7, q) 51 | 52 | def forward(self, x): 53 | """ 54 | :param x: MNIST image 55 | :return: mu,logvar that parameterize e(z|x) = N(mu, diag(exp(logvar))) 56 | """ 57 | x = F.relu(self.conv1(x)) 58 | x = F.relu(self.conv2(x)) 59 | x = x.view(x.size(0), -1) 60 | mu = self.fc_mu(x) 61 | logvar = self.fc_logvar(x) 62 | return mu, logvar 63 | 64 | class Discriminator(nn.Module): 65 | def __init__(self, w,useSigmoid=True): 66 | """ 67 | Discriminator for GANs 68 | :param w: number of channels on finest level 69 | :param useSigmoid: true --> DCGAN, false --> WGAN 70 | """ 71 | super(Discriminator, self).__init__() 72 | self.w = w 73 | self.useSigmoid = useSigmoid 74 | self.conv1 = nn.Conv2d(1, w, kernel_size=4, stride=2, padding=1) 75 | self.conv2 = nn.Conv2d(w, w * 2, kernel_size=4, stride=2, padding=1) 76 | self.fc = nn.Linear(w * 2 * 7 * 7, 1) 77 | self.bn1 = nn.BatchNorm2d(w) 78 | self.bn2 = nn.BatchNorm2d(2*w) 79 | 80 | def forward(self,x): 81 | """ 82 | :param x: MNIST image or generated image 83 | :return: d(x), value of discriminator 84 | """ 85 | x = (x-0.5)/0.5 # 86 | x = self.conv1(x) 87 | x = self.bn1(x) 88 | x = F.leaky_relu(x,0.2) 89 | x = self.conv2(x) 90 | x = self.bn2(x) 91 | 92 | x = F.leaky_relu(x,0.2) 93 | x = x.view(x.shape[0],-1) 94 | x = self.fc(x) 95 | if self.useSigmoid: 96 | x = torch.sigmoid(x) 97 | return x 98 | -------------------------------------------------------------------------------- /trainVAEmnist.py: -------------------------------------------------------------------------------- 1 | from vae import * 2 | from torch import distributions 3 | import argparse 4 | 5 | import numpy as np 6 | 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | ## load MNIST 10 | parser = argparse.ArgumentParser('VAE') 11 | parser.add_argument("--batch_size" , type=int, default=128, help="batch size") 12 | parser.add_argument("--q" , type=int, default=2, help="latent space dimension") 13 | parser.add_argument("--width_enc" , type=int, default=4, help="width of encoder") 14 | parser.add_argument("--width_dec" , type=int, default=4, help="width of decoder") 15 | parser.add_argument("--num_epochs" , type=int, default=2, help="number of epochs") 16 | parser.add_argument("--out_file", type=str, default=None, help="base filename saving trained model (extension .pt), history (extension .mat), and intermediate plots (extension .png") 17 | args = parser.parse_args() 18 | 19 | 20 | import torchvision.transforms as transforms 21 | from torch.utils.data import DataLoader 22 | from torchvision.datasets import MNIST 23 | 24 | img_transform = transforms.Compose([ 25 | transforms.ToTensor() 26 | ]) 27 | 28 | train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform) 29 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 30 | 31 | test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform) 32 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True) 33 | 34 | 35 | from modelMNIST import Encoder, Generator 36 | g = Generator(args.width_dec,args.q) 37 | e = Encoder(args.width_enc,args.q) 38 | 39 | vae = VAE(e,g).to(device) 40 | 41 | optimizer = torch.optim.Adam(params=vae.parameters(), lr=1e-3, weight_decay=1e-5) 42 | 43 | his = np.zeros((args.num_epochs,6)) 44 | 45 | print((3*"--" + "device=%s, q=%d, batch_size=%d, num_epochs=%d, w_enc=%d, w_dec=%d" + 3*"--") % (device, args.q, args.batch_size, args.num_epochs, args.width_enc, args.width_dec)) 46 | 47 | if args.out_file is not None: 48 | import os 49 | out_dir, fname = os.path.split(args.out_file) 50 | if not os.path.exists(out_dir): 51 | os.makedirs(out_dir) 52 | print((3*"--" + "out_file: %s" + 3*"--") % (args.out_file)) 53 | 54 | print((7*"%7s ") % ("epoch","Jtrain","pzxtrain","ezxtrain","Jval","pzxval","ezxval")) 55 | 56 | 57 | for epoch in range(args.num_epochs): 58 | vae.train() 59 | 60 | train_loss = 0.0 61 | train_pzx = 0.0 62 | train_ezx = 0.0 63 | num_ex = 0 64 | for image_batch, _ in train_dataloader: 65 | image_batch = image_batch.to(device) 66 | 67 | # take a step 68 | loss, pzx, ezx,gz,mu = vae.ELBO(image_batch) 69 | optimizer.zero_grad() 70 | loss.backward() 71 | optimizer.step() 72 | 73 | # update history 74 | train_loss += loss.item()*image_batch.shape[0] 75 | train_pzx += pzx*image_batch.shape[0] 76 | train_ezx += ezx*image_batch.shape[0] 77 | num_ex += image_batch.shape[0] 78 | 79 | train_loss /= num_ex 80 | train_pzx /= num_ex 81 | train_ezx /= num_ex 82 | 83 | # evaluate validation points 84 | vae.eval() 85 | val_loss = 0.0 86 | val_pzx = 0.0 87 | val_ezx = 0.0 88 | num_ex = 0 89 | for image_batch, label_batch in test_dataloader: 90 | with torch.no_grad(): 91 | image_batch = image_batch.to(device) 92 | # vae reconstruction 93 | loss, pzx, ezx, gz, mu = vae.ELBO(image_batch) 94 | val_loss += loss.item() * image_batch.shape[0] 95 | val_pzx += pzx * image_batch.shape[0] 96 | val_ezx += ezx * image_batch.shape[0] 97 | num_ex += image_batch.shape[0] 98 | 99 | val_loss /= num_ex 100 | val_pzx/= num_ex 101 | val_ezx/= num_ex 102 | 103 | print(("%06d " + 6*"%1.4e ") % 104 | (epoch + 1, train_loss, train_pzx, train_ezx, val_loss, val_pzx, val_ezx)) 105 | 106 | his[epoch,:] = [train_loss, train_pzx, train_ezx, val_loss, val_pzx, val_ezx] 107 | 108 | if args.out_file is not None: 109 | torch.save(vae.g.state_dict(), ("%s-g.pt") % (args.out_file)) 110 | torch.save(vae.state_dict(), ("%s.pt") % (args.out_file)) 111 | from scipy.io import savemat 112 | savemat(("%s.mat") % (args.out_file), {"his":his}) -------------------------------------------------------------------------------- /results/VAE/trainVAE.txt: -------------------------------------------------------------------------------- 1 | ------device=cuda:0, q=2, batch_size=64, num_epochs=50, w_enc=32, w_dec=32------ 2 | ------out_file: results/VAE/trainVAE------ 3 | epoch Jtrain pzxtrain ezxtrain Jval pzxval ezxval 4 | 000001 1.7771e+02 1.7577e+02 1.9351e+00 1.6330e+02 1.6098e+02 2.3209e+00 5 | 000002 1.5944e+02 1.5729e+02 2.1508e+00 1.5661e+02 1.5412e+02 2.4921e+00 6 | 000003 1.5594e+02 1.5366e+02 2.2728e+00 1.5409e+02 1.5157e+02 2.5213e+00 7 | 000004 1.5415e+02 1.5182e+02 2.3312e+00 1.5290e+02 1.5019e+02 2.7143e+00 8 | 000005 1.5297e+02 1.5059e+02 2.3822e+00 1.5206e+02 1.4936e+02 2.7023e+00 9 | 000006 1.5203e+02 1.4962e+02 2.4179e+00 1.5153e+02 1.4874e+02 2.7933e+00 10 | 000007 1.5135e+02 1.4891e+02 2.4409e+00 1.5074e+02 1.4801e+02 2.7246e+00 11 | 000008 1.5071e+02 1.4825e+02 2.4669e+00 1.5018e+02 1.4744e+02 2.7378e+00 12 | 000009 1.5024e+02 1.4776e+02 2.4783e+00 1.5039e+02 1.4754e+02 2.8493e+00 13 | 000010 1.4981e+02 1.4731e+02 2.4965e+00 1.4941e+02 1.4657e+02 2.8432e+00 14 | 000011 1.4940e+02 1.4690e+02 2.5029e+00 1.4950e+02 1.4683e+02 2.6757e+00 15 | 000012 1.4911e+02 1.4658e+02 2.5252e+00 1.4935e+02 1.4650e+02 2.8510e+00 16 | 000013 1.4881e+02 1.4628e+02 2.5295e+00 1.4872e+02 1.4596e+02 2.7607e+00 17 | 000014 1.4846e+02 1.4592e+02 2.5445e+00 1.4897e+02 1.4617e+02 2.7960e+00 18 | 000015 1.4821e+02 1.4566e+02 2.5490e+00 1.5004e+02 1.4718e+02 2.8554e+00 19 | 000016 1.4801e+02 1.4544e+02 2.5663e+00 1.4870e+02 1.4595e+02 2.7539e+00 20 | 000017 1.4781e+02 1.4524e+02 2.5726e+00 1.4874e+02 1.4584e+02 2.9001e+00 21 | 000018 1.4756e+02 1.4498e+02 2.5795e+00 1.4822e+02 1.4548e+02 2.7444e+00 22 | 000019 1.4738e+02 1.4480e+02 2.5752e+00 1.4822e+02 1.4539e+02 2.8321e+00 23 | 000020 1.4718e+02 1.4459e+02 2.5877e+00 1.4834e+02 1.4547e+02 2.8747e+00 24 | 000021 1.4702e+02 1.4443e+02 2.5927e+00 1.4831e+02 1.4548e+02 2.8288e+00 25 | 000022 1.4686e+02 1.4426e+02 2.5941e+00 1.4804e+02 1.4520e+02 2.8367e+00 26 | 000023 1.4672e+02 1.4412e+02 2.5977e+00 1.4806e+02 1.4512e+02 2.9358e+00 27 | 000024 1.4656e+02 1.4396e+02 2.6019e+00 1.4791e+02 1.4513e+02 2.7793e+00 28 | 000025 1.4642e+02 1.4381e+02 2.6084e+00 1.4795e+02 1.4518e+02 2.7680e+00 29 | 000026 1.4627e+02 1.4366e+02 2.6123e+00 1.4769e+02 1.4493e+02 2.7604e+00 30 | 000027 1.4618e+02 1.4356e+02 2.6213e+00 1.4777e+02 1.4488e+02 2.8827e+00 31 | 000028 1.4603e+02 1.4341e+02 2.6209e+00 1.4794e+02 1.4521e+02 2.7254e+00 32 | 000029 1.4599e+02 1.4336e+02 2.6311e+00 1.4751e+02 1.4472e+02 2.7926e+00 33 | 000030 1.4581e+02 1.4317e+02 2.6360e+00 1.4796e+02 1.4527e+02 2.6887e+00 34 | 000031 1.4575e+02 1.4312e+02 2.6355e+00 1.4765e+02 1.4489e+02 2.7599e+00 35 | 000032 1.4565e+02 1.4302e+02 2.6385e+00 1.4778e+02 1.4483e+02 2.9423e+00 36 | 000033 1.4554e+02 1.4290e+02 2.6463e+00 1.4728e+02 1.4438e+02 2.8942e+00 37 | 000034 1.4546e+02 1.4282e+02 2.6412e+00 1.4735e+02 1.4453e+02 2.8193e+00 38 | 000035 1.4537e+02 1.4271e+02 2.6566e+00 1.4747e+02 1.4465e+02 2.8138e+00 39 | 000036 1.4531e+02 1.4265e+02 2.6536e+00 1.4742e+02 1.4460e+02 2.8190e+00 40 | 000037 1.4525e+02 1.4260e+02 2.6542e+00 1.4755e+02 1.4467e+02 2.8760e+00 41 | 000038 1.4516e+02 1.4251e+02 2.6503e+00 1.4730e+02 1.4440e+02 2.9045e+00 42 | 000039 1.4506e+02 1.4240e+02 2.6628e+00 1.4730e+02 1.4440e+02 2.9018e+00 43 | 000040 1.4497e+02 1.4231e+02 2.6587e+00 1.4748e+02 1.4469e+02 2.7950e+00 44 | 000041 1.4493e+02 1.4226e+02 2.6631e+00 1.4783e+02 1.4494e+02 2.8864e+00 45 | 000042 1.4486e+02 1.4218e+02 2.6778e+00 1.4716e+02 1.4433e+02 2.8293e+00 46 | 000043 1.4477e+02 1.4211e+02 2.6594e+00 1.4718e+02 1.4440e+02 2.7866e+00 47 | 000044 1.4470e+02 1.4204e+02 2.6662e+00 1.4723e+02 1.4444e+02 2.7945e+00 48 | 000045 1.4463e+02 1.4195e+02 2.6749e+00 1.4786e+02 1.4504e+02 2.8218e+00 49 | 000046 1.4459e+02 1.4192e+02 2.6716e+00 1.4733e+02 1.4446e+02 2.8692e+00 50 | 000047 1.4450e+02 1.4182e+02 2.6802e+00 1.4754e+02 1.4470e+02 2.8353e+00 51 | 000048 1.4446e+02 1.4179e+02 2.6775e+00 1.4713e+02 1.4431e+02 2.8164e+00 52 | 000049 1.4440e+02 1.4172e+02 2.6808e+00 1.4727e+02 1.4449e+02 2.7826e+00 53 | 000050 1.4438e+02 1.4169e+02 2.6882e+00 1.4746e+02 1.4459e+02 2.8640e+00 54 | -------------------------------------------------------------------------------- /trainOTFlow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import argparse 4 | import numpy as np 5 | from torch import distributions 6 | from sklearn import datasets 7 | import matplotlib.pyplot as plt 8 | import sys,os 9 | 10 | ot_flow_dir = '/Users/lruthot/Google Drive/OT-Flow/' 11 | if not os.path.exists(ot_flow_dir): 12 | raise Exception("Cannot find OT-Flow in %s" %(ot_flow_dir)) 13 | sys.path.append(os.path.dirname(ot_flow_dir)) 14 | 15 | from src.plotter import plot4 16 | from src.OTFlowProblem import * 17 | 18 | device = "cpu" 19 | 20 | parser = argparse.ArgumentParser('OTFlow') 21 | parser.add_argument("--batch_size" , type=int, default=1024, help="batch size") 22 | parser.add_argument("--noise" , type=int, default=0.05, help="noise of moons") 23 | parser.add_argument("--width" , type=int, default=32, help="width of neural net") 24 | parser.add_argument('--alph' , type=str, default='1.0,10.0,5.0',help="alph[0]-> weight for transport costs, alph[1] and alph[2]-> HJB penalties") 25 | parser.add_argument("--nTh" , type=int, default=2, help="number of layers") 26 | parser.add_argument("--nt" , type=int, default=4, help="number of time steps in training") 27 | parser.add_argument("--nt_val" , type=int, default=8, help="number of time steps in validation") 28 | parser.add_argument("--num_steps" , type=int, default=10000, help="number of training steps") 29 | parser.add_argument("--plot_interval" , type=int, default=500, help="plot solution every so many steps") 30 | parser.add_argument("--out_file", type=str, default=None, help="base filename saving trained model (extension .pt), history (extension .mat), and intermediate plots (extension .png") 31 | 32 | args = parser.parse_args() 33 | args.alph = [float(item) for item in args.alph.split(',')] 34 | 35 | 36 | def compute_loss(net, x, nt): 37 | Jc , cs = OTFlowProblem(x, net, [0,1], nt=nt, stepper="rk4", alph=net.alph) 38 | return Jc, cs 39 | 40 | 41 | net = Phi(nTh=args.nTh, m=args.width, d=2, alph=args.alph) 42 | optim = torch.optim.Adam(net.parameters(), lr=1e-2) # lr=0.04 good 43 | 44 | his = np.zeros((0,4)) 45 | 46 | 47 | print((3*"--" + "device=%s, nTh=%d, width=%d, batch_size=%d, num_steps=%d" + 3*"--") % (device, args.nTh, args.width, args.batch_size, args.num_steps, )) 48 | 49 | out_dir = "results/OTFlow-noise-%1.5f-nTh-%d-width-%d" % (args.noise, args.nTh, args.width) 50 | 51 | if args.out_file is not None: 52 | import os 53 | out_dir, fname = os.path.split(args.out_file) 54 | if not os.path.exists(out_dir): 55 | os.makedirs(out_dir) 56 | print((3*"--" + "out_file: %s" + 3*"--") % (args.out_file)) 57 | 58 | print((5*"%7s ") % ("step","J", "J_L", "J_ML","J_HJB")) 59 | 60 | train_J = 0.0 61 | train_L = 0.0 62 | train_JML = 0.0 63 | num_step = 0 64 | train_HJB = 0.0 65 | 66 | for step in range(args.num_steps): 67 | 68 | x = torch.tensor(datasets.make_moons(n_samples=args.batch_size, noise=args.noise)[0], dtype=torch.float32) 69 | optim.zero_grad() 70 | loss, costs = compute_loss(net, x, nt=args.nt) 71 | loss.backward() 72 | optim.step() 73 | 74 | train_J += loss.item() 75 | train_L += costs[0].item() 76 | train_JML += costs[1].item() 77 | train_HJB += costs[2].item() 78 | num_step += 1 79 | 80 | if (step + 1) % args.plot_interval == 0: 81 | train_J /= num_step 82 | train_JML /= num_step 83 | train_L /= num_step 84 | train_HJB /= num_step 85 | 86 | 87 | print(("%06d " + 4*"%1.4e ") % 88 | (step + 1, train_J, train_L, train_JML, train_HJB)) 89 | his = np.vstack([his, [train_J, train_L, train_JML, train_HJB]]) 90 | train_J = 0.0 91 | train_L = 0.0 92 | train_JML = 0.0 93 | num_step = 0 94 | train_HJB = 0.0 95 | 96 | with torch.no_grad(): 97 | nSamples = 10000 98 | xs = torch.tensor(datasets.make_moons(n_samples=nSamples, noise=args.noise)[0], dtype=torch.float32) 99 | zs = torch.randn(nSamples, 2) # sampling from the standard normal (rho_1) 100 | if args.out_file is not None: 101 | plot4(net, xs, zs, args.nt_val, "%s-step-%d.png" % (args.out_file,step+1), doPaths=True) 102 | plt.show() 103 | 104 | if args.out_file is not None: 105 | torch.save(net.state_dict(), ("%s.pt") % (args.out_file)) 106 | from scipy.io import savemat 107 | savemat(("%s.mat") % (args.out_file), {"his":his}) -------------------------------------------------------------------------------- /realNVP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class NF(nn.Module): 5 | """ 6 | Normalizing flow for density estimation and sampling 7 | 8 | """ 9 | def __init__(self, layers, prior): 10 | """ 11 | Initialize normalizing flow 12 | 13 | :param layers: list of layers f_j with tractable inverse and log-determinant (e.g., RealNVPLayer) 14 | :param prior: latent distribution, e.g., distributions.MultivariateNormal(torch.zeros(d), torch.eye(d)) 15 | """ 16 | super(NF, self).__init__() 17 | self.prior = prior 18 | self.layers = layers 19 | 20 | def g(self, z): 21 | """ 22 | :param z: latent variable 23 | :return: g(z) and hidden states 24 | """ 25 | y = z 26 | ys = [torch.clone(y).detach()] 27 | for i in range(len(self.layers)): 28 | y, _ = self.layers[i].f(y) 29 | ys.append(torch.clone(y).detach()) 30 | 31 | return y, ys 32 | 33 | def ginv(self, x): 34 | """ 35 | :param x: sample from dataset 36 | :return: g^(-1)(x), value of log-determinant, and hidden layers 37 | """ 38 | p = x 39 | log_det_ginv = torch.zeros(x.shape[0]) 40 | ps = [torch.clone(p).detach()] 41 | for i in reversed(range(len(self.layers))): 42 | p, log_det_finv = self.layers[i].finv(p) 43 | ps.append(torch.clone(p).detach().cpu()) 44 | log_det_ginv += log_det_finv 45 | 46 | return p, log_det_ginv, ps 47 | 48 | def log_prob(self, x): 49 | """ 50 | Compute log-probability of a sample using change of variable formula 51 | 52 | :param x: sample from dataset 53 | :return: logp_{\theta}(x) 54 | """ 55 | z, log_det_ginv, _ = self.ginv(x) 56 | return self.prior.log_prob(z) + log_det_ginv 57 | 58 | def sample(self, s): 59 | """ 60 | Draw random samples from p_{\theta} 61 | 62 | :param s: number of samples to draw 63 | :return: 64 | """ 65 | z = self.prior.sample((s, 1)).squeeze(1) 66 | x, _ = self.g(z) 67 | return x 68 | 69 | 70 | class RealNVPLayer(nn.Module): 71 | """ 72 | Real non-volume preserving flow layer 73 | 74 | Reference: Dinh, L., Sohl-Dickstein, J., & Bengio, S. (2016, May 27). 75 | Density estimation using Real NVP. arXiv.org. 76 | """ 77 | def __init__(self, s, t, mask): 78 | """ 79 | Initialize real NVP layer 80 | :param s: network to compute the shift 81 | :param t: network to compute the translation 82 | :param mask: splits the feature vector into two parts 83 | """ 84 | super(RealNVPLayer, self).__init__() 85 | self.mask = mask 86 | self.t = t 87 | self.s = s 88 | 89 | def f(self, y): 90 | """ 91 | apply the layer function f 92 | :param y: feature vector 93 | :return: 94 | """ 95 | y1 = y * self.mask 96 | s = self.s(y1) 97 | t = self.t(y1) 98 | y2 = (y * torch.exp(s) + t) * (1 - self.mask) 99 | return y1 + y2, torch.sum(s, dim=1) 100 | 101 | def finv(self, y): 102 | """ 103 | apply the inverse of the layer function 104 | :param y: feature vector 105 | :return: 106 | """ 107 | y1 = self.mask * y 108 | s = self.s(y1) 109 | t = self.t(y1) 110 | y2 = (1 - self.mask) * (y - t) * torch.exp(-s) 111 | return y1 + y2, -torch.sum(s, dim=1) 112 | 113 | 114 | if __name__ == "__main__": 115 | # layers and masks 116 | K = 6 117 | w = 128 118 | layers = torch.nn.ModuleList() 119 | for k in range(K): 120 | mask = torch.tensor([1 - (k % 2), k % 2]) 121 | t = nn.Sequential(nn.Linear(2, w), nn.LeakyReLU(), nn.Linear(w, w), nn.LeakyReLU(), nn.Linear(w, 2), 122 | nn.Tanh()) 123 | s = nn.Sequential(nn.Linear(2, w), nn.LeakyReLU(), nn.Linear(w, w), nn.LeakyReLU(), nn.Linear(w, 2), 124 | nn.Tanh()) 125 | layer = RealNVPLayer(s, t, mask) 126 | layers.append(layer) 127 | 128 | from torch import distributions 129 | prior = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2)) 130 | 131 | flow = NF(layers, prior) 132 | 133 | x = flow.sample(200).detach() 134 | # test inverse 135 | xt = flow.ginv(flow.g(x)[0])[0].detach() 136 | print(torch.norm(x - xt) / torch.norm(x)) 137 | 138 | # test inverse 139 | xt = flow.g(flow.ginv(x)[0])[0].detach() 140 | print(torch.norm(x - xt) / torch.norm(x)) 141 | 142 | 143 | -------------------------------------------------------------------------------- /trainRealNVP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import argparse 4 | import numpy as np 5 | from torch import distributions 6 | from sklearn import datasets 7 | import matplotlib.pyplot as plt 8 | 9 | device = "cpu" 10 | 11 | parser = argparse.ArgumentParser('RealNVP') 12 | parser.add_argument("--batch_size" , type=int, default=256, help="batch size") 13 | parser.add_argument("--noise" , type=int, default=0.05, help="noise of moons") 14 | parser.add_argument("--width" , type=int, default=128, help="width neural nets") 15 | parser.add_argument("--K" , type=int, default=6, help="number of layers") 16 | parser.add_argument("--num_steps" , type=int, default=20000, help="number of training steps") 17 | parser.add_argument("--plot_interval" , type=int, default=1000, help="plot solution every so many steps") 18 | parser.add_argument("--out_file", type=str, default=None, help="base filename saving trained model (extension .pt), history (extension .mat), and intermediate plots (extension .png") 19 | 20 | args = parser.parse_args() 21 | 22 | from realNVP import NF, RealNVPLayer 23 | K = args.K 24 | w = args.width 25 | 26 | layers = torch.nn.ModuleList() 27 | for k in range(K): 28 | mask = torch.tensor([1 - (k % 2), k % 2]) 29 | t = nn.Sequential(nn.Linear(2, w), nn.LeakyReLU(), nn.Linear(w, w), nn.LeakyReLU(), nn.Linear(w, 2), 30 | nn.Tanh()) 31 | s = nn.Sequential(nn.Linear(2, w), nn.LeakyReLU(), nn.Linear(w, w), nn.LeakyReLU(), nn.Linear(w, 2), 32 | nn.Tanh()) 33 | layer = RealNVPLayer(s, t, mask) 34 | layers.append(layer) 35 | 36 | prior = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2)) 37 | flow = NF(layers, prior).to(device) 38 | optimizer = torch.optim.Adam(flow.parameters(), lr=1e-4) 39 | 40 | his = np.zeros((0,1)) 41 | 42 | print((3*"--" + "device=%s, K=%d, width=%d, batch_size=%d, num_steps=%d" + 3*"--") % (device, args.K, args.width, args.batch_size, args.num_steps, )) 43 | 44 | if args.out_file is not None: 45 | import os 46 | out_dir, fname = os.path.split(args.out_file) 47 | if not os.path.exists(out_dir): 48 | os.makedirs(out_dir) 49 | print((3*"--" + "out_file: %s" + 3*"--") % (args.out_file)) 50 | 51 | print((2*"%7s ") % ("step","J_ML")) 52 | 53 | 54 | train_JML = 0.0 55 | num_step = 0 56 | 57 | for step in range(args.num_steps): 58 | 59 | x = torch.tensor(datasets.make_moons(n_samples=args.batch_size, noise=args.noise)[0], dtype=torch.float32) 60 | loss = -flow.log_prob(x).mean() 61 | optimizer.zero_grad() 62 | loss.backward() 63 | optimizer.step() 64 | train_JML += loss.item() 65 | num_step += 1 66 | 67 | if (step + 1) % args.plot_interval == 0: 68 | train_JML /= num_step 69 | 70 | print(("%06d " + "%1.4e ") % 71 | (step + 1, train_JML)) 72 | his = np.vstack([his, [train_JML]]) 73 | 74 | zs = flow.ginv(x)[0].detach() 75 | xs = flow.sample(200).detach() 76 | x1 = torch.linspace(-1.2, 2.1, 100) 77 | x2 = torch.linspace(-1.2, 2.1, 100) 78 | xg = torch.meshgrid(x1, x2) 79 | xx = torch.cat((xg[0].reshape(-1, 1), xg[1].reshape(-1, 1)), 1) 80 | log_px = flow.log_prob(xx).detach() 81 | train_JML = 0.0 82 | num_step = 0 83 | 84 | plt.Figure() 85 | plt.rcParams.update({'font.size': 16, "text.usetex": True}) 86 | 87 | plt.subplot(1,3,1) 88 | plt.plot(xs[:, 0], xs[:, 1], "bs") 89 | plt.axis((-1.2, 2.1, -1.2, 2.1)) 90 | plt.xticks((-1.2, 2.1)) 91 | plt.yticks((-1.2, 2.1)) 92 | plt.xlabel("$\mathbf{x}_1$", labelpad=-20) 93 | plt.ylabel("$\mathbf{x}_2$", labelpad=-30) 94 | plt.title("$g_{\\theta}(\mathcal{Z})$") 95 | 96 | plt.subplot(1,3,2) 97 | plt.plot(zs[:, 0], zs[:, 1], "or") 98 | plt.axis((-3.5, 3.5, -3.5, 3.5)) 99 | plt.xticks((-3.5, 3.5)) 100 | plt.yticks((-3.5, 3.5)) 101 | plt.xlabel("$\mathbf{z}_1$", labelpad=-20) 102 | plt.ylabel("$\mathbf{z}_2$", labelpad=-30) 103 | plt.title("$g^{-1}_{\\theta}(\mathcal{Z})$") 104 | 105 | plt.subplot(1,3,3) 106 | img = log_px-torch.min(log_px) 107 | img/=torch.max(img) 108 | plt.imshow(img.reshape(len(x1), len(x2)), extent=(-1.2, 2.1, -1.2, 2.1)) 109 | plt.axis((-1.2, 2.1, -1.2, 2.1)) 110 | plt.xticks((-1.2, 2.1)) 111 | plt.yticks((-1.2, 2.1)) 112 | plt.xlabel("$\mathbf{x}_1$", labelpad=-20) 113 | plt.ylabel("$\mathbf{x}_2$", labelpad=-30) 114 | plt.title("$p_{\\theta}(\mathbf{x}), step=%d$" % (step+1)) 115 | 116 | plt.margins(0, 0) 117 | if args.out_file is not None: 118 | plt.savefig("%s-step-%d.png" % (args.out_file,step+1), bbox_inches='tight', pad_inches=0) 119 | plt.show() 120 | 121 | 122 | if args.out_file is not None: 123 | torch.save(flow.state_dict(), ("%s.pt") % (args.out_file)) 124 | from scipy.io import savemat 125 | savemat(("%s.mat") % (args.out_file), {"his":his}) -------------------------------------------------------------------------------- /trainDCGANmnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | parser = argparse.ArgumentParser('DCGAN') 10 | parser.add_argument("--batch_size" , type=int, default=64, help="batch size") 11 | parser.add_argument("--q" , type=int, default=2, help="latent space dimension") 12 | parser.add_argument("--width_disc" , type=int, default=32, help="width of discriminator") 13 | parser.add_argument("--width_dec" , type=int, default=32, help="width of decoder") 14 | parser.add_argument("--num_steps" , type=int, default=50, help="number of training steps") 15 | parser.add_argument("--plot_interval" , type=int, default=5, help="plot solution every so many steps") 16 | parser.add_argument("--init_g", type=str, default=None, help="path to .pt file that contains weights of a trained generator") 17 | parser.add_argument("--out_file", type=str, default=None, help="base filename saving trained model (extension .pt), history (extension .mat), and intermediate plots (extension .png") 18 | 19 | args = parser.parse_args() 20 | 21 | import torchvision.transforms as transforms 22 | from torch.utils.data import DataLoader 23 | from torchvision.datasets import MNIST 24 | 25 | img_transform = transforms.Compose([ 26 | transforms.ToTensor() 27 | ]) 28 | 29 | train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform) 30 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 31 | 32 | from modelMNIST import Generator, Discriminator 33 | g = Generator(args.width_dec,args.q).to(device) 34 | d = Discriminator(args.width_disc,useSigmoid=True).to(device) 35 | 36 | optimizer_g = torch.optim.Adam(params=g.parameters(), lr=0.0002, betas=(0.5, 0.999)) 37 | optimizer_d = torch.optim.Adam(params=d.parameters(), lr=0.0002, betas=(0.5, 0.999)) 38 | 39 | his = np.zeros((0,3)) 40 | 41 | if args.init_g is not None: 42 | print("initialize g with weights in %s" % args.init_g) 43 | g.load_state_dict(torch.load(args.init_g)) 44 | 45 | print((3*"--" + "device=%s, q=%d, batch_size=%d, num_steps=%d, w_disc=%d, w_dec=%d," + 3*"--") % (device, args.q, args.batch_size, args.num_steps, args.width_disc, args.width_dec)) 46 | if args.out_file is not None: 47 | import os 48 | out_dir, fname = os.path.split(args.out_file) 49 | if not os.path.exists(out_dir): 50 | os.makedirs(out_dir) 51 | print((3*"--" + "out_file: %s" + 3*"--") % (args.out_file)) 52 | print((4*"%7s ") % ("step","J_GAN","J_Gen","ProbDist")) 53 | 54 | from epsTest import epsTest 55 | 56 | train_JGAN = 0.0 57 | train_JGen = 0.0 58 | train_epsTest = 0.0 59 | num_ex = 0 60 | 61 | def inf_train_gen(): 62 | while True: 63 | for images, targets in enumerate(train_dataloader): 64 | yield images,targets 65 | 66 | get_true_images = inf_train_gen() 67 | 68 | for step in range(args.num_steps): 69 | g.train() 70 | d.train() 71 | # update discriminator using - J_GAN = - E_x [log(d(x)] - E_z[1-log(d(g(z))] 72 | x = get_true_images.__next__()[1][0] 73 | x = x.to(device) 74 | dx = d(x) 75 | z = torch.randn((x.shape[0],args.q),device=device) 76 | gz = g(z) 77 | dgz = d(gz) 78 | J_GAN = - torch.mean(torch.log(dx)) - torch.mean(torch.log(1-dgz)) 79 | optimizer_d.zero_grad() 80 | J_GAN.backward() 81 | optimizer_d.step() 82 | 83 | # update the generator using J_Gen = - E_z[log(d(g(z))] 84 | optimizer_g.zero_grad() 85 | z = torch.randn((x.shape[0], args.q), device=device) 86 | gz = g(z) 87 | dgz = d(gz) 88 | # J_Gen = -torch.mean(torch.log(dgz)) 89 | J_Gen = torch.mean(torch.log(1-dgz)) 90 | J_Gen.backward() 91 | optimizer_g.step() 92 | 93 | # update history 94 | train_JGAN -= J_GAN.item()*x.shape[0] 95 | train_JGen += J_Gen.item()*x.shape[0] 96 | train_epsTest += epsTest(gz.detach(), x) 97 | num_ex += x.shape[0] 98 | 99 | if (step + 1) % args.plot_interval == 0: 100 | train_JGAN /= num_ex 101 | train_JGen /= num_ex 102 | 103 | 104 | print(("%06d " + 3*"%1.4e ") % 105 | (step + 1, train_JGAN, train_JGen, train_epsTest)) 106 | his = np.vstack([his, [train_JGAN, train_JGen, train_epsTest]]) 107 | 108 | plt.Figure() 109 | img = gz.detach().cpu() 110 | img -= torch.min(img) 111 | img /=torch.max(img) 112 | plt.imshow(torchvision.utils.make_grid(img, 16, 5,pad_value=1.0).permute((1, 2, 0))) 113 | plt.title("trainDCGANmnist: step=%d" % (step + 1)) 114 | if args.out_file is not None: 115 | plt.savefig(("%s-step-%d.png") % (args.out_file, step + 1)) 116 | plt.show() 117 | 118 | train_JGAN = 0.0 119 | train_JGen = 0.0 120 | train_epsTest = 0.0 121 | num_ex = 0 122 | 123 | if args.out_file is not None: 124 | torch.save(g.state_dict(), ("%s-g.pt") % (args.out_file)) 125 | torch.save(d.state_dict(), ("%s-d.pt") % (args.out_file)) 126 | 127 | from scipy.io import savemat 128 | savemat(("%s.mat") % (args.out_file), {"his":his}) 129 | 130 | plt.Figure() 131 | plt.subplot(1,2,1) 132 | plt.plot(his[:,0:2]) 133 | plt.legend(("JGAN","JGen")) 134 | plt.title("GAN Objectives") 135 | plt.subplot(1,2,2) 136 | plt.plot(his[:,2]) 137 | plt.title("epsTest") 138 | if args.out_file is not None: 139 | plt.savefig(("%s-his.png") % (args.out_file)) 140 | plt.show() -------------------------------------------------------------------------------- /trainWGANmnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | ## load MNIST 10 | parser = argparse.ArgumentParser('WGAN') 11 | parser.add_argument("--batch_size" , type=int, default=64, help="batch size") 12 | parser.add_argument("--q" , type=int, default=2, help="latent space dimension") 13 | parser.add_argument("--width_disc" , type=int, default=32, help="width of discriminator") 14 | parser.add_argument("--width_dec" , type=int, default=32, help="width of decoder") 15 | parser.add_argument("--clip_limit" , type=float, default=1e-2, help="limit for weights of discriminator") 16 | parser.add_argument("--iter_disc" , type=int, default=5, help="number of iterations for discriminator") 17 | parser.add_argument("--num_steps" , type=int, default=50, help="number of training steps") 18 | parser.add_argument("--plot_interval" , type=int, default=5, help="plot solution every so many steps") 19 | parser.add_argument("--init_g", type=str, default=None, help="path to .pt file that contains weights of a trained generator") 20 | parser.add_argument("--out_file", type=str, default=None, help="base filename saving trained model (extension .pt), history (extension .mat), and intermediate plots (extension .png") 21 | 22 | args = parser.parse_args() 23 | 24 | import torchvision.transforms as transforms 25 | from torch.utils.data import DataLoader 26 | from torchvision.datasets import MNIST 27 | 28 | img_transform = transforms.Compose([ 29 | transforms.ToTensor() 30 | ]) 31 | 32 | train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform) 33 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 34 | 35 | from modelMNIST import Generator, Discriminator 36 | g = Generator(args.width_dec,args.q).to(device) 37 | d = Discriminator(args.width_disc,useSigmoid=False).to(device) 38 | 39 | if args.init_g is not None: 40 | print("initialize g with weights in %s" % args.init_g) 41 | g.load_state_dict(torch.load(args.init_g)) 42 | 43 | optimizer_g = torch.optim.RMSprop(g.parameters(), lr=0.00005) 44 | optimizer_d = torch.optim.RMSprop(d.parameters(), lr=0.00005) 45 | his = np.zeros((0,3)) 46 | 47 | print((3*"--" + "device=%s, q=%d, batch_size=%d, num_steps=%d, w_disc=%d, w_dec=%d" + 3*"--") % (device, args.q, args.batch_size, args.num_steps, args.width_disc, args.width_dec)) 48 | if args.out_file is not None: 49 | import os 50 | out_dir, fname = os.path.split(args.out_file) 51 | if not os.path.exists(out_dir): 52 | os.makedirs(out_dir) 53 | print((3*"--" + "out_file: %s" + 3*"--") % (args.out_file)) 54 | 55 | print((4*"%7s ") % ("step","J_GAN","J_Gen","ProbDist")) 56 | 57 | 58 | from epsTest import epsTest 59 | 60 | train_JGAN = 0.0 61 | train_JGen = 0.0 62 | train_epsTest = 0.0 63 | num_ex = 0 64 | 65 | def inf_train_gen(): 66 | while True: 67 | for images, targets in enumerate(train_dataloader): 68 | yield images,targets 69 | 70 | get_true_images = inf_train_gen() 71 | 72 | 73 | for step in range(args.num_steps): 74 | g.train() 75 | d.train() 76 | # update discriminator using ascent on J_GAN = E_x [d(x)] - E_z[d(g(z))] 77 | for iter_critic in range(args.iter_disc): 78 | x = get_true_images.__next__()[1][0] 79 | x = x.to(device) 80 | 81 | for p in d.parameters(): 82 | p.data.clamp_(-0.01, 0.01) 83 | 84 | dx = d(x) 85 | z = torch.randn((x.shape[0],args.q),device=device) 86 | gz = g(z) 87 | dgz = d(gz) 88 | J_GAN = -(torch.mean(dx) - torch.mean(dgz)) 89 | optimizer_d.zero_grad() 90 | J_GAN.backward() 91 | optimizer_d.step() 92 | train_JGAN -= J_GAN.item() * x.shape[0] 93 | 94 | # update the generator using descent on J_Gen = - E_z[d(g(z))] 95 | optimizer_g.zero_grad() 96 | z = torch.randn((x.shape[0], args.q), device=device) 97 | gz = g(z) 98 | dgz = d(gz) 99 | J_Gen = -torch.mean(dgz) 100 | J_Gen.backward() 101 | optimizer_g.step() 102 | 103 | # update history 104 | train_JGen += J_Gen.item()*x.shape[0] 105 | train_epsTest += epsTest(gz.detach(),x) 106 | 107 | num_ex += x.shape[0] 108 | 109 | if (step+1) % args.plot_interval==0: 110 | train_JGAN /= args.iter_disc * num_ex 111 | train_JGen /= num_ex 112 | 113 | print(("%06d " + 3 * "%1.4e ") % 114 | (step + 1, train_JGAN, train_JGen, train_epsTest)) 115 | his = np.vstack([his, [train_JGAN, train_JGen, train_epsTest]]) 116 | 117 | plt.Figure() 118 | img = gz.detach().cpu() 119 | img -= torch.min(img) 120 | img /= torch.max(img) 121 | plt.imshow(torchvision.utils.make_grid(img, 8, 5).permute((1, 2, 0))) 122 | plt.title("trainWGANmnist: step=%d" % (step+1)) 123 | if args.out_file is not None: 124 | plt.savefig(("%s-step-%d.png") % (args.out_file,step+1)) 125 | plt.show() 126 | 127 | train_JGAN = 0.0 128 | train_JGen = 0.0 129 | train_epsTest = 0.0 130 | 131 | num_ex = 0 132 | 133 | if args.out_file is not None: 134 | torch.save(g.state_dict(), ("%s-g.pt") % (args.out_file)) 135 | torch.save(d.state_dict(), ("%s-d.pt") % (args.out_file)) 136 | 137 | from scipy.io import savemat 138 | savemat(("%s.mat") % (args.out_file), {"his":his}) 139 | 140 | plt.Figure() 141 | plt.subplot(1,2,1) 142 | plt.plot(his[:,0:2]) 143 | plt.legend(("JGAN","JGen")) 144 | plt.title("GAN Objectives") 145 | plt.subplot(1,2,2) 146 | plt.plot(his[:,2]) 147 | plt.title("epsTest") 148 | if args.out_file is not None: 149 | plt.savefig(("%s-his.png") % (args.out_file)) 150 | plt.show() -------------------------------------------------------------------------------- /OTFlow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.functional import pad 4 | from Phi import * 5 | from torch import distributions 6 | 7 | 8 | class OTFlow(nn.Module): 9 | """ 10 | OT-Flow for density estimation and sampling as described in 11 | 12 | @article{onken2020otflow, 13 | title={OT-Flow: Fast and Accurate Continuous Normalizing Flows via Optimal Transport}, 14 | author={Derek Onken and Samy Wu Fung and Xingjian Li and Lars Ruthotto}, 15 | year={2020}, 16 | journal = {arXiv preprint arXiv:2006.00104}, 17 | } 18 | 19 | """ 20 | def __init__(self, net, nt, alph, prior, T=1.0): 21 | """ 22 | Initialize OT-Flow 23 | 24 | :param net: network for value function 25 | :param nt: number of rk4 steps 26 | :param alph: penalty parameters 27 | :param prior: latent distribution, e.g., distributions.MultivariateNormal(torch.zeros(d), torch.eye(d)) 28 | """ 29 | super(OTFlow, self).__init__() 30 | self.prior = prior 31 | self.nt = nt 32 | self.T = T 33 | self.net = net 34 | self.alph = alph 35 | 36 | 37 | def g(self, z, nt = None, storeAll=False): 38 | """ 39 | :param z: latent variable 40 | :return: g(z) and hidden states 41 | """ 42 | return self.integrate(z,[self.T, 0.0], nt,storeAll) 43 | 44 | def ginv(self, x, nt=None, storeAll=False): 45 | """ 46 | :param x: sample from dataset 47 | :return: g^(-1)(x), value of log-determinant, and hidden layers 48 | """ 49 | 50 | return self.integrate(x,[0.0, self.T], nt,storeAll) 51 | 52 | def log_prob(self, x, nt=None): 53 | """ 54 | Compute log-probability of a sample using change of variable formula 55 | 56 | :param x: sample from dataset 57 | :return: logp_{\theta}(x) 58 | """ 59 | z, _, log_det_ginv, v, r = self.ginv(x,nt) 60 | return self.prior.log_prob(z) + log_det_ginv, v, r 61 | 62 | def sample(self, s,nt=None): 63 | """ 64 | Draw random samples from p_{\theta} 65 | 66 | :param s: number of samples to draw 67 | :return: 68 | """ 69 | z = self.prior.sample((s, 1)).squeeze(1) 70 | x, _, _, _, _ = self.g(z,nt) 71 | return x 72 | 73 | def f(self,x, t): 74 | """ 75 | neural ODE combining the characteristics and log-determinant (see Eq. (2)), the transport costs (see Eq. (5)), and 76 | the HJB regularizer (see Eq. (7)). 77 | 78 | d_t [x ; l ; v ; r] = odefun( [x ; l ; v ; r] , t ) 79 | 80 | x - particle position 81 | l - log determinant 82 | v - accumulated transport costs (Lagrangian) 83 | r - accumulates violation of HJB condition along trajectory 84 | """ 85 | nex, d = x.shape 86 | z = pad(x[:, :d], (0, 1, 0, 0), value=t) 87 | gradPhi, trH = self.net.trHess(z) 88 | 89 | dx = -(1.0 / self.alph[0]) * gradPhi[:, 0:d] 90 | dl = -(1.0 / self.alph[0]) * trH 91 | dv = 0.5 * torch.sum(torch.pow(dx, 2), 1) 92 | dr = torch.abs(-gradPhi[:, -1] + self.alph[0] * dv) 93 | 94 | return dx, dl, dv, dr 95 | 96 | def integrate(self, y, tspan, nt=None,storeAll=False): 97 | """ 98 | RK4 time-stepping to integrate the neural ODE 99 | 100 | :param y: initial state 101 | :param tspan: time interval (can go backward in time) 102 | :param nt: number of time steps (default is self.nt) 103 | :return: y (final state), ys (all states), l (log determinant), v (transport costs), r (HJB penalty) 104 | """ 105 | if nt is None: 106 | nt = self.nt 107 | 108 | nex, d = y.shape 109 | h = (tspan[1] - tspan[0])/ nt 110 | tk = tspan[0] 111 | 112 | l = torch.zeros((nex), device=y.device, dtype=y.dtype) 113 | v = torch.zeros((nex), device=y.device, dtype=y.dtype) 114 | r = torch.zeros((nex), device=y.device, dtype=y.dtype) 115 | if storeAll: 116 | ys = [torch.clone(y).detach().cpu()] 117 | else: 118 | ys = None 119 | 120 | w = [(h/6.0),2.0*(h/6.0),2.0*(h/6.0),1.0*(h/6.0)] 121 | for i in range(nt): 122 | y0 = y 123 | 124 | dy, dl, dv, dr = self.f(y0, tk) 125 | y = y0 + w[0] * dy 126 | l += w[0] * dl 127 | v += w[0] * dv 128 | r += w[0] * dr 129 | 130 | dy, dl, dv, dr = self.f(y0 + 0.5 * h * dy, tk + (h / 2)) 131 | y += w[1] * dy 132 | l += w[1] * dl 133 | v += w[1] * dv 134 | r += w[1] * dr 135 | 136 | dy, dl, dv, dr = self.f(y0 + 0.5 * h * dy, tk + (h / 2)) 137 | y += w[2] * dy 138 | l += w[2] * dl 139 | v += w[2] * dv 140 | r += w[2] * dr 141 | 142 | dy, dl, dv, dr = self.f(y0 + h * dy, tk + h) 143 | y += w[3] * dy 144 | l += w[3] * dl 145 | v += w[3] * dv 146 | r += w[3] * dr 147 | 148 | if storeAll: 149 | ys.append(torch.clone(y).detach().cpu()) 150 | tk +=h 151 | 152 | return y, ys, l, v, r 153 | 154 | if __name__ == "__main__": 155 | # layers and masks 156 | 157 | nt = 16 158 | alph = [1.0, 5.0, 10.0] 159 | prior = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2)) 160 | net = Phi(nTh=2, m=16, d=2, alph=alph) 161 | T=1.0 162 | 163 | flow = OTFlow(net,nt,alph,prior,T) 164 | 165 | x = flow.sample(200).detach() 166 | # test inverse 167 | xt = flow.ginv(flow.g(x)[0])[0].detach() 168 | print(torch.norm(x - xt) / torch.norm(x)) 169 | 170 | # test inverse 171 | xt = flow.g(flow.ginv(x)[0])[0].detach() 172 | print(torch.norm(x - xt) / torch.norm(x)) 173 | 174 | 175 | -------------------------------------------------------------------------------- /Phi.py: -------------------------------------------------------------------------------- 1 | # Phi.py 2 | # neural network to model the potential function 3 | import torch 4 | import torch.nn as nn 5 | import copy 6 | import math 7 | 8 | def antiderivTanh(x): # activation function aka the antiderivative of tanh 9 | return torch.abs(x) + torch.log(1+torch.exp(-2.0*torch.abs(x))) 10 | 11 | def derivTanh(x): # act'' aka the second derivative of the activation function antiderivTanh 12 | return 1 - torch.pow( torch.tanh(x) , 2 ) 13 | 14 | class ResNN(nn.Module): 15 | def __init__(self, d, m, nTh=2): 16 | """ 17 | ResNet N portion of Phi 18 | 19 | This implementation was first described in: 20 | 21 | @article{onken2020otflow, 22 | title={OT-Flow: Fast and Accurate Continuous Normalizing Flows via Optimal Transport}, 23 | author={Derek Onken and Samy Wu Fung and Xingjian Li and Lars Ruthotto}, 24 | year={2020}, 25 | journal = {arXiv preprint arXiv:2006.00104}, 26 | } 27 | :param d: int, dimension of space input (expect inputs to be d+1 for space-time) 28 | :param m: int, hidden dimension 29 | :param nTh: int, number of resNet layers , (number of theta layers) 30 | """ 31 | super().__init__() 32 | 33 | if nTh < 2: 34 | print("nTh must be an integer >= 2") 35 | exit(1) 36 | 37 | self.d = d 38 | self.m = m 39 | self.nTh = nTh 40 | self.layers = nn.ModuleList([]) 41 | self.layers.append(nn.Linear(d + 1, m, bias=True)) # opening layer 42 | self.layers.append(nn.Linear(m,m, bias=True)) # resnet layers 43 | for i in range(nTh-2): 44 | self.layers.append(copy.deepcopy(self.layers[1])) 45 | self.act = antiderivTanh 46 | self.h = 1.0 / (self.nTh-1) # step size for the ResNet 47 | 48 | def forward(self, x): 49 | """ 50 | N(s;theta). the forward propogation of the ResNet 51 | :param x: tensor nex-by-d+1, inputs 52 | :return: tensor nex-by-m, outputs 53 | """ 54 | 55 | x = self.act(self.layers[0].forward(x)) 56 | 57 | for i in range(1,self.nTh): 58 | x = x + self.h * self.act(self.layers[i](x)) 59 | 60 | return x 61 | 62 | 63 | 64 | class Phi(nn.Module): 65 | def __init__(self, nTh, m, d, r=10, alph=[1.0] * 5): 66 | """ 67 | neural network approximating Phi (see Eq. (9) in our paper) 68 | 69 | Phi( x,t ) = w'*ResNet( [x;t]) + 0.5*[x' t] * A'A * [x;t] + b'*[x;t] + c 70 | 71 | :param nTh: int, number of resNet layers , (number of theta layers) 72 | :param m: int, hidden dimension 73 | :param d: int, dimension of space input (expect inputs to be d+1 for space-time) 74 | :param r: int, rank r for the A matrix 75 | :param alph: list, alpha values / weighted multipliers for the optimization problem 76 | """ 77 | super().__init__() 78 | 79 | self.m = m 80 | self.nTh = nTh 81 | self.d = d 82 | self.alph = alph 83 | 84 | r = min(r,d+1) # if number of dimensions is smaller than default r, use that 85 | 86 | self.A = nn.Parameter(torch.zeros(r, d+1) , requires_grad=True) 87 | self.A = nn.init.xavier_uniform_(self.A) 88 | self.c = nn.Linear( d+1 , 1 , bias=True) # b'*[x;t] + c 89 | self.w = nn.Linear( m , 1 , bias=False) 90 | 91 | self.N = ResNN(d, m, nTh=nTh) 92 | 93 | # set initial values 94 | self.w.weight.data = torch.ones(self.w.weight.data.shape) 95 | self.c.weight.data = torch.zeros(self.c.weight.data.shape) 96 | self.c.bias.data = torch.zeros(self.c.bias.data.shape) 97 | 98 | 99 | 100 | def forward(self, x): 101 | """ calculating Phi(s, theta)...not used in OT-Flow """ 102 | 103 | # force A to be symmetric 104 | symA = torch.matmul(torch.t(self.A), self.A) # A'A 105 | 106 | return self.w( self.N(x)) + 0.5 * torch.sum( torch.matmul(x , symA) * x , dim=1, keepdims=True) + self.c(x) 107 | 108 | 109 | def trHess(self,x,d=None, justGrad=False ): 110 | """ 111 | compute gradient of Phi wrt x and trace(Hessian of Phi); see Eq. (11) and Eq. (13), respectively 112 | recomputes the forward propogation portions of Phi 113 | 114 | :param x: input data, torch Tensor nex-by-d 115 | :param justGrad: boolean, if True only return gradient, if False return (grad, trHess) 116 | :return: gradient , trace(hessian) OR just gradient 117 | """ 118 | 119 | # code in E = eye(d+1,d) as index slicing instead of matrix multiplication 120 | # assumes specific N.act as the antiderivative of tanh 121 | 122 | N = self.N 123 | m = N.layers[0].weight.shape[0] 124 | nex = x.shape[0] # number of examples in the batch 125 | if d is None: 126 | d = x.shape[1]-1 127 | symA = torch.matmul(self.A.t(), self.A) 128 | 129 | u = [] # hold the u_0,u_1,...,u_M for the forward pass 130 | z = N.nTh*[None] # hold the z_0,z_1,...,z_M for the backward pass 131 | # preallocate z because we will store in the backward pass and we want the indices to match the paper 132 | 133 | # Forward of ResNet N and fill u 134 | opening = N.layers[0].forward(x) # K_0 * S + b_0 135 | u.append(N.act(opening)) # u0 136 | feat = u[0] 137 | 138 | for i in range(1,N.nTh): 139 | feat = feat + N.h * N.act(N.layers[i](feat)) 140 | u.append(feat) 141 | 142 | # going to be used more than once 143 | tanhopen = torch.tanh(opening) # act'( K_0 * S + b_0 ) 144 | 145 | # compute gradient and fill z 146 | for i in range(N.nTh-1,0,-1): # work backwards, placing z_i in appropriate spot 147 | if i == N.nTh-1: 148 | term = self.w.weight.t() 149 | else: 150 | term = z[i+1] 151 | 152 | # z_i = z_{i+1} + h K_i' diag(...) z_{i+1} 153 | z[i] = term + N.h * torch.mm( N.layers[i].weight.t() , torch.tanh( N.layers[i].forward(u[i-1]) ).t() * term) 154 | 155 | # z_0 = K_0' diag(...) z_1 156 | z[0] = torch.mm( N.layers[0].weight.t() , tanhopen.t() * z[1] ) 157 | grad = z[0] + torch.mm(symA, x.t() ) + self.c.weight.t() 158 | 159 | if justGrad: 160 | return grad.t() 161 | 162 | # ----------------- 163 | # trace of Hessian 164 | #----------------- 165 | 166 | # t_0, the trace of the opening layer 167 | Kopen = N.layers[0].weight[:,0:d] # indexed version of Kopen = torch.mm( N.layers[0].weight, E ) 168 | temp = derivTanh(opening.t()) * z[1] 169 | trH = torch.sum(temp.reshape(m, -1, nex) * torch.pow(Kopen.unsqueeze(2), 2), dim=(0, 1)) # trH = t_0 170 | 171 | # grad_s u_0 ^ T 172 | temp = tanhopen.t() # act'( K_0 * S + b_0 ) 173 | Jac = Kopen.unsqueeze(2) * temp.unsqueeze(1) # K_0' * act'( K_0 * S + b_0 ) 174 | # Jac is shape m by d by nex 175 | 176 | # t_i, trace of the resNet layers 177 | # KJ is the K_i^T * grad_s u_{i-1}^T 178 | for i in range(1,N.nTh): 179 | KJ = torch.mm(N.layers[i].weight , Jac.reshape(m,-1) ) 180 | KJ = KJ.reshape(m,-1,nex) 181 | if i == N.nTh-1: 182 | term = self.w.weight.t() 183 | else: 184 | term = z[i+1] 185 | 186 | temp = N.layers[i].forward(u[i-1]).t() # (K_i * u_{i-1} + b_i) 187 | t_i = torch.sum( ( derivTanh(temp) * term ).reshape(m,-1,nex) * torch.pow(KJ,2) , dim=(0, 1) ) 188 | trH = trH + N.h * t_i # add t_i to the accumulate trace 189 | if i < N.nTh: 190 | Jac = Jac + N.h * torch.tanh(temp).reshape(m, -1, nex) * KJ # update Jacobian 191 | 192 | return grad.t(), trH + torch.trace(symA[0:d,0:d]) 193 | # indexed version of: return grad.t() , trH + torch.trace( torch.mm( E.t() , torch.mm( symA , E) ) ) 194 | 195 | 196 | 197 | if __name__ == "__main__": 198 | 199 | import time 200 | import math 201 | 202 | # test case 203 | d = 2 204 | m = 5 205 | 206 | net = Phi(nTh=2, m=m, d=d) 207 | net.N.layers[0].weight.data = 0.1 + 0.0 * net.N.layers[0].weight.data 208 | net.N.layers[0].bias.data = 0.2 + 0.0 * net.N.layers[0].bias.data 209 | net.N.layers[1].weight.data = 0.3 + 0.0 * net.N.layers[1].weight.data 210 | net.N.layers[1].weight.data = 0.3 + 0.0 * net.N.layers[1].weight.data 211 | 212 | # number of samples-by-(d+1) 213 | x = torch.Tensor([[1.0 ,4.0 , 0.5],[2.0,5.0,0.6],[3.0,6.0,0.7],[0.0,0.0,0.0]]) 214 | y = net(x) 215 | print(y) 216 | 217 | # test timings 218 | d = 400 219 | m = 32 220 | nex = 1000 221 | 222 | net = Phi(nTh=5, m=m, d=d) 223 | net.eval() 224 | x = torch.randn(nex,d+1) 225 | y = net(x) 226 | 227 | end = time.time() 228 | g,h = net.trHess(x) 229 | print('traceHess takes ', time.time()-end) 230 | 231 | end = time.time() 232 | g = net.trHess(x, justGrad=True) 233 | print('JustGrad takes ', time.time()-end) 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | --------------------------------------------------------------------------------