├── .gitignore
├── Pipfile
├── Pipfile.lock
├── README.md
├── assets
├── mnist_recon.gif
├── mnist_samples.gif
├── omni_recon.gif
└── omni_samples.gif
├── data_loader
├── cifar10.py
├── data_loader.py
├── fixed_mnist.py
├── omniglot.py
└── stoch_mnist.py
├── exp.sh
├── main.py
├── model
├── bernoulli_vae.py
├── conv_vae.py
└── vae_base.py
├── requirements.txt
└── utils
├── config.py
├── draw_figs.py
└── to_sheets.py
/.gitignore:
--------------------------------------------------------------------------------
1 | dataset
2 | result
3 | dfc_exp.sh
4 |
5 |
6 | ###### python ######
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # celery beat schedule file
90 | celerybeat-schedule
91 |
92 | # SageMath parsed files
93 | *.sage.py
94 |
95 | # Environments
96 | .env
97 | .venv
98 | env/
99 | venv/
100 | ENV/
101 | env.bak/
102 | venv.bak/
103 |
104 | # Spyder project settings
105 | .spyderproject
106 | .spyproject
107 |
108 | # Rope project settings
109 | .ropeproject
110 |
111 | # mkdocs documentation
112 | /site
113 |
114 | # mypy
115 | .mypy_cache/
116 | .dmypy.json
117 | dmypy.json
118 |
--------------------------------------------------------------------------------
/Pipfile:
--------------------------------------------------------------------------------
1 | [[source]]
2 | name = "pypi"
3 | url = "https://pypi.org/simple"
4 | verify_ssl = true
5 |
6 | [dev-packages]
7 |
8 | [requires]
9 | python_version = "3.6"
10 |
11 | [packages]
12 | numpy = "*"
13 | "h5py" = "*"
14 | scipy = "*"
15 | matplotlib = "*"
16 | imageio = "*"
17 | pathlib = "*"
18 | gspread = "*"
19 | "oauth2client" = "*"
20 | tensorboardx = "*"
21 | torch = "*"
22 | torchvision = "*"
23 | urllib3 = ">=1.24.2"
24 |
--------------------------------------------------------------------------------
/Pipfile.lock:
--------------------------------------------------------------------------------
1 | {
2 | "_meta": {
3 | "hash": {
4 | "sha256": "de41289fb28f71e8475cba544066bcd5ea9a0353f7fb699cd317ecba70a4f749"
5 | },
6 | "pipfile-spec": 6,
7 | "requires": {
8 | "python_version": "3.6"
9 | },
10 | "sources": [
11 | {
12 | "name": "pypi",
13 | "url": "https://pypi.org/simple",
14 | "verify_ssl": true
15 | }
16 | ]
17 | },
18 | "default": {
19 | "certifi": {
20 | "hashes": [
21 | "sha256:046832c04d4e752f37383b628bc601a7ea7211496b4638f6514d0e5b9acc4939",
22 | "sha256:945e3ba63a0b9f577b1395204e13c3a231f9bc0223888be653286534e5873695"
23 | ],
24 | "version": "==2019.6.16"
25 | },
26 | "chardet": {
27 | "hashes": [
28 | "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae",
29 | "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691"
30 | ],
31 | "version": "==3.0.4"
32 | },
33 | "cycler": {
34 | "hashes": [
35 | "sha256:1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d",
36 | "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8"
37 | ],
38 | "version": "==0.10.0"
39 | },
40 | "gspread": {
41 | "hashes": [
42 | "sha256:dd945e3ae5d3d0325ad9982e0d5667f79ca121d0bb6f35274dc84371bbb79dd5",
43 | "sha256:f7ce6c06250f694976c3cd4944e3b607b0810b93383839e5b67c7199ce2f0d3d"
44 | ],
45 | "index": "pypi",
46 | "version": "==3.1.0"
47 | },
48 | "h5py": {
49 | "hashes": [
50 | "sha256:05750b91640273c69989c657eaac34b091abdd75efc8c4824c82aaf898a2da0a",
51 | "sha256:082a27208aa3a2286e7272e998e7e225b2a7d4b7821bd840aebf96d50977abbb",
52 | "sha256:08e2e8297195f9e813e894b6c63f79372582787795bba2014a2db6a2de95f713",
53 | "sha256:0dd2adeb2e9de5081eb8dcec88874e7fd35dae9a21557be3a55a3c7d491842a4",
54 | "sha256:0f94de7a10562b991967a66bbe6dda9808e18088676834c0a4dcec3fdd3bcc6f",
55 | "sha256:106e42e2e01e486a3d32eeb9ba0e3a7f65c12fa8998d63625fa41fb8bdc44cdb",
56 | "sha256:1606c66015f04719c41a9863c156fc0e6b992150de21c067444bcb82e7d75579",
57 | "sha256:1854c4beff9961e477e133143c5e5e355dac0b3ebf19c52cf7cc1b1ef757703c",
58 | "sha256:1e9fb6f1746500ea91a00193ce2361803c70c6b13f10aae9a33ad7b5bd28e800",
59 | "sha256:2cca17e80ddb151894333377675db90cd0279fa454776e0a4f74308376afd050",
60 | "sha256:30e365e8408759db3778c361f1e4e0fe8e98a875185ae46c795a85e9bafb9cdf",
61 | "sha256:3206bac900e16eda81687d787086f4ffd4f3854980d798e191a9868a6510c3ae",
62 | "sha256:3c23d72058647cee19b30452acc7895621e2de0a0bd5b8a1e34204b9ea9ed43c",
63 | "sha256:407b5f911a83daa285bbf1ef78a9909ee5957f257d3524b8606be37e8643c5f0",
64 | "sha256:4162953714a9212d373ac953c10e3329f1e830d3c7473f2a2e4f25dd6241eef0",
65 | "sha256:5fc7aba72a51b2c80605eba1c50dbf84224dcd206279d30a75c154e5652e1fe4",
66 | "sha256:713ac19307e11de4d9833af0c4bd6778bde0a3d967cafd2f0f347223711c1e31",
67 | "sha256:71b946d80ef3c3f12db157d7778b1fe74a517ca85e94809358b15580983c2ce2",
68 | "sha256:8cc4aed71e20d87e0a6f02094d718a95252f11f8ed143bc112d22167f08d4040",
69 | "sha256:9d41ca62daf36d6b6515ab8765e4c8c4388ee18e2a665701fef2b41563821002",
70 | "sha256:a744e13b000f234cd5a5b2a1f95816b819027c57f385da54ad2b7da1adace2f3",
71 | "sha256:b087ee01396c4b34e9dc41e3a6a0442158206d383c19c7d0396d52067b17c1cb",
72 | "sha256:b0f03af381d33306ce67d18275b61acb4ca111ced645381387a02c8a5ee1b796",
73 | "sha256:b9e4b8dfd587365bdd719ae178fa1b6c1231f81280b1375eef8626dfd8761bf3",
74 | "sha256:c5dd4ec75985b99166c045909e10f0534704d102848b1d9f0992720e908928e7",
75 | "sha256:d2b82f23cd862a9d05108fe99967e9edfa95c136f532a71cb3d28dc252771f50",
76 | "sha256:e58a25764472af07b7e1c4b10b0179c8ea726446c7141076286e41891bf3a563",
77 | "sha256:f3b49107fbfc77333fc2b1ef4d5de2abcd57e7ea3a1482455229494cf2da56ce"
78 | ],
79 | "index": "pypi",
80 | "version": "==2.9.0"
81 | },
82 | "httplib2": {
83 | "hashes": [
84 | "sha256:6901c8c0ffcf721f9ce270ad86da37bc2b4d32b8802d4a9cec38274898a64044",
85 | "sha256:cf6f9d5876d796539ec922a2c9b9a7cad9bfd90f04badcdc3bcfa537168052c3"
86 | ],
87 | "version": "==0.13.1"
88 | },
89 | "idna": {
90 | "hashes": [
91 | "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407",
92 | "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c"
93 | ],
94 | "version": "==2.8"
95 | },
96 | "imageio": {
97 | "hashes": [
98 | "sha256:1a2bbbb7cd38161340fa3b14d806dfbf914abf3ee6fd4592af2afb87d049f209",
99 | "sha256:42e65aadfc3d57a1043615c92bdf6319b67589e49a0aae2b985b82144aceacad"
100 | ],
101 | "index": "pypi",
102 | "version": "==2.5.0"
103 | },
104 | "kiwisolver": {
105 | "hashes": [
106 | "sha256:05b5b061e09f60f56244adc885c4a7867da25ca387376b02c1efc29cc16bcd0f",
107 | "sha256:26f4fbd6f5e1dabff70a9ba0d2c4bd30761086454aa30dddc5b52764ee4852b7",
108 | "sha256:3b2378ad387f49cbb328205bda569b9f87288d6bc1bf4cd683c34523a2341efe",
109 | "sha256:400599c0fe58d21522cae0e8b22318e09d9729451b17ee61ba8e1e7c0346565c",
110 | "sha256:47b8cb81a7d18dbaf4fed6a61c3cecdb5adec7b4ac292bddb0d016d57e8507d5",
111 | "sha256:53eaed412477c836e1b9522c19858a8557d6e595077830146182225613b11a75",
112 | "sha256:58e626e1f7dfbb620d08d457325a4cdac65d1809680009f46bf41eaf74ad0187",
113 | "sha256:5a52e1b006bfa5be04fe4debbcdd2688432a9af4b207a3f429c74ad625022641",
114 | "sha256:5c7ca4e449ac9f99b3b9d4693debb1d6d237d1542dd6a56b3305fe8a9620f883",
115 | "sha256:682e54f0ce8f45981878756d7203fd01e188cc6c8b2c5e2cf03675390b4534d5",
116 | "sha256:79bfb2f0bd7cbf9ea256612c9523367e5ec51d7cd616ae20ca2c90f575d839a2",
117 | "sha256:7f4dd50874177d2bb060d74769210f3bce1af87a8c7cf5b37d032ebf94f0aca3",
118 | "sha256:8944a16020c07b682df861207b7e0efcd2f46c7488619cb55f65882279119389",
119 | "sha256:8aa7009437640beb2768bfd06da049bad0df85f47ff18426261acecd1cf00897",
120 | "sha256:939f36f21a8c571686eb491acfffa9c7f1ac345087281b412d63ea39ca14ec4a",
121 | "sha256:9733b7f64bd9f807832d673355f79703f81f0b3e52bfce420fc00d8cb28c6a6c",
122 | "sha256:a02f6c3e229d0b7220bd74600e9351e18bc0c361b05f29adae0d10599ae0e326",
123 | "sha256:a0c0a9f06872330d0dd31b45607197caab3c22777600e88031bfe66799e70bb0",
124 | "sha256:acc4df99308111585121db217681f1ce0eecb48d3a828a2f9bbf9773f4937e9e",
125 | "sha256:b64916959e4ae0ac78af7c3e8cef4becee0c0e9694ad477b4c6b3a536de6a544",
126 | "sha256:d3fcf0819dc3fea58be1fd1ca390851bdb719a549850e708ed858503ff25d995",
127 | "sha256:d52e3b1868a4e8fd18b5cb15055c76820df514e26aa84cc02f593d99fef6707f",
128 | "sha256:db1a5d3cc4ae943d674718d6c47d2d82488ddd94b93b9e12d24aabdbfe48caee",
129 | "sha256:e3a21a720791712ed721c7b95d433e036134de6f18c77dbe96119eaf7aa08004",
130 | "sha256:e8bf074363ce2babeb4764d94f8e65efd22e6a7c74860a4f05a6947afc020ff2",
131 | "sha256:f16814a4a96dc04bf1da7d53ee8d5b1d6decfc1a92a63349bb15d37b6a263dd9",
132 | "sha256:f2b22153870ca5cf2ab9c940d7bc38e8e9089fa0f7e5856ea195e1cf4ff43d5a",
133 | "sha256:f790f8b3dff3d53453de6a7b7ddd173d2e020fb160baff578d578065b108a05f"
134 | ],
135 | "version": "==1.1.0"
136 | },
137 | "matplotlib": {
138 | "hashes": [
139 | "sha256:1febd22afe1489b13c6749ea059d392c03261b2950d1d45c17e3aed812080c93",
140 | "sha256:31a30d03f39528c79f3a592857be62a08595dec4ac034978ecd0f814fa0eec2d",
141 | "sha256:4442ce720907f67a79d45de9ada47be81ce17e6c2f448b3c64765af93f6829c9",
142 | "sha256:796edbd1182cbffa7e1e7a97f1e141f875a8501ba8dd834269ae3cd45a8c976f",
143 | "sha256:934e6243df7165aad097572abf5b6003c77c9b6c480c3c4de6f2ef1b5fdd4ec0",
144 | "sha256:bab9d848dbf1517bc58d1f486772e99919b19efef5dd8596d4b26f9f5ee08b6b",
145 | "sha256:c1fe1e6cdaa53f11f088b7470c2056c0df7d80ee4858dadf6cbe433fcba4323b",
146 | "sha256:e5b8aeca9276a3a988caebe9f08366ed519fff98f77c6df5b64d7603d0e42e36",
147 | "sha256:ec6bd0a6a58df3628ff269978f4a4b924a0d371ad8ce1f8e2b635b99e482877a"
148 | ],
149 | "index": "pypi",
150 | "version": "==3.1.1"
151 | },
152 | "numpy": {
153 | "hashes": [
154 | "sha256:03e311b0a4c9f5755da7d52161280c6a78406c7be5c5cc7facfbcebb641efb7e",
155 | "sha256:0cdd229a53d2720d21175012ab0599665f8c9588b3b8ffa6095dd7b90f0691dd",
156 | "sha256:312bb18e95218bedc3563f26fcc9c1c6bfaaf9d453d15942c0839acdd7e4c473",
157 | "sha256:464b1c48baf49e8505b1bb754c47a013d2c305c5b14269b5c85ea0625b6a988a",
158 | "sha256:5adfde7bd3ee4864536e230bcab1c673f866736698724d5d28c11a4d63672658",
159 | "sha256:7724e9e31ee72389d522b88c0d4201f24edc34277999701ccd4a5392e7d8af61",
160 | "sha256:8d36f7c53ae741e23f54793ffefb2912340b800476eb0a831c6eb602e204c5c4",
161 | "sha256:910d2272403c2ea8a52d9159827dc9f7c27fb4b263749dca884e2e4a8af3b302",
162 | "sha256:951fefe2fb73f84c620bec4e001e80a80ddaa1b84dce244ded7f1e0cbe0ed34a",
163 | "sha256:9588c6b4157f493edeb9378788dcd02cb9e6a6aeaa518b511a1c79d06cbd8094",
164 | "sha256:9ce8300950f2f1d29d0e49c28ebfff0d2f1e2a7444830fbb0b913c7c08f31511",
165 | "sha256:be39cca66cc6806652da97103605c7b65ee4442c638f04ff064a7efd9a81d50a",
166 | "sha256:c3ab2d835b95ccb59d11dfcd56eb0480daea57cdf95d686d22eff35584bc4554",
167 | "sha256:eb0fc4a492cb896346c9e2c7a22eae3e766d407df3eb20f4ce027f23f76e4c54",
168 | "sha256:ec0c56eae6cee6299f41e780a0280318a93db519bbb2906103c43f3e2be1206c",
169 | "sha256:f4e4612de60a4f1c4d06c8c2857cdcb2b8b5289189a12053f37d3f41f06c60d0"
170 | ],
171 | "index": "pypi",
172 | "version": "==1.17.0"
173 | },
174 | "oauth2client": {
175 | "hashes": [
176 | "sha256:b8a81cc5d60e2d364f0b1b98f958dbd472887acaf1a5b05e21c28c31a2d6d3ac",
177 | "sha256:d486741e451287f69568a4d26d70d9acd73a2bbfa275746c535b4209891cccc6"
178 | ],
179 | "index": "pypi",
180 | "version": "==4.1.3"
181 | },
182 | "pathlib": {
183 | "hashes": [
184 | "sha256:6940718dfc3eff4258203ad5021090933e5c04707d5ca8cc9e73c94a7894ea9f"
185 | ],
186 | "index": "pypi",
187 | "version": "==1.0.1"
188 | },
189 | "pillow": {
190 | "hashes": [
191 | "sha256:0804f77cb1e9b6dbd37601cee11283bba39a8d44b9ddb053400c58e0c0d7d9de",
192 | "sha256:0ab7c5b5d04691bcbd570658667dd1e21ca311c62dcfd315ad2255b1cd37f64f",
193 | "sha256:0b3e6cf3ea1f8cecd625f1420b931c83ce74f00c29a0ff1ce4385f99900ac7c4",
194 | "sha256:365c06a45712cd723ec16fa4ceb32ce46ad201eb7bbf6d3c16b063c72b61a3ed",
195 | "sha256:38301fbc0af865baa4752ddae1bb3cbb24b3d8f221bf2850aad96b243306fa03",
196 | "sha256:3aef1af1a91798536bbab35d70d35750bd2884f0832c88aeb2499aa2d1ed4992",
197 | "sha256:3fe0ab49537d9330c9bba7f16a5f8b02da615b5c809cdf7124f356a0f182eccd",
198 | "sha256:45a619d5c1915957449264c81c008934452e3fd3604e36809212300b2a4dab68",
199 | "sha256:49f90f147883a0c3778fd29d3eb169d56416f25758d0f66775db9184debc8010",
200 | "sha256:571b5a758baf1cb6a04233fb23d6cf1ca60b31f9f641b1700bfaab1194020555",
201 | "sha256:5ac381e8b1259925287ccc5a87d9cf6322a2dc88ae28a97fe3e196385288413f",
202 | "sha256:6153db744a743c0c8c91b8e3b9d40e0b13a5d31dbf8a12748c6d9bfd3ddc01ad",
203 | "sha256:6fd63afd14a16f5d6b408f623cc2142917a1f92855f0df997e09a49f0341be8a",
204 | "sha256:70acbcaba2a638923c2d337e0edea210505708d7859b87c2bd81e8f9902ae826",
205 | "sha256:70b1594d56ed32d56ed21a7fbb2a5c6fd7446cdb7b21e749c9791eac3a64d9e4",
206 | "sha256:76638865c83b1bb33bcac2a61ce4d13c17dba2204969dedb9ab60ef62bede686",
207 | "sha256:7b2ec162c87fc496aa568258ac88631a2ce0acfe681a9af40842fc55deaedc99",
208 | "sha256:7cee2cef07c8d76894ebefc54e4bb707dfc7f258ad155bd61d87f6cd487a70ff",
209 | "sha256:7d16d4498f8b374fc625c4037742fbdd7f9ac383fd50b06f4df00c81ef60e829",
210 | "sha256:b50bc1780681b127e28f0075dfb81d6135c3a293e0c1d0211133c75e2179b6c0",
211 | "sha256:bd0582f831ad5bcad6ca001deba4568573a4675437db17c4031939156ff339fa",
212 | "sha256:cfd40d8a4b59f7567620410f966bb1f32dc555b2b19f82a91b147fac296f645c",
213 | "sha256:e3ae410089de680e8f84c68b755b42bc42c0ceb8c03dbea88a5099747091d38e",
214 | "sha256:e9046e559c299b395b39ac7dbf16005308821c2f24a63cae2ab173bd6aa11616",
215 | "sha256:ef6be704ae2bc8ad0ebc5cb850ee9139493b0fc4e81abcc240fb392a63ebc808",
216 | "sha256:f8dc19d92896558f9c4317ee365729ead9d7bbcf2052a9a19a3ef17abbb8ac5b"
217 | ],
218 | "version": "==6.1.0"
219 | },
220 | "protobuf": {
221 | "hashes": [
222 | "sha256:00a1b0b352dc7c809749526d1688a64b62ea400c5b05416f93cfb1b11a036295",
223 | "sha256:01acbca2d2c8c3f7f235f1842440adbe01bbc379fa1cbdd80753801432b3fae9",
224 | "sha256:0a795bca65987b62d6b8a2d934aa317fd1a4d06a6dd4df36312f5b0ade44a8d9",
225 | "sha256:0ec035114213b6d6e7713987a759d762dd94e9f82284515b3b7331f34bfaec7f",
226 | "sha256:31b18e1434b4907cb0113e7a372cd4d92c047ce7ba0fa7ea66a404d6388ed2c1",
227 | "sha256:32a3abf79b0bef073c70656e86d5bd68a28a1fbb138429912c4fc07b9d426b07",
228 | "sha256:55f85b7808766e5e3f526818f5e2aeb5ba2edcc45bcccede46a3ccc19b569cb0",
229 | "sha256:64ab9bc971989cbdd648c102a96253fdf0202b0c38f15bd34759a8707bdd5f64",
230 | "sha256:64cf847e843a465b6c1ba90fb6c7f7844d54dbe9eb731e86a60981d03f5b2e6e",
231 | "sha256:917c8662b585470e8fd42f052661fc66d59fccaae450a60044307dcbf82a3335",
232 | "sha256:afed9003d7f2be2c3df20f64220c30faec441073731511728a2cb4cab4cd46a6",
233 | "sha256:bf8e05d638b585d1752c5a84247134a0350d3a8b73d3632489a014a9f6f1e758",
234 | "sha256:d831b047bd69becaf64019a47179eb22118a50dd008340655266a906c69c6417",
235 | "sha256:de2760583ed28749ff885789c1cbc6c9c06d6de92fc825740ab99deb2f25ea4d",
236 | "sha256:eabc4cf1bc19689af8022ba52fd668564a8d96e0d08f3b4732d26a64255216a4",
237 | "sha256:fcff6086c86fb1628d94ea455c7b9de898afc50378042927a59df8065a79a549"
238 | ],
239 | "version": "==3.9.1"
240 | },
241 | "pyasn1": {
242 | "hashes": [
243 | "sha256:3bb81821d47b17146049e7574ab4bf1e315eb7aead30efe5d6a9ca422c9710be",
244 | "sha256:b773d5c9196ffbc3a1e13bdf909d446cad80a039aa3340bcad72f395b76ebc86"
245 | ],
246 | "version": "==0.4.6"
247 | },
248 | "pyasn1-modules": {
249 | "hashes": [
250 | "sha256:43c17a83c155229839cc5c6b868e8d0c6041dba149789b6d6e28801c64821722",
251 | "sha256:e30199a9d221f1b26c885ff3d87fd08694dbbe18ed0e8e405a2a7126d30ce4c0"
252 | ],
253 | "version": "==0.2.6"
254 | },
255 | "pyparsing": {
256 | "hashes": [
257 | "sha256:6f98a7b9397e206d78cc01df10131398f1c8b8510a2f4d97d9abd82e1aacdd80",
258 | "sha256:d9338df12903bbf5d65a0e4e87c2161968b10d2e489652bb47001d82a9b028b4"
259 | ],
260 | "version": "==2.4.2"
261 | },
262 | "python-dateutil": {
263 | "hashes": [
264 | "sha256:7e6584c74aeed623791615e26efd690f29817a27c73085b78e4bad02493df2fb",
265 | "sha256:c89805f6f4d64db21ed966fda138f8a5ed7a4fdbc1a8ee329ce1b74e3c74da9e"
266 | ],
267 | "version": "==2.8.0"
268 | },
269 | "requests": {
270 | "hashes": [
271 | "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4",
272 | "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31"
273 | ],
274 | "version": "==2.22.0"
275 | },
276 | "rsa": {
277 | "hashes": [
278 | "sha256:14ba45700ff1ec9eeb206a2ce76b32814958a98e372006c8fb76ba820211be66",
279 | "sha256:1a836406405730121ae9823e19c6e806c62bbad73f890574fff50efa4122c487"
280 | ],
281 | "version": "==4.0"
282 | },
283 | "scipy": {
284 | "hashes": [
285 | "sha256:0baa64bf42592032f6f6445a07144e355ca876b177f47ad8d0612901c9375bef",
286 | "sha256:243b04730d7223d2b844bda9500310eecc9eda0cba9ceaf0cde1839f8287dfa8",
287 | "sha256:2643cfb46d97b7797d1dbdb6f3c23fe3402904e3c90e6facfe6a9b98d808c1b5",
288 | "sha256:396eb4cdad421f846a1498299474f0a3752921229388f91f60dc3eda55a00488",
289 | "sha256:3ae3692616975d3c10aca6d574d6b4ff95568768d4525f76222fb60f142075b9",
290 | "sha256:435d19f80b4dcf67dc090cc04fde2c5c8a70b3372e64f6a9c58c5b806abfa5a8",
291 | "sha256:46a5e55850cfe02332998b3aef481d33f1efee1960fe6cfee0202c7dd6fc21ab",
292 | "sha256:75b513c462e58eeca82b22fc00f0d1875a37b12913eee9d979233349fce5c8b2",
293 | "sha256:7ccfa44a08226825126c4ef0027aa46a38c928a10f0a8a8483c80dd9f9a0ad44",
294 | "sha256:89dd6a6d329e3f693d1204d5562dd63af0fd7a17854ced17f9cbc37d5b853c8d",
295 | "sha256:a81da2fe32f4eab8b60d56ad43e44d93d392da228a77e229e59b51508a00299c",
296 | "sha256:a9d606d11eb2eec7ef893eb825017fbb6eef1e1d0b98a5b7fc11446ebeb2b9b1",
297 | "sha256:ac37eb652248e2d7cbbfd89619dce5ecfd27d657e714ed049d82f19b162e8d45",
298 | "sha256:cbc0611699e420774e945f6a4e2830f7ca2b3ee3483fca1aa659100049487dd5",
299 | "sha256:d02d813ec9958ed63b390ded463163685af6025cb2e9a226ec2c477df90c6957",
300 | "sha256:dd3b52e00f93fd1c86f2d78243dfb0d02743c94dd1d34ffea10055438e63b99d"
301 | ],
302 | "index": "pypi",
303 | "version": "==1.3.1"
304 | },
305 | "six": {
306 | "hashes": [
307 | "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c",
308 | "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
309 | ],
310 | "version": "==1.12.0"
311 | },
312 | "tensorboardx": {
313 | "hashes": [
314 | "sha256:13fe0abba27f407778a7321937190eedaf12bc8c544d9a4e294fcf0ba177fd76",
315 | "sha256:f52e59b38b4cdf83384f3fce067bcaf2d2847619f9f533394df0de3b5a71ab8e"
316 | ],
317 | "index": "pypi",
318 | "version": "==1.8"
319 | },
320 | "torch": {
321 | "hashes": [
322 | "sha256:0698d0a48014b9b8f36d93e69901eca2e7ec712cd2033908f7a77e7d86a4f0d7",
323 | "sha256:2ac8e58b069232f079bd289aa160366a9367ae1a4616a2c1007dceed19ff9bfa",
324 | "sha256:43a0e28c448ddeea65fb9e956bc743389592afac824095bdbc08e8a87364c639",
325 | "sha256:661ad06b4616663149bd504e8c0271196d0386712e21a92619d95ba88138794a",
326 | "sha256:880a0c22692eaebbce808a5bf2255ab7d345ab43c40795be0a421c6250ba0fb4",
327 | "sha256:a13bf6f78a49d844b85c142b8cd62d2e1833a11ed21ea0bc6b1ac73d24c76415",
328 | "sha256:a8c21f82fd03b67927078ea917040478c3263753fe1906fc19d0f5f0c7d9aa10",
329 | "sha256:b87fd224a7de3bc01ce87eb947698797b4514e27115b0aa60a56991515dd9dd6",
330 | "sha256:f63d489c54b4f170ce8335727bbb196ceb9acd0e7805477bbef8fabc914bc0f9"
331 | ],
332 | "index": "pypi",
333 | "version": "==1.2.0"
334 | },
335 | "torchvision": {
336 | "hashes": [
337 | "sha256:3a8e9403252fefdf6e8f9993ae111d28eb4ad1e73f696f03de485d7f77d88067",
338 | "sha256:6fff5a31d50de3a59dcceda2a48de9df33a5f43357dc3e0da0ffbb97699aec52",
339 | "sha256:740b3718470aa4ec0b389df876eb25117df1952dd2e8105b7828a02aa5bce73b",
340 | "sha256:8114c33b736ee430496eef4fe03b25be8b939b2abd2a968558737bb9aed1928b",
341 | "sha256:904ef213594672f2ed7fafa3ab010cbf2a4704a951a7bf221cf36b3d2e3acd62",
342 | "sha256:afff8e987564192bc7f139d8b089541d4471ad6fc99e977e8bc8dbb4e0873041",
343 | "sha256:d7939f2ca401de3067a30b6f4dcef63d13d24a4cd1ddc2d3a9af3413ce658d03",
344 | "sha256:d8c2402704ce8ef8e87e4922160388c7ca010ef27700082014d6bd694cf1cc51",
345 | "sha256:e00de7571d83f968f5aea7a59e84e3262669acef0a077ce4bd705eca2df68167"
346 | ],
347 | "index": "pypi",
348 | "version": "==0.4.0"
349 | },
350 | "urllib3": {
351 | "hashes": [
352 | "sha256:b246607a25ac80bedac05c6f282e3cdaf3afb65420fd024ac94435cabe6e18d1",
353 | "sha256:dbe59173209418ae49d485b87d1681aefa36252ee85884c31346debd19463232"
354 | ],
355 | "index": "pypi",
356 | "version": "==1.25.3"
357 | }
358 | },
359 | "develop": {}
360 | }
361 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-vae
2 |
3 | A minimal pytorch implementation of [VAE](https://arxiv.org/abs/1312.6114), [IWAE](https://arxiv.org/abs/1509.00519), and [MIWAE](https://arxiv.org/abs/1802.04537).
4 | We followed the experimental details of the [IWAE paper](https://arxiv.org/abs/1509.00519).
5 |
6 |
7 |
8 |
9 | ## Usage
10 |
11 | You should be able to run experiments right away.
12 | First create a virtual environment using [pipenv](https://github.com/pypa/pipenv):
13 |
14 | ```pipenv install```
15 |
16 | To run experiments, you simply have to use:
17 |
18 | ```pipenv run python main.py ```
19 |
20 |
21 |
22 |
23 | ## Example commands
24 |
25 | For original VAE:
26 |
27 | ```pipenv run python main.py ```
28 |
29 | To also make figures (reconstruction, samples):
30 |
31 | ```pipenv run python main.py --figs ```
32 |
33 | For IWAE with 5 importance samples:
34 |
35 | ```pipenv run python main.py --importance_num=5 ```
36 |
37 | For MIWAE(16, 4):
38 |
39 | ```pipenv run python main.py --mean_num=16 --importance_num=4 ```
40 |
41 | See [the config file](https://github.com/yoonholee/pytorch-generative/blob/master/utils/config.py) for more options.
42 |
43 |
44 |
45 |
46 | ## Results
47 | 
48 | 
49 |
50 | ### Quantitative results on dynamically binarized MNIST
51 | | Method | NLL (this repo) | NLL ([IWAE paper](https://arxiv.org/abs/1509.00519)) | NLL ([MIWAE paper](https://arxiv.org/abs/1802.04537)) | comments |
52 | | ------------- | ------------- | ------------- | ------------- | ---- |
53 | | VAE | 87.01 | 86.76 | - |
54 | | MIWAE(5, 1) | 86.45 | 86.47 | - | listed as VAE with k=5
55 | | MIWAE(1, 5) | 85.18 | 85.54 | - | listed as IWAE with k=5
56 | | MIWAE(64, 1) | 86.07 | - | 86.21 | listed as VAE
57 | | MIWAE(16, 4) | 84.99 | - | - |
58 | | MIWAE(8, 8) | 84.69 | - | 84.97 |
59 | | MIWAE(4, 16) | 84.52 | - | 84.56 |
60 | | MIWAE(1, 64) | 84.37 | - | 84.52 | listed as IWAE
61 |
--------------------------------------------------------------------------------
/assets/mnist_recon.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoonholee/pytorch-vae/9dc44aae64f0e2896427ce955a48733d6315bb2d/assets/mnist_recon.gif
--------------------------------------------------------------------------------
/assets/mnist_samples.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoonholee/pytorch-vae/9dc44aae64f0e2896427ce955a48733d6315bb2d/assets/mnist_samples.gif
--------------------------------------------------------------------------------
/assets/omni_recon.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoonholee/pytorch-vae/9dc44aae64f0e2896427ce955a48733d6315bb2d/assets/omni_recon.gif
--------------------------------------------------------------------------------
/assets/omni_samples.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoonholee/pytorch-vae/9dc44aae64f0e2896427ce955a48733d6315bb2d/assets/omni_samples.gif
--------------------------------------------------------------------------------
/data_loader/cifar10.py:
--------------------------------------------------------------------------------
1 | from torchvision import datasets
2 |
3 | class cifar10(datasets.CIFAR10):
4 | def get_mean_img(self):
5 | return self.train_data.mean(0)
6 |
7 |
--------------------------------------------------------------------------------
/data_loader/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | from .stoch_mnist import stochMNIST
4 | from .omniglot import omniglot
5 | from .fixed_mnist import fixedMNIST
6 | from .cifar10 import cifar10
7 |
8 |
9 | def data_loaders(args):
10 | if args.dataset == 'omniglot':
11 | loader_fn, root = omniglot, './dataset/omniglot'
12 | elif args.dataset == 'fixedmnist':
13 | loader_fn, root = fixedMNIST, './dataset/fixedmnist'
14 | elif args.dataset == 'stochmnist':
15 | loader_fn, root = stochMNIST, './dataset/stochmnist'
16 | elif args.dataset == 'cifar10':
17 | loader_fn, root = cifar10, './dataset/cifar10'
18 |
19 | if args.dataset_dir != '': root = args.dataset_dir
20 | kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
21 | train_loader = torch.utils.data.DataLoader(
22 | loader_fn(root, train=True, download=True, transform=transforms.ToTensor()),
23 | batch_size=args.batch_size, shuffle=True, **kwargs)
24 | test_loader = torch.utils.data.DataLoader( # need test bs <=64 to make L_5000 tractable in one pass
25 | loader_fn(root, train=False, download=True, transform=transforms.ToTensor()),
26 | batch_size=args.test_batch_size, shuffle=False, **kwargs)
27 | return train_loader, test_loader
28 |
--------------------------------------------------------------------------------
/data_loader/fixed_mnist.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import torch
3 | import torch.utils.data as data
4 | from torchvision import transforms
5 | import os
6 | import numpy as np
7 | from PIL import Image
8 | import urllib.request
9 |
10 |
11 | class fixedMNIST(data.Dataset):
12 | """ Binarized MNIST dataset, proposed in
13 | http://proceedings.mlr.press/v15/larochelle11a/larochelle11a.pdf """
14 | train_file = 'binarized_mnist_train.amat'
15 | val_file = 'binarized_mnist_valid.amat'
16 | test_file = 'binarized_mnist_test.amat'
17 |
18 | def __init__(self, root, train=True, transform=None, download=False):
19 | # we ignore transform.
20 | self.root = os.path.expanduser(root)
21 | self.train = train # training set or test set
22 |
23 | if download: self.download()
24 | if not self._check_exists():
25 | raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
26 |
27 | self.data = self._get_data(train=train)
28 |
29 | def __getitem__(self, index):
30 | img = self.data[index]
31 | img = Image.fromarray(img)
32 | img = transforms.ToTensor()(img).type(torch.FloatTensor)
33 | return img, torch.tensor(-1) # Meaningless tensor instead of target
34 |
35 | def __len__(self):
36 | return len(self.data)
37 |
38 | def _get_data(self, train=True):
39 | with h5py.File(os.path.join(self.root, 'data.h5'), 'r') as hf:
40 | data = hf.get('train' if train else 'test')
41 | data = np.array(data)
42 | return data
43 |
44 | def get_mean_img(self):
45 | return self.data.mean(0).flatten()
46 |
47 | def download(self):
48 | if self._check_exists():
49 | return
50 | if not os.path.exists(self.root):
51 | os.makedirs(self.root)
52 |
53 | print('Downloading MNIST with fixed binarization...')
54 | for dataset in ['train', 'valid', 'test']:
55 | filename = 'binarized_mnist_{}.amat'.format(dataset)
56 | url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format(dataset)
57 | print('Downloading from {}...'.format(url))
58 | local_filename = os.path.join(self.root, filename)
59 | urllib.request.urlretrieve(url, local_filename)
60 | print('Saved to {}'.format(local_filename))
61 |
62 | def filename_to_np(filename):
63 | with open(filename) as f:
64 | lines = f.readlines()
65 | return np.array([[int(i)for i in line.split()] for line in lines]).astype('int8')
66 |
67 | train_data = np.concatenate([filename_to_np(os.path.join(self.root, self.train_file)),
68 | filename_to_np(os.path.join(self.root, self.val_file))])
69 | test_data = filename_to_np(os.path.join(self.root, self.val_file))
70 | with h5py.File(os.path.join(self.root, 'data.h5'), 'w') as hf:
71 | hf.create_dataset('train', data=train_data.reshape(-1, 28, 28))
72 | hf.create_dataset('test', data=test_data.reshape(-1, 28, 28))
73 | print('Done!')
74 |
75 | def _check_exists(self):
76 | return os.path.exists(os.path.join(self.root, 'data.h5'))
77 |
--------------------------------------------------------------------------------
/data_loader/omniglot.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | from torchvision import transforms
4 | import os
5 | from PIL import Image
6 | import urllib.request
7 | import scipy.io
8 |
9 |
10 | class omniglot(data.Dataset):
11 | """ omniglot dataset """
12 | url = 'https://github.com/yburda/iwae/raw/master/datasets/OMNIGLOT/chardata.mat'
13 |
14 | def __init__(self, root, train=True, transform=None, download=False):
15 | # we ignore transform.
16 | self.root = os.path.expanduser(root)
17 | self.train = train # training set or test set
18 |
19 | if download: self.download()
20 | if not self._check_exists():
21 | raise RuntimeError('Dataset not found. You can use download=True to download it')
22 |
23 | self.data = self._get_data(train=train)
24 |
25 | def __getitem__(self, index):
26 | img = self.data[index].reshape(28, 28)
27 | img = Image.fromarray(img)
28 | img = transforms.ToTensor()(img).type(torch.FloatTensor)
29 | img = torch.bernoulli(img) # stochastically binarize
30 | return img, torch.tensor(-1) # Meaningless tensor instead of target
31 |
32 | def __len__(self):
33 | return len(self.data)
34 |
35 | def _get_data(self, train=True):
36 | def reshape_data(data):
37 | return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='fortran')
38 |
39 | omni_raw = scipy.io.loadmat(os.path.join(self.root, 'chardata.mat'))
40 | data_str = 'data' if train else 'testdata'
41 | data = reshape_data(omni_raw[data_str].T.astype('float32'))
42 | return data
43 |
44 | def get_mean_img(self):
45 | return self.data.mean(0)
46 |
47 | def download(self):
48 | if self._check_exists():
49 | return
50 | if not os.path.exists(self.root):
51 | os.makedirs(self.root)
52 |
53 | print('Downloading from {}...'.format(self.url))
54 | local_filename = os.path.join(self.root, 'chardata.mat')
55 | urllib.request.urlretrieve(self.url, local_filename)
56 | print('Saved to {}'.format(local_filename))
57 |
58 | def _check_exists(self):
59 | return os.path.exists(os.path.join(self.root, 'chardata.mat'))
60 |
--------------------------------------------------------------------------------
/data_loader/stoch_mnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 | from PIL import Image
4 |
5 |
6 | class stochMNIST(datasets.MNIST):
7 | """ Gets a new stochastic binarization of MNIST at each call. """
8 | def __getitem__(self, index):
9 | if self.train:
10 | img, target = self.train_data[index], self.train_labels[index]
11 | else:
12 | img, target = self.test_data[index], self.test_labels[index]
13 |
14 | img = Image.fromarray(img.numpy(), mode='L')
15 | img = transforms.ToTensor()(img)
16 | img = torch.bernoulli(img) # stochastically binarize
17 | return img, target
18 |
19 | def get_mean_img(self):
20 | imgs = self.train_data.type(torch.float) / 255
21 | mean_img = imgs.mean(0).reshape(-1).numpy()
22 | return mean_img
23 |
--------------------------------------------------------------------------------
/exp.sh:
--------------------------------------------------------------------------------
1 | run="pipenv run python main.py"
2 |
3 | $run --gpu=0 &
4 | $run --gpu=1 --importance_num=64 &
5 | $run --gpu=2 --importance_num=8 --mean_num=8 &
6 | $run --gpu=3 --no_iwae_lr &
7 | $run --gpu=4 --z=100 &
8 | wait
9 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 | import torch
6 | from tensorboardX import SummaryWriter
7 | from torch import optim
8 |
9 | from data_loader.data_loader import data_loaders
10 | from model.bernoulli_vae import BernoulliVAE
11 | from model.conv_vae import ConvVAE
12 | from utils.config import get_args
13 | from utils.draw_figs import draw_figs
14 |
15 | args = get_args()
16 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
17 | args.cuda = torch.cuda.is_available()
18 | device = torch.device("cuda:0" if args.cuda else "cpu")
19 | train_loader, test_loader = data_loaders(args)
20 | torch.manual_seed(args.seed)
21 | if args.cuda:
22 | torch.cuda.manual_seed_all(args.seed)
23 | writer = SummaryWriter(args.out_dir)
24 |
25 | model_class = BernoulliVAE if args.arch == "bernoulli" else ConvVAE
26 | mean_img = train_loader.dataset.get_mean_img()
27 | model = model_class(
28 | device=device,
29 | img_shape=args.img_shape,
30 | h_dim=args.h_dim,
31 | z_dim=args.z_dim,
32 | analytic_kl=args.analytic_kl,
33 | mean_img=mean_img,
34 | ).to(device)
35 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, eps=1e-4)
36 | if args.no_iwae_lr:
37 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
38 | optimizer, mode="min", patience=100, factor=10 ** (-1 / 7)
39 | )
40 | else:
41 | milestones = np.cumsum([3 ** i for i in range(8)])
42 | scheduler = optim.lr_scheduler.MultiStepLR(
43 | optimizer, milestones=milestones, gamma=10 ** (-1 / 7)
44 | )
45 |
46 |
47 | def train(epoch):
48 | for batch_idx, (data, _) in enumerate(train_loader):
49 | optimizer.zero_grad()
50 | outs = model(data, mean_n=args.mean_num, imp_n=args.importance_num)
51 | loss_1, loss = -outs["elbo"].cpu().data.numpy().mean(), outs["loss"].mean()
52 | loss.backward()
53 | optimizer.step()
54 | model.train_step += 1
55 | if model.train_step % args.log_interval == 0:
56 | print(
57 | "Train Epoch: {} ({:.0f}%)\tLoss: {:.6f}".format(
58 | epoch, 100.0 * batch_idx / len(train_loader), loss.item()
59 | )
60 | )
61 | writer.add_scalar("train/loss", loss.item(), model.train_step)
62 | writer.add_scalar("train/loss_1", loss_1, model.train_step)
63 |
64 |
65 | def test(epoch):
66 | elbos = [
67 | model(data, mean_n=1, imp_n=args.log_likelihood_k)["elbo"].squeeze(0)
68 | for data, _ in test_loader
69 | ]
70 |
71 | def get_loss_k(k):
72 | losses = [
73 | model.logmeanexp(elbo[:k], 0).cpu().numpy().flatten() for elbo in elbos
74 | ]
75 | return -np.concatenate(losses).mean()
76 |
77 | return map(get_loss_k, [args.importance_num, 1, 64, args.log_likelihood_k])
78 |
79 |
80 | if args.eval:
81 | model.load_state_dict(torch.load(args.best_model_file))
82 | with torch.no_grad():
83 | print(list(test(0)))
84 | if args.figs:
85 | draw_figs(model, args, test_loader, 0)
86 | sys.exit()
87 |
88 | for epoch in range(1, args.epochs + 1):
89 | writer.add_scalar("learning_rate", optimizer.param_groups[0]["lr"], epoch)
90 | train(epoch)
91 | with torch.no_grad():
92 | if args.figs and epoch % 100 == 1:
93 | draw_figs(model, args, test_loader, epoch)
94 | test_loss, test_1, test_64, test_ll = test(epoch)
95 | if test_loss < model.best_loss:
96 | model.best_loss = test_loss
97 | torch.save(model.state_dict(), args.best_model_file)
98 | scheduler_args = {"metrics": test_loss} if args.no_iwae_lr else {}
99 | scheduler.step(**scheduler_args)
100 | writer.add_scalar("test/loss", test_loss, epoch)
101 | writer.add_scalar("test/loss_1", test_1, epoch)
102 | writer.add_scalar("test/loss_64", test_64, epoch)
103 | writer.add_scalar("test/LL", test_ll, epoch)
104 | print("==== Testing. LL: {:.4f} ====\n".format(test_ll))
105 |
106 | if args.to_gsheets:
107 | from utils.to_sheets import upload_to_google_sheets
108 |
109 | row_data = [args.exp_name, str(test_ll), str(test_64), str(test_64 - test_ll)]
110 | upload_to_google_sheets(row_data=row_data)
111 |
--------------------------------------------------------------------------------
/model/bernoulli_vae.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torch.distributions.bernoulli import Bernoulli
5 | from torch.distributions.normal import Normal
6 |
7 | from .vae_base import VAE
8 |
9 |
10 | class BernoulliVAE(VAE):
11 | def __init__(self, device, img_shape, h_dim, z_dim, analytic_kl, mean_img):
12 | super().__init__(device, z_dim, analytic_kl)
13 | x_dim = np.prod(img_shape)
14 | self.img_shape = img_shape
15 | self.proc_data = lambda x: x.to(device).reshape(-1, x_dim)
16 | self.encoder = nn.Sequential(
17 | nn.Linear(x_dim, h_dim), nn.Tanh(), nn.Linear(h_dim, h_dim), nn.Tanh()
18 | )
19 | self.enc_mu = nn.Linear(h_dim, z_dim)
20 | self.enc_sig = nn.Linear(h_dim, z_dim)
21 | self.decoder = nn.Sequential(
22 | nn.Linear(z_dim, h_dim), nn.Tanh(),
23 | nn.Linear(h_dim, h_dim), nn.Tanh(),
24 | nn.Linear(h_dim, x_dim),
25 | ) # using Bern(logit) is equivalent to putting sigmoid here.
26 |
27 | self.apply(self.init)
28 | mean_img = np.clip(mean_img, 1e-8, 1.0 - 1e-7)
29 | mean_img_logit = np.log(mean_img / (1.0 - mean_img))
30 | self.decoder[-1].bias = torch.nn.Parameter(torch.Tensor(mean_img_logit))
31 |
32 | def init(self, module):
33 | if type(module) == nn.Linear:
34 | torch.nn.init.xavier_uniform_(
35 | module.weight, gain=nn.init.calculate_gain("tanh")
36 | )
37 | module.bias.data.fill_(0.01)
38 |
39 | def encode(self, x):
40 | x = self.proc_data(x)
41 | h = self.encoder(x)
42 | mu, _std = self.enc_mu(h), self.enc_sig(h)
43 | return Normal(mu, nn.functional.softplus(_std)) # torch.exp(.5 * _std)
44 |
45 | def decode(self, z):
46 | x = self.decoder(z)
47 | return Bernoulli(logits=x)
48 |
49 | def lpxz(self, true_x, x_dist):
50 | return x_dist.log_prob(true_x).sum(-1)
51 |
52 | def sample(self, num_samples=64):
53 | z = self.prior.sample((num_samples,))
54 | x_dist = self.decode(z)
55 | return x_dist.sample().view(num_samples, *self.img_shape)
56 |
--------------------------------------------------------------------------------
/model/conv_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.distributions.normal import Normal
4 | from .vae_base import VAE
5 |
6 |
7 | class Flatten(nn.Module):
8 | def forward(self, input):
9 | return input.view(input.size(0), 16 * 8 * 8).contiguous()
10 |
11 |
12 | class UnFlatten(nn.Module):
13 | def forward(self, input):
14 | return input.view(input.size(0), 16, 8, 8).contiguous()
15 |
16 |
17 | class ConvVAE(VAE):
18 | # XXX: This class does not work at the moment
19 | def __init__(self, device, x_dim, h_dim, z_dim, analytic_kl, mean_img):
20 | # FIXME: integrate so that plot etc works.
21 | VAE.__init__(self, device, x_dim, h_dim, z_dim, analytic_kl, mean_img)
22 | self.proc_data = lambda x: x.to(device)
23 | self.encoder = nn.Sequential(
24 | nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(),
25 | nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(),
26 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(),
27 | nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(),
28 | Flatten())
29 | self.enc_mu = nn.Linear(16 * 8 * 8, z_dim)
30 | self.enc_sig = nn.Linear(16 * 8 * 8, z_dim)
31 | self.decoder = nn.Sequential(
32 | nn.Linear(z_dim, 16 * 8 * 8), nn.ReLU(),
33 | UnFlatten(),
34 | nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
35 | nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(),
36 | nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(),
37 | nn.ConvTranspose2d(32, 6, kernel_size=3, stride=1, padding=1))
38 |
39 | self.apply(self.init)
40 | self.decoder[-1].bias = torch.nn.Parameter(torch.cat(
41 | [torch.Tensor(mean_img.mean(0).mean(0)) / 256, .01 * torch.ones([3])]))
42 |
43 | def init(self, module):
44 | if type(module) in [nn.Conv2d, nn.ConvTranspose2d]:
45 | torch.nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))
46 | module.bias.data.fill_(.01)
47 |
48 | def encode(self, x):
49 | x = self.proc_data(x)
50 | h = self.encoder(x)
51 | mu, _std = self.enc_mu(h), self.enc_sig(h)
52 | return Normal(mu, nn.functional.softplus(_std))
53 |
54 | def decode(self, z):
55 | mean_n, imp_n, bs = z.size(0), z.size(1), z.size(2)
56 | z = z.view([mean_n * imp_n * bs, -1]).contiguous()
57 | x = self.decoder(z)
58 | x = x.view([mean_n, imp_n, bs, 6, 32, 32]).contiguous()
59 | x_mean, x_std = x[:, :, :, :3, :, :].contiguous(), nn.functional.softplus(x[:, :, :, 3:, :, :]).contiguous()
60 | return Normal(x_mean, x_std)
61 |
62 | def lpxz(self, true_x, x_dist):
63 | return x_dist.log_prob(true_x).sum([-1, -2, -3])
64 |
--------------------------------------------------------------------------------
/model/vae_base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torch.distributions.normal import Normal
5 |
6 |
7 | class VAE(nn.Module):
8 | def __init__(self, device, z_dim, analytic_kl):
9 | super().__init__()
10 | self.train_step = 0
11 | self.best_loss = np.inf
12 | self.analytic_kl = analytic_kl
13 | self.prior = Normal(
14 | torch.zeros([z_dim]).to(device), torch.ones([z_dim]).to(device)
15 | )
16 |
17 | def proc_data(self, x):
18 | pass
19 |
20 | def encode(self, x):
21 | pass
22 |
23 | def decode(self, z):
24 | pass
25 |
26 | def lpxz(self, true_x, x_dist):
27 | pass
28 |
29 | def sample(self, num_samples=64):
30 | pass
31 |
32 | def elbo(self, true_x, z, x_dist, z_dist):
33 | true_x = self.proc_data(true_x)
34 | lpxz = self.lpxz(true_x, x_dist)
35 |
36 | if self.analytic_kl:
37 | # SGVB^B: -KL(q(z|x)||p(z)) + log p(x|z). Use when KL can be done analytically.
38 | assert z.size(0) == 1 and z.size(1) == 1
39 | kl = torch.distributions.kl.kl_divergence(z_dist, self.prior).sum(-1)
40 | else:
41 | # SGVB^A: log p(z) - log q(z|x) + log p(x|z)
42 | lpz = self.prior.log_prob(z).sum(-1)
43 | lqzx = z_dist.log_prob(z).sum(-1)
44 | kl = -lpz + lqzx
45 | return -kl + lpxz
46 |
47 | def logmeanexp(self, inputs, dim=1):
48 | if inputs.size(dim) == 1:
49 | return inputs
50 | else:
51 | input_max = inputs.max(dim, keepdim=True)[0]
52 | return (inputs - input_max).exp().mean(dim).log() + input_max
53 |
54 | def forward(self, true_x, mean_n, imp_n):
55 | z_dist = self.encode(true_x)
56 | # mean_n, imp_n, batch_size, z_dim
57 | z = z_dist.rsample(torch.Size([mean_n, imp_n]))
58 | x_dist = self.decode(z)
59 |
60 | elbo = self.elbo(true_x, z, x_dist, z_dist) # mean_n, imp_n, batch_size
61 | elbo_iwae = self.logmeanexp(elbo, 1).squeeze(1) # mean_n, batch_size
62 | elbo_iwae_m = torch.mean(elbo_iwae, 0) # batch_size
63 | return {"elbo": elbo, "loss": -elbo_iwae_m}
64 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -i https://pypi.org/simple
2 | certifi==2019.6.16
3 | chardet==3.0.4
4 | cycler==0.10.0
5 | gspread==3.1.0
6 | h5py==2.9.0
7 | httplib2==0.13.1
8 | idna==2.8
9 | imageio==2.5.0
10 | kiwisolver==1.1.0
11 | matplotlib==3.1.1
12 | numpy==1.17.0
13 | oauth2client==4.1.3
14 | pathlib==1.0.1
15 | pillow==6.1.0
16 | protobuf==3.9.1
17 | pyasn1-modules==0.2.6
18 | pyasn1==0.4.6
19 | pyparsing==2.4.2
20 | python-dateutil==2.8.0
21 | requests==2.22.0
22 | rsa==4.0
23 | scipy==1.3.1
24 | six==1.12.0
25 | tensorboardx==1.8
26 | torch==1.2.0
27 | torchvision==0.4.0
28 | urllib3==1.25.3
29 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--gpu', type=int, default=0)
6 | parser.add_argument('--seed', type=int, default=42)
7 | parser.add_argument('--log_interval', type=int, default=500)
8 | parser.add_argument('--eval', action='store_true')
9 | parser.add_argument('--figs', action='store_true')
10 | parser.add_argument('--to_gsheets', action='store_true')
11 | parser.add_argument('--arch', type=str, default='bernoulli', choices=['bernoulli']) # TODO: make conv work
12 |
13 | parser.add_argument('--dataset_dir', type=str, default='')
14 | parser.add_argument('--dataset', type=str, default='stochmnist',
15 | choices=['stochmnist', 'omniglot', 'fixedmnist']) # TODO: make cifar10 work
16 | parser.add_argument('--batch_size', type=int, default=20) # iwae uses 20
17 | parser.add_argument('--test_batch_size', type=int, default=64)
18 | parser.add_argument('--epochs', type=int, default=3280) # iwae uses 3280
19 |
20 | parser.add_argument('--learning_rate', type=float, default=1e-3)
21 | parser.add_argument('--no_iwae_lr', action='store_true')
22 | parser.add_argument('--mean_num', type=int, default=1) # M in "tighter variational bounds...". Use 1 for vanilla vae
23 | parser.add_argument('--importance_num', type=int, default=1) # k of iwae. Use 1 for vanilla vae
24 | parser.add_argument('--analytic_kl', action='store_true')
25 | parser.add_argument('--h_dim', type=int, default=200)
26 | parser.add_argument('--z_dim', type=int, default=50)
27 |
28 |
29 | def get_args():
30 | args = parser.parse_args()
31 |
32 | def cstr(arg, arg_name, default, custom_str=False):
33 | """ Get config str for arg, ignoring if set to default. """
34 | not_default = arg != default
35 | if not custom_str:
36 | custom_str = f'_{arg_name}{arg}'
37 | return custom_str if not_default else ''
38 |
39 | args.exp_name = (f'm{args.mean_num}_k{args.importance_num}'
40 | f'{cstr(args.dataset, "", "stochmnist")}{cstr(args.arch, "", "bernoulli")}'
41 | f'{cstr(args.seed, "seed", 42)}{cstr(args.batch_size, "bs", 20)}'
42 | f'{cstr(args.h_dim, "h", 200)}{cstr(args.z_dim, "z", 50)}'
43 | f'{cstr(args.learning_rate, "lr", 1e-3)}{cstr(args.analytic_kl, None, False, "_analytic")}'
44 | f'{cstr(args.no_iwae_lr, None, False, "_noiwae")}{cstr(args.epochs, "epoch", 3280)}')
45 |
46 | args.figs_dir = os.path.join('figs', args.exp_name)
47 | args.out_dir = os.path.join('result', args.exp_name)
48 | args.best_model_file = os.path.join('result', args.exp_name, 'best_model.pt')
49 | if not os.path.exists(args.out_dir):
50 | os.makedirs(args.out_dir)
51 | if not os.path.exists(args.figs_dir):
52 | os.makedirs(args.figs_dir)
53 |
54 | args.log_likelihood_k = 100 if args.dataset == 'cifar10' else 5000
55 | args.img_shape = (32, 32) if args.dataset == 'cifar10' else (28, 28)
56 | return args
57 |
--------------------------------------------------------------------------------
/utils/draw_figs.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import matplotlib.pyplot as plt
4 | import imageio
5 | import pathlib
6 | import numpy as np
7 | import torch
8 |
9 |
10 | def draw_gif(name, figs_dir, glob_str):
11 | files = [file for file in pathlib.Path(figs_dir).glob(glob_str)]
12 | images = [imageio.imread(str(file)) for file in sorted(files)]
13 | imageio.mimsave('{}/{}'.format(figs_dir, name), images, duration=.5)
14 |
15 |
16 | def draw_figs(model, args, test_loader, epoch):
17 | samples = model.sample(num_samples=100).data.cpu().numpy()
18 | plt.figure(figsize=(5, 5))
19 | plt.suptitle('Samples, Epoch {}'.format(epoch), fontsize=20)
20 | plt.axis('square')
21 | plt.legend(frameon=True)
22 | for idx, im in enumerate(samples):
23 | plt.subplot(10, 10, idx+1)
24 | plt.imshow(im, cmap='Greys')
25 | plt.axis('off')
26 | plt.savefig('figs/{}/samples_{:04}.jpg'.format(args.exp_name, epoch))
27 | plt.clf()
28 | draw_gif('{}_samples.gif'.format(args.exp_name), args.figs_dir, 'samples*.jpg')
29 |
30 | for batch_idx, (data, _) in enumerate(test_loader):
31 | break
32 | z_dist = model.encode(data)
33 | z = z_dist.rsample()
34 | recon = model.decode(z).probs.view(args.test_batch_size, 28, 28)
35 | data = data.view(args.test_batch_size, 28, 28)
36 | plt.figure(figsize=(5, 5))
37 | plt.suptitle('Reconstruction, Epoch {}'.format(epoch), fontsize=20)
38 | plt.axis('square')
39 | plt.legend(frameon=True)
40 | for i in range(50):
41 | data_i = data[i].data.cpu().numpy()
42 | recon_i = recon[i].data.cpu().numpy()
43 | plt.subplot(10, 10, 2*i+1)
44 | plt.imshow(data_i, cmap='Greys')
45 | plt.axis('off')
46 | plt.subplot(10, 10, 2*i+2)
47 | plt.imshow(recon_i, cmap='Greys')
48 | plt.axis('off')
49 | plt.savefig('figs/{}/reconstruction_{:04}.jpg'.format(args.exp_name, epoch))
50 | plt.clf()
51 | draw_gif('{}_reconstruction.gif'.format(args.exp_name), args.figs_dir, 'reconstruction*.jpg')
52 |
53 | if args.z_dim == 2:
54 | latent_space, labels = [], []
55 | for batch_idx, (data, label) in enumerate(test_loader):
56 | latent_space.append(model.encode(data).loc.data.cpu().numpy())
57 | labels.append(label)
58 | latent_space, labels = np.concatenate(latent_space), np.concatenate(labels)
59 | plt.figure(figsize=(5, 5))
60 | for c in range(10):
61 | idx = (labels == c)
62 | plt.scatter(latent_space[idx, 0], latent_space[idx, 1],
63 | c=matplotlib.cm.get_cmap('tab10')(c), marker=',', label=str(c), alpha=.7)
64 | plt.suptitle('Latent representation, Epoch {}'.format(epoch), fontsize=20)
65 | plt.axis('square')
66 | plt.legend(frameon=True)
67 | plt.savefig('figs/{}/latent_{:04}.jpg'.format(args.exp_name, epoch))
68 | plt.clf()
69 | draw_gif('{}_latent.gif'.format(args.exp_name), args.figs_dir, 'latent*.jpg')
70 |
71 | plt.close('all')
72 |
--------------------------------------------------------------------------------
/utils/to_sheets.py:
--------------------------------------------------------------------------------
1 | import gspread
2 | from oauth2client.service_account import ServiceAccountCredentials
3 | get_credentials = ServiceAccountCredentials.from_json_keyfile_name
4 |
5 | scope = ['https://spreadsheets.google.com/feeds',
6 | 'https://www.googleapis.com/auth/drive']
7 | # To make this work, obtain credentials from Google Sheets API and save to
8 | # creds.json in current directory.
9 | credentials = get_credentials('creds.json', scope)
10 | gc = gspread.authorize(credentials)
11 | sheet_name = 'pytorch-generative'
12 |
13 |
14 | def upload_to_google_sheets(row_data, index=2):
15 | worksheet = gc.open(sheet_name).sheet1
16 | worksheet.insert_row(row_data, index=index)
17 |
--------------------------------------------------------------------------------