├── .gitignore ├── LICENSE ├── README.md ├── models ├── __init__.py ├── idf.py └── realnvp.py ├── run.py └── utils ├── __init__.py ├── datasets.py ├── evaluation.py ├── nn.py └── training.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /results/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jakub Tomczak 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # General Invertible Transformations for Flow-based Generative Modeling 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmtomczak/git_flow/5a38d470d11f41f95683ab3fbcc2eae4dda9f746/models/__init__.py -------------------------------------------------------------------------------- /models/idf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from utils.nn import RoundStraightThrough 6 | 7 | 8 | class IDF(nn.Module): 9 | def __init__(self, nett, num_flows, D=2): 10 | super(IDF, self).__init__() 11 | 12 | print('IDF by JT.') 13 | 14 | self.t = torch.nn.ModuleList([nett() for _ in range(num_flows)]) 15 | self.num_flows = num_flows 16 | 17 | self.round = RoundStraightThrough.apply 18 | 19 | self.p = nn.Parameter(torch.zeros(1, D)) 20 | self.mu = nn.Parameter(torch.ones(1, D) * 0.5) 21 | 22 | def coupling(self, x, index, forward=True): 23 | (xa, xb) = torch.chunk(x, 2, 1) 24 | 25 | t = self.t[index](xa) 26 | 27 | if forward: 28 | yb = xb + self.round(t) 29 | else: 30 | yb = xb - self.round(t) 31 | 32 | return torch.cat((xa, yb), 1) 33 | 34 | def permute(self, x): 35 | return x.flip(1) 36 | 37 | def f(self, x): 38 | z = x 39 | for i in range(self.num_flows): 40 | z = self.coupling(z, i, forward=True) 41 | z = self.permute(z) 42 | 43 | return z 44 | 45 | def f_inv(self, z): 46 | x = z 47 | for i in reversed(range(self.num_flows)): 48 | x = self.permute(x) 49 | x = self.coupling(x, i, forward=False) 50 | 51 | return x 52 | 53 | def forward(self, x): 54 | z = self.f(x) 55 | return self.log_prior(z) 56 | 57 | def sample(self, batchSize, D=2, intMax=100): 58 | # sample z: 59 | z = self.prior_sample(batchSize=batchSize, D=D, intMax=intMax) 60 | # x = f^-1(z) 61 | x = self.f_inv(z) 62 | return x.view(batchSize, 1, D) 63 | 64 | def log_integer_probability(self, x, p, mu): 65 | # Chakraborty & Chakravarty, "A new discrete probability distribution with integer support on (−∞, ∞)", 66 | # Communications in Statistics - Theory and Methods, 45:2, 492-505, DOI: 10.1080/03610926.2013.830743 67 | log_p = torch.log(1. - p) + (x - mu) * torch.log(p) \ 68 | - torch.log(1. + torch.exp((x - mu) * torch.log(p))) \ 69 | - torch.log(1. + torch.exp((x - mu + 1.) * torch.log(p))) 70 | return log_p 71 | 72 | def log_prior(self, x): 73 | p = torch.sigmoid(self.p) 74 | log_p = self.log_integer_probability(x, p, self.mu) 75 | return log_p.sum(1) 76 | 77 | def prior_sample(self, batchSize, D=2, intMax=100): 78 | ints = np.expand_dims(np.arange(-intMax, intMax + 1), 0) 79 | for d in range(D): 80 | p = torch.sigmoid(self.p[:, [d]]) 81 | mu = self.mu[:, d] 82 | log_p = self.log_integer_probability(torch.from_numpy(ints), p, mu) 83 | 84 | if d == 0: 85 | z = torch.from_numpy(np.random.choice(ints[0], (batchSize, 1), 86 | p=torch.exp(log_p[0]).detach().numpy()).astype(np.float32)) 87 | else: 88 | z_new = torch.from_numpy(np.random.choice(ints[0], (batchSize, 1), 89 | p=torch.exp(log_p[0]).detach().numpy()).astype(np.float32)) 90 | z = torch.cat((z, z_new), 1) 91 | return z 92 | 93 | 94 | class IDF2(nn.Module): 95 | def __init__(self, nett_a, nett_b, num_flows, D=2): 96 | super(IDF2, self).__init__() 97 | 98 | print('IDF by JT.') 99 | 100 | self.t_a = torch.nn.ModuleList([nett_a() for _ in range(num_flows)]) 101 | self.t_b = torch.nn.ModuleList([nett_b() for _ in range(num_flows)]) 102 | self.num_flows = num_flows 103 | 104 | self.round = RoundStraightThrough.apply 105 | 106 | self.p = nn.Parameter(torch.zeros(1, D)) 107 | self.mu = nn.Parameter(torch.ones(1, D) * 0.5) 108 | 109 | def coupling(self, x, index, forward=True): 110 | (xa, xb) = torch.chunk(x, 2, 1) 111 | 112 | if forward: 113 | ya = xa + self.round(self.t_a[index](xb)) 114 | yb = xb + self.round(self.t_b[index](ya)) 115 | else: 116 | yb = xb - self.round(self.t_b[index](xa)) 117 | ya = xa - self.round(self.t_a[index](yb)) 118 | 119 | return torch.cat((ya, yb), 1) 120 | 121 | def permute(self, x): 122 | return x.flip(1) 123 | 124 | def f(self, x): 125 | z = x 126 | for i in range(self.num_flows): 127 | z = self.coupling(z, i, forward=True) 128 | z = self.permute(z) 129 | 130 | return z 131 | 132 | def f_inv(self, z): 133 | x = z 134 | for i in reversed(range(self.num_flows)): 135 | x = self.permute(x) 136 | x = self.coupling(x, i, forward=False) 137 | 138 | return x 139 | 140 | def forward(self, x): 141 | z = self.f(x) 142 | return self.log_prior(z) 143 | 144 | def sample(self, batchSize, D=2, intMax=100): 145 | # sample z: 146 | z = self.prior_sample(batchSize=batchSize, D=D, intMax=intMax) 147 | # x = f^-1(z) 148 | x = self.f_inv(z) 149 | return x.view(batchSize, 1, D) 150 | 151 | def log_integer_probability(self, x, p, mu): 152 | # Chakraborty & Chakravarty, "A new discrete probability distribution with integer support on (−∞, ∞)", 153 | # Communications in Statistics - Theory and Methods, 45:2, 492-505, DOI: 10.1080/03610926.2013.830743 154 | log_p = torch.log(1. - p) + (x - mu) * torch.log(p) \ 155 | - torch.log(1. + torch.exp((x - mu) * torch.log(p))) \ 156 | - torch.log(1. + torch.exp((x - mu + 1.) * torch.log(p))) 157 | return log_p 158 | 159 | def log_prior(self, x): 160 | p = torch.sigmoid(self.p) 161 | log_p = self.log_integer_probability(x, p, self.mu) 162 | return log_p.sum() 163 | 164 | def prior_sample(self, batchSize, D=2, intMax=100): 165 | ints = np.expand_dims(np.arange(-intMax, intMax + 1), 0) 166 | for d in range(D): 167 | p = torch.sigmoid(self.p[:, [d]]) 168 | mu = self.mu[:, d] 169 | log_p = self.log_integer_probability(torch.from_numpy(ints), p, mu) 170 | 171 | if d == 0: 172 | z = torch.from_numpy(np.random.choice(ints[0], (batchSize, 1), 173 | p=torch.exp(log_p[0]).detach().numpy()).astype(np.float32)) 174 | else: 175 | z_new = torch.from_numpy(np.random.choice(ints[0], (batchSize, 1), 176 | p=torch.exp(log_p[0]).detach().numpy()).astype(np.float32)) 177 | z = torch.cat((z, z_new), 1) 178 | return z 179 | 180 | 181 | class IDF4(nn.Module): 182 | def __init__(self, nett_a, nett_b, nett_c, nett_d, num_flows, D=2): 183 | super(IDF4, self).__init__() 184 | 185 | print('IDF by JT.') 186 | 187 | self.t_a = torch.nn.ModuleList([nett_a() for _ in range(num_flows)]) 188 | self.t_b = torch.nn.ModuleList([nett_b() for _ in range(num_flows)]) 189 | self.t_c = torch.nn.ModuleList([nett_c() for _ in range(num_flows)]) 190 | self.t_d = torch.nn.ModuleList([nett_d() for _ in range(num_flows)]) 191 | self.num_flows = num_flows 192 | 193 | self.round = RoundStraightThrough.apply 194 | 195 | self.p = nn.Parameter(torch.zeros(1, D)) 196 | self.mu = nn.Parameter(torch.ones(1, D) * 0.5) 197 | 198 | def coupling(self, x, index, forward=True): 199 | (xa, xb, xc, xd) = torch.chunk(x, 4, 1) 200 | 201 | if forward: 202 | ya = xa + self.round(self.t_a[index](torch.cat((xb, xc, xd), 1))) 203 | yb = xb + self.round(self.t_b[index](torch.cat((ya, xc, xd), 1))) 204 | yc = xc + self.round(self.t_c[index](torch.cat((ya, yb, xd), 1))) 205 | yd = xd + self.round(self.t_d[index](torch.cat((ya, yb, yc), 1))) 206 | else: 207 | yd = xd - self.round(self.t_d[index](torch.cat((xa, xb, xc), 1))) 208 | yc = xc - self.round(self.t_c[index](torch.cat((xa, xb, yd), 1))) 209 | yb = xb - self.round(self.t_b[index](torch.cat((xa, yc, yd), 1))) 210 | ya = xa - self.round(self.t_a[index](torch.cat((yb, yc, yd), 1))) 211 | 212 | return torch.cat((ya, yb, yc, yd), 1) 213 | 214 | def permute(self, x): 215 | return x.flip(1) 216 | 217 | def f(self, x): 218 | z = x 219 | for i in range(self.num_flows): 220 | z = self.coupling(z, i, forward=True) 221 | z = self.permute(z) 222 | 223 | return z 224 | 225 | def f_inv(self, z): 226 | x = z 227 | for i in reversed(range(self.num_flows)): 228 | x = self.permute(x) 229 | x = self.coupling(x, i, forward=False) 230 | 231 | return x 232 | 233 | def forward(self, x): 234 | z = self.f(x) 235 | return self.log_prior(z) 236 | 237 | def sample(self, batchSize, D=2, intMax=100): 238 | # sample z: 239 | z = self.prior_sample(batchSize=batchSize, D=D, intMax=intMax) 240 | # x = f^-1(z) 241 | x = self.f_inv(z) 242 | return x.view(batchSize, 1, D) 243 | 244 | def log_integer_probability(self, x, p, mu): 245 | # Chakraborty & Chakravarty, "A new discrete probability distribution with integer support on (−∞, ∞)", 246 | # Communications in Statistics - Theory and Methods, 45:2, 492-505, DOI: 10.1080/03610926.2013.830743 247 | log_p = torch.log(1. - p) + (x - mu) * torch.log(p) \ 248 | - torch.log(1. + torch.exp((x - mu) * torch.log(p))) \ 249 | - torch.log(1. + torch.exp((x - mu + 1.) * torch.log(p))) 250 | return log_p 251 | 252 | def log_prior(self, x): 253 | p = torch.sigmoid(self.p) 254 | log_p = self.log_integer_probability(x, p, self.mu) 255 | return log_p.sum() 256 | 257 | def prior_sample(self, batchSize, D=2, intMax=100): 258 | ints = np.expand_dims(np.arange(-intMax, intMax + 1), 0) 259 | for d in range(D): 260 | p = torch.sigmoid(self.p[:, [d]]) 261 | mu = self.mu[:, d] 262 | log_p = self.log_integer_probability(torch.from_numpy(ints), p, mu) 263 | 264 | if d == 0: 265 | z = torch.from_numpy(np.random.choice(ints[0], (batchSize, 1), 266 | p=torch.exp(log_p[0]).detach().numpy()).astype(np.float32)) 267 | else: 268 | z_new = torch.from_numpy(np.random.choice(ints[0], (batchSize, 1), 269 | p=torch.exp(log_p[0]).detach().numpy()).astype(np.float32)) 270 | z = torch.cat((z, z_new), 1) 271 | return z 272 | 273 | 274 | class IDF8(nn.Module): 275 | def __init__(self, nett_a, nett_b, nett_c, nett_d, nett_e, nett_f, nett_g, nett_h, num_flows, D=2): 276 | super(IDF8, self).__init__() 277 | 278 | print('IDF by JT.') 279 | 280 | self.t_a = torch.nn.ModuleList([nett_a() for _ in range(num_flows)]) 281 | self.t_b = torch.nn.ModuleList([nett_b() for _ in range(num_flows)]) 282 | self.t_c = torch.nn.ModuleList([nett_c() for _ in range(num_flows)]) 283 | self.t_d = torch.nn.ModuleList([nett_d() for _ in range(num_flows)]) 284 | self.t_e = torch.nn.ModuleList([nett_e() for _ in range(num_flows)]) 285 | self.t_f = torch.nn.ModuleList([nett_f() for _ in range(num_flows)]) 286 | self.t_g = torch.nn.ModuleList([nett_g() for _ in range(num_flows)]) 287 | self.t_h = torch.nn.ModuleList([nett_h() for _ in range(num_flows)]) 288 | self.num_flows = num_flows 289 | 290 | self.round = RoundStraightThrough.apply 291 | 292 | self.p = nn.Parameter(torch.zeros(1, D)) 293 | self.mu = nn.Parameter(torch.ones(1, D) * 0.5) 294 | 295 | def coupling(self, x, index, forward=True): 296 | (xa, xb, xc, xd, xe, xf, xg, xh) = torch.chunk(x, 8, 1) 297 | 298 | if forward: 299 | ya = xa + self.round(self.t_a[index](torch.cat((xb, xc, xd, xe, xf, xg, xh), 1))) 300 | yb = xb + self.round(self.t_b[index](torch.cat((ya, xc, xd, xe, xf, xg, xh), 1))) 301 | yc = xc + self.round(self.t_c[index](torch.cat((ya, yb, xd, xe, xf, xg, xh), 1))) 302 | yd = xd + self.round(self.t_d[index](torch.cat((ya, yb, yc, xe, xf, xg, xh), 1))) 303 | ye = xe + self.round(self.t_e[index](torch.cat((ya, yb, yc, yd, xf, xg, xh), 1))) 304 | yf = xf + self.round(self.t_f[index](torch.cat((ya, yb, yc, yd, ye, xg, xh), 1))) 305 | yg = xg + self.round(self.t_g[index](torch.cat((ya, yb, yc, yd, ye, yf, xh), 1))) 306 | yh = xh + self.round(self.t_h[index](torch.cat((ya, yb, yc, yd, ye, yf, yg), 1))) 307 | else: 308 | yh = xh - self.round(self.t_h[index](torch.cat((xa, xb, xc, xd, xe, xf, xg), 1))) 309 | yg = xg - self.round(self.t_g[index](torch.cat((xa, xb, xc, xd, xe, xf, yh), 1))) 310 | yf = xf - self.round(self.t_f[index](torch.cat((xa, xb, xc, xd, xe, yg, yh), 1))) 311 | ye = xe - self.round(self.t_e[index](torch.cat((xa, xb, xc, xd, yf, yg, yh), 1))) 312 | yd = xd - self.round(self.t_d[index](torch.cat((xa, xb, xc, ye, yf, yg, yh), 1))) 313 | yc = xc - self.round(self.t_c[index](torch.cat((xa, xb, yd, ye, yf, yg, yh), 1))) 314 | yb = xb - self.round(self.t_b[index](torch.cat((xa, yc, yd, ye, yf, yg, yh), 1))) 315 | ya = xa - self.round(self.t_a[index](torch.cat((yb, yc, yd, ye, yf, yg, yh), 1))) 316 | 317 | return torch.cat((ya, yb, yc, yd, ye, yf, yg, yh), 1) 318 | 319 | def permute(self, x): 320 | return x.flip(1) 321 | 322 | def f(self, x): 323 | z = x 324 | for i in range(self.num_flows): 325 | z = self.coupling(z, i, forward=True) 326 | z = self.permute(z) 327 | 328 | return z 329 | 330 | def f_inv(self, z): 331 | x = z 332 | for i in reversed(range(self.num_flows)): 333 | x = self.permute(x) 334 | x = self.coupling(x, i, forward=False) 335 | 336 | return x 337 | 338 | def forward(self, x): 339 | z = self.f(x) 340 | return self.log_prior(z) 341 | 342 | def sample(self, batchSize, D=2, intMax=100): 343 | # sample z: 344 | z = self.prior_sample(batchSize=batchSize, D=D, intMax=intMax) 345 | # x = f^-1(z) 346 | x = self.f_inv(z) 347 | return x.view(batchSize, 1, D) 348 | 349 | def log_integer_probability(self, x, p, mu): 350 | # Chakraborty & Chakravarty, "A new discrete probability distribution with integer support on (−∞, ∞)", 351 | # Communications in Statistics - Theory and Methods, 45:2, 492-505, DOI: 10.1080/03610926.2013.830743 352 | log_p = torch.log(1. - p) + (x - mu) * torch.log(p) \ 353 | - torch.log(1. + torch.exp((x - mu) * torch.log(p))) \ 354 | - torch.log(1. + torch.exp((x - mu + 1.) * torch.log(p))) 355 | return log_p 356 | 357 | def log_prior(self, x): 358 | p = torch.sigmoid(self.p) 359 | log_p = self.log_integer_probability(x, p, self.mu) 360 | return log_p.sum() 361 | 362 | def prior_sample(self, batchSize, D=2, intMax=100): 363 | ints = np.expand_dims(np.arange(-intMax, intMax + 1), 0) 364 | for d in range(D): 365 | p = torch.sigmoid(self.p[:, [d]]) 366 | mu = self.mu[:, d] 367 | log_p = self.log_integer_probability(torch.from_numpy(ints), p, mu) 368 | 369 | if d == 0: 370 | z = torch.from_numpy(np.random.choice(ints[0], (batchSize, 1), 371 | p=torch.exp(log_p[0]).detach().numpy()).astype(np.float32)) 372 | else: 373 | z_new = torch.from_numpy(np.random.choice(ints[0], (batchSize, 1), 374 | p=torch.exp(log_p[0]).detach().numpy()).astype(np.float32)) 375 | z = torch.cat((z, z_new), 1) 376 | return z -------------------------------------------------------------------------------- /models/realnvp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class RealNVP(nn.Module): 7 | def __init__(self, nets, nett, num_flows, prior, dequantization=True): 8 | super(RealNVP, self).__init__() 9 | 10 | print('RealNVP by JT.') 11 | 12 | self.dequantization = dequantization 13 | 14 | self.prior = prior 15 | self.t = torch.nn.ModuleList([nett() for _ in range(num_flows)]) 16 | self.s = torch.nn.ModuleList([nets() for _ in range(num_flows)]) 17 | self.num_flows = num_flows 18 | 19 | def coupling(self, x, index, forward=True): 20 | (xa, xb) = torch.chunk(x, 2, 1) 21 | 22 | s = self.s[index](xa) 23 | t = self.t[index](xa) 24 | 25 | if forward: 26 | yb = (xb - t) * torch.exp(-s) 27 | else: 28 | yb = torch.exp(s) * xb + t 29 | 30 | return torch.cat((xa, yb), 1), s, t 31 | 32 | def permute(self, x): 33 | return x.flip(1) 34 | 35 | def f(self, x): 36 | log_det_J, z = x.new_zeros(x.shape[0]), x 37 | for i in range(self.num_flows): 38 | z, s, _ = self.coupling(z, i, forward=True) 39 | z = self.permute(z) 40 | log_det_J = log_det_J - s.sum(dim=1) 41 | 42 | return z, log_det_J 43 | 44 | def f_inv(self, z): 45 | x = z 46 | for i in reversed(range(self.num_flows)): 47 | x = self.permute(x) 48 | x, _, _ = self.coupling(x, i, forward=False) 49 | 50 | return x 51 | 52 | def forward(self, x): 53 | z, log_det_J = self.f(x) 54 | return self.prior.log_prob(z) + log_det_J 55 | 56 | def sample(self, batchSize, D=2): 57 | z = self.prior.sample((batchSize, D)) 58 | z = z[:, 0, :] 59 | x = self.f_inv(z) 60 | return x.view(-1, D) -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import matplotlib.pyplot as plt 7 | 8 | from torch.utils.data import DataLoader 9 | from pylab import rcParams 10 | 11 | from models.idf import IDF, IDF2, IDF4, IDF8 12 | from models.realnvp import RealNVP 13 | 14 | from utils.datasets import Digits 15 | from utils.training import training 16 | from utils.evaluation import evaluation, plot_curve, samples_real 17 | 18 | 19 | if __name__ == '__main__': 20 | # DATA 21 | train_data = Digits(mode='train') 22 | val_data = Digits(mode='val') 23 | test_data = Digits(mode='test') 24 | 25 | training_loader = DataLoader(train_data, batch_size=64, shuffle=True) 26 | val_loader = DataLoader(val_data, batch_size=64, shuffle=False) 27 | test_loader = DataLoader(test_data, batch_size=64, shuffle=False) 28 | 29 | # SETUP 30 | models = ['idf', 'idf4', 'idf8', 'realnvp'] 31 | num_repetitions = range(5) 32 | 33 | D = 64 34 | M = 256 35 | 36 | lr = 1e-3 37 | num_epochs = 1000 38 | max_patience = 20 39 | 40 | # REPETITIONS 41 | for m in models: 42 | for r in num_repetitions: 43 | result_dir = 'results/' + m + '_' + str(r) + '/' 44 | 45 | if not (os.path.isdir(result_dir)): 46 | os.mkdir(result_dir) 47 | 48 | # MODEL 49 | name = m 50 | 51 | if name == 'idf8': 52 | print(name + " initialized!") 53 | num_flows = 2 54 | nett_a = lambda: nn.Sequential(nn.Linear(7 * (D // 8), M), nn.LeakyReLU(), 55 | nn.Linear(M, M), nn.LeakyReLU(), 56 | nn.Linear(M, D // 8)) 57 | 58 | nett_b = lambda: nn.Sequential(nn.Linear(7 * (D // 8), M), nn.LeakyReLU(), 59 | nn.Linear(M, M), nn.LeakyReLU(), 60 | nn.Linear(M, D // 8)) 61 | 62 | nett_c = lambda: nn.Sequential(nn.Linear(7 * (D // 8), M), nn.LeakyReLU(), 63 | nn.Linear(M, M), nn.LeakyReLU(), 64 | nn.Linear(M, D // 8)) 65 | 66 | nett_d = lambda: nn.Sequential(nn.Linear(7 * (D // 8), M), nn.LeakyReLU(), 67 | nn.Linear(M, M), nn.LeakyReLU(), 68 | nn.Linear(M, D // 8)) 69 | 70 | nett_e = lambda: nn.Sequential(nn.Linear(7 * (D // 8), M), nn.LeakyReLU(), 71 | nn.Linear(M, M), nn.LeakyReLU(), 72 | nn.Linear(M, D // 8)) 73 | 74 | nett_f = lambda: nn.Sequential(nn.Linear(7 * (D // 8), M), nn.LeakyReLU(), 75 | nn.Linear(M, M), nn.LeakyReLU(), 76 | nn.Linear(M, D // 8)) 77 | 78 | nett_g = lambda: nn.Sequential(nn.Linear(7 * (D // 8), M), nn.LeakyReLU(), 79 | nn.Linear(M, M), nn.LeakyReLU(), 80 | nn.Linear(M, D // 8)) 81 | 82 | nett_h = lambda: nn.Sequential(nn.Linear(7 * (D // 8), M), nn.LeakyReLU(), 83 | nn.Linear(M, M), nn.LeakyReLU(), 84 | nn.Linear(M, D // 8)) 85 | 86 | flow = IDF8(nett_a, nett_b, nett_c, nett_d, nett_e, nett_f, nett_g, nett_h, num_flows, D) 87 | 88 | elif name == 'idf4': 89 | print(name + " initialized!") 90 | num_flows = 4 91 | nett_a = lambda: nn.Sequential(nn.Linear(3 * (D // 4), M), nn.LeakyReLU(), 92 | nn.Linear(M, M), nn.LeakyReLU(), 93 | nn.Linear(M, D // 4)) 94 | 95 | nett_b = lambda: nn.Sequential(nn.Linear(3 * (D // 4), M), nn.LeakyReLU(), 96 | nn.Linear(M, M), nn.LeakyReLU(), 97 | nn.Linear(M, D // 4)) 98 | 99 | nett_c = lambda: nn.Sequential(nn.Linear(3 * (D // 4), M), nn.LeakyReLU(), 100 | nn.Linear(M, M), nn.LeakyReLU(), 101 | nn.Linear(M, D // 4)) 102 | 103 | nett_d = lambda: nn.Sequential(nn.Linear(3 * (D // 4), M), nn.LeakyReLU(), 104 | nn.Linear(M, M), nn.LeakyReLU(), 105 | nn.Linear(M, D // 4)) 106 | 107 | flow = IDF4(nett_a, nett_b, nett_c, nett_d, num_flows, D) 108 | 109 | elif name == 'idf2': 110 | print(name + " initialized!") 111 | num_flows = 8 112 | nett_a = lambda: nn.Sequential(nn.Linear((D // 2), M), nn.LeakyReLU(), 113 | nn.Linear(M, M), nn.LeakyReLU(), 114 | nn.Linear(M, D // 2)) 115 | 116 | nett_b = lambda: nn.Sequential(nn.Linear((D // 2), M), nn.LeakyReLU(), 117 | nn.Linear(M, M), nn.LeakyReLU(), 118 | nn.Linear(M, D // 2)) 119 | 120 | flow = IDF2(nett_a, nett_b, num_flows, D) 121 | 122 | elif name == 'idf': 123 | print(name +" initialized!") 124 | num_flows = 16 125 | nett = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(), 126 | nn.Linear(M, M), nn.LeakyReLU(), 127 | nn.Linear(M, D // 2)) 128 | 129 | flow = IDF(nett, num_flows, D) 130 | 131 | elif name == 'realnvp': 132 | num_flows = 8 133 | 134 | nets = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(), 135 | nn.Linear(M, M), nn.LeakyReLU(), 136 | nn.Linear(M, D // 2), nn.Tanh()) 137 | 138 | nett = lambda: nn.Sequential(nn.Linear(D // 2, M), nn.LeakyReLU(), 139 | nn.Linear(M, M), nn.LeakyReLU(), 140 | nn.Linear(M, D // 2)) 141 | 142 | prior = torch.distributions.MultivariateNormal(torch.zeros(D), torch.eye(D)) 143 | flow = RealNVP(nets, nett, num_flows, prior, dequantization=True) 144 | 145 | # OPTIMIZER 146 | optimizer = torch.optim.Adamax([p for p in flow.parameters() if p.requires_grad == True], lr=lr) 147 | 148 | # TRAINING 149 | nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, flow=flow, optimizer=optimizer, 150 | training_loader=training_loader, val_loader=val_loader) 151 | 152 | # EVALUATION 153 | test_loss = evaluation(name=result_dir + name, test_loader=test_loader) 154 | f = open(result_dir + name + '_test_loss.txt', "w") 155 | f.write(str(test_loss)) 156 | f.close() 157 | 158 | samples_real(result_dir + name, test_loader) 159 | 160 | plot_curve(result_dir + name, nll_val) 161 | 162 | 163 | 164 | 165 | 166 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmtomczak/git_flow/5a38d470d11f41f95683ab3fbcc2eae4dda9f746/utils/__init__.py -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.datasets import load_digits 3 | from sklearn import datasets 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class Digits(Dataset): 8 | """Scikit-Learn Digits dataset.""" 9 | 10 | def __init__(self, mode='train'): 11 | digits = load_digits() 12 | if mode == 'train': 13 | self.data = digits.data[:1000].astype(np.float32) 14 | elif mode == 'val': 15 | self.data = digits.data[1000:1350].astype(np.float32) 16 | else: 17 | self.data = digits.data[1350:].astype(np.float32) 18 | 19 | def __len__(self): 20 | return len(self.data) 21 | 22 | def __getitem__(self, idx): 23 | sample = self.data[idx] 24 | return sample 25 | 26 | 27 | class TwoMoonDatasetInt(Dataset): 28 | """Two Moon dataset.""" 29 | 30 | def __init__(self, N=1000, noise=0.05, scale=10.): 31 | self.data = np.round(datasets.make_moons(n_samples=N, noise=noise)[0].astype(np.float32) * scale) 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def __getitem__(self, idx): 37 | sample = self.data[idx] 38 | return sample -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | def evaluation(name, test_loader): 6 | # EVALUATION 7 | # load best performing model 8 | flow_best = torch.load(name + '.model') 9 | 10 | flow_best.eval() 11 | loss_test = 0. 12 | N = 0. 13 | for indx_batch, test_batch in enumerate(test_loader): 14 | loss_t = -flow_best.forward(test_batch).sum() 15 | loss_test = loss_test + loss_t.item() 16 | N = N + test_batch.shape[0] 17 | loss_test = loss_test / N 18 | 19 | print(f'FINAL LOSS: nll={loss_test}') 20 | 21 | return loss_test 22 | 23 | 24 | def samples_real(name, test_loader): 25 | # REAL------- 26 | num_x = 4 27 | num_y = 4 28 | x = next(iter(test_loader)).detach().numpy() 29 | 30 | fig, ax = plt.subplots(num_x, num_y) 31 | for i, ax in enumerate(ax.flatten()): 32 | plottable_image = np.reshape(x[i], (8, 8)) 33 | ax.imshow(plottable_image, cmap='gray') 34 | ax.axis('off') 35 | 36 | plt.savefig(name+'_real_images.pdf', bbox_inches='tight') 37 | plt.close() 38 | 39 | 40 | def samples_generated(name, data_loader, extra_name=''): 41 | x = next(iter(data_loader)).detach().numpy() 42 | 43 | # GENERATIONS------- 44 | flow_best = torch.load(name + '.model') 45 | flow_best.eval() 46 | 47 | num_x = 4 48 | num_y = 4 49 | x = flow_best.sample(num_x * num_y, D=x.shape[1]).detach().numpy() 50 | 51 | fig, ax = plt.subplots(num_x, num_y) 52 | for i, ax in enumerate(ax.flatten()): 53 | plottable_image = np.reshape(x[i], (8, 8)) 54 | ax.imshow(plottable_image, cmap='gray') 55 | ax.axis('off') 56 | 57 | plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight') 58 | plt.close() 59 | 60 | 61 | def plot_curve(name, nll_val): 62 | plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3') 63 | plt.xlabel('epochs') 64 | plt.ylabel('nll') 65 | plt.savefig(name + '_nll_val_curve.pdf', bbox_inches='tight') 66 | plt.close() -------------------------------------------------------------------------------- /utils/nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RoundStraightThrough(torch.autograd.Function): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | 10 | @staticmethod 11 | def forward(ctx, input): 12 | rounded = torch.round(input, out=None) 13 | return rounded 14 | 15 | @staticmethod 16 | def backward(ctx, grad_output): 17 | grad_input = grad_output.clone() 18 | return grad_input 19 | 20 | 21 | class Swish(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | 25 | def forward(self, x): 26 | return x * torch.sigmoid(x) 27 | 28 | 29 | class Reshape3d(nn.Module): 30 | def __init__(self, size): 31 | super().__init__() 32 | self.size = size 33 | 34 | def forward(self, x): 35 | B = x.shape[0] 36 | return x.view(B, self.size[0], self.size[1], self.size[2]) 37 | 38 | 39 | class Flatten(nn.Module): 40 | def __init__(self): 41 | super().__init__() 42 | 43 | def forward(self, x): 44 | B = x.shape[0] 45 | return x.view(B, -1) -------------------------------------------------------------------------------- /utils/training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from utils.evaluation import samples_generated 5 | 6 | 7 | def training(name, max_patience, num_epochs, flow, optimizer, training_loader, val_loader): 8 | nll_val = [] 9 | best_nll = 1000. 10 | patience = 0 11 | 12 | # Main loop 13 | for e in range(num_epochs): 14 | # TRAINING 15 | flow.train() 16 | for indx_batch, batch in enumerate(training_loader): 17 | if hasattr(flow, 'dequantization'): 18 | if flow.dequantization: 19 | batch = batch + torch.rand(batch.shape) 20 | loss = -flow.forward(batch).mean() 21 | 22 | optimizer.zero_grad() 23 | loss.backward(retain_graph=True) 24 | optimizer.step() 25 | 26 | # Validation 27 | flow.eval() 28 | loss_val = 0. 29 | N = 0. 30 | for indx_batch, val_batch in enumerate(val_loader): 31 | loss_v = -flow.forward(val_batch).sum() 32 | 33 | loss_val = loss_val + loss_v.item() 34 | 35 | N = N + val_batch.shape[0] 36 | loss_val = loss_val / N 37 | 38 | print(f'Epoch {e}: val nll={loss_val}') 39 | nll_val.append(loss_val) # save for plotting 40 | 41 | if e == 0: 42 | print('saved!') 43 | torch.save(flow, name + '.model') 44 | best_nll = loss_val 45 | else: 46 | if loss_val < best_nll: 47 | print('saved!') 48 | torch.save(flow, name + '.model') 49 | best_nll = loss_val 50 | patience = 0 51 | 52 | samples_generated(name, val_loader, extra_name="_epoch_" + str(e)) 53 | else: 54 | patience = patience + 1 55 | 56 | if patience > max_patience: 57 | break 58 | 59 | nll_val = np.asarray(nll_val) 60 | 61 | return nll_val --------------------------------------------------------------------------------