├── .gitignore ├── Pipfile ├── Pipfile.lock ├── README.md ├── assets ├── mnist_recon.gif ├── mnist_samples.gif ├── omni_recon.gif └── omni_samples.gif ├── data_loader ├── cifar10.py ├── data_loader.py ├── fixed_mnist.py ├── omniglot.py └── stoch_mnist.py ├── exp.sh ├── main.py ├── model ├── bernoulli_vae.py ├── conv_vae.py └── vae_base.py ├── requirements.txt └── utils ├── config.py ├── draw_figs.py └── to_sheets.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset 2 | result 3 | dfc_exp.sh 4 | 5 | 6 | ###### python ###### 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | .dmypy.json 117 | dmypy.json 118 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [requires] 9 | python_version = "3.6" 10 | 11 | [packages] 12 | numpy = "*" 13 | "h5py" = "*" 14 | scipy = "*" 15 | matplotlib = "*" 16 | imageio = "*" 17 | pathlib = "*" 18 | gspread = "*" 19 | "oauth2client" = "*" 20 | tensorboardx = "*" 21 | torch = "*" 22 | torchvision = "*" 23 | urllib3 = ">=1.24.2" 24 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "de41289fb28f71e8475cba544066bcd5ea9a0353f7fb699cd317ecba70a4f749" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.6" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "certifi": { 20 | "hashes": [ 21 | "sha256:046832c04d4e752f37383b628bc601a7ea7211496b4638f6514d0e5b9acc4939", 22 | "sha256:945e3ba63a0b9f577b1395204e13c3a231f9bc0223888be653286534e5873695" 23 | ], 24 | "version": "==2019.6.16" 25 | }, 26 | "chardet": { 27 | "hashes": [ 28 | "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae", 29 | "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691" 30 | ], 31 | "version": "==3.0.4" 32 | }, 33 | "cycler": { 34 | "hashes": [ 35 | "sha256:1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d", 36 | "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8" 37 | ], 38 | "version": "==0.10.0" 39 | }, 40 | "gspread": { 41 | "hashes": [ 42 | "sha256:dd945e3ae5d3d0325ad9982e0d5667f79ca121d0bb6f35274dc84371bbb79dd5", 43 | "sha256:f7ce6c06250f694976c3cd4944e3b607b0810b93383839e5b67c7199ce2f0d3d" 44 | ], 45 | "index": "pypi", 46 | "version": "==3.1.0" 47 | }, 48 | "h5py": { 49 | "hashes": [ 50 | "sha256:05750b91640273c69989c657eaac34b091abdd75efc8c4824c82aaf898a2da0a", 51 | "sha256:082a27208aa3a2286e7272e998e7e225b2a7d4b7821bd840aebf96d50977abbb", 52 | "sha256:08e2e8297195f9e813e894b6c63f79372582787795bba2014a2db6a2de95f713", 53 | "sha256:0dd2adeb2e9de5081eb8dcec88874e7fd35dae9a21557be3a55a3c7d491842a4", 54 | "sha256:0f94de7a10562b991967a66bbe6dda9808e18088676834c0a4dcec3fdd3bcc6f", 55 | "sha256:106e42e2e01e486a3d32eeb9ba0e3a7f65c12fa8998d63625fa41fb8bdc44cdb", 56 | "sha256:1606c66015f04719c41a9863c156fc0e6b992150de21c067444bcb82e7d75579", 57 | "sha256:1854c4beff9961e477e133143c5e5e355dac0b3ebf19c52cf7cc1b1ef757703c", 58 | "sha256:1e9fb6f1746500ea91a00193ce2361803c70c6b13f10aae9a33ad7b5bd28e800", 59 | "sha256:2cca17e80ddb151894333377675db90cd0279fa454776e0a4f74308376afd050", 60 | "sha256:30e365e8408759db3778c361f1e4e0fe8e98a875185ae46c795a85e9bafb9cdf", 61 | "sha256:3206bac900e16eda81687d787086f4ffd4f3854980d798e191a9868a6510c3ae", 62 | "sha256:3c23d72058647cee19b30452acc7895621e2de0a0bd5b8a1e34204b9ea9ed43c", 63 | "sha256:407b5f911a83daa285bbf1ef78a9909ee5957f257d3524b8606be37e8643c5f0", 64 | "sha256:4162953714a9212d373ac953c10e3329f1e830d3c7473f2a2e4f25dd6241eef0", 65 | "sha256:5fc7aba72a51b2c80605eba1c50dbf84224dcd206279d30a75c154e5652e1fe4", 66 | "sha256:713ac19307e11de4d9833af0c4bd6778bde0a3d967cafd2f0f347223711c1e31", 67 | "sha256:71b946d80ef3c3f12db157d7778b1fe74a517ca85e94809358b15580983c2ce2", 68 | "sha256:8cc4aed71e20d87e0a6f02094d718a95252f11f8ed143bc112d22167f08d4040", 69 | "sha256:9d41ca62daf36d6b6515ab8765e4c8c4388ee18e2a665701fef2b41563821002", 70 | "sha256:a744e13b000f234cd5a5b2a1f95816b819027c57f385da54ad2b7da1adace2f3", 71 | "sha256:b087ee01396c4b34e9dc41e3a6a0442158206d383c19c7d0396d52067b17c1cb", 72 | "sha256:b0f03af381d33306ce67d18275b61acb4ca111ced645381387a02c8a5ee1b796", 73 | "sha256:b9e4b8dfd587365bdd719ae178fa1b6c1231f81280b1375eef8626dfd8761bf3", 74 | "sha256:c5dd4ec75985b99166c045909e10f0534704d102848b1d9f0992720e908928e7", 75 | "sha256:d2b82f23cd862a9d05108fe99967e9edfa95c136f532a71cb3d28dc252771f50", 76 | "sha256:e58a25764472af07b7e1c4b10b0179c8ea726446c7141076286e41891bf3a563", 77 | "sha256:f3b49107fbfc77333fc2b1ef4d5de2abcd57e7ea3a1482455229494cf2da56ce" 78 | ], 79 | "index": "pypi", 80 | "version": "==2.9.0" 81 | }, 82 | "httplib2": { 83 | "hashes": [ 84 | "sha256:6901c8c0ffcf721f9ce270ad86da37bc2b4d32b8802d4a9cec38274898a64044", 85 | "sha256:cf6f9d5876d796539ec922a2c9b9a7cad9bfd90f04badcdc3bcfa537168052c3" 86 | ], 87 | "version": "==0.13.1" 88 | }, 89 | "idna": { 90 | "hashes": [ 91 | "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407", 92 | "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c" 93 | ], 94 | "version": "==2.8" 95 | }, 96 | "imageio": { 97 | "hashes": [ 98 | "sha256:1a2bbbb7cd38161340fa3b14d806dfbf914abf3ee6fd4592af2afb87d049f209", 99 | "sha256:42e65aadfc3d57a1043615c92bdf6319b67589e49a0aae2b985b82144aceacad" 100 | ], 101 | "index": "pypi", 102 | "version": "==2.5.0" 103 | }, 104 | "kiwisolver": { 105 | "hashes": [ 106 | "sha256:05b5b061e09f60f56244adc885c4a7867da25ca387376b02c1efc29cc16bcd0f", 107 | "sha256:26f4fbd6f5e1dabff70a9ba0d2c4bd30761086454aa30dddc5b52764ee4852b7", 108 | "sha256:3b2378ad387f49cbb328205bda569b9f87288d6bc1bf4cd683c34523a2341efe", 109 | "sha256:400599c0fe58d21522cae0e8b22318e09d9729451b17ee61ba8e1e7c0346565c", 110 | "sha256:47b8cb81a7d18dbaf4fed6a61c3cecdb5adec7b4ac292bddb0d016d57e8507d5", 111 | "sha256:53eaed412477c836e1b9522c19858a8557d6e595077830146182225613b11a75", 112 | "sha256:58e626e1f7dfbb620d08d457325a4cdac65d1809680009f46bf41eaf74ad0187", 113 | "sha256:5a52e1b006bfa5be04fe4debbcdd2688432a9af4b207a3f429c74ad625022641", 114 | "sha256:5c7ca4e449ac9f99b3b9d4693debb1d6d237d1542dd6a56b3305fe8a9620f883", 115 | "sha256:682e54f0ce8f45981878756d7203fd01e188cc6c8b2c5e2cf03675390b4534d5", 116 | "sha256:79bfb2f0bd7cbf9ea256612c9523367e5ec51d7cd616ae20ca2c90f575d839a2", 117 | "sha256:7f4dd50874177d2bb060d74769210f3bce1af87a8c7cf5b37d032ebf94f0aca3", 118 | "sha256:8944a16020c07b682df861207b7e0efcd2f46c7488619cb55f65882279119389", 119 | "sha256:8aa7009437640beb2768bfd06da049bad0df85f47ff18426261acecd1cf00897", 120 | "sha256:939f36f21a8c571686eb491acfffa9c7f1ac345087281b412d63ea39ca14ec4a", 121 | "sha256:9733b7f64bd9f807832d673355f79703f81f0b3e52bfce420fc00d8cb28c6a6c", 122 | "sha256:a02f6c3e229d0b7220bd74600e9351e18bc0c361b05f29adae0d10599ae0e326", 123 | "sha256:a0c0a9f06872330d0dd31b45607197caab3c22777600e88031bfe66799e70bb0", 124 | "sha256:acc4df99308111585121db217681f1ce0eecb48d3a828a2f9bbf9773f4937e9e", 125 | "sha256:b64916959e4ae0ac78af7c3e8cef4becee0c0e9694ad477b4c6b3a536de6a544", 126 | "sha256:d3fcf0819dc3fea58be1fd1ca390851bdb719a549850e708ed858503ff25d995", 127 | "sha256:d52e3b1868a4e8fd18b5cb15055c76820df514e26aa84cc02f593d99fef6707f", 128 | "sha256:db1a5d3cc4ae943d674718d6c47d2d82488ddd94b93b9e12d24aabdbfe48caee", 129 | "sha256:e3a21a720791712ed721c7b95d433e036134de6f18c77dbe96119eaf7aa08004", 130 | "sha256:e8bf074363ce2babeb4764d94f8e65efd22e6a7c74860a4f05a6947afc020ff2", 131 | "sha256:f16814a4a96dc04bf1da7d53ee8d5b1d6decfc1a92a63349bb15d37b6a263dd9", 132 | "sha256:f2b22153870ca5cf2ab9c940d7bc38e8e9089fa0f7e5856ea195e1cf4ff43d5a", 133 | "sha256:f790f8b3dff3d53453de6a7b7ddd173d2e020fb160baff578d578065b108a05f" 134 | ], 135 | "version": "==1.1.0" 136 | }, 137 | "matplotlib": { 138 | "hashes": [ 139 | "sha256:1febd22afe1489b13c6749ea059d392c03261b2950d1d45c17e3aed812080c93", 140 | "sha256:31a30d03f39528c79f3a592857be62a08595dec4ac034978ecd0f814fa0eec2d", 141 | "sha256:4442ce720907f67a79d45de9ada47be81ce17e6c2f448b3c64765af93f6829c9", 142 | "sha256:796edbd1182cbffa7e1e7a97f1e141f875a8501ba8dd834269ae3cd45a8c976f", 143 | "sha256:934e6243df7165aad097572abf5b6003c77c9b6c480c3c4de6f2ef1b5fdd4ec0", 144 | "sha256:bab9d848dbf1517bc58d1f486772e99919b19efef5dd8596d4b26f9f5ee08b6b", 145 | "sha256:c1fe1e6cdaa53f11f088b7470c2056c0df7d80ee4858dadf6cbe433fcba4323b", 146 | "sha256:e5b8aeca9276a3a988caebe9f08366ed519fff98f77c6df5b64d7603d0e42e36", 147 | "sha256:ec6bd0a6a58df3628ff269978f4a4b924a0d371ad8ce1f8e2b635b99e482877a" 148 | ], 149 | "index": "pypi", 150 | "version": "==3.1.1" 151 | }, 152 | "numpy": { 153 | "hashes": [ 154 | "sha256:03e311b0a4c9f5755da7d52161280c6a78406c7be5c5cc7facfbcebb641efb7e", 155 | "sha256:0cdd229a53d2720d21175012ab0599665f8c9588b3b8ffa6095dd7b90f0691dd", 156 | "sha256:312bb18e95218bedc3563f26fcc9c1c6bfaaf9d453d15942c0839acdd7e4c473", 157 | "sha256:464b1c48baf49e8505b1bb754c47a013d2c305c5b14269b5c85ea0625b6a988a", 158 | "sha256:5adfde7bd3ee4864536e230bcab1c673f866736698724d5d28c11a4d63672658", 159 | "sha256:7724e9e31ee72389d522b88c0d4201f24edc34277999701ccd4a5392e7d8af61", 160 | "sha256:8d36f7c53ae741e23f54793ffefb2912340b800476eb0a831c6eb602e204c5c4", 161 | "sha256:910d2272403c2ea8a52d9159827dc9f7c27fb4b263749dca884e2e4a8af3b302", 162 | "sha256:951fefe2fb73f84c620bec4e001e80a80ddaa1b84dce244ded7f1e0cbe0ed34a", 163 | "sha256:9588c6b4157f493edeb9378788dcd02cb9e6a6aeaa518b511a1c79d06cbd8094", 164 | "sha256:9ce8300950f2f1d29d0e49c28ebfff0d2f1e2a7444830fbb0b913c7c08f31511", 165 | "sha256:be39cca66cc6806652da97103605c7b65ee4442c638f04ff064a7efd9a81d50a", 166 | "sha256:c3ab2d835b95ccb59d11dfcd56eb0480daea57cdf95d686d22eff35584bc4554", 167 | "sha256:eb0fc4a492cb896346c9e2c7a22eae3e766d407df3eb20f4ce027f23f76e4c54", 168 | "sha256:ec0c56eae6cee6299f41e780a0280318a93db519bbb2906103c43f3e2be1206c", 169 | "sha256:f4e4612de60a4f1c4d06c8c2857cdcb2b8b5289189a12053f37d3f41f06c60d0" 170 | ], 171 | "index": "pypi", 172 | "version": "==1.17.0" 173 | }, 174 | "oauth2client": { 175 | "hashes": [ 176 | "sha256:b8a81cc5d60e2d364f0b1b98f958dbd472887acaf1a5b05e21c28c31a2d6d3ac", 177 | "sha256:d486741e451287f69568a4d26d70d9acd73a2bbfa275746c535b4209891cccc6" 178 | ], 179 | "index": "pypi", 180 | "version": "==4.1.3" 181 | }, 182 | "pathlib": { 183 | "hashes": [ 184 | "sha256:6940718dfc3eff4258203ad5021090933e5c04707d5ca8cc9e73c94a7894ea9f" 185 | ], 186 | "index": "pypi", 187 | "version": "==1.0.1" 188 | }, 189 | "pillow": { 190 | "hashes": [ 191 | "sha256:0804f77cb1e9b6dbd37601cee11283bba39a8d44b9ddb053400c58e0c0d7d9de", 192 | "sha256:0ab7c5b5d04691bcbd570658667dd1e21ca311c62dcfd315ad2255b1cd37f64f", 193 | "sha256:0b3e6cf3ea1f8cecd625f1420b931c83ce74f00c29a0ff1ce4385f99900ac7c4", 194 | "sha256:365c06a45712cd723ec16fa4ceb32ce46ad201eb7bbf6d3c16b063c72b61a3ed", 195 | "sha256:38301fbc0af865baa4752ddae1bb3cbb24b3d8f221bf2850aad96b243306fa03", 196 | "sha256:3aef1af1a91798536bbab35d70d35750bd2884f0832c88aeb2499aa2d1ed4992", 197 | "sha256:3fe0ab49537d9330c9bba7f16a5f8b02da615b5c809cdf7124f356a0f182eccd", 198 | "sha256:45a619d5c1915957449264c81c008934452e3fd3604e36809212300b2a4dab68", 199 | "sha256:49f90f147883a0c3778fd29d3eb169d56416f25758d0f66775db9184debc8010", 200 | "sha256:571b5a758baf1cb6a04233fb23d6cf1ca60b31f9f641b1700bfaab1194020555", 201 | "sha256:5ac381e8b1259925287ccc5a87d9cf6322a2dc88ae28a97fe3e196385288413f", 202 | "sha256:6153db744a743c0c8c91b8e3b9d40e0b13a5d31dbf8a12748c6d9bfd3ddc01ad", 203 | "sha256:6fd63afd14a16f5d6b408f623cc2142917a1f92855f0df997e09a49f0341be8a", 204 | "sha256:70acbcaba2a638923c2d337e0edea210505708d7859b87c2bd81e8f9902ae826", 205 | "sha256:70b1594d56ed32d56ed21a7fbb2a5c6fd7446cdb7b21e749c9791eac3a64d9e4", 206 | "sha256:76638865c83b1bb33bcac2a61ce4d13c17dba2204969dedb9ab60ef62bede686", 207 | "sha256:7b2ec162c87fc496aa568258ac88631a2ce0acfe681a9af40842fc55deaedc99", 208 | "sha256:7cee2cef07c8d76894ebefc54e4bb707dfc7f258ad155bd61d87f6cd487a70ff", 209 | "sha256:7d16d4498f8b374fc625c4037742fbdd7f9ac383fd50b06f4df00c81ef60e829", 210 | "sha256:b50bc1780681b127e28f0075dfb81d6135c3a293e0c1d0211133c75e2179b6c0", 211 | "sha256:bd0582f831ad5bcad6ca001deba4568573a4675437db17c4031939156ff339fa", 212 | "sha256:cfd40d8a4b59f7567620410f966bb1f32dc555b2b19f82a91b147fac296f645c", 213 | "sha256:e3ae410089de680e8f84c68b755b42bc42c0ceb8c03dbea88a5099747091d38e", 214 | "sha256:e9046e559c299b395b39ac7dbf16005308821c2f24a63cae2ab173bd6aa11616", 215 | "sha256:ef6be704ae2bc8ad0ebc5cb850ee9139493b0fc4e81abcc240fb392a63ebc808", 216 | "sha256:f8dc19d92896558f9c4317ee365729ead9d7bbcf2052a9a19a3ef17abbb8ac5b" 217 | ], 218 | "version": "==6.1.0" 219 | }, 220 | "protobuf": { 221 | "hashes": [ 222 | "sha256:00a1b0b352dc7c809749526d1688a64b62ea400c5b05416f93cfb1b11a036295", 223 | "sha256:01acbca2d2c8c3f7f235f1842440adbe01bbc379fa1cbdd80753801432b3fae9", 224 | "sha256:0a795bca65987b62d6b8a2d934aa317fd1a4d06a6dd4df36312f5b0ade44a8d9", 225 | "sha256:0ec035114213b6d6e7713987a759d762dd94e9f82284515b3b7331f34bfaec7f", 226 | "sha256:31b18e1434b4907cb0113e7a372cd4d92c047ce7ba0fa7ea66a404d6388ed2c1", 227 | "sha256:32a3abf79b0bef073c70656e86d5bd68a28a1fbb138429912c4fc07b9d426b07", 228 | "sha256:55f85b7808766e5e3f526818f5e2aeb5ba2edcc45bcccede46a3ccc19b569cb0", 229 | "sha256:64ab9bc971989cbdd648c102a96253fdf0202b0c38f15bd34759a8707bdd5f64", 230 | "sha256:64cf847e843a465b6c1ba90fb6c7f7844d54dbe9eb731e86a60981d03f5b2e6e", 231 | "sha256:917c8662b585470e8fd42f052661fc66d59fccaae450a60044307dcbf82a3335", 232 | "sha256:afed9003d7f2be2c3df20f64220c30faec441073731511728a2cb4cab4cd46a6", 233 | "sha256:bf8e05d638b585d1752c5a84247134a0350d3a8b73d3632489a014a9f6f1e758", 234 | "sha256:d831b047bd69becaf64019a47179eb22118a50dd008340655266a906c69c6417", 235 | "sha256:de2760583ed28749ff885789c1cbc6c9c06d6de92fc825740ab99deb2f25ea4d", 236 | "sha256:eabc4cf1bc19689af8022ba52fd668564a8d96e0d08f3b4732d26a64255216a4", 237 | "sha256:fcff6086c86fb1628d94ea455c7b9de898afc50378042927a59df8065a79a549" 238 | ], 239 | "version": "==3.9.1" 240 | }, 241 | "pyasn1": { 242 | "hashes": [ 243 | "sha256:3bb81821d47b17146049e7574ab4bf1e315eb7aead30efe5d6a9ca422c9710be", 244 | "sha256:b773d5c9196ffbc3a1e13bdf909d446cad80a039aa3340bcad72f395b76ebc86" 245 | ], 246 | "version": "==0.4.6" 247 | }, 248 | "pyasn1-modules": { 249 | "hashes": [ 250 | "sha256:43c17a83c155229839cc5c6b868e8d0c6041dba149789b6d6e28801c64821722", 251 | "sha256:e30199a9d221f1b26c885ff3d87fd08694dbbe18ed0e8e405a2a7126d30ce4c0" 252 | ], 253 | "version": "==0.2.6" 254 | }, 255 | "pyparsing": { 256 | "hashes": [ 257 | "sha256:6f98a7b9397e206d78cc01df10131398f1c8b8510a2f4d97d9abd82e1aacdd80", 258 | "sha256:d9338df12903bbf5d65a0e4e87c2161968b10d2e489652bb47001d82a9b028b4" 259 | ], 260 | "version": "==2.4.2" 261 | }, 262 | "python-dateutil": { 263 | "hashes": [ 264 | "sha256:7e6584c74aeed623791615e26efd690f29817a27c73085b78e4bad02493df2fb", 265 | "sha256:c89805f6f4d64db21ed966fda138f8a5ed7a4fdbc1a8ee329ce1b74e3c74da9e" 266 | ], 267 | "version": "==2.8.0" 268 | }, 269 | "requests": { 270 | "hashes": [ 271 | "sha256:11e007a8a2aa0323f5a921e9e6a2d7e4e67d9877e85773fba9ba6419025cbeb4", 272 | "sha256:9cf5292fcd0f598c671cfc1e0d7d1a7f13bb8085e9a590f48c010551dc6c4b31" 273 | ], 274 | "version": "==2.22.0" 275 | }, 276 | "rsa": { 277 | "hashes": [ 278 | "sha256:14ba45700ff1ec9eeb206a2ce76b32814958a98e372006c8fb76ba820211be66", 279 | "sha256:1a836406405730121ae9823e19c6e806c62bbad73f890574fff50efa4122c487" 280 | ], 281 | "version": "==4.0" 282 | }, 283 | "scipy": { 284 | "hashes": [ 285 | "sha256:0baa64bf42592032f6f6445a07144e355ca876b177f47ad8d0612901c9375bef", 286 | "sha256:243b04730d7223d2b844bda9500310eecc9eda0cba9ceaf0cde1839f8287dfa8", 287 | "sha256:2643cfb46d97b7797d1dbdb6f3c23fe3402904e3c90e6facfe6a9b98d808c1b5", 288 | "sha256:396eb4cdad421f846a1498299474f0a3752921229388f91f60dc3eda55a00488", 289 | "sha256:3ae3692616975d3c10aca6d574d6b4ff95568768d4525f76222fb60f142075b9", 290 | "sha256:435d19f80b4dcf67dc090cc04fde2c5c8a70b3372e64f6a9c58c5b806abfa5a8", 291 | "sha256:46a5e55850cfe02332998b3aef481d33f1efee1960fe6cfee0202c7dd6fc21ab", 292 | "sha256:75b513c462e58eeca82b22fc00f0d1875a37b12913eee9d979233349fce5c8b2", 293 | "sha256:7ccfa44a08226825126c4ef0027aa46a38c928a10f0a8a8483c80dd9f9a0ad44", 294 | "sha256:89dd6a6d329e3f693d1204d5562dd63af0fd7a17854ced17f9cbc37d5b853c8d", 295 | "sha256:a81da2fe32f4eab8b60d56ad43e44d93d392da228a77e229e59b51508a00299c", 296 | "sha256:a9d606d11eb2eec7ef893eb825017fbb6eef1e1d0b98a5b7fc11446ebeb2b9b1", 297 | "sha256:ac37eb652248e2d7cbbfd89619dce5ecfd27d657e714ed049d82f19b162e8d45", 298 | "sha256:cbc0611699e420774e945f6a4e2830f7ca2b3ee3483fca1aa659100049487dd5", 299 | "sha256:d02d813ec9958ed63b390ded463163685af6025cb2e9a226ec2c477df90c6957", 300 | "sha256:dd3b52e00f93fd1c86f2d78243dfb0d02743c94dd1d34ffea10055438e63b99d" 301 | ], 302 | "index": "pypi", 303 | "version": "==1.3.1" 304 | }, 305 | "six": { 306 | "hashes": [ 307 | "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", 308 | "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" 309 | ], 310 | "version": "==1.12.0" 311 | }, 312 | "tensorboardx": { 313 | "hashes": [ 314 | "sha256:13fe0abba27f407778a7321937190eedaf12bc8c544d9a4e294fcf0ba177fd76", 315 | "sha256:f52e59b38b4cdf83384f3fce067bcaf2d2847619f9f533394df0de3b5a71ab8e" 316 | ], 317 | "index": "pypi", 318 | "version": "==1.8" 319 | }, 320 | "torch": { 321 | "hashes": [ 322 | "sha256:0698d0a48014b9b8f36d93e69901eca2e7ec712cd2033908f7a77e7d86a4f0d7", 323 | "sha256:2ac8e58b069232f079bd289aa160366a9367ae1a4616a2c1007dceed19ff9bfa", 324 | "sha256:43a0e28c448ddeea65fb9e956bc743389592afac824095bdbc08e8a87364c639", 325 | "sha256:661ad06b4616663149bd504e8c0271196d0386712e21a92619d95ba88138794a", 326 | "sha256:880a0c22692eaebbce808a5bf2255ab7d345ab43c40795be0a421c6250ba0fb4", 327 | "sha256:a13bf6f78a49d844b85c142b8cd62d2e1833a11ed21ea0bc6b1ac73d24c76415", 328 | "sha256:a8c21f82fd03b67927078ea917040478c3263753fe1906fc19d0f5f0c7d9aa10", 329 | "sha256:b87fd224a7de3bc01ce87eb947698797b4514e27115b0aa60a56991515dd9dd6", 330 | "sha256:f63d489c54b4f170ce8335727bbb196ceb9acd0e7805477bbef8fabc914bc0f9" 331 | ], 332 | "index": "pypi", 333 | "version": "==1.2.0" 334 | }, 335 | "torchvision": { 336 | "hashes": [ 337 | "sha256:3a8e9403252fefdf6e8f9993ae111d28eb4ad1e73f696f03de485d7f77d88067", 338 | "sha256:6fff5a31d50de3a59dcceda2a48de9df33a5f43357dc3e0da0ffbb97699aec52", 339 | "sha256:740b3718470aa4ec0b389df876eb25117df1952dd2e8105b7828a02aa5bce73b", 340 | "sha256:8114c33b736ee430496eef4fe03b25be8b939b2abd2a968558737bb9aed1928b", 341 | "sha256:904ef213594672f2ed7fafa3ab010cbf2a4704a951a7bf221cf36b3d2e3acd62", 342 | "sha256:afff8e987564192bc7f139d8b089541d4471ad6fc99e977e8bc8dbb4e0873041", 343 | "sha256:d7939f2ca401de3067a30b6f4dcef63d13d24a4cd1ddc2d3a9af3413ce658d03", 344 | "sha256:d8c2402704ce8ef8e87e4922160388c7ca010ef27700082014d6bd694cf1cc51", 345 | "sha256:e00de7571d83f968f5aea7a59e84e3262669acef0a077ce4bd705eca2df68167" 346 | ], 347 | "index": "pypi", 348 | "version": "==0.4.0" 349 | }, 350 | "urllib3": { 351 | "hashes": [ 352 | "sha256:b246607a25ac80bedac05c6f282e3cdaf3afb65420fd024ac94435cabe6e18d1", 353 | "sha256:dbe59173209418ae49d485b87d1681aefa36252ee85884c31346debd19463232" 354 | ], 355 | "index": "pypi", 356 | "version": "==1.25.3" 357 | } 358 | }, 359 | "develop": {} 360 | } 361 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-vae 2 | 3 | A minimal pytorch implementation of [VAE](https://arxiv.org/abs/1312.6114), [IWAE](https://arxiv.org/abs/1509.00519), and [MIWAE](https://arxiv.org/abs/1802.04537). 4 | We followed the experimental details of the [IWAE paper](https://arxiv.org/abs/1509.00519). 5 |

6 | 7 | 8 | 9 | ## Usage 10 | 11 | You should be able to run experiments right away. 12 | First create a virtual environment using [pipenv](https://github.com/pypa/pipenv): 13 | 14 | ```pipenv install``` 15 | 16 | To run experiments, you simply have to use: 17 | 18 | ```pipenv run python main.py ``` 19 |

20 | 21 | 22 | 23 | ## Example commands 24 | 25 | For original VAE: 26 | 27 | ```pipenv run python main.py ``` 28 | 29 | To also make figures (reconstruction, samples): 30 | 31 | ```pipenv run python main.py --figs ``` 32 | 33 | For IWAE with 5 importance samples: 34 | 35 | ```pipenv run python main.py --importance_num=5 ``` 36 | 37 | For MIWAE(16, 4): 38 | 39 | ```pipenv run python main.py --mean_num=16 --importance_num=4 ``` 40 | 41 | See [the config file](https://github.com/yoonholee/pytorch-generative/blob/master/utils/config.py) for more options. 42 |

43 | 44 | 45 | 46 | ## Results 47 | 48 | 49 | 50 | ### Quantitative results on dynamically binarized MNIST 51 | | Method | NLL (this repo) | NLL ([IWAE paper](https://arxiv.org/abs/1509.00519)) | NLL ([MIWAE paper](https://arxiv.org/abs/1802.04537)) | comments | 52 | | ------------- | ------------- | ------------- | ------------- | ---- | 53 | | VAE | 87.01 | 86.76 | - | 54 | | MIWAE(5, 1) | 86.45 | 86.47 | - | listed as VAE with k=5 55 | | MIWAE(1, 5) | 85.18 | 85.54 | - | listed as IWAE with k=5 56 | | MIWAE(64, 1) | 86.07 | - | 86.21 | listed as VAE 57 | | MIWAE(16, 4) | 84.99 | - | - | 58 | | MIWAE(8, 8) | 84.69 | - | 84.97 | 59 | | MIWAE(4, 16) | 84.52 | - | 84.56 | 60 | | MIWAE(1, 64) | 84.37 | - | 84.52 | listed as IWAE 61 | -------------------------------------------------------------------------------- /assets/mnist_recon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoonholee/pytorch-vae/9dc44aae64f0e2896427ce955a48733d6315bb2d/assets/mnist_recon.gif -------------------------------------------------------------------------------- /assets/mnist_samples.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoonholee/pytorch-vae/9dc44aae64f0e2896427ce955a48733d6315bb2d/assets/mnist_samples.gif -------------------------------------------------------------------------------- /assets/omni_recon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoonholee/pytorch-vae/9dc44aae64f0e2896427ce955a48733d6315bb2d/assets/omni_recon.gif -------------------------------------------------------------------------------- /assets/omni_samples.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoonholee/pytorch-vae/9dc44aae64f0e2896427ce955a48733d6315bb2d/assets/omni_samples.gif -------------------------------------------------------------------------------- /data_loader/cifar10.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets 2 | 3 | class cifar10(datasets.CIFAR10): 4 | def get_mean_img(self): 5 | return self.train_data.mean(0) 6 | 7 | -------------------------------------------------------------------------------- /data_loader/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from .stoch_mnist import stochMNIST 4 | from .omniglot import omniglot 5 | from .fixed_mnist import fixedMNIST 6 | from .cifar10 import cifar10 7 | 8 | 9 | def data_loaders(args): 10 | if args.dataset == 'omniglot': 11 | loader_fn, root = omniglot, './dataset/omniglot' 12 | elif args.dataset == 'fixedmnist': 13 | loader_fn, root = fixedMNIST, './dataset/fixedmnist' 14 | elif args.dataset == 'stochmnist': 15 | loader_fn, root = stochMNIST, './dataset/stochmnist' 16 | elif args.dataset == 'cifar10': 17 | loader_fn, root = cifar10, './dataset/cifar10' 18 | 19 | if args.dataset_dir != '': root = args.dataset_dir 20 | kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} 21 | train_loader = torch.utils.data.DataLoader( 22 | loader_fn(root, train=True, download=True, transform=transforms.ToTensor()), 23 | batch_size=args.batch_size, shuffle=True, **kwargs) 24 | test_loader = torch.utils.data.DataLoader( # need test bs <=64 to make L_5000 tractable in one pass 25 | loader_fn(root, train=False, download=True, transform=transforms.ToTensor()), 26 | batch_size=args.test_batch_size, shuffle=False, **kwargs) 27 | return train_loader, test_loader 28 | -------------------------------------------------------------------------------- /data_loader/fixed_mnist.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | import torch.utils.data as data 4 | from torchvision import transforms 5 | import os 6 | import numpy as np 7 | from PIL import Image 8 | import urllib.request 9 | 10 | 11 | class fixedMNIST(data.Dataset): 12 | """ Binarized MNIST dataset, proposed in 13 | http://proceedings.mlr.press/v15/larochelle11a/larochelle11a.pdf """ 14 | train_file = 'binarized_mnist_train.amat' 15 | val_file = 'binarized_mnist_valid.amat' 16 | test_file = 'binarized_mnist_test.amat' 17 | 18 | def __init__(self, root, train=True, transform=None, download=False): 19 | # we ignore transform. 20 | self.root = os.path.expanduser(root) 21 | self.train = train # training set or test set 22 | 23 | if download: self.download() 24 | if not self._check_exists(): 25 | raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') 26 | 27 | self.data = self._get_data(train=train) 28 | 29 | def __getitem__(self, index): 30 | img = self.data[index] 31 | img = Image.fromarray(img) 32 | img = transforms.ToTensor()(img).type(torch.FloatTensor) 33 | return img, torch.tensor(-1) # Meaningless tensor instead of target 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | def _get_data(self, train=True): 39 | with h5py.File(os.path.join(self.root, 'data.h5'), 'r') as hf: 40 | data = hf.get('train' if train else 'test') 41 | data = np.array(data) 42 | return data 43 | 44 | def get_mean_img(self): 45 | return self.data.mean(0).flatten() 46 | 47 | def download(self): 48 | if self._check_exists(): 49 | return 50 | if not os.path.exists(self.root): 51 | os.makedirs(self.root) 52 | 53 | print('Downloading MNIST with fixed binarization...') 54 | for dataset in ['train', 'valid', 'test']: 55 | filename = 'binarized_mnist_{}.amat'.format(dataset) 56 | url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format(dataset) 57 | print('Downloading from {}...'.format(url)) 58 | local_filename = os.path.join(self.root, filename) 59 | urllib.request.urlretrieve(url, local_filename) 60 | print('Saved to {}'.format(local_filename)) 61 | 62 | def filename_to_np(filename): 63 | with open(filename) as f: 64 | lines = f.readlines() 65 | return np.array([[int(i)for i in line.split()] for line in lines]).astype('int8') 66 | 67 | train_data = np.concatenate([filename_to_np(os.path.join(self.root, self.train_file)), 68 | filename_to_np(os.path.join(self.root, self.val_file))]) 69 | test_data = filename_to_np(os.path.join(self.root, self.val_file)) 70 | with h5py.File(os.path.join(self.root, 'data.h5'), 'w') as hf: 71 | hf.create_dataset('train', data=train_data.reshape(-1, 28, 28)) 72 | hf.create_dataset('test', data=test_data.reshape(-1, 28, 28)) 73 | print('Done!') 74 | 75 | def _check_exists(self): 76 | return os.path.exists(os.path.join(self.root, 'data.h5')) 77 | -------------------------------------------------------------------------------- /data_loader/omniglot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | from torchvision import transforms 4 | import os 5 | from PIL import Image 6 | import urllib.request 7 | import scipy.io 8 | 9 | 10 | class omniglot(data.Dataset): 11 | """ omniglot dataset """ 12 | url = 'https://github.com/yburda/iwae/raw/master/datasets/OMNIGLOT/chardata.mat' 13 | 14 | def __init__(self, root, train=True, transform=None, download=False): 15 | # we ignore transform. 16 | self.root = os.path.expanduser(root) 17 | self.train = train # training set or test set 18 | 19 | if download: self.download() 20 | if not self._check_exists(): 21 | raise RuntimeError('Dataset not found. You can use download=True to download it') 22 | 23 | self.data = self._get_data(train=train) 24 | 25 | def __getitem__(self, index): 26 | img = self.data[index].reshape(28, 28) 27 | img = Image.fromarray(img) 28 | img = transforms.ToTensor()(img).type(torch.FloatTensor) 29 | img = torch.bernoulli(img) # stochastically binarize 30 | return img, torch.tensor(-1) # Meaningless tensor instead of target 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | def _get_data(self, train=True): 36 | def reshape_data(data): 37 | return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='fortran') 38 | 39 | omni_raw = scipy.io.loadmat(os.path.join(self.root, 'chardata.mat')) 40 | data_str = 'data' if train else 'testdata' 41 | data = reshape_data(omni_raw[data_str].T.astype('float32')) 42 | return data 43 | 44 | def get_mean_img(self): 45 | return self.data.mean(0) 46 | 47 | def download(self): 48 | if self._check_exists(): 49 | return 50 | if not os.path.exists(self.root): 51 | os.makedirs(self.root) 52 | 53 | print('Downloading from {}...'.format(self.url)) 54 | local_filename = os.path.join(self.root, 'chardata.mat') 55 | urllib.request.urlretrieve(self.url, local_filename) 56 | print('Saved to {}'.format(local_filename)) 57 | 58 | def _check_exists(self): 59 | return os.path.exists(os.path.join(self.root, 'chardata.mat')) 60 | -------------------------------------------------------------------------------- /data_loader/stoch_mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | from PIL import Image 4 | 5 | 6 | class stochMNIST(datasets.MNIST): 7 | """ Gets a new stochastic binarization of MNIST at each call. """ 8 | def __getitem__(self, index): 9 | if self.train: 10 | img, target = self.train_data[index], self.train_labels[index] 11 | else: 12 | img, target = self.test_data[index], self.test_labels[index] 13 | 14 | img = Image.fromarray(img.numpy(), mode='L') 15 | img = transforms.ToTensor()(img) 16 | img = torch.bernoulli(img) # stochastically binarize 17 | return img, target 18 | 19 | def get_mean_img(self): 20 | imgs = self.train_data.type(torch.float) / 255 21 | mean_img = imgs.mean(0).reshape(-1).numpy() 22 | return mean_img 23 | -------------------------------------------------------------------------------- /exp.sh: -------------------------------------------------------------------------------- 1 | run="pipenv run python main.py" 2 | 3 | $run --gpu=0 & 4 | $run --gpu=1 --importance_num=64 & 5 | $run --gpu=2 --importance_num=8 --mean_num=8 & 6 | $run --gpu=3 --no_iwae_lr & 7 | $run --gpu=4 --z=100 & 8 | wait 9 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | from tensorboardX import SummaryWriter 7 | from torch import optim 8 | 9 | from data_loader.data_loader import data_loaders 10 | from model.bernoulli_vae import BernoulliVAE 11 | from model.conv_vae import ConvVAE 12 | from utils.config import get_args 13 | from utils.draw_figs import draw_figs 14 | 15 | args = get_args() 16 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 17 | args.cuda = torch.cuda.is_available() 18 | device = torch.device("cuda:0" if args.cuda else "cpu") 19 | train_loader, test_loader = data_loaders(args) 20 | torch.manual_seed(args.seed) 21 | if args.cuda: 22 | torch.cuda.manual_seed_all(args.seed) 23 | writer = SummaryWriter(args.out_dir) 24 | 25 | model_class = BernoulliVAE if args.arch == "bernoulli" else ConvVAE 26 | mean_img = train_loader.dataset.get_mean_img() 27 | model = model_class( 28 | device=device, 29 | img_shape=args.img_shape, 30 | h_dim=args.h_dim, 31 | z_dim=args.z_dim, 32 | analytic_kl=args.analytic_kl, 33 | mean_img=mean_img, 34 | ).to(device) 35 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, eps=1e-4) 36 | if args.no_iwae_lr: 37 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 38 | optimizer, mode="min", patience=100, factor=10 ** (-1 / 7) 39 | ) 40 | else: 41 | milestones = np.cumsum([3 ** i for i in range(8)]) 42 | scheduler = optim.lr_scheduler.MultiStepLR( 43 | optimizer, milestones=milestones, gamma=10 ** (-1 / 7) 44 | ) 45 | 46 | 47 | def train(epoch): 48 | for batch_idx, (data, _) in enumerate(train_loader): 49 | optimizer.zero_grad() 50 | outs = model(data, mean_n=args.mean_num, imp_n=args.importance_num) 51 | loss_1, loss = -outs["elbo"].cpu().data.numpy().mean(), outs["loss"].mean() 52 | loss.backward() 53 | optimizer.step() 54 | model.train_step += 1 55 | if model.train_step % args.log_interval == 0: 56 | print( 57 | "Train Epoch: {} ({:.0f}%)\tLoss: {:.6f}".format( 58 | epoch, 100.0 * batch_idx / len(train_loader), loss.item() 59 | ) 60 | ) 61 | writer.add_scalar("train/loss", loss.item(), model.train_step) 62 | writer.add_scalar("train/loss_1", loss_1, model.train_step) 63 | 64 | 65 | def test(epoch): 66 | elbos = [ 67 | model(data, mean_n=1, imp_n=args.log_likelihood_k)["elbo"].squeeze(0) 68 | for data, _ in test_loader 69 | ] 70 | 71 | def get_loss_k(k): 72 | losses = [ 73 | model.logmeanexp(elbo[:k], 0).cpu().numpy().flatten() for elbo in elbos 74 | ] 75 | return -np.concatenate(losses).mean() 76 | 77 | return map(get_loss_k, [args.importance_num, 1, 64, args.log_likelihood_k]) 78 | 79 | 80 | if args.eval: 81 | model.load_state_dict(torch.load(args.best_model_file)) 82 | with torch.no_grad(): 83 | print(list(test(0))) 84 | if args.figs: 85 | draw_figs(model, args, test_loader, 0) 86 | sys.exit() 87 | 88 | for epoch in range(1, args.epochs + 1): 89 | writer.add_scalar("learning_rate", optimizer.param_groups[0]["lr"], epoch) 90 | train(epoch) 91 | with torch.no_grad(): 92 | if args.figs and epoch % 100 == 1: 93 | draw_figs(model, args, test_loader, epoch) 94 | test_loss, test_1, test_64, test_ll = test(epoch) 95 | if test_loss < model.best_loss: 96 | model.best_loss = test_loss 97 | torch.save(model.state_dict(), args.best_model_file) 98 | scheduler_args = {"metrics": test_loss} if args.no_iwae_lr else {} 99 | scheduler.step(**scheduler_args) 100 | writer.add_scalar("test/loss", test_loss, epoch) 101 | writer.add_scalar("test/loss_1", test_1, epoch) 102 | writer.add_scalar("test/loss_64", test_64, epoch) 103 | writer.add_scalar("test/LL", test_ll, epoch) 104 | print("==== Testing. LL: {:.4f} ====\n".format(test_ll)) 105 | 106 | if args.to_gsheets: 107 | from utils.to_sheets import upload_to_google_sheets 108 | 109 | row_data = [args.exp_name, str(test_ll), str(test_64), str(test_64 - test_ll)] 110 | upload_to_google_sheets(row_data=row_data) 111 | -------------------------------------------------------------------------------- /model/bernoulli_vae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.distributions.bernoulli import Bernoulli 5 | from torch.distributions.normal import Normal 6 | 7 | from .vae_base import VAE 8 | 9 | 10 | class BernoulliVAE(VAE): 11 | def __init__(self, device, img_shape, h_dim, z_dim, analytic_kl, mean_img): 12 | super().__init__(device, z_dim, analytic_kl) 13 | x_dim = np.prod(img_shape) 14 | self.img_shape = img_shape 15 | self.proc_data = lambda x: x.to(device).reshape(-1, x_dim) 16 | self.encoder = nn.Sequential( 17 | nn.Linear(x_dim, h_dim), nn.Tanh(), nn.Linear(h_dim, h_dim), nn.Tanh() 18 | ) 19 | self.enc_mu = nn.Linear(h_dim, z_dim) 20 | self.enc_sig = nn.Linear(h_dim, z_dim) 21 | self.decoder = nn.Sequential( 22 | nn.Linear(z_dim, h_dim), nn.Tanh(), 23 | nn.Linear(h_dim, h_dim), nn.Tanh(), 24 | nn.Linear(h_dim, x_dim), 25 | ) # using Bern(logit) is equivalent to putting sigmoid here. 26 | 27 | self.apply(self.init) 28 | mean_img = np.clip(mean_img, 1e-8, 1.0 - 1e-7) 29 | mean_img_logit = np.log(mean_img / (1.0 - mean_img)) 30 | self.decoder[-1].bias = torch.nn.Parameter(torch.Tensor(mean_img_logit)) 31 | 32 | def init(self, module): 33 | if type(module) == nn.Linear: 34 | torch.nn.init.xavier_uniform_( 35 | module.weight, gain=nn.init.calculate_gain("tanh") 36 | ) 37 | module.bias.data.fill_(0.01) 38 | 39 | def encode(self, x): 40 | x = self.proc_data(x) 41 | h = self.encoder(x) 42 | mu, _std = self.enc_mu(h), self.enc_sig(h) 43 | return Normal(mu, nn.functional.softplus(_std)) # torch.exp(.5 * _std) 44 | 45 | def decode(self, z): 46 | x = self.decoder(z) 47 | return Bernoulli(logits=x) 48 | 49 | def lpxz(self, true_x, x_dist): 50 | return x_dist.log_prob(true_x).sum(-1) 51 | 52 | def sample(self, num_samples=64): 53 | z = self.prior.sample((num_samples,)) 54 | x_dist = self.decode(z) 55 | return x_dist.sample().view(num_samples, *self.img_shape) 56 | -------------------------------------------------------------------------------- /model/conv_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.distributions.normal import Normal 4 | from .vae_base import VAE 5 | 6 | 7 | class Flatten(nn.Module): 8 | def forward(self, input): 9 | return input.view(input.size(0), 16 * 8 * 8).contiguous() 10 | 11 | 12 | class UnFlatten(nn.Module): 13 | def forward(self, input): 14 | return input.view(input.size(0), 16, 8, 8).contiguous() 15 | 16 | 17 | class ConvVAE(VAE): 18 | # XXX: This class does not work at the moment 19 | def __init__(self, device, x_dim, h_dim, z_dim, analytic_kl, mean_img): 20 | # FIXME: integrate so that plot etc works. 21 | VAE.__init__(self, device, x_dim, h_dim, z_dim, analytic_kl, mean_img) 22 | self.proc_data = lambda x: x.to(device) 23 | self.encoder = nn.Sequential( 24 | nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), 25 | nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(), 26 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), 27 | nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), 28 | Flatten()) 29 | self.enc_mu = nn.Linear(16 * 8 * 8, z_dim) 30 | self.enc_sig = nn.Linear(16 * 8 * 8, z_dim) 31 | self.decoder = nn.Sequential( 32 | nn.Linear(z_dim, 16 * 8 * 8), nn.ReLU(), 33 | UnFlatten(), 34 | nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), 35 | nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), 36 | nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), 37 | nn.ConvTranspose2d(32, 6, kernel_size=3, stride=1, padding=1)) 38 | 39 | self.apply(self.init) 40 | self.decoder[-1].bias = torch.nn.Parameter(torch.cat( 41 | [torch.Tensor(mean_img.mean(0).mean(0)) / 256, .01 * torch.ones([3])])) 42 | 43 | def init(self, module): 44 | if type(module) in [nn.Conv2d, nn.ConvTranspose2d]: 45 | torch.nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu')) 46 | module.bias.data.fill_(.01) 47 | 48 | def encode(self, x): 49 | x = self.proc_data(x) 50 | h = self.encoder(x) 51 | mu, _std = self.enc_mu(h), self.enc_sig(h) 52 | return Normal(mu, nn.functional.softplus(_std)) 53 | 54 | def decode(self, z): 55 | mean_n, imp_n, bs = z.size(0), z.size(1), z.size(2) 56 | z = z.view([mean_n * imp_n * bs, -1]).contiguous() 57 | x = self.decoder(z) 58 | x = x.view([mean_n, imp_n, bs, 6, 32, 32]).contiguous() 59 | x_mean, x_std = x[:, :, :, :3, :, :].contiguous(), nn.functional.softplus(x[:, :, :, 3:, :, :]).contiguous() 60 | return Normal(x_mean, x_std) 61 | 62 | def lpxz(self, true_x, x_dist): 63 | return x_dist.log_prob(true_x).sum([-1, -2, -3]) 64 | -------------------------------------------------------------------------------- /model/vae_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.distributions.normal import Normal 5 | 6 | 7 | class VAE(nn.Module): 8 | def __init__(self, device, z_dim, analytic_kl): 9 | super().__init__() 10 | self.train_step = 0 11 | self.best_loss = np.inf 12 | self.analytic_kl = analytic_kl 13 | self.prior = Normal( 14 | torch.zeros([z_dim]).to(device), torch.ones([z_dim]).to(device) 15 | ) 16 | 17 | def proc_data(self, x): 18 | pass 19 | 20 | def encode(self, x): 21 | pass 22 | 23 | def decode(self, z): 24 | pass 25 | 26 | def lpxz(self, true_x, x_dist): 27 | pass 28 | 29 | def sample(self, num_samples=64): 30 | pass 31 | 32 | def elbo(self, true_x, z, x_dist, z_dist): 33 | true_x = self.proc_data(true_x) 34 | lpxz = self.lpxz(true_x, x_dist) 35 | 36 | if self.analytic_kl: 37 | # SGVB^B: -KL(q(z|x)||p(z)) + log p(x|z). Use when KL can be done analytically. 38 | assert z.size(0) == 1 and z.size(1) == 1 39 | kl = torch.distributions.kl.kl_divergence(z_dist, self.prior).sum(-1) 40 | else: 41 | # SGVB^A: log p(z) - log q(z|x) + log p(x|z) 42 | lpz = self.prior.log_prob(z).sum(-1) 43 | lqzx = z_dist.log_prob(z).sum(-1) 44 | kl = -lpz + lqzx 45 | return -kl + lpxz 46 | 47 | def logmeanexp(self, inputs, dim=1): 48 | if inputs.size(dim) == 1: 49 | return inputs 50 | else: 51 | input_max = inputs.max(dim, keepdim=True)[0] 52 | return (inputs - input_max).exp().mean(dim).log() + input_max 53 | 54 | def forward(self, true_x, mean_n, imp_n): 55 | z_dist = self.encode(true_x) 56 | # mean_n, imp_n, batch_size, z_dim 57 | z = z_dist.rsample(torch.Size([mean_n, imp_n])) 58 | x_dist = self.decode(z) 59 | 60 | elbo = self.elbo(true_x, z, x_dist, z_dist) # mean_n, imp_n, batch_size 61 | elbo_iwae = self.logmeanexp(elbo, 1).squeeze(1) # mean_n, batch_size 62 | elbo_iwae_m = torch.mean(elbo_iwae, 0) # batch_size 63 | return {"elbo": elbo, "loss": -elbo_iwae_m} 64 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -i https://pypi.org/simple 2 | certifi==2019.6.16 3 | chardet==3.0.4 4 | cycler==0.10.0 5 | gspread==3.1.0 6 | h5py==2.9.0 7 | httplib2==0.13.1 8 | idna==2.8 9 | imageio==2.5.0 10 | kiwisolver==1.1.0 11 | matplotlib==3.1.1 12 | numpy==1.17.0 13 | oauth2client==4.1.3 14 | pathlib==1.0.1 15 | pillow==6.1.0 16 | protobuf==3.9.1 17 | pyasn1-modules==0.2.6 18 | pyasn1==0.4.6 19 | pyparsing==2.4.2 20 | python-dateutil==2.8.0 21 | requests==2.22.0 22 | rsa==4.0 23 | scipy==1.3.1 24 | six==1.12.0 25 | tensorboardx==1.8 26 | torch==1.2.0 27 | torchvision==0.4.0 28 | urllib3==1.25.3 29 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--gpu', type=int, default=0) 6 | parser.add_argument('--seed', type=int, default=42) 7 | parser.add_argument('--log_interval', type=int, default=500) 8 | parser.add_argument('--eval', action='store_true') 9 | parser.add_argument('--figs', action='store_true') 10 | parser.add_argument('--to_gsheets', action='store_true') 11 | parser.add_argument('--arch', type=str, default='bernoulli', choices=['bernoulli']) # TODO: make conv work 12 | 13 | parser.add_argument('--dataset_dir', type=str, default='') 14 | parser.add_argument('--dataset', type=str, default='stochmnist', 15 | choices=['stochmnist', 'omniglot', 'fixedmnist']) # TODO: make cifar10 work 16 | parser.add_argument('--batch_size', type=int, default=20) # iwae uses 20 17 | parser.add_argument('--test_batch_size', type=int, default=64) 18 | parser.add_argument('--epochs', type=int, default=3280) # iwae uses 3280 19 | 20 | parser.add_argument('--learning_rate', type=float, default=1e-3) 21 | parser.add_argument('--no_iwae_lr', action='store_true') 22 | parser.add_argument('--mean_num', type=int, default=1) # M in "tighter variational bounds...". Use 1 for vanilla vae 23 | parser.add_argument('--importance_num', type=int, default=1) # k of iwae. Use 1 for vanilla vae 24 | parser.add_argument('--analytic_kl', action='store_true') 25 | parser.add_argument('--h_dim', type=int, default=200) 26 | parser.add_argument('--z_dim', type=int, default=50) 27 | 28 | 29 | def get_args(): 30 | args = parser.parse_args() 31 | 32 | def cstr(arg, arg_name, default, custom_str=False): 33 | """ Get config str for arg, ignoring if set to default. """ 34 | not_default = arg != default 35 | if not custom_str: 36 | custom_str = f'_{arg_name}{arg}' 37 | return custom_str if not_default else '' 38 | 39 | args.exp_name = (f'm{args.mean_num}_k{args.importance_num}' 40 | f'{cstr(args.dataset, "", "stochmnist")}{cstr(args.arch, "", "bernoulli")}' 41 | f'{cstr(args.seed, "seed", 42)}{cstr(args.batch_size, "bs", 20)}' 42 | f'{cstr(args.h_dim, "h", 200)}{cstr(args.z_dim, "z", 50)}' 43 | f'{cstr(args.learning_rate, "lr", 1e-3)}{cstr(args.analytic_kl, None, False, "_analytic")}' 44 | f'{cstr(args.no_iwae_lr, None, False, "_noiwae")}{cstr(args.epochs, "epoch", 3280)}') 45 | 46 | args.figs_dir = os.path.join('figs', args.exp_name) 47 | args.out_dir = os.path.join('result', args.exp_name) 48 | args.best_model_file = os.path.join('result', args.exp_name, 'best_model.pt') 49 | if not os.path.exists(args.out_dir): 50 | os.makedirs(args.out_dir) 51 | if not os.path.exists(args.figs_dir): 52 | os.makedirs(args.figs_dir) 53 | 54 | args.log_likelihood_k = 100 if args.dataset == 'cifar10' else 5000 55 | args.img_shape = (32, 32) if args.dataset == 'cifar10' else (28, 28) 56 | return args 57 | -------------------------------------------------------------------------------- /utils/draw_figs.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | import imageio 5 | import pathlib 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def draw_gif(name, figs_dir, glob_str): 11 | files = [file for file in pathlib.Path(figs_dir).glob(glob_str)] 12 | images = [imageio.imread(str(file)) for file in sorted(files)] 13 | imageio.mimsave('{}/{}'.format(figs_dir, name), images, duration=.5) 14 | 15 | 16 | def draw_figs(model, args, test_loader, epoch): 17 | samples = model.sample(num_samples=100).data.cpu().numpy() 18 | plt.figure(figsize=(5, 5)) 19 | plt.suptitle('Samples, Epoch {}'.format(epoch), fontsize=20) 20 | plt.axis('square') 21 | plt.legend(frameon=True) 22 | for idx, im in enumerate(samples): 23 | plt.subplot(10, 10, idx+1) 24 | plt.imshow(im, cmap='Greys') 25 | plt.axis('off') 26 | plt.savefig('figs/{}/samples_{:04}.jpg'.format(args.exp_name, epoch)) 27 | plt.clf() 28 | draw_gif('{}_samples.gif'.format(args.exp_name), args.figs_dir, 'samples*.jpg') 29 | 30 | for batch_idx, (data, _) in enumerate(test_loader): 31 | break 32 | z_dist = model.encode(data) 33 | z = z_dist.rsample() 34 | recon = model.decode(z).probs.view(args.test_batch_size, 28, 28) 35 | data = data.view(args.test_batch_size, 28, 28) 36 | plt.figure(figsize=(5, 5)) 37 | plt.suptitle('Reconstruction, Epoch {}'.format(epoch), fontsize=20) 38 | plt.axis('square') 39 | plt.legend(frameon=True) 40 | for i in range(50): 41 | data_i = data[i].data.cpu().numpy() 42 | recon_i = recon[i].data.cpu().numpy() 43 | plt.subplot(10, 10, 2*i+1) 44 | plt.imshow(data_i, cmap='Greys') 45 | plt.axis('off') 46 | plt.subplot(10, 10, 2*i+2) 47 | plt.imshow(recon_i, cmap='Greys') 48 | plt.axis('off') 49 | plt.savefig('figs/{}/reconstruction_{:04}.jpg'.format(args.exp_name, epoch)) 50 | plt.clf() 51 | draw_gif('{}_reconstruction.gif'.format(args.exp_name), args.figs_dir, 'reconstruction*.jpg') 52 | 53 | if args.z_dim == 2: 54 | latent_space, labels = [], [] 55 | for batch_idx, (data, label) in enumerate(test_loader): 56 | latent_space.append(model.encode(data).loc.data.cpu().numpy()) 57 | labels.append(label) 58 | latent_space, labels = np.concatenate(latent_space), np.concatenate(labels) 59 | plt.figure(figsize=(5, 5)) 60 | for c in range(10): 61 | idx = (labels == c) 62 | plt.scatter(latent_space[idx, 0], latent_space[idx, 1], 63 | c=matplotlib.cm.get_cmap('tab10')(c), marker=',', label=str(c), alpha=.7) 64 | plt.suptitle('Latent representation, Epoch {}'.format(epoch), fontsize=20) 65 | plt.axis('square') 66 | plt.legend(frameon=True) 67 | plt.savefig('figs/{}/latent_{:04}.jpg'.format(args.exp_name, epoch)) 68 | plt.clf() 69 | draw_gif('{}_latent.gif'.format(args.exp_name), args.figs_dir, 'latent*.jpg') 70 | 71 | plt.close('all') 72 | -------------------------------------------------------------------------------- /utils/to_sheets.py: -------------------------------------------------------------------------------- 1 | import gspread 2 | from oauth2client.service_account import ServiceAccountCredentials 3 | get_credentials = ServiceAccountCredentials.from_json_keyfile_name 4 | 5 | scope = ['https://spreadsheets.google.com/feeds', 6 | 'https://www.googleapis.com/auth/drive'] 7 | # To make this work, obtain credentials from Google Sheets API and save to 8 | # creds.json in current directory. 9 | credentials = get_credentials('creds.json', scope) 10 | gc = gspread.authorize(credentials) 11 | sheet_name = 'pytorch-generative' 12 | 13 | 14 | def upload_to_google_sheets(row_data, index=2): 15 | worksheet = gc.open(sheet_name).sheet1 16 | worksheet.insert_row(row_data, index=index) 17 | --------------------------------------------------------------------------------