├── results
├── VAE
│ ├── trainVAE-g.pt
│ ├── trainVAE.mat
│ ├── trainVAE.pt
│ └── trainVAE.txt
├── OTFlow
│ ├── trainOTFlow.mat
│ ├── trainOTFlow.pt
│ └── trainOTFlow
├── realNVP
│ ├── trainRealNVP.mat
│ ├── trainRealNVP.pt
│ └── trainRealNVP.txt
├── DCGAN
│ ├── trainWGAN-vaeInit.mat
│ ├── trainDCGAN-vaeInit-d.pt
│ ├── trainDCGAN-vaeInit-g.pt
│ ├── trainDCGAN-vaeInit.mat
│ ├── trainWGAN-vaeInit-d.pt
│ ├── trainWGAN-vaeInit-g.pt
│ ├── trainDCGAN-randomInit-d.pt
│ ├── trainDCGAN-randomInit-g.pt
│ ├── trainDCGAN-randomInit.mat
│ ├── trainWGAN-randomInit-d.pt
│ ├── trainWGAN-randomInit-g.pt
│ ├── trainWGAN-randomInit.mat
│ ├── trainWGAN-vaeInit-his.png
│ ├── trainDCGAN-randomInit.txt
│ └── trainDCGAN-vaeInit.txt
└── WGAN
│ ├── trainWGAN-randomInit.txt
│ └── trainWGAN-vaeInit.txt
├── .idea
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
└── modules.xml
├── testAll.sh
├── LICENSE
├── epsTest.py
├── plotting.py
├── runAll.sh
├── .gitignore
├── vae.py
├── README.md
├── modelMNIST.py
├── trainVAEmnist.py
├── trainOTFlow.py
├── realNVP.py
├── trainRealNVP.py
├── trainDCGANmnist.py
├── trainWGANmnist.py
├── OTFlow.py
└── Phi.py
/results/VAE/trainVAE-g.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/VAE/trainVAE-g.pt
--------------------------------------------------------------------------------
/results/VAE/trainVAE.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/VAE/trainVAE.mat
--------------------------------------------------------------------------------
/results/VAE/trainVAE.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/VAE/trainVAE.pt
--------------------------------------------------------------------------------
/results/OTFlow/trainOTFlow.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/OTFlow/trainOTFlow.mat
--------------------------------------------------------------------------------
/results/OTFlow/trainOTFlow.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/OTFlow/trainOTFlow.pt
--------------------------------------------------------------------------------
/results/realNVP/trainRealNVP.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/realNVP/trainRealNVP.mat
--------------------------------------------------------------------------------
/results/realNVP/trainRealNVP.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/realNVP/trainRealNVP.pt
--------------------------------------------------------------------------------
/results/DCGAN/trainWGAN-vaeInit.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-vaeInit.mat
--------------------------------------------------------------------------------
/results/DCGAN/trainDCGAN-vaeInit-d.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-vaeInit-d.pt
--------------------------------------------------------------------------------
/results/DCGAN/trainDCGAN-vaeInit-g.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-vaeInit-g.pt
--------------------------------------------------------------------------------
/results/DCGAN/trainDCGAN-vaeInit.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-vaeInit.mat
--------------------------------------------------------------------------------
/results/DCGAN/trainWGAN-vaeInit-d.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-vaeInit-d.pt
--------------------------------------------------------------------------------
/results/DCGAN/trainWGAN-vaeInit-g.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-vaeInit-g.pt
--------------------------------------------------------------------------------
/results/DCGAN/trainDCGAN-randomInit-d.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-randomInit-d.pt
--------------------------------------------------------------------------------
/results/DCGAN/trainDCGAN-randomInit-g.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-randomInit-g.pt
--------------------------------------------------------------------------------
/results/DCGAN/trainDCGAN-randomInit.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainDCGAN-randomInit.mat
--------------------------------------------------------------------------------
/results/DCGAN/trainWGAN-randomInit-d.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-randomInit-d.pt
--------------------------------------------------------------------------------
/results/DCGAN/trainWGAN-randomInit-g.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-randomInit-g.pt
--------------------------------------------------------------------------------
/results/DCGAN/trainWGAN-randomInit.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-randomInit.mat
--------------------------------------------------------------------------------
/results/DCGAN/trainWGAN-vaeInit-his.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EmoryMLIP/DeepGenerativeModelingIntro/HEAD/results/DCGAN/trainWGAN-vaeInit-his.png
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/results/realNVP/trainRealNVP.txt:
--------------------------------------------------------------------------------
1 | ------device=cpu, K=6, width=128, batch_size=256, num_steps=20000------
2 | ------out_file: results/realNVP/trainRealNVP------
3 | step J_ML
4 | 001000 -4.7034e+00
5 | 002000 -5.5544e+00
6 | 003000 -5.5929e+00
7 | 004000 -5.6134e+00
8 | 005000 -5.6250e+00
9 | 006000 -5.6343e+00
10 | 007000 -5.6390e+00
11 | 008000 -5.6381e+00
12 | 009000 -5.6421e+00
13 | 010000 -5.6439e+00
14 | 011000 -5.6426e+00
15 | 012000 -5.6476e+00
16 | 013000 -5.6472e+00
17 | 014000 -5.6476e+00
18 | 015000 -5.6507e+00
19 | 016000 -5.6488e+00
20 | 017000 -5.6525e+00
21 | 018000 -5.6557e+00
22 | 019000 -5.6495e+00
23 | 020000 -5.6543e+00
24 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/testAll.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python3 -u trainRealNVP.py --batch_size 64 --width 32 --K 4 --num_steps 50 --plot_interval 10 --out_file /tmp/realNVP/test 2> testLog.txt
4 | python3 -u trainOTFlow.py --batch_size 64 --width 32 --nt 4 --num_steps 50 --plot_interval 10 --out_file /tmp/OTFlow/test 2>> testLog.txt
5 | python3 -u trainVAEmnist.py --batch_size 256 --num_epochs 2 --width_enc 8 --width_dec 8 --out_file /tmp/VAE/test 2>> testLog.txt
6 | python3 -u trainDCGANmnist.py --batch_size 64 --num_steps 50 --width_disc 8 --width_dec 8 --plot_interval 10 --out_file /tmp/DCGAN/test 2>> testLog.txt
7 | python3 -u trainWGANmnist.py --batch_size 64 --num_steps 50 --width_disc 8 --width_dec 8 --plot_interval 10 --out_file /tmp/WGAN/test 2>> testLog.txt
8 | python3 -u trainDCGANmnist.py --batch_size 64 --num_steps 50 --width_disc 8 --width_dec 8 --plot_interval 10 --init_g /tmp/VAE/test-g.pt --out_file /tmp/DCGAN/test2 2>> testLog.txt
9 | python3 -u trainWGANmnist.py --batch_size 64 --num_steps 50 --width_disc 8 --width_dec 8 --plot_interval 10 --init_g /tmp/VAE/test-g.pt --out_file /tmp/WGAN/test2 2>> testLog.txt
10 |
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Lars Ruthotto
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/epsTest.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def epsTest(X, Y, eps=1e-1):
4 | """
5 | Test for equal distributions suggested in Székely, G. J., InterStat, M. R., 2004. (n.d.).
6 | Testing for equal distributions in high dimension. Personal.Bgsu.Edu.
7 |
8 | :param X: Samples from first distribution
9 | :param Y: Samples from second distribution
10 | :param eps: conditioning paramter
11 | :return:
12 | """
13 | nx = X.shape[0]
14 | ny = Y.shape[0]
15 |
16 | X = X.view(nx, -1)
17 | Y = Y.view(ny, -1)
18 |
19 | sX = torch.norm(X, dim=1) ** 2;
20 | sY = torch.norm(Y, dim=1) ** 2;
21 |
22 | CXX = sX.unsqueeze(1) + sX.unsqueeze(0) - 2 * X @ X.t()
23 | CXX = torch.sqrt(CXX + eps)
24 |
25 | CYY = sY.unsqueeze(1) + sY.unsqueeze(0) - 2 * Y @ Y.t()
26 | CYY = torch.sqrt(CYY + eps)
27 |
28 | CXY = sX.unsqueeze(1) + sY.unsqueeze(0) - 2 * X @ Y.t()
29 | CXY = torch.sqrt(CXY + eps)
30 |
31 | D = (nx * ny) / (nx + ny) * (2.0 / (nx * ny) * torch.sum(CXY)
32 | - 1.0 / nx ** 2 * (torch.sum(CXX)) - 1.0 / ny ** 2 * (torch.sum(CYY)));
33 |
34 | return D / (nx + ny)
--------------------------------------------------------------------------------
/plotting.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 |
4 |
5 | def plot_x(xs,domain=(0,1,0,1)):
6 | plt.plot(xs[:, 0], xs[:, 1], "bs")
7 | plt.axis("square")
8 | plt.axis(domain)
9 | plt.xticks(domain[0:2])
10 | plt.yticks(domain[2:])
11 | plt.xlabel("$\mathbf{x}_1$", labelpad=-20)
12 | plt.ylabel("$\mathbf{x}_2$", labelpad=-30)
13 |
14 |
15 | def plot_z(zs):
16 | plt.plot(zs[:, 0], zs[:, 1], "or")
17 | plt.axis("square")
18 | plt.axis((-3.5, 3.5, -3.5, 3.5))
19 | plt.xticks((-3.5, 3.5))
20 | plt.yticks((-3.5, 3.5))
21 | plt.xlabel("$\mathbf{z}_1$", labelpad=-20)
22 | plt.ylabel("$\mathbf{z}_2$", labelpad=-30)
23 |
24 |
25 | def plot_px(log_px,domain=(0,1,0,1)):
26 | px = torch.exp(log_px)
27 | img = px
28 | plt.imshow(img.t(), extent=domain,origin='lower')
29 | plt.axis("square")
30 | plt.axis(domain)
31 | plt.xticks(domain[0:2])
32 | plt.yticks(domain[2:])
33 | plt.xlabel("$\mathbf{x}_1$", labelpad=-20)
34 | plt.ylabel("$\mathbf{x}_2$", labelpad=-30)
35 |
36 | def plot_pz(zz,domain=(-3.5, 3.5, -3.5, 3.5)):
37 | plt.hist2d(zz[:,0], zz[:,1],bins=256,range=[[-domain[0], domain[1]], [domain[2], domain[3]]])
38 | plt.axis("square")
39 | plt.axis(domain)
40 | plt.xticks(domain[0:2])
41 | plt.yticks(domain[2:])
42 | plt.xlabel("$\mathbf{z}_1$", labelpad=-20)
43 | plt.ylabel("$\mathbf{z}_2$", labelpad=-30)
--------------------------------------------------------------------------------
/runAll.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | python3 -u trainRealNVP.py --batch_size 256 --width 128 --K 6 --num_steps 20000 --plot_interval 1000 --out_file results/realNVP/trainRealNVP | tee results/realNVP/trainRealNVP.txt
4 | python3 -u trainOTFlow.py --batch_size 256 --width 32 --nt 4 --num_steps 20000 --plot_interval 1000 --out_file results/OTFlow/trainOTFlow |tee results/OTFlow/trainOTFlow
5 | python3 -u trainVAEmnist.py --batch_size 64 --num_epochs 50 --width_enc 32 --width_dec 32 --out_file results/VAE/trainVAE |tee results/VAE/trainVAE.txt
6 | python3 -u trainDCGANmnist.py --batch_size 64 --num_steps 50000 --width_disc 32 --width_dec 32 --plot_interval 1000 --out_file results/DCGAN/trainDCGAN-randomInit |tee results/DCGAN/trainDCGAN-randomInit.txt
7 | python3 -u trainWGANmnist.py --batch_size 64 --num_steps 50000 --width_disc 32 --width_dec 32 --plot_interval 1000 --out_file results/DCGAN/trainWGAN-randomInit |tee results/WGAN/trainWGAN-randomInit.txt
8 | python3 -u trainDCGANmnist.py --batch_size 64 --num_steps 50000 --width_disc 32 --width_dec 32 --plot_interval 1000 --init_g results/VAE/trainVAE-g.pt --out_file results/WGAN/trainDCGAN-vaeInit |tee results/DCGAN/trainDCGAN-vaeInit.txt
9 | python3 -u trainWGANmnist.py --batch_size 64 --num_steps 50000 --width_disc 32 --width_dec 32 --plot_interval 1000 --init_g results/VAE/trainVAE-g.pt --out_file results/WGAN/trainWGAN-vaeInit |tee results/WGAN/trainWGAN-vaeInit.txt
10 |
11 |
12 |
--------------------------------------------------------------------------------
/results/OTFlow/trainOTFlow:
--------------------------------------------------------------------------------
1 | ------device=cpu, nTh=2, width=32, batch_size=256, num_steps=20000------
2 | ------out_file: results/OTFlow/trainOTFlow------
3 | step J J_L J_ML J_HJB
4 | 001000 1.6560e+01 3.1925e-01 1.5329e+00 1.8234e-01
5 | 002000 9.6450e+00 3.3269e-01 8.6591e-01 1.3064e-01
6 | 003000 7.5762e+00 3.5424e-01 6.5991e-01 1.2457e-01
7 | 004000 6.8949e+00 3.6180e-01 5.9144e-01 1.2373e-01
8 | 005000 6.4849e+00 3.6729e-01 5.4807e-01 1.2739e-01
9 | 006000 6.1083e+00 3.6985e-01 5.1064e-01 1.2642e-01
10 | 007000 5.8387e+00 3.6832e-01 4.8543e-01 1.2322e-01
11 | 008000 5.7509e+00 3.6623e-01 4.7970e-01 1.1755e-01
12 | 009000 5.6017e+00 3.6534e-01 4.6686e-01 1.1355e-01
13 | 010000 5.4999e+00 3.6431e-01 4.5894e-01 1.0924e-01
14 | 011000 5.3779e+00 3.6332e-01 4.4920e-01 1.0452e-01
15 | 012000 5.2646e+00 3.6242e-01 4.3959e-01 1.0126e-01
16 | 013000 5.1163e+00 3.6151e-01 4.2714e-01 9.6683e-02
17 | 014000 5.0453e+00 3.5966e-01 4.2190e-01 9.3342e-02
18 | 015000 4.9553e+00 3.5814e-01 4.1512e-01 8.9203e-02
19 | 016000 4.8560e+00 3.5744e-01 4.0607e-01 8.7583e-02
20 | 017000 4.7735e+00 3.5802e-01 3.9844e-01 8.6216e-02
21 | 018000 4.7460e+00 3.5733e-01 3.9660e-01 8.4522e-02
22 | 019000 4.6819e+00 3.5838e-01 3.9125e-01 8.2200e-02
23 | 020000 4.6584e+00 3.5816e-01 3.8998e-01 8.0097e-02
24 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/results/WGAN/trainWGAN-randomInit.txt:
--------------------------------------------------------------------------------
1 | ------device=cuda:0, q=2, batch_size=64, num_steps=50000, w_disc=32, w_dec=32------
2 | ------out_file: results/DCGAN/trainWGAN-randomInit------
3 | step J_GAN J_Gen ProbDist
4 | 001000 9.0434e-02 3.6219e-02 9.5863e+02
5 | 002000 3.0011e-02 1.7892e-02 2.3149e+02
6 | 003000 2.4966e-02 1.2125e-02 1.6959e+02
7 | 004000 2.0141e-02 7.8579e-03 1.4928e+02
8 | 005000 1.7243e-02 4.9258e-03 1.3729e+02
9 | 006000 1.4991e-02 4.5250e-03 1.2822e+02
10 | 007000 1.3449e-02 4.1725e-03 1.2226e+02
11 | 008000 1.2059e-02 3.4655e-03 1.1704e+02
12 | 009000 1.0941e-02 4.7856e-03 1.1407e+02
13 | 010000 9.9864e-03 3.7305e-03 1.1019e+02
14 | 011000 9.3497e-03 6.0258e-03 1.0907e+02
15 | 012000 8.7816e-03 7.7004e-03 1.0663e+02
16 | 013000 8.3173e-03 5.8963e-03 1.0627e+02
17 | 014000 7.9613e-03 4.1980e-03 1.0445e+02
18 | 015000 7.5758e-03 4.1143e-03 1.0471e+02
19 | 016000 7.2847e-03 5.2026e-03 1.0285e+02
20 | 017000 7.0688e-03 5.4935e-03 1.0152e+02
21 | 018000 6.7694e-03 6.2888e-03 1.0057e+02
22 | 019000 6.4661e-03 5.1770e-03 1.0075e+02
23 | 020000 6.2607e-03 5.3424e-03 9.9820e+01
24 | 021000 6.0518e-03 3.9571e-03 9.9159e+01
25 | 022000 5.8393e-03 3.5221e-03 9.8704e+01
26 | 023000 5.6738e-03 3.4962e-03 9.8189e+01
27 | 024000 5.5298e-03 2.3215e-03 9.7550e+01
28 | 025000 5.4000e-03 1.7744e-03 9.6966e+01
29 | 026000 5.2774e-03 1.0285e-03 9.6718e+01
30 | 027000 5.2354e-03 9.6602e-04 9.6132e+01
31 | 028000 5.1198e-03 8.4378e-04 9.6180e+01
32 | 029000 5.0476e-03 1.6136e-03 9.5849e+01
33 | 030000 4.9132e-03 1.7841e-03 9.5108e+01
34 | 031000 4.8423e-03 2.8072e-03 9.5625e+01
35 | 032000 4.7830e-03 2.4392e-03 9.5571e+01
36 | 033000 4.7676e-03 2.6003e-03 9.5113e+01
37 | 034000 4.6369e-03 3.0293e-03 9.4740e+01
38 | 035000 4.5631e-03 2.2041e-03 9.4086e+01
39 | 036000 4.4883e-03 3.2473e-03 9.3532e+01
40 | 037000 4.4246e-03 2.9545e-03 9.3902e+01
41 | 038000 4.3665e-03 3.2789e-03 9.4058e+01
42 | 039000 4.3227e-03 2.3956e-03 9.3894e+01
43 | 040000 4.2762e-03 1.6265e-03 9.2691e+01
44 | 041000 4.2137e-03 1.3642e-03 9.4176e+01
45 | 042000 4.1963e-03 1.1782e-03 9.3716e+01
46 | 043000 4.1571e-03 8.6539e-04 9.3948e+01
47 | 044000 4.1267e-03 1.3113e-03 9.3176e+01
48 | 045000 4.1107e-03 4.9682e-04 9.3393e+01
49 | 046000 4.0724e-03 6.0957e-04 9.3622e+01
50 | 047000 4.0431e-03 4.6356e-04 9.2811e+01
51 | 048000 4.0419e-03 2.5190e-05 9.3241e+01
52 | 049000 4.0043e-03 -5.5302e-06 9.3191e+01
53 | 050000 3.9637e-03 3.9199e-04 9.2859e+01
54 |
--------------------------------------------------------------------------------
/results/WGAN/trainWGAN-vaeInit.txt:
--------------------------------------------------------------------------------
1 | initialize g with weights in results/VAE/trainVAE-g.pt
2 | ------device=cuda:0, q=2, batch_size=64, num_steps=50000, w_disc=32, w_dec=32------
3 | ------out_file: results/DCGAN/trainWGAN-vaeInit------
4 | step J_GAN J_Gen ProbDist
5 | 001000 9.8016e-02 3.7898e-02 2.1267e+02
6 | 002000 4.0626e-02 2.3535e-02 1.7130e+02
7 | 003000 2.4519e-02 5.8282e-03 1.2997e+02
8 | 004000 1.9377e-02 3.2689e-02 1.1579e+02
9 | 005000 1.4879e-02 3.2376e-02 1.0757e+02
10 | 006000 1.2726e-02 3.8867e-02 1.0229e+02
11 | 007000 1.1038e-02 3.1195e-02 1.0047e+02
12 | 008000 1.0462e-02 3.2218e-02 9.9698e+01
13 | 009000 9.8914e-03 3.0616e-02 9.7902e+01
14 | 010000 9.3823e-03 2.8720e-02 9.7652e+01
15 | 011000 9.0339e-03 2.7796e-02 9.6153e+01
16 | 012000 8.5916e-03 2.4935e-02 9.5867e+01
17 | 013000 8.2343e-03 2.5079e-02 9.4840e+01
18 | 014000 7.9187e-03 2.3666e-02 9.3587e+01
19 | 015000 7.6963e-03 1.9979e-02 9.2664e+01
20 | 016000 7.5541e-03 1.8075e-02 9.2843e+01
21 | 017000 7.3329e-03 1.6485e-02 9.1934e+01
22 | 018000 7.1320e-03 1.3612e-02 9.0954e+01
23 | 019000 6.9459e-03 1.4304e-02 9.0953e+01
24 | 020000 6.7611e-03 1.5827e-02 9.0958e+01
25 | 021000 6.6543e-03 1.6127e-02 9.0671e+01
26 | 022000 6.5316e-03 1.5347e-02 9.0589e+01
27 | 023000 6.4119e-03 1.5592e-02 9.0986e+01
28 | 024000 6.3097e-03 1.5181e-02 9.0269e+01
29 | 025000 6.2119e-03 1.5757e-02 8.9563e+01
30 | 026000 6.1119e-03 1.5680e-02 8.9119e+01
31 | 027000 6.0015e-03 1.5852e-02 8.8693e+01
32 | 028000 5.8711e-03 1.4648e-02 8.9046e+01
33 | 029000 5.7979e-03 1.3644e-02 8.7846e+01
34 | 030000 5.7710e-03 1.3030e-02 8.9200e+01
35 | 031000 5.7028e-03 1.4651e-02 8.8510e+01
36 | 032000 5.6183e-03 1.4748e-02 8.8049e+01
37 | 033000 5.5419e-03 1.4654e-02 8.8659e+01
38 | 034000 5.4964e-03 1.3939e-02 8.8095e+01
39 | 035000 5.4264e-03 1.3998e-02 8.8708e+01
40 | 036000 5.3579e-03 1.2901e-02 8.7748e+01
41 | 037000 5.3108e-03 1.3888e-02 8.7898e+01
42 | 038000 5.2779e-03 1.4576e-02 8.7882e+01
43 | 039000 5.1993e-03 1.3370e-02 8.8319e+01
44 | 040000 5.1644e-03 1.2683e-02 8.7140e+01
45 | 041000 5.1564e-03 1.3523e-02 8.8754e+01
46 | 042000 5.1318e-03 1.3073e-02 8.8003e+01
47 | 043000 5.1313e-03 1.1972e-02 8.7905e+01
48 | 044000 5.0500e-03 1.1824e-02 8.7697e+01
49 | 045000 5.0123e-03 1.1280e-02 8.7181e+01
50 | 046000 4.9632e-03 1.1529e-02 8.8270e+01
51 | 047000 4.9310e-03 1.2047e-02 8.7663e+01
52 | 048000 4.8943e-03 1.1841e-02 8.8133e+01
53 | 049000 4.8601e-03 1.1867e-02 8.7772e+01
54 | 050000 4.8166e-03 1.1109e-02 8.7176e+01
55 |
--------------------------------------------------------------------------------
/results/DCGAN/trainDCGAN-randomInit.txt:
--------------------------------------------------------------------------------
1 | ------device=cuda:0, q=2, batch_size=64, num_steps=50000, w_disc=32, w_dec=32,------
2 | ------out_file: results/DCGAN/trainDCGAN-randomInit------
3 | step J_GAN J_Gen ProbDist
4 | 001000 -1.6240e-01 -7.0221e-02 8.2363e+02
5 | 002000 -4.4369e-02 -1.9166e-02 5.2543e+02
6 | 003000 -5.2540e-05 -1.7542e-05 7.4219e+02
7 | 004000 -7.8862e-02 -1.9111e-02 4.9744e+02
8 | 005000 -2.6664e-01 -1.0555e-01 2.5069e+02
9 | 006000 -2.9983e-01 -1.2022e-01 2.2268e+02
10 | 007000 -2.5933e-01 -1.0743e-01 2.0266e+02
11 | 008000 -2.7962e-01 -1.1531e-01 1.7890e+02
12 | 009000 -3.2447e-01 -1.3183e-01 1.7197e+02
13 | 010000 -3.0801e-01 -1.2520e-01 1.6948e+02
14 | 011000 -2.9975e-01 -1.2312e-01 1.7978e+02
15 | 012000 -3.0585e-01 -1.2521e-01 1.6989e+02
16 | 013000 -3.3236e-01 -1.3747e-01 1.6467e+02
17 | 014000 -3.8498e-01 -1.6704e-01 1.4661e+02
18 | 015000 -3.6886e-01 -1.5774e-01 1.4381e+02
19 | 016000 -3.6060e-01 -1.5909e-01 1.3357e+02
20 | 017000 -3.5183e-01 -1.5064e-01 1.3173e+02
21 | 018000 -3.4364e-01 -1.4763e-01 1.3900e+02
22 | 019000 -4.0491e-01 -1.7854e-01 1.3424e+02
23 | 020000 -3.9583e-01 -1.7362e-01 1.2394e+02
24 | 021000 -3.5344e-01 -1.5191e-01 1.2611e+02
25 | 022000 -3.6819e-01 -1.5867e-01 1.3833e+02
26 | 023000 -3.4235e-01 -1.4709e-01 1.2866e+02
27 | 024000 -3.5760e-01 -1.5520e-01 1.2599e+02
28 | 025000 -3.3611e-01 -1.4285e-01 1.2532e+02
29 | 026000 -3.5132e-01 -1.4984e-01 1.2587e+02
30 | 027000 -3.5437e-01 -1.5670e-01 1.2115e+02
31 | 028000 -3.8594e-01 -1.6715e-01 1.2182e+02
32 | 029000 -3.3948e-01 -1.4934e-01 1.1738e+02
33 | 030000 -3.3234e-01 -1.4675e-01 1.1683e+02
34 | 031000 -3.2204e-01 -1.4313e-01 1.1764e+02
35 | 032000 -3.2760e-01 -1.4531e-01 1.2158e+02
36 | 033000 -3.2468e-01 -1.3883e-01 1.1953e+02
37 | 034000 -3.1281e-01 -1.3564e-01 1.1946e+02
38 | 035000 -3.1644e-01 -1.3420e-01 1.1847e+02
39 | 036000 -3.3908e-01 -1.4482e-01 1.1678e+02
40 | 037000 -3.4147e-01 -1.5113e-01 1.1595e+02
41 | 038000 -2.9663e-01 -1.2712e-01 1.1481e+02
42 | 039000 -2.8986e-01 -1.2577e-01 1.1691e+02
43 | 040000 -2.8042e-01 -1.1591e-01 1.1890e+02
44 | 041000 -2.8173e-01 -1.2197e-01 1.1177e+02
45 | 042000 -2.7905e-01 -1.1820e-01 1.1523e+02
46 | 043000 -2.7027e-01 -1.1974e-01 1.1492e+02
47 | 044000 -2.7638e-01 -1.1535e-01 1.1516e+02
48 | 045000 -2.8521e-01 -1.2421e-01 1.1780e+02
49 | 046000 -2.7991e-01 -1.2134e-01 1.1599e+02
50 | 047000 -2.5829e-01 -1.1118e-01 1.2519e+02
51 | 048000 -3.1127e-01 -1.3053e-01 1.1762e+02
52 | 049000 -2.6953e-01 -1.1793e-01 1.1570e+02
53 | 050000 -2.9524e-01 -1.3025e-01 1.1301e+02
54 |
--------------------------------------------------------------------------------
/results/DCGAN/trainDCGAN-vaeInit.txt:
--------------------------------------------------------------------------------
1 | initialize g with weights in results/VAE/trainVAE-g.pt
2 | ------device=cuda:0, q=2, batch_size=64, num_steps=50000, w_disc=32, w_dec=32,------
3 | ------out_file: results/DCGAN/trainDCGAN-vaeInit------
4 | step J_GAN J_Gen ProbDist
5 | 001000 -1.5065e-01 -5.7034e-02 2.3224e+02
6 | 002000 -2.4258e-01 -1.1107e-01 1.5204e+02
7 | 003000 -2.6185e-01 -1.1970e-01 1.2679e+02
8 | 004000 -2.5740e-01 -1.1888e-01 1.2323e+02
9 | 005000 -2.6669e-01 -1.2124e-01 1.1961e+02
10 | 006000 -2.3286e-01 -1.0311e-01 1.1147e+02
11 | 007000 -2.2996e-01 -1.0447e-01 1.0950e+02
12 | 008000 -2.1535e-01 -1.0029e-01 1.1002e+02
13 | 009000 -2.3136e-01 -1.0422e-01 1.1030e+02
14 | 010000 -2.1367e-01 -9.4286e-02 1.1156e+02
15 | 011000 -2.0159e-01 -8.8985e-02 1.1032e+02
16 | 012000 -2.1341e-01 -9.4397e-02 1.1280e+02
17 | 013000 -2.1593e-01 -9.6983e-02 1.1122e+02
18 | 014000 -1.8815e-01 -8.5642e-02 1.1019e+02
19 | 015000 -2.0207e-01 -9.2386e-02 1.0957e+02
20 | 016000 -2.0112e-01 -9.0253e-02 1.1057e+02
21 | 017000 -2.1544e-01 -9.3563e-02 1.0798e+02
22 | 018000 -2.1136e-01 -9.1400e-02 1.0599e+02
23 | 019000 -2.2456e-01 -9.6762e-02 1.0651e+02
24 | 020000 -1.8257e-01 -8.3164e-02 1.0817e+02
25 | 021000 -2.0020e-01 -8.9516e-02 1.0715e+02
26 | 022000 -1.9245e-01 -8.6282e-02 1.0793e+02
27 | 023000 -1.8340e-01 -7.9962e-02 1.0357e+02
28 | 024000 -1.6653e-01 -7.4786e-02 1.0495e+02
29 | 025000 -1.6308e-01 -7.4571e-02 1.0434e+02
30 | 026000 -1.5791e-01 -6.9270e-02 1.0587e+02
31 | 027000 -1.7612e-01 -7.6740e-02 1.0506e+02
32 | 028000 -1.8177e-01 -8.1500e-02 1.0438e+02
33 | 029000 -1.7027e-01 -7.6036e-02 1.0431e+02
34 | 030000 -1.5498e-01 -6.8826e-02 1.0276e+02
35 | 031000 -1.7734e-01 -7.9945e-02 1.0110e+02
36 | 032000 -1.5316e-01 -6.9440e-02 1.0415e+02
37 | 033000 -1.7264e-01 -7.7597e-02 1.0333e+02
38 | 034000 -1.6258e-01 -7.1064e-02 1.0077e+02
39 | 035000 -1.5907e-01 -7.0660e-02 9.8916e+01
40 | 036000 -1.5300e-01 -6.9610e-02 1.0103e+02
41 | 037000 -1.2759e-01 -5.4200e-02 1.0244e+02
42 | 038000 -1.3363e-01 -5.9072e-02 1.0061e+02
43 | 039000 -1.3983e-01 -6.1393e-02 1.0243e+02
44 | 040000 -1.2590e-01 -5.7687e-02 1.0177e+02
45 | 041000 -1.3524e-01 -6.0630e-02 1.0402e+02
46 | 042000 -1.3517e-01 -6.1084e-02 1.0345e+02
47 | 043000 -1.3010e-01 -5.7706e-02 1.0231e+02
48 | 044000 -1.3336e-01 -5.9891e-02 1.0317e+02
49 | 045000 -1.3037e-01 -5.9534e-02 1.0790e+02
50 | 046000 -1.3391e-01 -5.7453e-02 1.0566e+02
51 | 047000 -1.5771e-01 -6.9367e-02 1.0580e+02
52 | 048000 -1.3496e-01 -5.8078e-02 1.0980e+02
53 | 049000 -1.1879e-01 -5.2954e-02 1.0454e+02
54 | 050000 -1.3243e-01 -5.8981e-02 1.0384e+02
55 |
--------------------------------------------------------------------------------
/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 | import torch.nn.functional as F
6 |
7 | class VAE(nn.Module):
8 | """
9 | Variational Autoencoder (VAE)
10 |
11 | """
12 | def __init__(self,e,g):
13 | """
14 | Initialize VAE
15 | :param e: encoder network e(z|x), provides parameters of approximate posteriors
16 | :param g: generator g(z), provides samples in data space
17 | """
18 | super(VAE, self).__init__()
19 | self.e = e
20 | self.g = g
21 |
22 | def ELBO(self,x):
23 | """
24 | Empirical lower bound on p_{\theta}(x)
25 | :param x: sample from dataset
26 | :return:
27 | """
28 |
29 | s = x.shape[0] # number of samples
30 | mu, logvar = self.e(x) # parameters of approximate posterior
31 | z, eps = self.sample_ezx(x,mu, logvar)
32 | gz = self.g(z)
33 |
34 | log_pzx = torch.sum(self.log_prob_pzx(z,x,gz)[0])
35 |
36 | log_ezx = -0.5*torch.norm(eps)**2 - 0.5*torch.sum(logvar) - (z.shape[1]/2)*np.log(2*np.pi)
37 |
38 | return (-log_pzx+log_ezx)/s, (-log_pzx/s).item(), (log_ezx/s).item(), gz.detach(), mu.detach()
39 |
40 | def sample_ezx(self, x , mu=None, logvar=None, sample=None):
41 | """
42 | draw sample from approximate posterior
43 |
44 | :param x: sample from dataset
45 | :param mu: mean of approximate posterior (optional; will be computed here if is None)
46 | :param logvar: log-variance of approximate posterior (optional; will be computed here if is None)
47 | :param sample: flag whether to sample or return the mean
48 | :return:
49 | """
50 | if mu is None or logvar is None:
51 | mu, logvar = self.e(x)
52 |
53 | if sample is None:
54 | sample = self.training
55 |
56 | if sample:
57 | std = torch.exp(logvar)
58 | eps = torch.randn_like(std)
59 | return std * eps + mu, eps
60 | else:
61 | return mu, mu
62 |
63 | def log_prob_ezx(self,z,x):
64 | """
65 | :param z: latent sample
66 | :param x: data sample
67 | :return: log(e(z|x))
68 | """
69 | q = z.shape[1]
70 |
71 | mu, logvar = self.e(x)
72 | ezx = -torch.sum((0.5 / torch.exp(logvar)) * (z - mu) ** 2, dim=1) - 0.5 * torch.sum(logvar,dim=1) - (q/2)*np.log(2*np.pi)
73 | return ezx
74 |
75 | def log_prob_pzx(self,z,x,gz=None):
76 | """
77 | :param z: latent sample
78 | :param x: data sample
79 | :return: log(p(z,x)) = log(p(x|z)) + log(p(z)), log(p(x|z|), log(p(z))
80 | """
81 | if gz is None:
82 | gz = self.g(z)
83 | n = x.shape[1]
84 | px = -F.binary_cross_entropy(gz.view(-1, 784), x.view(-1, 784), reduction='none')
85 | px = torch.sum(px,dim=1)
86 | pz = - 0.5 * torch.norm(z, dim=1) ** 2 - (n/2)*np.log(2*np.pi)
87 | return px + pz, px, pz
88 |
89 |
90 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepGenerativeModelingIntro
2 |
3 | PyTorch Code used in the paper *Introduction to Deep Generative Modeling*:
4 |
5 | @article{RuthottoHaber2021,
6 | title = {An Introduction to Deep Generative Modeling},
7 | year = {2021},
8 | journal = {arXiv preprint arXiv:tbd},
9 | author = {L. Ruthotto and E. Haber},
10 | pages = {25 pages},
11 | url={https://arxiv.org/abs/2103.05180}
12 | }
13 |
14 | ## Run Examples from the Terminal
15 |
16 | To reproduce the examples from the paper (up to randomization), we provide a shell script `runAll.sh`.
17 |
18 | ## Run Examples in Colab
19 |
20 | The `examples` directory contains interactive version of the examples from the paper. Those can be run locally or using
21 | Google Colab. For the latter option, you may click the badges below:
22 |
23 | 1. [Two-Dimensional Normalizing Flow Examples with Real NVP](https://github.com/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/RealNVP.ipynb) [](https://colab.research.google.com/github/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/RealNVP.ipynb)
24 | 1. [Two-Dimensional Continuous Normalizing Flow Example with OT-Flow](https://github.com/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/OTFlow.ipynb) [](https://colab.research.google.com/github/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/OTFlow.ipynb)
25 | 1. [Variational Autoencoder for MNIST Image Generation](https://github.com/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/VAE.ipynb) [](https://colab.research.google.com/github/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/VAE.ipynb)
26 | 1. [DCGAN for MNIST Image Generation](https://github.com/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/DCGAN.ipynb) [](https://colab.research.google.com/github/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/DCGAN.ipynb)
27 | 1. [WGAN for MNIST Image Generation](https://github.com/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/WGAN.ipynb) [](https://colab.research.google.com/github/EmoryMLIP/DeepGenerativeModelingIntro/blob/main/examples/WGAN.ipynb)
28 |
29 | ## Dependencies
30 |
31 | The code is based on pytorch and some other standard machine learning packages. In addition, training the continuous normalizing flow example requires [OT-Flow](https://github.com/EmoryMLIP/OT-Flow).
32 |
33 | ## Acknowledgements
34 |
35 | This material is in part based upon work supported by the National Science Foundation under Grant Number 1751636, the Air Force Office of Scientific Research under Grant Number 20RT0237, and
36 | the US DOE's Office of Advanced Scientific Computing Research Field Work Proposal 20-023231. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the funding agencies.
37 |
--------------------------------------------------------------------------------
/modelMNIST.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch
3 | from torch import nn
4 |
5 | class Generator(nn.Module):
6 | def __init__(self,w,q):
7 | """
8 | Initialize generator
9 |
10 | :param w: number of channels on the finest level
11 | :param q: latent space dimension
12 | """
13 | super(Generator, self).__init__()
14 | self.w = w
15 | self.fc = nn.Linear(q, w * 2 * 7 * 7)
16 | self.conv2 = nn.ConvTranspose2d(w * 2, w, kernel_size=4, stride=2, padding=1)
17 | self.conv1 = nn.ConvTranspose2d(w, 1, kernel_size=4, stride=2, padding=1)
18 |
19 | self.bn1 = nn.BatchNorm2d(w)
20 | self.bn2 = nn.BatchNorm2d(2*w)
21 |
22 | def forward(self, z):
23 | """
24 | :param z: latent space sample
25 | :return: g(z)
26 | """
27 | gz = self.fc(z)
28 | gz = gz.view(gz.size(0), self.w * 2, 7, 7)
29 | gz = self.bn2(gz)
30 | gz = F.relu(gz)
31 | gz = self.conv2(gz)
32 | gz = self.bn1(gz)
33 |
34 | gz = F.relu(gz)
35 | gz = torch.sigmoid(self.conv1(gz))
36 | return gz
37 |
38 | class Encoder(nn.Module):
39 | def __init__(self,w,q):
40 | """
41 | Initialize the encoder for the VAE
42 |
43 | :param w: number of channels on finest level
44 | :param q: latent space dimension
45 | """
46 | super(Encoder, self).__init__()
47 | self.conv1 = nn.Conv2d(1, w, kernel_size=4, stride=2, padding=1)
48 | self.conv2 = nn.Conv2d(w, w * 2, kernel_size=4, stride=2, padding=1)
49 | self.fc_mu = nn.Linear(w * 2 * 7 * 7, q)
50 | self.fc_logvar = nn.Linear(w * 2 * 7 * 7, q)
51 |
52 | def forward(self, x):
53 | """
54 | :param x: MNIST image
55 | :return: mu,logvar that parameterize e(z|x) = N(mu, diag(exp(logvar)))
56 | """
57 | x = F.relu(self.conv1(x))
58 | x = F.relu(self.conv2(x))
59 | x = x.view(x.size(0), -1)
60 | mu = self.fc_mu(x)
61 | logvar = self.fc_logvar(x)
62 | return mu, logvar
63 |
64 | class Discriminator(nn.Module):
65 | def __init__(self, w,useSigmoid=True):
66 | """
67 | Discriminator for GANs
68 | :param w: number of channels on finest level
69 | :param useSigmoid: true --> DCGAN, false --> WGAN
70 | """
71 | super(Discriminator, self).__init__()
72 | self.w = w
73 | self.useSigmoid = useSigmoid
74 | self.conv1 = nn.Conv2d(1, w, kernel_size=4, stride=2, padding=1)
75 | self.conv2 = nn.Conv2d(w, w * 2, kernel_size=4, stride=2, padding=1)
76 | self.fc = nn.Linear(w * 2 * 7 * 7, 1)
77 | self.bn1 = nn.BatchNorm2d(w)
78 | self.bn2 = nn.BatchNorm2d(2*w)
79 |
80 | def forward(self,x):
81 | """
82 | :param x: MNIST image or generated image
83 | :return: d(x), value of discriminator
84 | """
85 | x = (x-0.5)/0.5 #
86 | x = self.conv1(x)
87 | x = self.bn1(x)
88 | x = F.leaky_relu(x,0.2)
89 | x = self.conv2(x)
90 | x = self.bn2(x)
91 |
92 | x = F.leaky_relu(x,0.2)
93 | x = x.view(x.shape[0],-1)
94 | x = self.fc(x)
95 | if self.useSigmoid:
96 | x = torch.sigmoid(x)
97 | return x
98 |
--------------------------------------------------------------------------------
/trainVAEmnist.py:
--------------------------------------------------------------------------------
1 | from vae import *
2 | from torch import distributions
3 | import argparse
4 |
5 | import numpy as np
6 |
7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8 |
9 | ## load MNIST
10 | parser = argparse.ArgumentParser('VAE')
11 | parser.add_argument("--batch_size" , type=int, default=128, help="batch size")
12 | parser.add_argument("--q" , type=int, default=2, help="latent space dimension")
13 | parser.add_argument("--width_enc" , type=int, default=4, help="width of encoder")
14 | parser.add_argument("--width_dec" , type=int, default=4, help="width of decoder")
15 | parser.add_argument("--num_epochs" , type=int, default=2, help="number of epochs")
16 | parser.add_argument("--out_file", type=str, default=None, help="base filename saving trained model (extension .pt), history (extension .mat), and intermediate plots (extension .png")
17 | args = parser.parse_args()
18 |
19 |
20 | import torchvision.transforms as transforms
21 | from torch.utils.data import DataLoader
22 | from torchvision.datasets import MNIST
23 |
24 | img_transform = transforms.Compose([
25 | transforms.ToTensor()
26 | ])
27 |
28 | train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
29 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
30 |
31 | test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
32 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True)
33 |
34 |
35 | from modelMNIST import Encoder, Generator
36 | g = Generator(args.width_dec,args.q)
37 | e = Encoder(args.width_enc,args.q)
38 |
39 | vae = VAE(e,g).to(device)
40 |
41 | optimizer = torch.optim.Adam(params=vae.parameters(), lr=1e-3, weight_decay=1e-5)
42 |
43 | his = np.zeros((args.num_epochs,6))
44 |
45 | print((3*"--" + "device=%s, q=%d, batch_size=%d, num_epochs=%d, w_enc=%d, w_dec=%d" + 3*"--") % (device, args.q, args.batch_size, args.num_epochs, args.width_enc, args.width_dec))
46 |
47 | if args.out_file is not None:
48 | import os
49 | out_dir, fname = os.path.split(args.out_file)
50 | if not os.path.exists(out_dir):
51 | os.makedirs(out_dir)
52 | print((3*"--" + "out_file: %s" + 3*"--") % (args.out_file))
53 |
54 | print((7*"%7s ") % ("epoch","Jtrain","pzxtrain","ezxtrain","Jval","pzxval","ezxval"))
55 |
56 |
57 | for epoch in range(args.num_epochs):
58 | vae.train()
59 |
60 | train_loss = 0.0
61 | train_pzx = 0.0
62 | train_ezx = 0.0
63 | num_ex = 0
64 | for image_batch, _ in train_dataloader:
65 | image_batch = image_batch.to(device)
66 |
67 | # take a step
68 | loss, pzx, ezx,gz,mu = vae.ELBO(image_batch)
69 | optimizer.zero_grad()
70 | loss.backward()
71 | optimizer.step()
72 |
73 | # update history
74 | train_loss += loss.item()*image_batch.shape[0]
75 | train_pzx += pzx*image_batch.shape[0]
76 | train_ezx += ezx*image_batch.shape[0]
77 | num_ex += image_batch.shape[0]
78 |
79 | train_loss /= num_ex
80 | train_pzx /= num_ex
81 | train_ezx /= num_ex
82 |
83 | # evaluate validation points
84 | vae.eval()
85 | val_loss = 0.0
86 | val_pzx = 0.0
87 | val_ezx = 0.0
88 | num_ex = 0
89 | for image_batch, label_batch in test_dataloader:
90 | with torch.no_grad():
91 | image_batch = image_batch.to(device)
92 | # vae reconstruction
93 | loss, pzx, ezx, gz, mu = vae.ELBO(image_batch)
94 | val_loss += loss.item() * image_batch.shape[0]
95 | val_pzx += pzx * image_batch.shape[0]
96 | val_ezx += ezx * image_batch.shape[0]
97 | num_ex += image_batch.shape[0]
98 |
99 | val_loss /= num_ex
100 | val_pzx/= num_ex
101 | val_ezx/= num_ex
102 |
103 | print(("%06d " + 6*"%1.4e ") %
104 | (epoch + 1, train_loss, train_pzx, train_ezx, val_loss, val_pzx, val_ezx))
105 |
106 | his[epoch,:] = [train_loss, train_pzx, train_ezx, val_loss, val_pzx, val_ezx]
107 |
108 | if args.out_file is not None:
109 | torch.save(vae.g.state_dict(), ("%s-g.pt") % (args.out_file))
110 | torch.save(vae.state_dict(), ("%s.pt") % (args.out_file))
111 | from scipy.io import savemat
112 | savemat(("%s.mat") % (args.out_file), {"his":his})
--------------------------------------------------------------------------------
/results/VAE/trainVAE.txt:
--------------------------------------------------------------------------------
1 | ------device=cuda:0, q=2, batch_size=64, num_epochs=50, w_enc=32, w_dec=32------
2 | ------out_file: results/VAE/trainVAE------
3 | epoch Jtrain pzxtrain ezxtrain Jval pzxval ezxval
4 | 000001 1.7771e+02 1.7577e+02 1.9351e+00 1.6330e+02 1.6098e+02 2.3209e+00
5 | 000002 1.5944e+02 1.5729e+02 2.1508e+00 1.5661e+02 1.5412e+02 2.4921e+00
6 | 000003 1.5594e+02 1.5366e+02 2.2728e+00 1.5409e+02 1.5157e+02 2.5213e+00
7 | 000004 1.5415e+02 1.5182e+02 2.3312e+00 1.5290e+02 1.5019e+02 2.7143e+00
8 | 000005 1.5297e+02 1.5059e+02 2.3822e+00 1.5206e+02 1.4936e+02 2.7023e+00
9 | 000006 1.5203e+02 1.4962e+02 2.4179e+00 1.5153e+02 1.4874e+02 2.7933e+00
10 | 000007 1.5135e+02 1.4891e+02 2.4409e+00 1.5074e+02 1.4801e+02 2.7246e+00
11 | 000008 1.5071e+02 1.4825e+02 2.4669e+00 1.5018e+02 1.4744e+02 2.7378e+00
12 | 000009 1.5024e+02 1.4776e+02 2.4783e+00 1.5039e+02 1.4754e+02 2.8493e+00
13 | 000010 1.4981e+02 1.4731e+02 2.4965e+00 1.4941e+02 1.4657e+02 2.8432e+00
14 | 000011 1.4940e+02 1.4690e+02 2.5029e+00 1.4950e+02 1.4683e+02 2.6757e+00
15 | 000012 1.4911e+02 1.4658e+02 2.5252e+00 1.4935e+02 1.4650e+02 2.8510e+00
16 | 000013 1.4881e+02 1.4628e+02 2.5295e+00 1.4872e+02 1.4596e+02 2.7607e+00
17 | 000014 1.4846e+02 1.4592e+02 2.5445e+00 1.4897e+02 1.4617e+02 2.7960e+00
18 | 000015 1.4821e+02 1.4566e+02 2.5490e+00 1.5004e+02 1.4718e+02 2.8554e+00
19 | 000016 1.4801e+02 1.4544e+02 2.5663e+00 1.4870e+02 1.4595e+02 2.7539e+00
20 | 000017 1.4781e+02 1.4524e+02 2.5726e+00 1.4874e+02 1.4584e+02 2.9001e+00
21 | 000018 1.4756e+02 1.4498e+02 2.5795e+00 1.4822e+02 1.4548e+02 2.7444e+00
22 | 000019 1.4738e+02 1.4480e+02 2.5752e+00 1.4822e+02 1.4539e+02 2.8321e+00
23 | 000020 1.4718e+02 1.4459e+02 2.5877e+00 1.4834e+02 1.4547e+02 2.8747e+00
24 | 000021 1.4702e+02 1.4443e+02 2.5927e+00 1.4831e+02 1.4548e+02 2.8288e+00
25 | 000022 1.4686e+02 1.4426e+02 2.5941e+00 1.4804e+02 1.4520e+02 2.8367e+00
26 | 000023 1.4672e+02 1.4412e+02 2.5977e+00 1.4806e+02 1.4512e+02 2.9358e+00
27 | 000024 1.4656e+02 1.4396e+02 2.6019e+00 1.4791e+02 1.4513e+02 2.7793e+00
28 | 000025 1.4642e+02 1.4381e+02 2.6084e+00 1.4795e+02 1.4518e+02 2.7680e+00
29 | 000026 1.4627e+02 1.4366e+02 2.6123e+00 1.4769e+02 1.4493e+02 2.7604e+00
30 | 000027 1.4618e+02 1.4356e+02 2.6213e+00 1.4777e+02 1.4488e+02 2.8827e+00
31 | 000028 1.4603e+02 1.4341e+02 2.6209e+00 1.4794e+02 1.4521e+02 2.7254e+00
32 | 000029 1.4599e+02 1.4336e+02 2.6311e+00 1.4751e+02 1.4472e+02 2.7926e+00
33 | 000030 1.4581e+02 1.4317e+02 2.6360e+00 1.4796e+02 1.4527e+02 2.6887e+00
34 | 000031 1.4575e+02 1.4312e+02 2.6355e+00 1.4765e+02 1.4489e+02 2.7599e+00
35 | 000032 1.4565e+02 1.4302e+02 2.6385e+00 1.4778e+02 1.4483e+02 2.9423e+00
36 | 000033 1.4554e+02 1.4290e+02 2.6463e+00 1.4728e+02 1.4438e+02 2.8942e+00
37 | 000034 1.4546e+02 1.4282e+02 2.6412e+00 1.4735e+02 1.4453e+02 2.8193e+00
38 | 000035 1.4537e+02 1.4271e+02 2.6566e+00 1.4747e+02 1.4465e+02 2.8138e+00
39 | 000036 1.4531e+02 1.4265e+02 2.6536e+00 1.4742e+02 1.4460e+02 2.8190e+00
40 | 000037 1.4525e+02 1.4260e+02 2.6542e+00 1.4755e+02 1.4467e+02 2.8760e+00
41 | 000038 1.4516e+02 1.4251e+02 2.6503e+00 1.4730e+02 1.4440e+02 2.9045e+00
42 | 000039 1.4506e+02 1.4240e+02 2.6628e+00 1.4730e+02 1.4440e+02 2.9018e+00
43 | 000040 1.4497e+02 1.4231e+02 2.6587e+00 1.4748e+02 1.4469e+02 2.7950e+00
44 | 000041 1.4493e+02 1.4226e+02 2.6631e+00 1.4783e+02 1.4494e+02 2.8864e+00
45 | 000042 1.4486e+02 1.4218e+02 2.6778e+00 1.4716e+02 1.4433e+02 2.8293e+00
46 | 000043 1.4477e+02 1.4211e+02 2.6594e+00 1.4718e+02 1.4440e+02 2.7866e+00
47 | 000044 1.4470e+02 1.4204e+02 2.6662e+00 1.4723e+02 1.4444e+02 2.7945e+00
48 | 000045 1.4463e+02 1.4195e+02 2.6749e+00 1.4786e+02 1.4504e+02 2.8218e+00
49 | 000046 1.4459e+02 1.4192e+02 2.6716e+00 1.4733e+02 1.4446e+02 2.8692e+00
50 | 000047 1.4450e+02 1.4182e+02 2.6802e+00 1.4754e+02 1.4470e+02 2.8353e+00
51 | 000048 1.4446e+02 1.4179e+02 2.6775e+00 1.4713e+02 1.4431e+02 2.8164e+00
52 | 000049 1.4440e+02 1.4172e+02 2.6808e+00 1.4727e+02 1.4449e+02 2.7826e+00
53 | 000050 1.4438e+02 1.4169e+02 2.6882e+00 1.4746e+02 1.4459e+02 2.8640e+00
54 |
--------------------------------------------------------------------------------
/trainOTFlow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import argparse
4 | import numpy as np
5 | from torch import distributions
6 | from sklearn import datasets
7 | import matplotlib.pyplot as plt
8 | import sys,os
9 |
10 | ot_flow_dir = '/Users/lruthot/Google Drive/OT-Flow/'
11 | if not os.path.exists(ot_flow_dir):
12 | raise Exception("Cannot find OT-Flow in %s" %(ot_flow_dir))
13 | sys.path.append(os.path.dirname(ot_flow_dir))
14 |
15 | from src.plotter import plot4
16 | from src.OTFlowProblem import *
17 |
18 | device = "cpu"
19 |
20 | parser = argparse.ArgumentParser('OTFlow')
21 | parser.add_argument("--batch_size" , type=int, default=1024, help="batch size")
22 | parser.add_argument("--noise" , type=int, default=0.05, help="noise of moons")
23 | parser.add_argument("--width" , type=int, default=32, help="width of neural net")
24 | parser.add_argument('--alph' , type=str, default='1.0,10.0,5.0',help="alph[0]-> weight for transport costs, alph[1] and alph[2]-> HJB penalties")
25 | parser.add_argument("--nTh" , type=int, default=2, help="number of layers")
26 | parser.add_argument("--nt" , type=int, default=4, help="number of time steps in training")
27 | parser.add_argument("--nt_val" , type=int, default=8, help="number of time steps in validation")
28 | parser.add_argument("--num_steps" , type=int, default=10000, help="number of training steps")
29 | parser.add_argument("--plot_interval" , type=int, default=500, help="plot solution every so many steps")
30 | parser.add_argument("--out_file", type=str, default=None, help="base filename saving trained model (extension .pt), history (extension .mat), and intermediate plots (extension .png")
31 |
32 | args = parser.parse_args()
33 | args.alph = [float(item) for item in args.alph.split(',')]
34 |
35 |
36 | def compute_loss(net, x, nt):
37 | Jc , cs = OTFlowProblem(x, net, [0,1], nt=nt, stepper="rk4", alph=net.alph)
38 | return Jc, cs
39 |
40 |
41 | net = Phi(nTh=args.nTh, m=args.width, d=2, alph=args.alph)
42 | optim = torch.optim.Adam(net.parameters(), lr=1e-2) # lr=0.04 good
43 |
44 | his = np.zeros((0,4))
45 |
46 |
47 | print((3*"--" + "device=%s, nTh=%d, width=%d, batch_size=%d, num_steps=%d" + 3*"--") % (device, args.nTh, args.width, args.batch_size, args.num_steps, ))
48 |
49 | out_dir = "results/OTFlow-noise-%1.5f-nTh-%d-width-%d" % (args.noise, args.nTh, args.width)
50 |
51 | if args.out_file is not None:
52 | import os
53 | out_dir, fname = os.path.split(args.out_file)
54 | if not os.path.exists(out_dir):
55 | os.makedirs(out_dir)
56 | print((3*"--" + "out_file: %s" + 3*"--") % (args.out_file))
57 |
58 | print((5*"%7s ") % ("step","J", "J_L", "J_ML","J_HJB"))
59 |
60 | train_J = 0.0
61 | train_L = 0.0
62 | train_JML = 0.0
63 | num_step = 0
64 | train_HJB = 0.0
65 |
66 | for step in range(args.num_steps):
67 |
68 | x = torch.tensor(datasets.make_moons(n_samples=args.batch_size, noise=args.noise)[0], dtype=torch.float32)
69 | optim.zero_grad()
70 | loss, costs = compute_loss(net, x, nt=args.nt)
71 | loss.backward()
72 | optim.step()
73 |
74 | train_J += loss.item()
75 | train_L += costs[0].item()
76 | train_JML += costs[1].item()
77 | train_HJB += costs[2].item()
78 | num_step += 1
79 |
80 | if (step + 1) % args.plot_interval == 0:
81 | train_J /= num_step
82 | train_JML /= num_step
83 | train_L /= num_step
84 | train_HJB /= num_step
85 |
86 |
87 | print(("%06d " + 4*"%1.4e ") %
88 | (step + 1, train_J, train_L, train_JML, train_HJB))
89 | his = np.vstack([his, [train_J, train_L, train_JML, train_HJB]])
90 | train_J = 0.0
91 | train_L = 0.0
92 | train_JML = 0.0
93 | num_step = 0
94 | train_HJB = 0.0
95 |
96 | with torch.no_grad():
97 | nSamples = 10000
98 | xs = torch.tensor(datasets.make_moons(n_samples=nSamples, noise=args.noise)[0], dtype=torch.float32)
99 | zs = torch.randn(nSamples, 2) # sampling from the standard normal (rho_1)
100 | if args.out_file is not None:
101 | plot4(net, xs, zs, args.nt_val, "%s-step-%d.png" % (args.out_file,step+1), doPaths=True)
102 | plt.show()
103 |
104 | if args.out_file is not None:
105 | torch.save(net.state_dict(), ("%s.pt") % (args.out_file))
106 | from scipy.io import savemat
107 | savemat(("%s.mat") % (args.out_file), {"his":his})
--------------------------------------------------------------------------------
/realNVP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class NF(nn.Module):
5 | """
6 | Normalizing flow for density estimation and sampling
7 |
8 | """
9 | def __init__(self, layers, prior):
10 | """
11 | Initialize normalizing flow
12 |
13 | :param layers: list of layers f_j with tractable inverse and log-determinant (e.g., RealNVPLayer)
14 | :param prior: latent distribution, e.g., distributions.MultivariateNormal(torch.zeros(d), torch.eye(d))
15 | """
16 | super(NF, self).__init__()
17 | self.prior = prior
18 | self.layers = layers
19 |
20 | def g(self, z):
21 | """
22 | :param z: latent variable
23 | :return: g(z) and hidden states
24 | """
25 | y = z
26 | ys = [torch.clone(y).detach()]
27 | for i in range(len(self.layers)):
28 | y, _ = self.layers[i].f(y)
29 | ys.append(torch.clone(y).detach())
30 |
31 | return y, ys
32 |
33 | def ginv(self, x):
34 | """
35 | :param x: sample from dataset
36 | :return: g^(-1)(x), value of log-determinant, and hidden layers
37 | """
38 | p = x
39 | log_det_ginv = torch.zeros(x.shape[0])
40 | ps = [torch.clone(p).detach()]
41 | for i in reversed(range(len(self.layers))):
42 | p, log_det_finv = self.layers[i].finv(p)
43 | ps.append(torch.clone(p).detach().cpu())
44 | log_det_ginv += log_det_finv
45 |
46 | return p, log_det_ginv, ps
47 |
48 | def log_prob(self, x):
49 | """
50 | Compute log-probability of a sample using change of variable formula
51 |
52 | :param x: sample from dataset
53 | :return: logp_{\theta}(x)
54 | """
55 | z, log_det_ginv, _ = self.ginv(x)
56 | return self.prior.log_prob(z) + log_det_ginv
57 |
58 | def sample(self, s):
59 | """
60 | Draw random samples from p_{\theta}
61 |
62 | :param s: number of samples to draw
63 | :return:
64 | """
65 | z = self.prior.sample((s, 1)).squeeze(1)
66 | x, _ = self.g(z)
67 | return x
68 |
69 |
70 | class RealNVPLayer(nn.Module):
71 | """
72 | Real non-volume preserving flow layer
73 |
74 | Reference: Dinh, L., Sohl-Dickstein, J., & Bengio, S. (2016, May 27).
75 | Density estimation using Real NVP. arXiv.org.
76 | """
77 | def __init__(self, s, t, mask):
78 | """
79 | Initialize real NVP layer
80 | :param s: network to compute the shift
81 | :param t: network to compute the translation
82 | :param mask: splits the feature vector into two parts
83 | """
84 | super(RealNVPLayer, self).__init__()
85 | self.mask = mask
86 | self.t = t
87 | self.s = s
88 |
89 | def f(self, y):
90 | """
91 | apply the layer function f
92 | :param y: feature vector
93 | :return:
94 | """
95 | y1 = y * self.mask
96 | s = self.s(y1)
97 | t = self.t(y1)
98 | y2 = (y * torch.exp(s) + t) * (1 - self.mask)
99 | return y1 + y2, torch.sum(s, dim=1)
100 |
101 | def finv(self, y):
102 | """
103 | apply the inverse of the layer function
104 | :param y: feature vector
105 | :return:
106 | """
107 | y1 = self.mask * y
108 | s = self.s(y1)
109 | t = self.t(y1)
110 | y2 = (1 - self.mask) * (y - t) * torch.exp(-s)
111 | return y1 + y2, -torch.sum(s, dim=1)
112 |
113 |
114 | if __name__ == "__main__":
115 | # layers and masks
116 | K = 6
117 | w = 128
118 | layers = torch.nn.ModuleList()
119 | for k in range(K):
120 | mask = torch.tensor([1 - (k % 2), k % 2])
121 | t = nn.Sequential(nn.Linear(2, w), nn.LeakyReLU(), nn.Linear(w, w), nn.LeakyReLU(), nn.Linear(w, 2),
122 | nn.Tanh())
123 | s = nn.Sequential(nn.Linear(2, w), nn.LeakyReLU(), nn.Linear(w, w), nn.LeakyReLU(), nn.Linear(w, 2),
124 | nn.Tanh())
125 | layer = RealNVPLayer(s, t, mask)
126 | layers.append(layer)
127 |
128 | from torch import distributions
129 | prior = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))
130 |
131 | flow = NF(layers, prior)
132 |
133 | x = flow.sample(200).detach()
134 | # test inverse
135 | xt = flow.ginv(flow.g(x)[0])[0].detach()
136 | print(torch.norm(x - xt) / torch.norm(x))
137 |
138 | # test inverse
139 | xt = flow.g(flow.ginv(x)[0])[0].detach()
140 | print(torch.norm(x - xt) / torch.norm(x))
141 |
142 |
143 |
--------------------------------------------------------------------------------
/trainRealNVP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import argparse
4 | import numpy as np
5 | from torch import distributions
6 | from sklearn import datasets
7 | import matplotlib.pyplot as plt
8 |
9 | device = "cpu"
10 |
11 | parser = argparse.ArgumentParser('RealNVP')
12 | parser.add_argument("--batch_size" , type=int, default=256, help="batch size")
13 | parser.add_argument("--noise" , type=int, default=0.05, help="noise of moons")
14 | parser.add_argument("--width" , type=int, default=128, help="width neural nets")
15 | parser.add_argument("--K" , type=int, default=6, help="number of layers")
16 | parser.add_argument("--num_steps" , type=int, default=20000, help="number of training steps")
17 | parser.add_argument("--plot_interval" , type=int, default=1000, help="plot solution every so many steps")
18 | parser.add_argument("--out_file", type=str, default=None, help="base filename saving trained model (extension .pt), history (extension .mat), and intermediate plots (extension .png")
19 |
20 | args = parser.parse_args()
21 |
22 | from realNVP import NF, RealNVPLayer
23 | K = args.K
24 | w = args.width
25 |
26 | layers = torch.nn.ModuleList()
27 | for k in range(K):
28 | mask = torch.tensor([1 - (k % 2), k % 2])
29 | t = nn.Sequential(nn.Linear(2, w), nn.LeakyReLU(), nn.Linear(w, w), nn.LeakyReLU(), nn.Linear(w, 2),
30 | nn.Tanh())
31 | s = nn.Sequential(nn.Linear(2, w), nn.LeakyReLU(), nn.Linear(w, w), nn.LeakyReLU(), nn.Linear(w, 2),
32 | nn.Tanh())
33 | layer = RealNVPLayer(s, t, mask)
34 | layers.append(layer)
35 |
36 | prior = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))
37 | flow = NF(layers, prior).to(device)
38 | optimizer = torch.optim.Adam(flow.parameters(), lr=1e-4)
39 |
40 | his = np.zeros((0,1))
41 |
42 | print((3*"--" + "device=%s, K=%d, width=%d, batch_size=%d, num_steps=%d" + 3*"--") % (device, args.K, args.width, args.batch_size, args.num_steps, ))
43 |
44 | if args.out_file is not None:
45 | import os
46 | out_dir, fname = os.path.split(args.out_file)
47 | if not os.path.exists(out_dir):
48 | os.makedirs(out_dir)
49 | print((3*"--" + "out_file: %s" + 3*"--") % (args.out_file))
50 |
51 | print((2*"%7s ") % ("step","J_ML"))
52 |
53 |
54 | train_JML = 0.0
55 | num_step = 0
56 |
57 | for step in range(args.num_steps):
58 |
59 | x = torch.tensor(datasets.make_moons(n_samples=args.batch_size, noise=args.noise)[0], dtype=torch.float32)
60 | loss = -flow.log_prob(x).mean()
61 | optimizer.zero_grad()
62 | loss.backward()
63 | optimizer.step()
64 | train_JML += loss.item()
65 | num_step += 1
66 |
67 | if (step + 1) % args.plot_interval == 0:
68 | train_JML /= num_step
69 |
70 | print(("%06d " + "%1.4e ") %
71 | (step + 1, train_JML))
72 | his = np.vstack([his, [train_JML]])
73 |
74 | zs = flow.ginv(x)[0].detach()
75 | xs = flow.sample(200).detach()
76 | x1 = torch.linspace(-1.2, 2.1, 100)
77 | x2 = torch.linspace(-1.2, 2.1, 100)
78 | xg = torch.meshgrid(x1, x2)
79 | xx = torch.cat((xg[0].reshape(-1, 1), xg[1].reshape(-1, 1)), 1)
80 | log_px = flow.log_prob(xx).detach()
81 | train_JML = 0.0
82 | num_step = 0
83 |
84 | plt.Figure()
85 | plt.rcParams.update({'font.size': 16, "text.usetex": True})
86 |
87 | plt.subplot(1,3,1)
88 | plt.plot(xs[:, 0], xs[:, 1], "bs")
89 | plt.axis((-1.2, 2.1, -1.2, 2.1))
90 | plt.xticks((-1.2, 2.1))
91 | plt.yticks((-1.2, 2.1))
92 | plt.xlabel("$\mathbf{x}_1$", labelpad=-20)
93 | plt.ylabel("$\mathbf{x}_2$", labelpad=-30)
94 | plt.title("$g_{\\theta}(\mathcal{Z})$")
95 |
96 | plt.subplot(1,3,2)
97 | plt.plot(zs[:, 0], zs[:, 1], "or")
98 | plt.axis((-3.5, 3.5, -3.5, 3.5))
99 | plt.xticks((-3.5, 3.5))
100 | plt.yticks((-3.5, 3.5))
101 | plt.xlabel("$\mathbf{z}_1$", labelpad=-20)
102 | plt.ylabel("$\mathbf{z}_2$", labelpad=-30)
103 | plt.title("$g^{-1}_{\\theta}(\mathcal{Z})$")
104 |
105 | plt.subplot(1,3,3)
106 | img = log_px-torch.min(log_px)
107 | img/=torch.max(img)
108 | plt.imshow(img.reshape(len(x1), len(x2)), extent=(-1.2, 2.1, -1.2, 2.1))
109 | plt.axis((-1.2, 2.1, -1.2, 2.1))
110 | plt.xticks((-1.2, 2.1))
111 | plt.yticks((-1.2, 2.1))
112 | plt.xlabel("$\mathbf{x}_1$", labelpad=-20)
113 | plt.ylabel("$\mathbf{x}_2$", labelpad=-30)
114 | plt.title("$p_{\\theta}(\mathbf{x}), step=%d$" % (step+1))
115 |
116 | plt.margins(0, 0)
117 | if args.out_file is not None:
118 | plt.savefig("%s-step-%d.png" % (args.out_file,step+1), bbox_inches='tight', pad_inches=0)
119 | plt.show()
120 |
121 |
122 | if args.out_file is not None:
123 | torch.save(flow.state_dict(), ("%s.pt") % (args.out_file))
124 | from scipy.io import savemat
125 | savemat(("%s.mat") % (args.out_file), {"his":his})
--------------------------------------------------------------------------------
/trainDCGANmnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import argparse
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 |
7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8 |
9 | parser = argparse.ArgumentParser('DCGAN')
10 | parser.add_argument("--batch_size" , type=int, default=64, help="batch size")
11 | parser.add_argument("--q" , type=int, default=2, help="latent space dimension")
12 | parser.add_argument("--width_disc" , type=int, default=32, help="width of discriminator")
13 | parser.add_argument("--width_dec" , type=int, default=32, help="width of decoder")
14 | parser.add_argument("--num_steps" , type=int, default=50, help="number of training steps")
15 | parser.add_argument("--plot_interval" , type=int, default=5, help="plot solution every so many steps")
16 | parser.add_argument("--init_g", type=str, default=None, help="path to .pt file that contains weights of a trained generator")
17 | parser.add_argument("--out_file", type=str, default=None, help="base filename saving trained model (extension .pt), history (extension .mat), and intermediate plots (extension .png")
18 |
19 | args = parser.parse_args()
20 |
21 | import torchvision.transforms as transforms
22 | from torch.utils.data import DataLoader
23 | from torchvision.datasets import MNIST
24 |
25 | img_transform = transforms.Compose([
26 | transforms.ToTensor()
27 | ])
28 |
29 | train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
30 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
31 |
32 | from modelMNIST import Generator, Discriminator
33 | g = Generator(args.width_dec,args.q).to(device)
34 | d = Discriminator(args.width_disc,useSigmoid=True).to(device)
35 |
36 | optimizer_g = torch.optim.Adam(params=g.parameters(), lr=0.0002, betas=(0.5, 0.999))
37 | optimizer_d = torch.optim.Adam(params=d.parameters(), lr=0.0002, betas=(0.5, 0.999))
38 |
39 | his = np.zeros((0,3))
40 |
41 | if args.init_g is not None:
42 | print("initialize g with weights in %s" % args.init_g)
43 | g.load_state_dict(torch.load(args.init_g))
44 |
45 | print((3*"--" + "device=%s, q=%d, batch_size=%d, num_steps=%d, w_disc=%d, w_dec=%d," + 3*"--") % (device, args.q, args.batch_size, args.num_steps, args.width_disc, args.width_dec))
46 | if args.out_file is not None:
47 | import os
48 | out_dir, fname = os.path.split(args.out_file)
49 | if not os.path.exists(out_dir):
50 | os.makedirs(out_dir)
51 | print((3*"--" + "out_file: %s" + 3*"--") % (args.out_file))
52 | print((4*"%7s ") % ("step","J_GAN","J_Gen","ProbDist"))
53 |
54 | from epsTest import epsTest
55 |
56 | train_JGAN = 0.0
57 | train_JGen = 0.0
58 | train_epsTest = 0.0
59 | num_ex = 0
60 |
61 | def inf_train_gen():
62 | while True:
63 | for images, targets in enumerate(train_dataloader):
64 | yield images,targets
65 |
66 | get_true_images = inf_train_gen()
67 |
68 | for step in range(args.num_steps):
69 | g.train()
70 | d.train()
71 | # update discriminator using - J_GAN = - E_x [log(d(x)] - E_z[1-log(d(g(z))]
72 | x = get_true_images.__next__()[1][0]
73 | x = x.to(device)
74 | dx = d(x)
75 | z = torch.randn((x.shape[0],args.q),device=device)
76 | gz = g(z)
77 | dgz = d(gz)
78 | J_GAN = - torch.mean(torch.log(dx)) - torch.mean(torch.log(1-dgz))
79 | optimizer_d.zero_grad()
80 | J_GAN.backward()
81 | optimizer_d.step()
82 |
83 | # update the generator using J_Gen = - E_z[log(d(g(z))]
84 | optimizer_g.zero_grad()
85 | z = torch.randn((x.shape[0], args.q), device=device)
86 | gz = g(z)
87 | dgz = d(gz)
88 | # J_Gen = -torch.mean(torch.log(dgz))
89 | J_Gen = torch.mean(torch.log(1-dgz))
90 | J_Gen.backward()
91 | optimizer_g.step()
92 |
93 | # update history
94 | train_JGAN -= J_GAN.item()*x.shape[0]
95 | train_JGen += J_Gen.item()*x.shape[0]
96 | train_epsTest += epsTest(gz.detach(), x)
97 | num_ex += x.shape[0]
98 |
99 | if (step + 1) % args.plot_interval == 0:
100 | train_JGAN /= num_ex
101 | train_JGen /= num_ex
102 |
103 |
104 | print(("%06d " + 3*"%1.4e ") %
105 | (step + 1, train_JGAN, train_JGen, train_epsTest))
106 | his = np.vstack([his, [train_JGAN, train_JGen, train_epsTest]])
107 |
108 | plt.Figure()
109 | img = gz.detach().cpu()
110 | img -= torch.min(img)
111 | img /=torch.max(img)
112 | plt.imshow(torchvision.utils.make_grid(img, 16, 5,pad_value=1.0).permute((1, 2, 0)))
113 | plt.title("trainDCGANmnist: step=%d" % (step + 1))
114 | if args.out_file is not None:
115 | plt.savefig(("%s-step-%d.png") % (args.out_file, step + 1))
116 | plt.show()
117 |
118 | train_JGAN = 0.0
119 | train_JGen = 0.0
120 | train_epsTest = 0.0
121 | num_ex = 0
122 |
123 | if args.out_file is not None:
124 | torch.save(g.state_dict(), ("%s-g.pt") % (args.out_file))
125 | torch.save(d.state_dict(), ("%s-d.pt") % (args.out_file))
126 |
127 | from scipy.io import savemat
128 | savemat(("%s.mat") % (args.out_file), {"his":his})
129 |
130 | plt.Figure()
131 | plt.subplot(1,2,1)
132 | plt.plot(his[:,0:2])
133 | plt.legend(("JGAN","JGen"))
134 | plt.title("GAN Objectives")
135 | plt.subplot(1,2,2)
136 | plt.plot(his[:,2])
137 | plt.title("epsTest")
138 | if args.out_file is not None:
139 | plt.savefig(("%s-his.png") % (args.out_file))
140 | plt.show()
--------------------------------------------------------------------------------
/trainWGANmnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import argparse
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 |
7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8 |
9 | ## load MNIST
10 | parser = argparse.ArgumentParser('WGAN')
11 | parser.add_argument("--batch_size" , type=int, default=64, help="batch size")
12 | parser.add_argument("--q" , type=int, default=2, help="latent space dimension")
13 | parser.add_argument("--width_disc" , type=int, default=32, help="width of discriminator")
14 | parser.add_argument("--width_dec" , type=int, default=32, help="width of decoder")
15 | parser.add_argument("--clip_limit" , type=float, default=1e-2, help="limit for weights of discriminator")
16 | parser.add_argument("--iter_disc" , type=int, default=5, help="number of iterations for discriminator")
17 | parser.add_argument("--num_steps" , type=int, default=50, help="number of training steps")
18 | parser.add_argument("--plot_interval" , type=int, default=5, help="plot solution every so many steps")
19 | parser.add_argument("--init_g", type=str, default=None, help="path to .pt file that contains weights of a trained generator")
20 | parser.add_argument("--out_file", type=str, default=None, help="base filename saving trained model (extension .pt), history (extension .mat), and intermediate plots (extension .png")
21 |
22 | args = parser.parse_args()
23 |
24 | import torchvision.transforms as transforms
25 | from torch.utils.data import DataLoader
26 | from torchvision.datasets import MNIST
27 |
28 | img_transform = transforms.Compose([
29 | transforms.ToTensor()
30 | ])
31 |
32 | train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
33 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
34 |
35 | from modelMNIST import Generator, Discriminator
36 | g = Generator(args.width_dec,args.q).to(device)
37 | d = Discriminator(args.width_disc,useSigmoid=False).to(device)
38 |
39 | if args.init_g is not None:
40 | print("initialize g with weights in %s" % args.init_g)
41 | g.load_state_dict(torch.load(args.init_g))
42 |
43 | optimizer_g = torch.optim.RMSprop(g.parameters(), lr=0.00005)
44 | optimizer_d = torch.optim.RMSprop(d.parameters(), lr=0.00005)
45 | his = np.zeros((0,3))
46 |
47 | print((3*"--" + "device=%s, q=%d, batch_size=%d, num_steps=%d, w_disc=%d, w_dec=%d" + 3*"--") % (device, args.q, args.batch_size, args.num_steps, args.width_disc, args.width_dec))
48 | if args.out_file is not None:
49 | import os
50 | out_dir, fname = os.path.split(args.out_file)
51 | if not os.path.exists(out_dir):
52 | os.makedirs(out_dir)
53 | print((3*"--" + "out_file: %s" + 3*"--") % (args.out_file))
54 |
55 | print((4*"%7s ") % ("step","J_GAN","J_Gen","ProbDist"))
56 |
57 |
58 | from epsTest import epsTest
59 |
60 | train_JGAN = 0.0
61 | train_JGen = 0.0
62 | train_epsTest = 0.0
63 | num_ex = 0
64 |
65 | def inf_train_gen():
66 | while True:
67 | for images, targets in enumerate(train_dataloader):
68 | yield images,targets
69 |
70 | get_true_images = inf_train_gen()
71 |
72 |
73 | for step in range(args.num_steps):
74 | g.train()
75 | d.train()
76 | # update discriminator using ascent on J_GAN = E_x [d(x)] - E_z[d(g(z))]
77 | for iter_critic in range(args.iter_disc):
78 | x = get_true_images.__next__()[1][0]
79 | x = x.to(device)
80 |
81 | for p in d.parameters():
82 | p.data.clamp_(-0.01, 0.01)
83 |
84 | dx = d(x)
85 | z = torch.randn((x.shape[0],args.q),device=device)
86 | gz = g(z)
87 | dgz = d(gz)
88 | J_GAN = -(torch.mean(dx) - torch.mean(dgz))
89 | optimizer_d.zero_grad()
90 | J_GAN.backward()
91 | optimizer_d.step()
92 | train_JGAN -= J_GAN.item() * x.shape[0]
93 |
94 | # update the generator using descent on J_Gen = - E_z[d(g(z))]
95 | optimizer_g.zero_grad()
96 | z = torch.randn((x.shape[0], args.q), device=device)
97 | gz = g(z)
98 | dgz = d(gz)
99 | J_Gen = -torch.mean(dgz)
100 | J_Gen.backward()
101 | optimizer_g.step()
102 |
103 | # update history
104 | train_JGen += J_Gen.item()*x.shape[0]
105 | train_epsTest += epsTest(gz.detach(),x)
106 |
107 | num_ex += x.shape[0]
108 |
109 | if (step+1) % args.plot_interval==0:
110 | train_JGAN /= args.iter_disc * num_ex
111 | train_JGen /= num_ex
112 |
113 | print(("%06d " + 3 * "%1.4e ") %
114 | (step + 1, train_JGAN, train_JGen, train_epsTest))
115 | his = np.vstack([his, [train_JGAN, train_JGen, train_epsTest]])
116 |
117 | plt.Figure()
118 | img = gz.detach().cpu()
119 | img -= torch.min(img)
120 | img /= torch.max(img)
121 | plt.imshow(torchvision.utils.make_grid(img, 8, 5).permute((1, 2, 0)))
122 | plt.title("trainWGANmnist: step=%d" % (step+1))
123 | if args.out_file is not None:
124 | plt.savefig(("%s-step-%d.png") % (args.out_file,step+1))
125 | plt.show()
126 |
127 | train_JGAN = 0.0
128 | train_JGen = 0.0
129 | train_epsTest = 0.0
130 |
131 | num_ex = 0
132 |
133 | if args.out_file is not None:
134 | torch.save(g.state_dict(), ("%s-g.pt") % (args.out_file))
135 | torch.save(d.state_dict(), ("%s-d.pt") % (args.out_file))
136 |
137 | from scipy.io import savemat
138 | savemat(("%s.mat") % (args.out_file), {"his":his})
139 |
140 | plt.Figure()
141 | plt.subplot(1,2,1)
142 | plt.plot(his[:,0:2])
143 | plt.legend(("JGAN","JGen"))
144 | plt.title("GAN Objectives")
145 | plt.subplot(1,2,2)
146 | plt.plot(his[:,2])
147 | plt.title("epsTest")
148 | if args.out_file is not None:
149 | plt.savefig(("%s-his.png") % (args.out_file))
150 | plt.show()
--------------------------------------------------------------------------------
/OTFlow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.functional import pad
4 | from Phi import *
5 | from torch import distributions
6 |
7 |
8 | class OTFlow(nn.Module):
9 | """
10 | OT-Flow for density estimation and sampling as described in
11 |
12 | @article{onken2020otflow,
13 | title={OT-Flow: Fast and Accurate Continuous Normalizing Flows via Optimal Transport},
14 | author={Derek Onken and Samy Wu Fung and Xingjian Li and Lars Ruthotto},
15 | year={2020},
16 | journal = {arXiv preprint arXiv:2006.00104},
17 | }
18 |
19 | """
20 | def __init__(self, net, nt, alph, prior, T=1.0):
21 | """
22 | Initialize OT-Flow
23 |
24 | :param net: network for value function
25 | :param nt: number of rk4 steps
26 | :param alph: penalty parameters
27 | :param prior: latent distribution, e.g., distributions.MultivariateNormal(torch.zeros(d), torch.eye(d))
28 | """
29 | super(OTFlow, self).__init__()
30 | self.prior = prior
31 | self.nt = nt
32 | self.T = T
33 | self.net = net
34 | self.alph = alph
35 |
36 |
37 | def g(self, z, nt = None, storeAll=False):
38 | """
39 | :param z: latent variable
40 | :return: g(z) and hidden states
41 | """
42 | return self.integrate(z,[self.T, 0.0], nt,storeAll)
43 |
44 | def ginv(self, x, nt=None, storeAll=False):
45 | """
46 | :param x: sample from dataset
47 | :return: g^(-1)(x), value of log-determinant, and hidden layers
48 | """
49 |
50 | return self.integrate(x,[0.0, self.T], nt,storeAll)
51 |
52 | def log_prob(self, x, nt=None):
53 | """
54 | Compute log-probability of a sample using change of variable formula
55 |
56 | :param x: sample from dataset
57 | :return: logp_{\theta}(x)
58 | """
59 | z, _, log_det_ginv, v, r = self.ginv(x,nt)
60 | return self.prior.log_prob(z) + log_det_ginv, v, r
61 |
62 | def sample(self, s,nt=None):
63 | """
64 | Draw random samples from p_{\theta}
65 |
66 | :param s: number of samples to draw
67 | :return:
68 | """
69 | z = self.prior.sample((s, 1)).squeeze(1)
70 | x, _, _, _, _ = self.g(z,nt)
71 | return x
72 |
73 | def f(self,x, t):
74 | """
75 | neural ODE combining the characteristics and log-determinant (see Eq. (2)), the transport costs (see Eq. (5)), and
76 | the HJB regularizer (see Eq. (7)).
77 |
78 | d_t [x ; l ; v ; r] = odefun( [x ; l ; v ; r] , t )
79 |
80 | x - particle position
81 | l - log determinant
82 | v - accumulated transport costs (Lagrangian)
83 | r - accumulates violation of HJB condition along trajectory
84 | """
85 | nex, d = x.shape
86 | z = pad(x[:, :d], (0, 1, 0, 0), value=t)
87 | gradPhi, trH = self.net.trHess(z)
88 |
89 | dx = -(1.0 / self.alph[0]) * gradPhi[:, 0:d]
90 | dl = -(1.0 / self.alph[0]) * trH
91 | dv = 0.5 * torch.sum(torch.pow(dx, 2), 1)
92 | dr = torch.abs(-gradPhi[:, -1] + self.alph[0] * dv)
93 |
94 | return dx, dl, dv, dr
95 |
96 | def integrate(self, y, tspan, nt=None,storeAll=False):
97 | """
98 | RK4 time-stepping to integrate the neural ODE
99 |
100 | :param y: initial state
101 | :param tspan: time interval (can go backward in time)
102 | :param nt: number of time steps (default is self.nt)
103 | :return: y (final state), ys (all states), l (log determinant), v (transport costs), r (HJB penalty)
104 | """
105 | if nt is None:
106 | nt = self.nt
107 |
108 | nex, d = y.shape
109 | h = (tspan[1] - tspan[0])/ nt
110 | tk = tspan[0]
111 |
112 | l = torch.zeros((nex), device=y.device, dtype=y.dtype)
113 | v = torch.zeros((nex), device=y.device, dtype=y.dtype)
114 | r = torch.zeros((nex), device=y.device, dtype=y.dtype)
115 | if storeAll:
116 | ys = [torch.clone(y).detach().cpu()]
117 | else:
118 | ys = None
119 |
120 | w = [(h/6.0),2.0*(h/6.0),2.0*(h/6.0),1.0*(h/6.0)]
121 | for i in range(nt):
122 | y0 = y
123 |
124 | dy, dl, dv, dr = self.f(y0, tk)
125 | y = y0 + w[0] * dy
126 | l += w[0] * dl
127 | v += w[0] * dv
128 | r += w[0] * dr
129 |
130 | dy, dl, dv, dr = self.f(y0 + 0.5 * h * dy, tk + (h / 2))
131 | y += w[1] * dy
132 | l += w[1] * dl
133 | v += w[1] * dv
134 | r += w[1] * dr
135 |
136 | dy, dl, dv, dr = self.f(y0 + 0.5 * h * dy, tk + (h / 2))
137 | y += w[2] * dy
138 | l += w[2] * dl
139 | v += w[2] * dv
140 | r += w[2] * dr
141 |
142 | dy, dl, dv, dr = self.f(y0 + h * dy, tk + h)
143 | y += w[3] * dy
144 | l += w[3] * dl
145 | v += w[3] * dv
146 | r += w[3] * dr
147 |
148 | if storeAll:
149 | ys.append(torch.clone(y).detach().cpu())
150 | tk +=h
151 |
152 | return y, ys, l, v, r
153 |
154 | if __name__ == "__main__":
155 | # layers and masks
156 |
157 | nt = 16
158 | alph = [1.0, 5.0, 10.0]
159 | prior = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))
160 | net = Phi(nTh=2, m=16, d=2, alph=alph)
161 | T=1.0
162 |
163 | flow = OTFlow(net,nt,alph,prior,T)
164 |
165 | x = flow.sample(200).detach()
166 | # test inverse
167 | xt = flow.ginv(flow.g(x)[0])[0].detach()
168 | print(torch.norm(x - xt) / torch.norm(x))
169 |
170 | # test inverse
171 | xt = flow.g(flow.ginv(x)[0])[0].detach()
172 | print(torch.norm(x - xt) / torch.norm(x))
173 |
174 |
175 |
--------------------------------------------------------------------------------
/Phi.py:
--------------------------------------------------------------------------------
1 | # Phi.py
2 | # neural network to model the potential function
3 | import torch
4 | import torch.nn as nn
5 | import copy
6 | import math
7 |
8 | def antiderivTanh(x): # activation function aka the antiderivative of tanh
9 | return torch.abs(x) + torch.log(1+torch.exp(-2.0*torch.abs(x)))
10 |
11 | def derivTanh(x): # act'' aka the second derivative of the activation function antiderivTanh
12 | return 1 - torch.pow( torch.tanh(x) , 2 )
13 |
14 | class ResNN(nn.Module):
15 | def __init__(self, d, m, nTh=2):
16 | """
17 | ResNet N portion of Phi
18 |
19 | This implementation was first described in:
20 |
21 | @article{onken2020otflow,
22 | title={OT-Flow: Fast and Accurate Continuous Normalizing Flows via Optimal Transport},
23 | author={Derek Onken and Samy Wu Fung and Xingjian Li and Lars Ruthotto},
24 | year={2020},
25 | journal = {arXiv preprint arXiv:2006.00104},
26 | }
27 | :param d: int, dimension of space input (expect inputs to be d+1 for space-time)
28 | :param m: int, hidden dimension
29 | :param nTh: int, number of resNet layers , (number of theta layers)
30 | """
31 | super().__init__()
32 |
33 | if nTh < 2:
34 | print("nTh must be an integer >= 2")
35 | exit(1)
36 |
37 | self.d = d
38 | self.m = m
39 | self.nTh = nTh
40 | self.layers = nn.ModuleList([])
41 | self.layers.append(nn.Linear(d + 1, m, bias=True)) # opening layer
42 | self.layers.append(nn.Linear(m,m, bias=True)) # resnet layers
43 | for i in range(nTh-2):
44 | self.layers.append(copy.deepcopy(self.layers[1]))
45 | self.act = antiderivTanh
46 | self.h = 1.0 / (self.nTh-1) # step size for the ResNet
47 |
48 | def forward(self, x):
49 | """
50 | N(s;theta). the forward propogation of the ResNet
51 | :param x: tensor nex-by-d+1, inputs
52 | :return: tensor nex-by-m, outputs
53 | """
54 |
55 | x = self.act(self.layers[0].forward(x))
56 |
57 | for i in range(1,self.nTh):
58 | x = x + self.h * self.act(self.layers[i](x))
59 |
60 | return x
61 |
62 |
63 |
64 | class Phi(nn.Module):
65 | def __init__(self, nTh, m, d, r=10, alph=[1.0] * 5):
66 | """
67 | neural network approximating Phi (see Eq. (9) in our paper)
68 |
69 | Phi( x,t ) = w'*ResNet( [x;t]) + 0.5*[x' t] * A'A * [x;t] + b'*[x;t] + c
70 |
71 | :param nTh: int, number of resNet layers , (number of theta layers)
72 | :param m: int, hidden dimension
73 | :param d: int, dimension of space input (expect inputs to be d+1 for space-time)
74 | :param r: int, rank r for the A matrix
75 | :param alph: list, alpha values / weighted multipliers for the optimization problem
76 | """
77 | super().__init__()
78 |
79 | self.m = m
80 | self.nTh = nTh
81 | self.d = d
82 | self.alph = alph
83 |
84 | r = min(r,d+1) # if number of dimensions is smaller than default r, use that
85 |
86 | self.A = nn.Parameter(torch.zeros(r, d+1) , requires_grad=True)
87 | self.A = nn.init.xavier_uniform_(self.A)
88 | self.c = nn.Linear( d+1 , 1 , bias=True) # b'*[x;t] + c
89 | self.w = nn.Linear( m , 1 , bias=False)
90 |
91 | self.N = ResNN(d, m, nTh=nTh)
92 |
93 | # set initial values
94 | self.w.weight.data = torch.ones(self.w.weight.data.shape)
95 | self.c.weight.data = torch.zeros(self.c.weight.data.shape)
96 | self.c.bias.data = torch.zeros(self.c.bias.data.shape)
97 |
98 |
99 |
100 | def forward(self, x):
101 | """ calculating Phi(s, theta)...not used in OT-Flow """
102 |
103 | # force A to be symmetric
104 | symA = torch.matmul(torch.t(self.A), self.A) # A'A
105 |
106 | return self.w( self.N(x)) + 0.5 * torch.sum( torch.matmul(x , symA) * x , dim=1, keepdims=True) + self.c(x)
107 |
108 |
109 | def trHess(self,x,d=None, justGrad=False ):
110 | """
111 | compute gradient of Phi wrt x and trace(Hessian of Phi); see Eq. (11) and Eq. (13), respectively
112 | recomputes the forward propogation portions of Phi
113 |
114 | :param x: input data, torch Tensor nex-by-d
115 | :param justGrad: boolean, if True only return gradient, if False return (grad, trHess)
116 | :return: gradient , trace(hessian) OR just gradient
117 | """
118 |
119 | # code in E = eye(d+1,d) as index slicing instead of matrix multiplication
120 | # assumes specific N.act as the antiderivative of tanh
121 |
122 | N = self.N
123 | m = N.layers[0].weight.shape[0]
124 | nex = x.shape[0] # number of examples in the batch
125 | if d is None:
126 | d = x.shape[1]-1
127 | symA = torch.matmul(self.A.t(), self.A)
128 |
129 | u = [] # hold the u_0,u_1,...,u_M for the forward pass
130 | z = N.nTh*[None] # hold the z_0,z_1,...,z_M for the backward pass
131 | # preallocate z because we will store in the backward pass and we want the indices to match the paper
132 |
133 | # Forward of ResNet N and fill u
134 | opening = N.layers[0].forward(x) # K_0 * S + b_0
135 | u.append(N.act(opening)) # u0
136 | feat = u[0]
137 |
138 | for i in range(1,N.nTh):
139 | feat = feat + N.h * N.act(N.layers[i](feat))
140 | u.append(feat)
141 |
142 | # going to be used more than once
143 | tanhopen = torch.tanh(opening) # act'( K_0 * S + b_0 )
144 |
145 | # compute gradient and fill z
146 | for i in range(N.nTh-1,0,-1): # work backwards, placing z_i in appropriate spot
147 | if i == N.nTh-1:
148 | term = self.w.weight.t()
149 | else:
150 | term = z[i+1]
151 |
152 | # z_i = z_{i+1} + h K_i' diag(...) z_{i+1}
153 | z[i] = term + N.h * torch.mm( N.layers[i].weight.t() , torch.tanh( N.layers[i].forward(u[i-1]) ).t() * term)
154 |
155 | # z_0 = K_0' diag(...) z_1
156 | z[0] = torch.mm( N.layers[0].weight.t() , tanhopen.t() * z[1] )
157 | grad = z[0] + torch.mm(symA, x.t() ) + self.c.weight.t()
158 |
159 | if justGrad:
160 | return grad.t()
161 |
162 | # -----------------
163 | # trace of Hessian
164 | #-----------------
165 |
166 | # t_0, the trace of the opening layer
167 | Kopen = N.layers[0].weight[:,0:d] # indexed version of Kopen = torch.mm( N.layers[0].weight, E )
168 | temp = derivTanh(opening.t()) * z[1]
169 | trH = torch.sum(temp.reshape(m, -1, nex) * torch.pow(Kopen.unsqueeze(2), 2), dim=(0, 1)) # trH = t_0
170 |
171 | # grad_s u_0 ^ T
172 | temp = tanhopen.t() # act'( K_0 * S + b_0 )
173 | Jac = Kopen.unsqueeze(2) * temp.unsqueeze(1) # K_0' * act'( K_0 * S + b_0 )
174 | # Jac is shape m by d by nex
175 |
176 | # t_i, trace of the resNet layers
177 | # KJ is the K_i^T * grad_s u_{i-1}^T
178 | for i in range(1,N.nTh):
179 | KJ = torch.mm(N.layers[i].weight , Jac.reshape(m,-1) )
180 | KJ = KJ.reshape(m,-1,nex)
181 | if i == N.nTh-1:
182 | term = self.w.weight.t()
183 | else:
184 | term = z[i+1]
185 |
186 | temp = N.layers[i].forward(u[i-1]).t() # (K_i * u_{i-1} + b_i)
187 | t_i = torch.sum( ( derivTanh(temp) * term ).reshape(m,-1,nex) * torch.pow(KJ,2) , dim=(0, 1) )
188 | trH = trH + N.h * t_i # add t_i to the accumulate trace
189 | if i < N.nTh:
190 | Jac = Jac + N.h * torch.tanh(temp).reshape(m, -1, nex) * KJ # update Jacobian
191 |
192 | return grad.t(), trH + torch.trace(symA[0:d,0:d])
193 | # indexed version of: return grad.t() , trH + torch.trace( torch.mm( E.t() , torch.mm( symA , E) ) )
194 |
195 |
196 |
197 | if __name__ == "__main__":
198 |
199 | import time
200 | import math
201 |
202 | # test case
203 | d = 2
204 | m = 5
205 |
206 | net = Phi(nTh=2, m=m, d=d)
207 | net.N.layers[0].weight.data = 0.1 + 0.0 * net.N.layers[0].weight.data
208 | net.N.layers[0].bias.data = 0.2 + 0.0 * net.N.layers[0].bias.data
209 | net.N.layers[1].weight.data = 0.3 + 0.0 * net.N.layers[1].weight.data
210 | net.N.layers[1].weight.data = 0.3 + 0.0 * net.N.layers[1].weight.data
211 |
212 | # number of samples-by-(d+1)
213 | x = torch.Tensor([[1.0 ,4.0 , 0.5],[2.0,5.0,0.6],[3.0,6.0,0.7],[0.0,0.0,0.0]])
214 | y = net(x)
215 | print(y)
216 |
217 | # test timings
218 | d = 400
219 | m = 32
220 | nex = 1000
221 |
222 | net = Phi(nTh=5, m=m, d=d)
223 | net.eval()
224 | x = torch.randn(nex,d+1)
225 | y = net(x)
226 |
227 | end = time.time()
228 | g,h = net.trHess(x)
229 | print('traceHess takes ', time.time()-end)
230 |
231 | end = time.time()
232 | g = net.trHess(x, justGrad=True)
233 | print('JustGrad takes ', time.time()-end)
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
--------------------------------------------------------------------------------