├── src
├── plot_tools
│ ├── __init__.py
│ ├── pics_tools.py
│ └── cmap_tools.py
├── data_tools
│ ├── __init__.py
│ └── mnist_dataset.py
├── kl_tools.py
└── test.py
├── .gitignore
├── output_1.png
├── output_2.png
├── README.md
└── tutorial.ipynb
/src/plot_tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *__pycache__/
3 |
--------------------------------------------------------------------------------
/src/data_tools/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dcmoyer/invariance-tutorial/HEAD/output_1.png
--------------------------------------------------------------------------------
/output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dcmoyer/invariance-tutorial/HEAD/output_2.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Non-adversarial Invariant Representations Tutorial
2 |
3 | This repo contains tutorial content, example code, and working loss function
4 | calculators for
5 | [Invariant Representations without Adversarial Training](https://arxiv.org/abs/1805.09458).
6 |
7 | A transcription of the tutorial content in non-interactive HTML form is hosted
8 | on my website [in at test blogpost](dcmoyer.github.io/selfhosted/blag.html).
9 |
10 | You may also be interested in the [Echo Noise Distribution](https://arxiv.org/abs/1904.07199),
11 | a noise model for encodings _z_ that provides closed form rate terms I(z,x),
12 | with nicely computable derivatives. This is exactly the term we would like to
13 | penalize in the Inv. Rep. paper, but instead had to bound.
14 |
15 |
16 |
--------------------------------------------------------------------------------
/src/data_tools/mnist_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import tensorflow as tf
3 | import numpy as np
4 | import sklearn.preprocessing as prep
5 |
6 | IMG_DIM=28
7 | NUM_LABELS=10
8 |
9 | def get_data():
10 |
11 | #TODO MOVE TO PREPROCESSING
12 | (train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()
13 |
14 | train_x = train_x.astype(np.float32)
15 | test_x = test_x.astype(np.float32)
16 |
17 | train_x = train_x.reshape( (train_x.shape[0], IMG_DIM**2) ) / 255.0
18 | test_x = test_x.reshape( (test_x.shape[0], IMG_DIM**2) ) / 255.0
19 |
20 | def one_hot(labels):
21 | num_labels_data = labels.shape[0]
22 | one_hot_encoding = np.zeros((num_labels_data,NUM_LABELS))
23 | one_hot_encoding[np.arange(num_labels_data),labels] = 1
24 | one_hot_encoding = np.reshape(one_hot_encoding, [-1, NUM_LABELS])
25 | return one_hot_encoding
26 |
27 | train_y = one_hot(train_y)
28 | test_y = one_hot(test_y)
29 |
30 | train_y = train_y.astype(np.float32)
31 | test_y = test_y.astype(np.float32)
32 |
33 | #def standard_scale(X_train, X_test):
34 | # preprocessor = prep.StandardScaler().fit(X_train)
35 | # X_train = preprocessor.transform(X_train)
36 | # X_test = preprocessor.transform(X_test)
37 | # return X_train, X_test
38 |
39 | #train_x, test_x = standard_scale(train_x, test_x)
40 | #train_x = train_x / 255.0
41 | #test_x = test_x / 255.0
42 |
43 | return (train_x, train_y), (test_x, test_y)
44 |
45 |
--------------------------------------------------------------------------------
/src/plot_tools/pics_tools.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | def plot_image_grid(x_data, img_shape, grid_shape):
7 |
8 | one_dim = False
9 | for i in grid_shape:
10 | if i == 1:
11 | one_dim = True
12 |
13 | n_pics = x_data.shape[0]
14 | n_pics_in_grid = np.prod(grid_shape)
15 |
16 | if n_pics_in_grid < n_pics:
17 | print("adding extra row to grid")
18 | grid_shape = (grid_shape[0] + 1, grid_shape[1])
19 |
20 | #sizes grid automatically (in inches)
21 | fig, ax = plt.subplots(grid_shape[0], grid_shape[1])
22 |
23 | if one_dim:
24 | for i in np.arange(max(grid_shape)):
25 | data_idx = i
26 |
27 | if data_idx >= n_pics:
28 | break
29 |
30 | #ax = plt.subplot(grid_shape[0], grid_shape[1], data_idx+1)
31 |
32 | img_data = x_data[data_idx].reshape(*img_shape)
33 |
34 | ax[i].imshow(img_data, interpolation="nearest", vmin=0, vmax=1)
35 | ax[i].get_xaxis().set_visible(False)
36 | ax[i].get_yaxis().set_visible(False)
37 |
38 | else:
39 | for i in np.arange(grid_shape[0]):
40 | for j in np.arange(grid_shape[1]):
41 |
42 | data_idx = j + grid_shape[1]*i
43 |
44 | if data_idx >= n_pics:
45 | break
46 |
47 | #ax = plt.subplot(grid_shape[0], grid_shape[1], data_idx+1)
48 |
49 | img_data = x_data[data_idx].reshape(*img_shape)
50 |
51 | ax[i,j].imshow(img_data, interpolation="nearest", vmin=0, vmax=1)
52 | ax[i,j].get_xaxis().set_visible(False)
53 | ax[i,j].get_yaxis().set_visible(False)
54 | plt.tight_layout()
55 |
56 | return fig
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/src/plot_tools/cmap_tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | import matplotlib.pyplot as plt
4 | from mpl_toolkits.axes_grid1 import AxesGrid
5 |
6 | def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
7 | '''
8 | Function to offset the "center" of a colormap. Useful for
9 | data with a negative min and positive max and you want the
10 | middle of the colormap's dynamic range to be at zero.
11 |
12 | Input
13 | -----
14 | cmap : The matplotlib colormap to be altered
15 | start : Offset from lowest point in the colormap's range.
16 | Defaults to 0.0 (no lower offset). Should be between
17 | 0.0 and `midpoint`.
18 | midpoint : The new center of the colormap. Defaults to
19 | 0.5 (no shift). Should be between 0.0 and 1.0. In
20 | general, this should be 1 - vmax / (vmax + abs(vmin))
21 | For example if your data range from -15.0 to +5.0 and
22 | you want the center of the colormap at 0.0, `midpoint`
23 | should be set to 1 - 5/(5 + 15)) or 0.75
24 | stop : Offset from highest point in the colormap's range.
25 | Defaults to 1.0 (no upper offset). Should be between
26 | `midpoint` and 1.0.
27 | '''
28 | cdict = {
29 | 'red': [],
30 | 'green': [],
31 | 'blue': [],
32 | 'alpha': []
33 | }
34 |
35 | # regular index to compute the colors
36 | reg_index = np.linspace(start, stop, 257)
37 |
38 | # shifted index to match the data
39 | shift_index = np.hstack([
40 | np.linspace(0.0, midpoint, 128, endpoint=False),
41 | np.linspace(midpoint, 1.0, 129, endpoint=True)
42 | ])
43 |
44 | for ri, si in zip(reg_index, shift_index):
45 | r, g, b, a = cmap(ri)
46 |
47 | cdict['red'].append((si, r, r))
48 | cdict['green'].append((si, g, g))
49 | cdict['blue'].append((si, b, b))
50 | cdict['alpha'].append((si, a, a))
51 |
52 | newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
53 | plt.register_cmap(cmap=newcmap)
54 |
55 | return newcmap
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/src/kl_tools.py:
--------------------------------------------------------------------------------
1 |
2 | import tensorflow as tf
3 | import math
4 |
5 | #KL(N_0|N_1) = tr(\sigma_1^{-1} \sigma_0) +
6 | # (\mu_1 - \mu_0)\sigma_1^{-1}(\mu_1 - \mu_0) - k +
7 | # \log( \frac{\det \sigma_1}{\det \sigma_0} )
8 | def all_pairs_gaussian_kl(mu, sigma, add_third_term=False):
9 |
10 | sigma_sq = tf.square(sigma) + 1e-8
11 |
12 | #mu is [batchsize x dim_z]
13 | #sigma is [batchsize x dim_z]
14 |
15 | sigma_sq_inv = tf.math.reciprocal(sigma_sq)
16 | #sigma_inv is [batchsize x sizeof(latent_space)]
17 |
18 | #
19 | # first term
20 | #
21 |
22 | #dot product of all sigma_inv vectors with sigma
23 | #is the same as a matrix mult of diag
24 | first_term = tf.matmul(sigma_sq,tf.transpose(sigma_sq_inv))
25 |
26 | #
27 | # second term
28 | #
29 |
30 | #TODO: check this
31 | #REMEMBER THAT THIS IS SIGMA_1, not SIGMA_0
32 |
33 | r = tf.matmul(mu * mu,tf.transpose(sigma_sq_inv))
34 | #r is now [batchsize x batchsize] = sum(mu[:,i]**2 / Sigma[j])
35 |
36 | r2 = mu * mu * sigma_sq_inv
37 | r2 = tf.reduce_sum(r2,1)
38 | #r2 is now [batchsize, 1] = mu[j]**2 / Sigma[j]
39 |
40 | #squared distance
41 | #(mu[i] - mu[j])\sigma_inv(mu[i] - mu[j]) = r[i] - 2*mu[i]*mu[j] + r[j]
42 | #uses broadcasting
43 | second_term = 2*tf.matmul(mu, tf.transpose(mu*sigma_sq_inv))
44 | second_term = r - second_term + tf.transpose(r2)
45 |
46 | ##uncomment to check using tf_tester
47 | #return second_term
48 |
49 | #
50 | # third term
51 | #
52 |
53 | # log det A = tr log A
54 | # log \frac{ det \Sigma_1 }{ det \Sigma_0 } =
55 | # \tr\log \Sigma_1 - \tr\log \Sigma_0
56 | # for each sample, we have B comparisons to B other samples...
57 | # so this cancels out
58 |
59 | if(add_third_term):
60 | r = tf.reduce_sum(tf.math.log(sigma_sq),1)
61 | r = tf.reshape(r,[-1,1])
62 | third_term = r - tf.transpose(r)
63 | else:
64 | third_term = 0
65 |
66 | #- tf.reduce_sum(tf.log(1e-8 + tf.square(sigma)))\
67 | # the dim_z ** 3 term comes from
68 | # -the k in the original expression
69 | # -this happening k times in for each sample
70 | # -this happening for k samples
71 | #return 0.5 * ( first_term + second_term + third_term - dim_z )
72 | return 0.5 * ( first_term + second_term + third_term )
73 |
74 | #
75 | # kl_conditional_and_marg
76 | # \sum_{x'} KL[ q(z|x) \| q(z|x') ] + (B-1) H[q(z|x)]
77 | #
78 |
79 | #def kl_conditional_and_marg(args):
80 | def kl_conditional_and_marg(z_mean, z_log_sigma_sq, dim_z):
81 | z_sigma = tf.exp( 0.5 * z_log_sigma_sq )
82 | all_pairs_GKL = all_pairs_gaussian_kl(z_mean, z_sigma, True) - 0.5*dim_z
83 | return tf.reduce_mean(all_pairs_GKL)
84 |
85 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 |
2 | from data_tools import mnist_dataset
3 | import numpy as np
4 |
5 | params = {
6 | "beta" : 0.1,
7 | "lambda" : 1.0,
8 | }
9 |
10 | (train_x, train_y), (test_x, test_y) = mnist_dataset.get_data()
11 |
12 | ##
13 | ## let's build our VAE network
14 | ##
15 |
16 | import keras
17 | # sorry, for portability should be just
18 | import keras.backend as K
19 | # but both kl_tools and echo lock us into tensorflow sooo...
20 | import tensorflow as tf
21 |
22 | import kl_tools
23 |
24 | DIM_Z = 16
25 | DIM_C = mnist_dataset.NUM_LABELS
26 | INPUT_SHAPE=mnist_dataset.IMG_DIM ** 2
27 | ACTIVATION="tanh"
28 |
29 | input_x = keras.layers.Input( shape = [INPUT_SHAPE], name="x" )
30 |
31 | enc_hidden_1 = keras.layers.Dense(512, activation=ACTIVATION, name="enc_h1")(input_x)
32 | enc_hidden_2 = keras.layers.Dense(512, activation=ACTIVATION, name="enc_h2")(enc_hidden_1)
33 |
34 | #stolen straight from the docs
35 | #https://keras.io/examples/variational_autoencoder/
36 | def sampling(args):
37 | """Reparameterization trick by sampling from an isotropic unit Gaussian.
38 |
39 | # Arguments
40 | args (tensor): mean and log of variance of Q(z|X)
41 |
42 | # Returns
43 | z (tensor): sampled latent vector
44 | """
45 |
46 | z_mean, z_log_var = args
47 | batch = K.shape(z_mean)[0]
48 | dim = K.int_shape(z_mean)[1]
49 | # by default, random_normal has mean = 0 and std = 1.0
50 | epsilon = K.random_normal(shape=(batch, dim))
51 | return z_mean + K.exp(0.5 * z_log_var) * epsilon
52 |
53 |
54 | z_mean = keras.layers.Dense(DIM_Z, activation="tanh")(enc_hidden_2)
55 | z_log_sigma_sq = keras.layers.Dense(DIM_Z, activation="linear")(enc_hidden_2)
56 |
57 | z = keras.layers.Lambda(sampling, output_shape=(DIM_Z,), name='z')([z_mean, z_log_sigma_sq])
58 |
59 | ## this is the concat operation!
60 | input_c = keras.layers.Input( shape = [DIM_C], name="c")
61 | z_with_c = keras.layers.concatenate([z,input_c])
62 | z_mean_with_c = keras.layers.concatenate([z_mean,input_c])
63 |
64 | dec_h1 = keras.layers.Dense(512, activation=ACTIVATION, name="dec_h1")
65 | dec_h2 = keras.layers.Dense(512, activation=ACTIVATION, name="dec_h2")
66 | output_layer = keras.layers.Dense( INPUT_SHAPE, name="x_hat" )
67 |
68 | dec_hidden_1 = dec_h1(z_with_c)
69 | dec_hidden_2 = dec_h2(dec_hidden_1)
70 | x_hat = output_layer(dec_hidden_2)
71 |
72 | cvae = keras.models.Model(inputs=[input_x,input_c], outputs=x_hat, name="ICVAE")
73 |
74 | print(cvae.summary())
75 |
76 | ##
77 | ## make a mean model for outputs
78 | ##
79 |
80 | mean_dec_hidden_1 = dec_h1(z_mean_with_c)
81 | mean_dec_hidden_2 = dec_h2(mean_dec_hidden_1)
82 | mean_x_hat = output_layer(mean_dec_hidden_2)
83 |
84 | mean_cvae = keras.models.Model(
85 | inputs=[input_x, input_c],
86 | outputs=mean_x_hat,name="mean_VAE",
87 | )
88 |
89 |
90 | ##
91 | ## okay, now we have a network. Let's build the losses
92 | ##
93 |
94 | recon_loss = keras.losses.mse(input_x, x_hat)
95 | recon_loss *= INPUT_SHAPE #optional, in the tutorial code though
96 |
97 | kl_loss = 1 + z_log_sigma_sq - K.square(z_mean) - K.exp(z_log_sigma_sq)
98 | kl_loss = K.sum(kl_loss, axis=-1)
99 | kl_loss *= -0.5
100 |
101 | kl_qzx_qz_loss = kl_tools.kl_conditional_and_marg(z_mean, z_log_sigma_sq, DIM_Z)
102 |
103 | #optional add beta param here
104 | # and cite Higgins et al.
105 | cvae_loss = K.mean((1 + params["lambda"]) * recon_loss + params["beta"]*kl_loss + params["lambda"]*kl_qzx_qz_loss)
106 |
107 | cvae.add_loss(cvae_loss)
108 |
109 | ##
110 | ##
111 | ##
112 |
113 | learning_rate = 0.0005
114 | opt = keras.optimizers.Adam(lr=learning_rate)
115 |
116 | cvae.compile( optimizer=opt, )
117 |
118 | #training?
119 | import os
120 | if not os.path.exists("mnist_icvae.h5"):
121 | cvae.fit(
122 | { "x" : train_x, "c" : train_y }, epochs=50
123 | )
124 | cvae.save_weights("mnist_icvae.h5")
125 | else:
126 | cvae.load_weights("mnist_icvae.h5")
127 |
128 | exit(1)
129 |
130 | n_plot_samps = 10
131 | test_x_hat = mean_cvae.predict(
132 | { "x" : test_x[:n_plot_samps], "c" : test_y[:n_plot_samps] }
133 | )
134 |
135 | ##
136 | ## plot first N
137 | ##
138 |
139 | from plot_tools import pics_tools as pic
140 | import matplotlib.pyplot as plt
141 |
142 | fig = pic.plot_image_grid( \
143 | np.concatenate([test_x[:n_plot_samps],test_x_hat], axis=0),
144 | [mnist_dataset.IMG_DIM, mnist_dataset.IMG_DIM], \
145 | (2,n_plot_samps) \
146 | )
147 | plt.show()
148 |
149 | X_test_set = []
150 | Y_test_set = []
151 |
152 | for i in range(n_plot_samps):
153 | tmp_tile_array = np.tile(test_x[i],[10,1])
154 | X_test_set.append(test_x[i:(i+1),:])
155 | X_test_set.append(tmp_tile_array)
156 |
157 | Y_test_set.append(test_y[i:(i+1),:])
158 | #Y_test_set.append(np.array([[0],[1]]))
159 | Y_test_set.append(np.eye(10))
160 |
161 | X_test_set = np.concatenate(X_test_set, axis=0)
162 | Y_test_set = np.concatenate(Y_test_set, axis=0)
163 |
164 | X_test_hat = mean_cvae.predict(
165 | { "x" : X_test_set, "c" : Y_test_set }
166 | )
167 |
168 | plot_collection = []
169 | for i in range(n_plot_samps):
170 | plot_collection.append( test_x[i:(i+1),:] )
171 | plot_collection.append( X_test_hat[i*11:(i+1)*11,:] )
172 |
173 | plot_collection = np.concatenate( plot_collection, axis=0 )
174 |
175 | fig = pic.plot_image_grid( \
176 | plot_collection,
177 | [mnist_dataset.IMG_DIM, mnist_dataset.IMG_DIM], \
178 | (n_plot_samps,12) \
179 | )
180 | plt.show()
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
--------------------------------------------------------------------------------
/tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Invariant Representations Tutorial\n",
8 | "\n",
9 | "\n",
10 | "This is part of an Invariant Representations tutorial series. There's a short intro here to the theory behind what's going on, but for more in depth coverage see [this unimplemented second part](sorry dead link right now). This tutorial requires basic variational auto-encoder knowledge, and some information theory knowledge, but otherwise should hopefully be pretty accessible.\n",
11 | "\n",
12 | "If you're not familiar with VAE, there's a great tutorial from Lilian Weng [here](https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html).\n",
13 | "I should find a IT for ML tutorial, but I haven't yet, so if you know of a good one please link me!\n",
14 | "\n",
15 | "PLEASE leave comments/advice/requests. I'm new at this, so this is probably sub-optimal, whatever its current state. I'm also not a regular .ipynb user, so even minor things can be PR'd in happily.\n",
16 | "\n",
17 | "This tutorial is done in Keras/TF.\n"
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {},
23 | "source": [
24 | "## Theory summary\n",
25 | "\n",
26 | "**Three sentence setup from VAE:** Regular auto-encoders learn functions $q(z|x)$ and $p(x|z)$, which respectively encode and decode our data $x$ to a lower-dimensional variable $z$. If the auto-encoder does well, then $p(x|z = q(z|x))$ is close to $x$. VAE builds on this, letting $q$ be a probabilistic mapping, so that $q(z|x)$ is a distribution, usually a Gaussian.\n",
27 | "\n",
28 | "**The invariant auto-encoder objective:** We want to learn $q$ and $p$ in the same setup that\n",
29 | "1. accurately describes $x$ (just like regular AEs)\n",
30 | "2. has a $z$ that is invariant to outside factor $c$ (the invariance condition)\n",
31 | "\n",
32 | "The first part can be described using maximum likelihood from e.g. [Kingma and Welling 2013](https://arxiv.org/abs/1312.6114):\n",
33 | "\n",
34 | "$$ \\max \\mathbb{E}_{q(z|x)p(x)}[ p(x|z) ] $$\n",
35 | "\n",
36 | "We can write down a Markov chain that describes this setup: $ c \\rightarrow x \\rightarrow z $. We want to enforce the constraint $z\\perp c$, i.e. $z$ independent of $c$, i.e. that $p(z|c) = p(z)$. Basically: *Make sure $z$ doesn't change if you change $c$*.\n",
37 | "\n",
38 | "Okay, so what if we can't do that? $z \\perp c$ is a really harsh constraint.\n",
39 | "\n",
40 | "One idea: relax it to minimal mutual information, $\\min I(z,c)$. If $I(z,c)$ is close to $0$, then we're good.\n",
41 | "\n",
42 | "How do we minimize $I(z,c)$? This, it turns out, is also hard (are you sensing a theme) for arbitrary encodings $z$. Doing some* math we end up at this bound:\n",
43 | "\n",
44 | "$$I(z, c) \\leq\n",
45 | "\\underbrace{- \\mathbb{E}_{x,c,z\\sim q}[ \\log p(x|z,c)]}_{\\text{Reconstruction}}\n",
46 | "+ \\underbrace{\\mathbb{E}_{x}[~KL[~q(z|x)~\\|~q(z)~]~ ]}_{\\text{Compression}}\n",
47 | "- \\underbrace{\\vphantom{\\mathbb{E}_{x,q}[]}H(x|c)}_{\\text{Const}}$$\n",
48 | "\\*\"some\" meaning section 2 of [our paper](https://arxiv.org/abs/1805.09458).\n",
49 | "\n",
50 | "This has two parts we can optimize, $p(x|z,c)$ and $q(z|x)$. These look like an encoder and decoder pair from the auto-encoders introduced at the start, except our decoder is conditional. We also have an additional KL term. Otherwise it's the same as an auto-encoder.\n",
51 | "\n",
52 | "\n",
53 | "**TL;DR:** Use a conditional decoder and an additional KL term to make your auto-encoder invariant to whatever you conditioned on.\n",
54 | "\n",
55 | "The \"TODO List\":\n",
56 | "* Make a generic encoder, $q(z|x)$.\n",
57 | "* Make a conditional decoder, $p(x|z,c)$.\n",
58 | "* Implement that KL loss (or an approximation thereof): $KL[~q(z|x)~\\|~q(z)~]$\n",
59 | "\n",
60 | "That's what we're going to do here, using neural networks for $p$ and $q$."
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "### Additional Motivations/Reference Notes\n",
68 | "\n",
69 | "In classical CV/ML, we might have designed these analytically for things like rotation, translation, scaling, etc.; these transformations themselves have analytic descriptions, and we can view the removal their effects as [quotienting out a group operation](https://arxiv.org/abs/1602.07576).\n",
70 | "\n",
71 | "For more general factors $c$, this isn't always possible. One example in [algorithmic fairness](https://arxiv.org/abs/1511.00830) requires the removal of sensitive or protected attributes (Race, Religion, Gender, Orientation) from $x$, which is a) probably not a group operation, and b) definitely has no easy analytic description. This also occurs in [instrument bias](https://arxiv.org/abs/1904.05375) for observational studies. Further, for generative models it can be helpful to \"modulo out\" certain factors, only to add them back in a [controlled manner](https://arxiv.org/abs/1706.00409).\n",
72 | "\n",
73 | "We summarized our theory objective with: \"*Make sure $z$ doesn't change if you change $c$*\". Here, \"*you*\" is often \"*a mysterious force which we cannot control*\", since if we could change $c$ and get new $(x,c)$ pairs, we could perform data augmentation. In fact, we find this regularly occuring things like rotation, translation, etc., and for some effects which aren't groups/semi-groups (occlusions, additive noise). It's simple and it \"completes the orbit\". In other words: it adds an $x$ to our dataset for each possible value of $c$. Usually on the fly."
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {},
79 | "source": [
80 | "## Programming Setup\n",
81 | "\n",
82 | "Okay, so we'll need Keras, Numpy, and MNIST data. I'm going to hide all that, but we're going to get $x$ as a big flat vector and $y$ as a one-hot categorical variable."
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 1,
88 | "metadata": {
89 | "tags": [
90 | "hide_input"
91 | ]
92 | },
93 | "outputs": [
94 | {
95 | "name": "stderr",
96 | "output_type": "stream",
97 | "text": [
98 | "Using TensorFlow backend.\n"
99 | ]
100 | }
101 | ],
102 | "source": [
103 | "import keras\n",
104 | "import keras.backend as K\n",
105 | "import tensorflow as tf\n",
106 | "import numpy as np\n",
107 | "\n",
108 | "IMG_DIM = 28\n",
109 | "NUM_LABELS = 10\n",
110 | "\n",
111 | "#data comes as images and 1-dim {0,...,9} categorical variable\n",
112 | "(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()\n",
113 | " \n",
114 | "#cast and flatten images, renormalizing to [0,1]\n",
115 | "train_x = train_x.astype(np.float32).reshape( (train_x.shape[0], IMG_DIM**2) ) / 255.0\n",
116 | "test_x = test_x.astype(np.float32).reshape( (test_x.shape[0], IMG_DIM**2) ) / 255.0\n",
117 | "\n",
118 | "#copypaste\n",
119 | "def one_hot(labels):\n",
120 | " num_labels_data = labels.shape[0]\n",
121 | " one_hot_encoding = np.zeros((num_labels_data,NUM_LABELS))\n",
122 | " one_hot_encoding[np.arange(num_labels_data),labels] = 1\n",
123 | " one_hot_encoding = np.reshape(one_hot_encoding, [-1, NUM_LABELS])\n",
124 | " return one_hot_encoding\n",
125 | "\n",
126 | "train_y = one_hot(train_y).astype(np.float32)\n",
127 | "test_y = one_hot(test_y).astype(np.float32)"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": 2,
133 | "metadata": {},
134 | "outputs": [
135 | {
136 | "name": "stdout",
137 | "output_type": "stream",
138 | "text": [
139 | "(60000, 784)\n",
140 | "(60000, 10)\n"
141 | ]
142 | }
143 | ],
144 | "source": [
145 | "print(train_x.shape)\n",
146 | "print(train_y.shape)"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "metadata": {},
152 | "source": [
153 | "## Building a Conditional VAE architecture \n",
154 | "\n",
155 | "\n",
156 | "The Conditional VAE is, as its name suggests, a variational auto-encoder with conditional output. This means it should \n",
157 | "1. Take in $x$ and map it to a latent variable $z$ using the encoder\n",
158 | "2. Map $z$ to a new $\\hat{x}$ using the decoder.\n",
159 | "3. Condition (a.k.a. control) the output $\\hat{x}$ by another input $c$, representing specific other factors.\n",
160 | "\n",
161 | "In order to do this we'll also set a few constants. We're using ```DIM_Z=5```, but feel free to come back and play around with it later. We're also using \"tanh\" activations, but you can come back and change this as well (e.g. to \"relu\").\n",
162 | "\n",
163 | "For some this is pretty basic stuff. Skip to [Loss Construction](#loss-construction) if you can do this on your own."
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 3,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "DIM_Z = 16\n",
173 | "DIM_C = NUM_LABELS # just as an example. Sometimes we also have another y\n",
174 | "INPUT_SHAPE = IMG_DIM ** 2\n",
175 | "ACTIVATION = \"tanh\""
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {},
181 | "source": [
182 | "Then we'll build the encoder, which is a two layer fully-connected feed-forward network with two outputs, $z_{mean}$ and $z_{sigma}$. To avoid domain problems, we'll actually output $\\log z_{sigma}$. This means we can use a linear layer instead of having to choose a non-negative activation."
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 4,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "#declare inputs to the encoder, which is just x\n",
192 | "input_x = keras.layers.Input( shape = [INPUT_SHAPE], name=\"x\" )\n",
193 | "\n",
194 | "#first hidden layer\n",
195 | "enc_hidden_1 = keras.layers.Dense(512, activation=ACTIVATION, name=\"enc_h1\")(input_x)\n",
196 | "#second hidden layer\n",
197 | "enc_hidden_2 = keras.layers.Dense(512, activation=ACTIVATION, name=\"enc_h2\")(enc_hidden_1)\n",
198 | "\n",
199 | "#first output, z_mean\n",
200 | "z_mean = keras.layers.Dense(DIM_Z, activation=ACTIVATION)(enc_hidden_2)\n",
201 | "#second hidden output, z_log_sigma_sq.\n",
202 | "z_log_sigma_sq = keras.layers.Dense(DIM_Z, activation=\"linear\")(enc_hidden_2)"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {},
208 | "source": [
209 | "Creating the latent variable $z$ using a Gaussian layer can be a bit tricky, but luckily is covered in the [Keras documentation](https://keras.io/examples/variational_autoencoder/), which this part follows almost exactly. If you haven't seen the reparameterization trick before, this is to create a Gaussian distributed layer $z$ which can be differentiated w.r.t. its parameters. It was popularized by [Kingma and Welling 2013](https://arxiv.org/abs/1312.6114). "
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 5,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "#stolen straight from the docs\n",
219 | "#https://keras.io/examples/variational_autoencoder/\n",
220 | "def sampling(args):\n",
221 | " \"\"\"Reparameterization trick by sampling from an isotropic unit Gaussian.\n",
222 | "\n",
223 | " # Arguments\n",
224 | " args (tensor): mean and log of variance of Q(z|X)\n",
225 | "\n",
226 | " # Returns\n",
227 | " z (tensor): sampled latent vector\n",
228 | " \"\"\"\n",
229 | "\n",
230 | " z_mean, z_log_var = args\n",
231 | " batch = K.shape(z_mean)[0]\n",
232 | " dim = K.int_shape(z_mean)[1]\n",
233 | " # by default, random_normal has mean = 0 and std = 1.0\n",
234 | " epsilon = K.random_normal(shape=(batch, dim))\n",
235 | " return z_mean + K.exp(0.5 * z_log_var) * epsilon\n",
236 | "\n",
237 | "\n",
238 | "z_mean = keras.layers.Dense(DIM_Z, activation=\"tanh\")(enc_hidden_2)\n",
239 | "z_log_sigma_sq = keras.layers.Dense(DIM_Z, activation=\"linear\")(enc_hidden_2)\n",
240 | "\n",
241 | "z = keras.layers.Lambda(sampling, output_shape=(DIM_Z,), name='z')([z_mean, z_log_sigma_sq])"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "metadata": {},
247 | "source": [
248 | "Great, that's the Encoder from $x$ to $z$ done. Let's build the conditional decoder. There isn't an established standard way to condition the output, but concatenating $z$ and $c$ before inputting into the decoder is good enough for now."
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 6,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "#declare any additional inputs to our decoder, in this case c\n",
258 | "input_c = keras.layers.Input( shape = [DIM_C], name=\"c\")\n",
259 | "#this is the concat operation!\n",
260 | "z_with_c = keras.layers.concatenate([z,input_c])\n",
261 | "\n",
262 | "#first hidden layer\n",
263 | "dec_hidden_1 = keras.layers.Dense(512, activation=ACTIVATION, name=\"dec_h1\")(z_with_c)\n",
264 | "#second hidden layer\n",
265 | "dec_hidden_2 = keras.layers.Dense(512, activation=ACTIVATION, name=\"dec_h2\")(dec_hidden_1)\n",
266 | "\n",
267 | "#output, should be same domain as x_hat\n",
268 | "#could also use sigmoid activation\n",
269 | "x_hat = keras.layers.Dense( INPUT_SHAPE, name=\"x_hat\", activation=\"linear\" )(dec_hidden_2)"
270 | ]
271 | },
272 | {
273 | "cell_type": "markdown",
274 | "metadata": {},
275 | "source": [
276 | "Loss Construction for Invariance (...and also all that VAE stuff)\n",
277 | "----\n",
278 | "\n",
279 | "\n",
280 | "Okay, so we have both the encoder and the conditional decoder now, so the next step is building the loss function. There are three sub-components:\n",
281 | "\n",
282 | "1. Reconstruction (how far is $\\hat{x}$ from $x$), usually $\\|x - \\hat{x}\\|_2^2$.\n",
283 | "2. \"Distance\" to the Prior (from the original VAE definition), $KL[q(z|x)| p(z)]$\n",
284 | "3. \"Distance\" to the Empirical Marginal Posterior (from our paper, among others), $KL[q(z|x)| q(z)]$.\n",
285 | "\n",
286 | "The third one we'll have to approximate, so we'll deal with that second. First, the two easy ones, straight from Keras Docs, plus defining some hyper parameters:"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": 7,
292 | "metadata": {},
293 | "outputs": [],
294 | "source": [
295 | "params = {\n",
296 | " \"beta\" : 0.1,\n",
297 | " \"lambda\" : 1.0,\n",
298 | "}\n",
299 | "\n",
300 | "recon_loss = keras.losses.mse(input_x, x_hat)\n",
301 | "recon_loss *= INPUT_SHAPE #optional, in the tutorial code though\n",
302 | "\n",
303 | "prior_loss = 1 + z_log_sigma_sq - K.square(z_mean) - K.exp(z_log_sigma_sq)\n",
304 | "prior_loss = K.sum(prior_loss, axis=-1)\n",
305 | "prior_loss *= -0.5"
306 | ]
307 | },
308 | {
309 | "cell_type": "markdown",
310 | "metadata": {},
311 | "source": [
312 | "Now for $KL[ q(z|x) | q(z) ]$. Since we're using the Gaussian $z$ layer, there's an approximation we can make using the pairwise Gaussian KL divergences. **In the original version of the paper there is an erroneous extra term at this part.** The corrected version is implemented below, but before we get into it, we should remember that we actually want to compute $KL[ q(z|x) | q(z) ]$, not this bound, and that this could also be solved using:\n",
313 | "\n",
314 | "* Direct approximation (e.g. use a neural network to approximate posterior marginal $q(z)$ term given $q(z|x)$ parameters and samples $x$).\n",
315 | "* Sampling\n",
316 | "* Use a different $z$ layer with analytic divergence to its marginal.\n",
317 | "\n",
318 | "The first and second options appear in the literature; see e.g. [Fixing a Broken ELBO (Alemi et al. 2017)](https://arxiv.org/abs/1711.00464) and [Structured Disentangled Representations Esmaeili et al. 2018](https://arxiv.org/abs/1804.02086). The third option appears in our paper [Echo Noise](https://arxiv.org/abs/1904.07199), which we'll demonstrate at the end.\n",
319 | "\n",
320 | "Knowing this, here's one way to do it *for Gaussian $z$ layers* using pairwise KL divergence:"
321 | ]
322 | },
323 | {
324 | "cell_type": "code",
325 | "execution_count": 8,
326 | "metadata": {},
327 | "outputs": [],
328 | "source": [
329 | "\n",
330 | "#KL(N_0|N_1) = tr(\\sigma_1^{-1} \\sigma_0) + \n",
331 | "# (\\mu_1 - \\mu_0)\\sigma_1^{-1}(\\mu_1 - \\mu_0) - k +\n",
332 | "# \\log( \\frac{\\det \\sigma_1}{\\det \\sigma_0} )\n",
333 | "def all_pairs_gaussian_kl(mu, sigma, add_third_term=False):\n",
334 | " sigma_sq = tf.square(sigma) + 1e-8\n",
335 | " sigma_sq_inv = tf.math.reciprocal(sigma_sq)\n",
336 | "\n",
337 | " #dot product of all sigma_inv vectors with sigma is the same as a matrix mult of diag\n",
338 | " first_term = tf.matmul(sigma_sq,tf.transpose(sigma_sq_inv))\n",
339 | " \n",
340 | " r = tf.matmul(mu * mu,tf.transpose(sigma_sq_inv))\n",
341 | " r2 = mu * mu * sigma_sq_inv \n",
342 | " r2 = tf.reduce_sum(r2,1)\n",
343 | " \n",
344 | " #squared distance\n",
345 | " #(mu[i] - mu[j])\\sigma_inv(mu[i] - mu[j]) = r[i] - 2*mu[i]*mu[j] + r[j]\n",
346 | " #uses broadcasting\n",
347 | " second_term = 2*tf.matmul(mu, tf.transpose(mu*sigma_sq_inv))\n",
348 | " second_term = r - second_term + tf.transpose(r2)\n",
349 | "\n",
350 | " # log det A = tr log A\n",
351 | " # log \\frac{ det \\Sigma_1 }{ det \\Sigma_0 } =\n",
352 | " # \\tr\\log \\Sigma_1 - \\tr\\log \\Sigma_0 \n",
353 | " # for each sample, we have B comparisons to B other samples...\n",
354 | " # so this cancels out\n",
355 | "\n",
356 | " if(add_third_term):\n",
357 | " r = tf.reduce_sum(tf.math.log(sigma_sq),1)\n",
358 | " r = tf.reshape(r,[-1,1])\n",
359 | " third_term = r - tf.transpose(r)\n",
360 | " else:\n",
361 | " third_term = 0\n",
362 | "\n",
363 | " #- tf.reduce_sum(tf.log(1e-8 + tf.square(sigma)))\\\n",
364 | " # the dim_z ** 3 term comes fro\n",
365 | " # -the k in the original expression\n",
366 | " # -this happening k times in for each sample\n",
367 | " # -this happening for k samples\n",
368 | " #return 0.5 * ( first_term + second_term + third_term - dim_z )\n",
369 | " return 0.5 * ( first_term + second_term + third_term )\n",
370 | "\n",
371 | "#\n",
372 | "# kl_conditional_and_marg\n",
373 | "# \\sum_{x'} KL[ q(z|x) \\| q(z|x') ] + (B-1) H[q(z|x)]\n",
374 | "#\n",
375 | "\n",
376 | "#def kl_conditional_and_marg(args):\n",
377 | "def kl_conditional_and_marg(z_mean, z_log_sigma_sq, dim_z):\n",
378 | " z_sigma = tf.exp( 0.5 * z_log_sigma_sq )\n",
379 | " all_pairs_GKL = all_pairs_gaussian_kl(z_mean, z_sigma, True) - 0.5*dim_z\n",
380 | " return tf.reduce_mean(all_pairs_GKL)"
381 | ]
382 | },
383 | {
384 | "cell_type": "markdown",
385 | "metadata": {},
386 | "source": [
387 | "So after all that, we can create our third loss term. We'll then add them all up, and create our model!"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 9,
393 | "metadata": {},
394 | "outputs": [
395 | {
396 | "name": "stderr",
397 | "output_type": "stream",
398 | "text": [
399 | "/home/dcmoyer/anaconda3/lib/python3.7/site-packages/keras/engine/training_utils.py:819: UserWarning: Output x_hat missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to x_hat.\n",
400 | " 'be expecting any data to be passed to {0}.'.format(name))\n"
401 | ]
402 | }
403 | ],
404 | "source": [
405 | "kl_qzx_qz_loss = kl_conditional_and_marg(z_mean, z_log_sigma_sq, DIM_Z)\n",
406 | "\n",
407 | "#Invariant Conditional Variational Autoencoder (ICVAE).\n",
408 | "#I think the name game is meh (what if someone else got there first? or defines a different one later?)\n",
409 | "#so maybe it'd be easier to say Moyer et al. 2018? too narcissistic? Conditional VAE with additional compression?\n",
410 | "#\n",
411 | "# ...the point is, this one can induce invariance. So can the others sometimes, too, in practice.\n",
412 | "\n",
413 | "icvae_loss = K.mean((1 + params[\"lambda\"]) * recon_loss + params[\"beta\"]*prior_loss + params[\"lambda\"]*kl_qzx_qz_loss)\n",
414 | "\n",
415 | "icvae = keras.models.Model(inputs=[input_x,input_c], outputs=x_hat, name=\"ICVAE\")\n",
416 | "\n",
417 | "icvae.add_loss(icvae_loss)\n",
418 | "\n",
419 | "learning_rate = 0.0005\n",
420 | "opt = keras.optimizers.Adam(lr=learning_rate)\n",
421 | "\n",
422 | "icvae.compile( optimizer=opt, )\n"
423 | ]
424 | },
425 | {
426 | "cell_type": "markdown",
427 | "metadata": {},
428 | "source": [
429 | "## Training\n",
430 | "\n",
431 | "Only run this bit once. It takes a bit (but not too long tbh). If this isn't your jupyter server, you should delete the path stuff. I just wanted to save you from repetitive gpu time/coffee breaks."
432 | ]
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": 10,
437 | "metadata": {},
438 | "outputs": [
439 | {
440 | "name": "stdout",
441 | "output_type": "stream",
442 | "text": [
443 | "loading from file\n"
444 | ]
445 | }
446 | ],
447 | "source": [
448 | "import os\n",
449 | "if not os.path.exists(\"mnist_icvae.h5\"):\n",
450 | " print(\"training\")\n",
451 | " icvae.fit(\n",
452 | " { \"x\" : train_x, \"c\" : train_y }, epochs=100\n",
453 | " )\n",
454 | " icvae.save_weights(\"mnist_icvae.h5\")\n",
455 | "else:\n",
456 | " print(\"loading from file\")\n",
457 | " icvae.load_weights(\"mnist_icvae.h5\")"
458 | ]
459 | },
460 | {
461 | "cell_type": "markdown",
462 | "metadata": {},
463 | "source": [
464 | "# Plots and Evaluation\n",
465 | "\n",
466 | "Okay, so now that we've trained it, let's see how well it does."
467 | ]
468 | },
469 | {
470 | "cell_type": "code",
471 | "execution_count": 13,
472 | "metadata": {},
473 | "outputs": [
474 | {
475 | "data": {
476 | "image/png": "\n",
477 | "text/plain": [
478 | ""
479 | ]
480 | },
481 | "metadata": {},
482 | "output_type": "display_data"
483 | }
484 | ],
485 | "source": [
486 | "n_plot_samps = 10\n",
487 | "test_x_hat = icvae.predict(\n",
488 | " { \"x\" : test_x[:n_plot_samps], \"c\" : test_y[:n_plot_samps] }\n",
489 | ")\n",
490 | "\n",
491 | "##\n",
492 | "## plot first N\n",
493 | "##\n",
494 | "\n",
495 | "from src.plot_tools import pics_tools as pic\n",
496 | "import matplotlib.pyplot as plt\n",
497 | "plt.style.use('grayscale')\n",
498 | "\n",
499 | "fig = pic.plot_image_grid( \\\n",
500 | " 1-np.concatenate([test_x[:n_plot_samps],test_x_hat], axis=0),\n",
501 | " [IMG_DIM, IMG_DIM], \\\n",
502 | " (2,n_plot_samps) \\\n",
503 | ")\n",
504 | "plt.show()"
505 | ]
506 | },
507 | {
508 | "cell_type": "markdown",
509 | "metadata": {},
510 | "source": [
511 | "Probably a bit fuzzy right?\n",
512 | "\n",
513 | "One thing we can do immediately though is force our network to use the mean encoding $z$. Currently it's actually sampling, so if you ran the above code block multiple times you'd get different outputs.\n",
514 | "\n",
515 | "Still, fuzzy output is a common complaint of VAE or Auto-encoders in general (in comparison with, e.g., GANs). You can somewhat ameliorate this by more epochs, a bigger latent space, more layers, convolution, etc., but in the end it might be that $\\ell_2$ isn't a great image metric (a.k.a. distortion measure). Designing one by hand seems very hard, but 5+ years of GAN research seems to show that adversaries do well (replacing $\\log p(x|z,c) \\propto \\|x - \\hat{x} \\|$). This is, in effect, training another network $r$ to produce proxy likelihoods $r(x)$ from outputs $p(x|z,c)$. There's nothing stopping us from using an adversary replacement for $\\log p(x|z,c)$.\n",
516 | "\n",
517 | "Anyway, in this next section we'll do the famous remapping tricks, taking one digit and making it map to all other 9. This is done by manipulating $c$ (at test time obviously)."
518 | ]
519 | },
520 | {
521 | "cell_type": "code",
522 | "execution_count": 14,
523 | "metadata": {},
524 | "outputs": [
525 | {
526 | "data": {
527 | "image/png": "\n",
528 | "text/plain": [
529 | ""
530 | ]
531 | },
532 | "metadata": {},
533 | "output_type": "display_data"
534 | }
535 | ],
536 | "source": [
537 | "X_test_set = []\n",
538 | "Y_test_set = []\n",
539 | "\n",
540 | "for i in range(n_plot_samps):\n",
541 | " tmp_tile_array = np.tile(test_x[i],[10,1])\n",
542 | " X_test_set.append(test_x[i:(i+1),:])\n",
543 | " X_test_set.append(tmp_tile_array)\n",
544 | "\n",
545 | " Y_test_set.append(test_y[i:(i+1),:])\n",
546 | " #Y_test_set.append(np.array([[0],[1]]))\n",
547 | " Y_test_set.append(np.eye(10))\n",
548 | "\n",
549 | "X_test_set = np.concatenate(X_test_set, axis=0)\n",
550 | "Y_test_set = np.concatenate(Y_test_set, axis=0)\n",
551 | "\n",
552 | "X_test_hat = icvae.predict(\n",
553 | " { \"x\" : X_test_set, \"c\" : Y_test_set }\n",
554 | ")\n",
555 | "\n",
556 | "plot_collection = []\n",
557 | "for i in range(n_plot_samps):\n",
558 | " plot_collection.append( test_x[i:(i+1),:] )\n",
559 | " plot_collection.append( X_test_hat[i*11:(i+1)*11,:] )\n",
560 | "\n",
561 | "plot_collection = np.concatenate( plot_collection, axis=0 )\n",
562 | "\n",
563 | "fig = pic.plot_image_grid( \\\n",
564 | " 1-plot_collection,\n",
565 | " [IMG_DIM, IMG_DIM], \\\n",
566 | " (n_plot_samps,12) \\\n",
567 | ")\n",
568 | "plt.show()"
569 | ]
570 | },
571 | {
572 | "cell_type": "markdown",
573 | "metadata": {},
574 | "source": [
575 | "Okay, so while it doesn't work amazingly well (...some of those 9s are questionable), it **is** doing two things correctly:\n",
576 | "1. Removing the input digit class, e.g. we can't see residual 0 bits when mapping 0 to other digits.\n",
577 | "2. Transfering some elements of the style, notably how dark digits are, how tilted each digit is, and their approximate vertical alignment.\n",
578 | "\n",
579 | "With better and more expressive decoders we can improve the reconstruction performance, but the general idea here is clear. Next up: more theory connections to adversaries, and an invariant prediction method."
580 | ]
581 | }
582 | ],
583 | "metadata": {
584 | "celltoolbar": "Tags",
585 | "kernelspec": {
586 | "display_name": "Python 3",
587 | "language": "python",
588 | "name": "python3"
589 | },
590 | "language_info": {
591 | "codemirror_mode": {
592 | "name": "ipython",
593 | "version": 3
594 | },
595 | "file_extension": ".py",
596 | "mimetype": "text/x-python",
597 | "name": "python",
598 | "nbconvert_exporter": "python",
599 | "pygments_lexer": "ipython3",
600 | "version": "3.7.3"
601 | }
602 | },
603 | "nbformat": 4,
604 | "nbformat_minor": 2
605 | }
606 |
--------------------------------------------------------------------------------