├── README.md ├── cppn.py ├── examples ├── .DS_Store ├── color_tanh_1.png ├── color_tanh_1b.png ├── color_tanh_2.png ├── color_tanh_2b.png ├── color_tanh_3.png ├── color_tanh_3b.png ├── color_tanh_4.png ├── color_tanh_4b.png ├── color_tanh_5.png ├── color_tanh_5b.png ├── color_tanh_6b.png ├── contours_1.png ├── countour_2.png ├── cppn.gif ├── output.gif ├── tanh_01.png ├── tanh_02.png ├── tanh_03.png ├── tanh_04.png ├── tanh_05.png ├── tanh_06.png ├── tanh_anim_0.gif ├── tanh_anim_1.gif ├── tanh_anim_end.png ├── tanh_anim_start.png ├── tanh_anim_twitter.gif ├── tanh_color.png ├── tanh_softplus_01.png ├── tanh_softplus_02.png ├── tanh_softplus_03.png ├── tanh_softplus_04.png ├── tanh_softplus_05.png ├── tanh_softplus_06.png ├── tanh_softplus_nature_01.png └── tanh_softplus_nature_02.png ├── main.py ├── ops.py ├── rppn.py └── sampler.py /README.md: -------------------------------------------------------------------------------- 1 | # cppn-tensorflow 2 | 3 | Simplified implementation of [Compositional Pattern Producing Network](https://en.wikipedia.org/wiki/Compositional_pattern-producing_network) in TensorFlow for the purpose of abstract art generation and for future deep learning work in generative algorithms. 4 | 5 | Examples of images generated by the simplified CPPN with tanh activation. 6 | ![color_tanh_2](https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/master/examples/color_tanh_2.png) 7 | ![color_tanh_4](https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/master/examples/color_tanh_4.png) 8 | ![color_tanh_5](https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/master/examples/color_tanh_5.png) 9 | 10 | Examples of animations generated by the simplified CPPN interpolating between two divverent _z_ embeddings. 11 | ![Morphing](https://cdn.rawgit.com/hardmaru/cppn-tensorflow/master/examples/cppn.gif) 12 | ![Morphing](https://cdn.rawgit.com/hardmaru/cppn-tensorflow/master/examples/output.gif) 13 | 14 | Examples of animations generated by the simplified RPPN gradually increasing the _k_ steps of recursion. 15 | ![Recurring](http://www.w4nderlu.st/content/2-projects/16-rppn/img2_anim_k0_k24_512.gif) 16 | 17 | See Otoro's blog post at [blog.otoro.net](http://blog.otoro.net/2016/03/25/generating-abstract-patterns-with-tensorflow/) for more details. 18 | See my blogpost at [w4nderlu.st](http://www.w4nderlu.st/projects/rppn) for details on RPPN. 19 | 20 | My contribution: 21 | - added a new model, RPPN (Recursive Pattern Producing Network) 22 | - added a generalized activation function strategy 23 | - added layer norm normalization 24 | - added cosine similary as an alternative linear layer 25 | - porting to Python 3 and TensorFlow 1.0 26 | - mp4 video generation 27 | 28 | Requirements: 29 | - TensorFlow 1.0.0+ 30 | - imageio for image generation 31 | - ffmpeg for video generation 32 | 33 | # License 34 | 35 | MIT 36 | -------------------------------------------------------------------------------- /cppn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementation of Compositional Pattern Producing Networks in Tensorflow 3 | 4 | https://en.wikipedia.org/wiki/Compositional_pattern-producing_network 5 | 6 | @hardmaru, 2016 7 | @w4nderlust, 2017 8 | 9 | ''' 10 | from __future__ import absolute_import, division, print_function, unicode_literals 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | from ops import * 15 | 16 | 17 | class CPPN(): 18 | def __init__(self, batch_size=1, z_dim=32, c_dim=1, scale=8.0, net_size=32, **kwargs): 19 | """ 20 | Args: 21 | z_dim: how many dimensions of the latent space vector (R^z_dim) 22 | c_dim: 1 for mono, 3 for rgb. dimension for output space. you can modify code to do HSV rather than RGB. 23 | net_size: number of nodes for each fully connected layer of cppn 24 | scale: the bigger, the more zoomed out the picture becomes 25 | 26 | """ 27 | self.batch_size = batch_size 28 | self.net_size = net_size 29 | x_dim = 256 30 | y_dim = 256 31 | self.x_dim = x_dim 32 | self.y_dim = y_dim 33 | self.scale = scale 34 | self.c_dim = c_dim 35 | self.z_dim = z_dim 36 | 37 | # tf Graph batch of image (batch_size, height, width, depth) 38 | self.batch = tf.placeholder(tf.float32, [batch_size, x_dim, y_dim, c_dim]) 39 | 40 | n_points = x_dim * y_dim 41 | self.n_points = n_points 42 | 43 | self.x_vec, self.y_vec, self.r_vec = self._coordinates(x_dim, y_dim, scale) 44 | 45 | # latent vector 46 | self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim]) 47 | # inputs to cppn, like coordinates and radius from centre 48 | self.x = tf.placeholder(tf.float32, [self.batch_size, None, 1]) 49 | self.y = tf.placeholder(tf.float32, [self.batch_size, None, 1]) 50 | self.r = tf.placeholder(tf.float32, [self.batch_size, None, 1]) 51 | 52 | # builds the generator network 53 | self.G = self.generator(x_dim=self.x_dim, y_dim=self.y_dim) 54 | 55 | self.init() 56 | 57 | def init(self): 58 | # Initializing the tensor flow variables 59 | init = tf.global_variables_initializer() 60 | # Launch the session 61 | self.sess = tf.Session() 62 | self.sess.run(init) 63 | 64 | def reinit(self): 65 | init = tf.variables_initializer(tf.trainable_variables()) 66 | self.sess.run(init) 67 | 68 | def _coordinates(self, x_dim=32, y_dim=32, scale=1.0): 69 | ''' 70 | calculates and returns a vector of x and y coordintes, and corresponding radius from the centre of image. 71 | ''' 72 | n_points = x_dim * y_dim 73 | # creates x and y ranges of x/y_dim numbers from -scale to +scale 74 | x_range = scale * ((np.arange(x_dim) - (x_dim - 1) / 2.0) / (x_dim - 1) * 2) 75 | y_range = scale * ((np.arange(y_dim) - (y_dim - 1) / 2.0) / (y_dim - 1) * 2) 76 | # create all r distances from center for any combination of coordinates on x and y 77 | x_mat = np.matmul(np.ones((y_dim, 1)), x_range.reshape((1, x_dim))) 78 | y_mat = np.matmul(y_range.reshape((y_dim, 1)), np.ones((1, x_dim))) 79 | r_mat = np.sqrt(x_mat * x_mat + y_mat * y_mat) 80 | # transform the x x y matrices tiling as many of them as the batch size 81 | # and reshaping to obtain a èbatch_size, x*y, 1+ tensor 82 | x_mat = np.tile(x_mat.flatten(), self.batch_size).reshape(self.batch_size, n_points, 1) 83 | y_mat = np.tile(y_mat.flatten(), self.batch_size).reshape(self.batch_size, n_points, 1) 84 | r_mat = np.tile(r_mat.flatten(), self.batch_size).reshape(self.batch_size, n_points, 1) 85 | return x_mat, y_mat, r_mat 86 | 87 | def generator(self, x_dim, y_dim, reuse=False): 88 | if reuse: 89 | tf.get_variable_scope().reuse_variables() 90 | 91 | net_size = self.net_size 92 | n_points = x_dim * y_dim 93 | 94 | # note that latent vector z is scaled to self.scale factor. 95 | z_scaled = tf.reshape(self.z, [self.batch_size, 1, self.z_dim]) * \ 96 | tf.ones([n_points, 1], dtype=tf.float32) * self.scale 97 | z_unroll = tf.reshape(z_scaled, [self.batch_size * n_points, self.z_dim]) 98 | x_unroll = tf.reshape(self.x, [self.batch_size * n_points, 1]) 99 | y_unroll = tf.reshape(self.y, [self.batch_size * n_points, 1]) 100 | r_unroll = tf.reshape(self.r, [self.batch_size * n_points, 1]) 101 | 102 | U = fully_connected(z_unroll, net_size, 'g_0_z') + \ 103 | fully_connected(x_unroll, net_size, 'g_0_x', with_bias=False) + \ 104 | fully_connected(y_unroll, net_size, 'g_0_y', with_bias=False) + \ 105 | fully_connected(r_unroll, net_size, 'g_0_r', with_bias=False) 106 | 107 | ''' 108 | Below are a bunch of examples of different CPPN configurations. 109 | Feel free to comment out and experiment! 110 | ''' 111 | 112 | ### 113 | ### Example: 3 layers of tanh() layers, with net_size = 32 activations/layer 114 | ### 115 | # ''' 116 | H = tf.nn.tanh(U) 117 | for i in range(3): 118 | H = tf.nn.tanh(fully_connected(H, net_size, 'g_tanh_' + str(i))) 119 | output = tf.sigmoid(fully_connected(H, self.c_dim, 'g_final')) 120 | # ''' 121 | 122 | ### 123 | ### Similar to example above, but instead the output is 124 | ### a weird function rather than just the sigmoid 125 | ''' 126 | H = tf.nn.tanh(U) 127 | for i in range(3): 128 | H = tf.nn.tanh(fully_connected(H, net_size, 'g_tanh_'+str(i))) 129 | output = tf.sqrt(1.0-tf.abs(tf.tanh(fully_connected(H, self.c_dim, 'g_final')))) 130 | ''' 131 | 132 | ### 133 | ### Example: mixing softplus and tanh layers, with net_size = 32 activations/layer 134 | ### 135 | ''' 136 | H = tf.nn.tanh(U) 137 | H = tf.nn.softplus(fully_connected(H, net_size, 'g_softplus_1')) 138 | H = tf.nn.tanh(fully_connected(H, net_size, 'g_tanh_2')) 139 | H = tf.nn.softplus(fully_connected(H, net_size, 'g_softplus_2')) 140 | H = tf.nn.tanh(fully_connected(H, net_size, 'g_tanh_2')) 141 | H = tf.nn.softplus(fully_connected(H, net_size, 'g_softplus_2')) 142 | output = tf.sigmoid(fully_connected(H, self.c_dim, 'g_final')) 143 | ''' 144 | 145 | ### 146 | ### Example: mixing sinusoids, tanh and multiple softplus layers 147 | ### 148 | ''' 149 | H = tf.nn.tanh(U) 150 | H = tf.nn.softplus(fully_connected(H, net_size, 'g_softplus_1')) 151 | H = tf.nn.tanh(fully_connected(H, net_size, 'g_tanh_2')) 152 | H = tf.nn.softplus(fully_connected(H, net_size, 'g_softplus_2')) 153 | output = 0.5 * tf.sin(fully_connected(H, self.c_dim, 'g_final')) + 0.5 154 | ''' 155 | 156 | ### 157 | ### Example: residual network of 4 tanh() layers 158 | ### 159 | ''' 160 | H = tf.nn.tanh(U) 161 | for i in range(3): 162 | H = H+tf.nn.tanh(fully_connected(H, net_size, g_tanh_'+str(i))) 163 | output = tf.sigmoid(fully_connected(H, self.c_dim, 'g_final')) 164 | ''' 165 | 166 | ''' 167 | The final hidden later is pass through a fully connected sigmoid later, so outputs -> (0, 1) 168 | Also, the output has a dimension of c_dim, so can be monotone or RGB 169 | ''' 170 | result = tf.reshape(output, [self.batch_size, y_dim, x_dim, self.c_dim]) 171 | 172 | return result 173 | 174 | def generate(self, z=None, x_dim=26, y_dim=26, scale=8.0, **kwargs): 175 | """ Generate data by sampling from latent space. 176 | If z is not None, data for this point in latent space is 177 | generated. Otherwise, z is drawn from prior in latent 178 | space. 179 | """ 180 | if z is None: 181 | z = np.random.uniform(-1.0, 1.0, size=(self.batch_size, self.z_dim)).astype(np.float32) 182 | # Note: This maps to mean of distribution, we could alternatively 183 | # sample from Gaussian distribution 184 | 185 | G = self.generator(x_dim=x_dim, y_dim=y_dim, reuse=True) 186 | x_vec, y_vec, r_vec = self._coordinates(x_dim, y_dim, scale=scale) 187 | image = self.sess.run(G, feed_dict={self.z: z, self.x: x_vec, self.y: y_vec, self.r: r_vec}) 188 | return image 189 | 190 | def close(self): 191 | self.sess.close() 192 | -------------------------------------------------------------------------------- /examples/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/.DS_Store -------------------------------------------------------------------------------- /examples/color_tanh_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_1.png -------------------------------------------------------------------------------- /examples/color_tanh_1b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_1b.png -------------------------------------------------------------------------------- /examples/color_tanh_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_2.png -------------------------------------------------------------------------------- /examples/color_tanh_2b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_2b.png -------------------------------------------------------------------------------- /examples/color_tanh_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_3.png -------------------------------------------------------------------------------- /examples/color_tanh_3b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_3b.png -------------------------------------------------------------------------------- /examples/color_tanh_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_4.png -------------------------------------------------------------------------------- /examples/color_tanh_4b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_4b.png -------------------------------------------------------------------------------- /examples/color_tanh_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_5.png -------------------------------------------------------------------------------- /examples/color_tanh_5b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_5b.png -------------------------------------------------------------------------------- /examples/color_tanh_6b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/color_tanh_6b.png -------------------------------------------------------------------------------- /examples/contours_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/contours_1.png -------------------------------------------------------------------------------- /examples/countour_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/countour_2.png -------------------------------------------------------------------------------- /examples/cppn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/cppn.gif -------------------------------------------------------------------------------- /examples/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/output.gif -------------------------------------------------------------------------------- /examples/tanh_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_01.png -------------------------------------------------------------------------------- /examples/tanh_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_02.png -------------------------------------------------------------------------------- /examples/tanh_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_03.png -------------------------------------------------------------------------------- /examples/tanh_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_04.png -------------------------------------------------------------------------------- /examples/tanh_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_05.png -------------------------------------------------------------------------------- /examples/tanh_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_06.png -------------------------------------------------------------------------------- /examples/tanh_anim_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_anim_0.gif -------------------------------------------------------------------------------- /examples/tanh_anim_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_anim_1.gif -------------------------------------------------------------------------------- /examples/tanh_anim_end.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_anim_end.png -------------------------------------------------------------------------------- /examples/tanh_anim_start.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_anim_start.png -------------------------------------------------------------------------------- /examples/tanh_anim_twitter.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_anim_twitter.gif -------------------------------------------------------------------------------- /examples/tanh_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_color.png -------------------------------------------------------------------------------- /examples/tanh_softplus_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_softplus_01.png -------------------------------------------------------------------------------- /examples/tanh_softplus_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_softplus_02.png -------------------------------------------------------------------------------- /examples/tanh_softplus_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_softplus_03.png -------------------------------------------------------------------------------- /examples/tanh_softplus_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_softplus_04.png -------------------------------------------------------------------------------- /examples/tanh_softplus_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_softplus_05.png -------------------------------------------------------------------------------- /examples/tanh_softplus_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_softplus_06.png -------------------------------------------------------------------------------- /examples/tanh_softplus_nature_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_softplus_nature_01.png -------------------------------------------------------------------------------- /examples/tanh_softplus_nature_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/w4nderlust/cppn-tensorflow/857539b5fd1e6cddc988df55317f4fdfd7a4ef19/examples/tanh_softplus_nature_02.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from sampler import Sampler 2 | 3 | if __name__ == '__main__': 4 | ########## Test CPPN ########## 5 | # sampler = Sampler(model_type='CPPN', z_dim=4, c_dim=3, scale=8.0, net_size=32) 6 | 7 | ########## Test different resolutions ########## 8 | # z1 = sampler.generate_z() 9 | # img = sampler.generate(z1, 256, 256) 10 | # sampler.save_png(img, 'img_256.png') 11 | # img = sampler.generate(z1, 512, 512) 12 | # sampler.save_png(img, 'img_512.png') 13 | 14 | ########## Generate several wallpapers ########## 15 | # for i in range(10): 16 | # sampler.reinit() 17 | # img = sampler.generate(None, 2880, 1800) 18 | # sampler.save_png(img, 'img{}_2880_1800.png'.format(i + 1)) 19 | 20 | ########## Test RPNN ########## 21 | sampler = Sampler(model_type='RPPN', z_dim=4, c_dim=3, scale=4.0, net_size=32) 22 | 23 | ########## Test different resolutions ########## 24 | # z1 = sampler.generate_z() 25 | # img1 = sampler.generate(z1, 256, 256) 26 | # sampler.save_png(img1, 'img_k3_256.png') 27 | # img2 = sampler.generate(z1, 512, 512) 28 | # sampler.save_png(img2, 'img_k3_512.png') 29 | 30 | ########## Test different number of repetitions ########## 31 | # z1 = sampler.generate_z() 32 | # img1 = sampler.generate(z1, 256, 256, k=0) 33 | # sampler.save_png(img1, 'img_k0_256.png') 34 | # img2 = sampler.generate(z1, 256, 256, k=1) 35 | # sampler.save_png(img2, 'img_k1_256.png') 36 | # img3 = sampler.generate(z1, 256, 256, k=2) 37 | # sampler.save_png(img3, 'img_k2_256.png') 38 | # img4 = sampler.generate(z1, 256, 256, k=3) 39 | # sampler.save_png(img4, 'img_k3_256.png') 40 | # img5 = sampler.generate(z1, 256, 256, k=4) 41 | # sampler.save_png(img5, 'img_k4_256.png') 42 | # img6 = sampler.generate(z1, 256, 256, k=5) 43 | # sampler.save_png(img6, 'img_k5_256.png') 44 | 45 | ########## Test gif with different number of repetitions ########## 46 | z1 = sampler.generate_z() 47 | sampler.save_anim_gif(z1, z1, 'anim_k0_k24_256.gif', n_frame=24, duration1=0.5, \ 48 | duration2=1.0, duration=0.2, x_dim=256, y_dim=256, scale=4.0, k1=0, k2=24, reverse=True) 49 | 50 | ########## Test mp4 with different number of repetitions ########## 51 | # z1 = sampler.generate_z() 52 | # sampler.save_anim_mp4(z1, z1, 'anim_k0_k24_256.mp4', n_frame=24, fps=12, x_dim=256, y_dim=256, scale=10.0, k1=0, k2=24) 53 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import tensorflow as tf 4 | 5 | ''' 6 | some helper code borrowed from: 7 | https://github.com/carpedm20/DCGAN-tensorflow 8 | ''' 9 | 10 | 11 | def linear(input_, output_size, scope=None, stddev=1.0, bias_start=0.0, with_w=False): 12 | shape = input_.get_shape().as_list() 13 | 14 | with tf.variable_scope(scope or "Linear"): 15 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 16 | tf.random_normal_initializer(stddev=stddev)) 17 | bias = tf.get_variable("bias", [output_size], 18 | initializer=tf.constant_initializer(bias_start)) 19 | if with_w: 20 | return tf.matmul(input_, matrix) + bias, matrix, bias 21 | else: 22 | return tf.matmul(input_, matrix) + bias 23 | 24 | 25 | def fully_connected(input_, output_size, scope=None, stddev=1.0, with_bias=True): 26 | shape = input_.get_shape().as_list() 27 | 28 | with tf.variable_scope(scope or "FC"): 29 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 30 | tf.random_normal_initializer(stddev=stddev)) 31 | 32 | result = tf.matmul(input_, matrix) 33 | 34 | if with_bias: 35 | bias = tf.get_variable("bias", [1, output_size], 36 | initializer=tf.random_normal_initializer(stddev=stddev)) 37 | result += bias * tf.ones([shape[0], 1], dtype=tf.float32) 38 | 39 | return result 40 | -------------------------------------------------------------------------------- /rppn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementation of Compositional Pattern Producing Networks in Tensorflow 3 | 4 | https://en.wikipedia.org/wiki/Compositional_pattern-producing_network 5 | 6 | @hardmaru, 2016 7 | @w4nderlust, 2017 8 | 9 | ''' 10 | from __future__ import absolute_import, division, print_function, unicode_literals 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | from ops import * 15 | import math 16 | 17 | class RPPN(): 18 | def __init__(self, batch_size=1, z_dim=32, c_dim=1, scale=8.0, net_size=32, act='tanh', **kwargs): 19 | """ 20 | Args: 21 | z_dim: how many dimensions of the latent space vector (R^z_dim) 22 | c_dim: 1 for mono, 3 for rgb. dimension for output space. you can modify code to do HSV rather than RGB. 23 | net_size: number of nodes for each fully connected layer of cppn 24 | scale: the bigger, the more zoomed out the picture becomes 25 | 26 | """ 27 | self.batch_size = batch_size 28 | self.net_size = net_size 29 | self.act = act 30 | x_dim = 256 31 | y_dim = 256 32 | self.x_dim = x_dim 33 | self.y_dim = y_dim 34 | self.scale = scale 35 | self.c_dim = c_dim 36 | self.z_dim = z_dim 37 | 38 | # tf Graph batch of image (batch_size, height, width, depth) 39 | self.batch = tf.placeholder(tf.float32, [batch_size, x_dim, y_dim, c_dim]) 40 | 41 | n_points = x_dim * y_dim 42 | self.n_points = n_points 43 | 44 | self.x_vec, self.y_vec, self.r_vec = self._coordinates(x_dim, y_dim, scale) 45 | 46 | # latent vector 47 | self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim]) 48 | # inputs to cppn, like coordinates and radius from centre 49 | self.x = tf.placeholder(tf.float32, [self.batch_size, None, 1]) 50 | self.y = tf.placeholder(tf.float32, [self.batch_size, None, 1]) 51 | self.r = tf.placeholder(tf.float32, [self.batch_size, None, 1]) 52 | # input for number of repetitions 53 | self.k = tf.placeholder(tf.int32) 54 | 55 | # builds the generator network 56 | self.G = self.generator(x_dim=self.x_dim, y_dim=self.y_dim, act=act) 57 | 58 | self.init() 59 | 60 | def init(self): 61 | # Initializing the tensor flow variables 62 | init = tf.global_variables_initializer() 63 | # Launch the session 64 | self.sess = tf.Session() 65 | self.sess.run(init) 66 | 67 | def reinit(self): 68 | init = tf.variables_initializer(tf.trainable_variables()) 69 | self.sess.run(init) 70 | 71 | def _coordinates(self, x_dim=32, y_dim=32, scale=1.0): 72 | ''' 73 | calculates and returns a vector of x and y coordintes, and corresponding radius from the centre of image. 74 | ''' 75 | n_points = x_dim * y_dim 76 | # creates x and y ranges of x/y_dim numbers from -scale to +scale 77 | x_range = scale * ((np.arange(x_dim) - (x_dim - 1) / 2.0) / (x_dim - 1) * 2) 78 | y_range = scale * ((np.arange(y_dim) - (y_dim - 1) / 2.0) / (y_dim - 1) * 2) 79 | # create all r distances from center for any combination of coordinates on x and y 80 | x_mat = np.matmul(np.ones((y_dim, 1)), x_range.reshape((1, x_dim))) 81 | y_mat = np.matmul(y_range.reshape((y_dim, 1)), np.ones((1, x_dim))) 82 | r_mat = np.sqrt(x_mat * x_mat + y_mat * y_mat) 83 | # transform the x x y matrices tiling as many of them as the batch size 84 | # and reshaping to obtain a èbatch_size, x*y, 1+ tensor 85 | x_mat = np.tile(x_mat.flatten(), self.batch_size).reshape(self.batch_size, n_points, 1) 86 | y_mat = np.tile(y_mat.flatten(), self.batch_size).reshape(self.batch_size, n_points, 1) 87 | r_mat = np.tile(r_mat.flatten(), self.batch_size).reshape(self.batch_size, n_points, 1) 88 | return x_mat, y_mat, r_mat 89 | 90 | def generator(self, x_dim, y_dim, act='tanh', reuse=False): 91 | if reuse: 92 | tf.get_variable_scope().reuse_variables() 93 | 94 | if not hasattr(tf.nn, act): 95 | print("No activation {} available, using default tanh".format(act)) 96 | act = 'tanh' 97 | 98 | net_size = self.net_size 99 | n_points = x_dim * y_dim 100 | 101 | # note that latent vector z is scaled to self.scale factor. 102 | z_scaled = tf.reshape(self.z, [self.batch_size, 1, self.z_dim]) * \ 103 | tf.ones([n_points, 1], dtype=tf.float32) * self.scale 104 | z_unroll = tf.reshape(z_scaled, [self.batch_size * n_points, self.z_dim]) 105 | x_unroll = tf.reshape(self.x, [self.batch_size * n_points, 1]) 106 | y_unroll = tf.reshape(self.y, [self.batch_size * n_points, 1]) 107 | r_unroll = tf.reshape(self.r, [self.batch_size * n_points, 1]) 108 | 109 | sum_input = fully_connected(z_unroll, net_size, 'g_0_z') + \ 110 | fully_connected(x_unroll, net_size, 'g_0_x', with_bias=False) + \ 111 | fully_connected(y_unroll, net_size, 'g_0_y', with_bias=False) + \ 112 | fully_connected(r_unroll, net_size, 'g_0_r', with_bias=False) 113 | 114 | hidden = tf.nn.tanh(sum_input) 115 | 116 | i = tf.constant(0) 117 | zero = tf.constant(0) 118 | 119 | def condition(i, k, H): 120 | return tf.less(i, k) 121 | 122 | def body(i, k, H): 123 | hidden = tf.cond(tf.equal(i, zero), 124 | lambda: single_iteration(H, net_size, act, reuse=False), 125 | lambda: single_iteration(H, net_size, act, reuse=True)) 126 | i = tf.add(i, 1) 127 | return i, k, hidden 128 | 129 | i, k, hidden = tf.while_loop(condition, body, [i, self.k, hidden]) 130 | 131 | output = tf.sigmoid(fully_connected(hidden, self.c_dim, 'g_final')) 132 | 133 | ''' 134 | The final hidden later is pass through a fully connected sigmoid later, so outputs -> (0, 1) 135 | Also, the output has a dimension of c_dim, so can be monotone or RGB 136 | ''' 137 | result = tf.reshape(output, [self.batch_size, y_dim, x_dim, self.c_dim]) 138 | 139 | return result 140 | 141 | 142 | def generate(self, z=None, x_dim=26, y_dim=26, scale=8.0, k=3, act=None, **kwargs): 143 | """ Generate data by sampling from latent space. 144 | If z is not None, data for this point in latent space is 145 | generated. Otherwise, z is drawn from prior in latent 146 | space. 147 | """ 148 | if z is None: 149 | z = np.random.uniform(-1.0, 1.0, size=(self.batch_size, self.z_dim)).astype(np.float32) 150 | # Note: This maps to mean of distribution, we could alternatively 151 | # sample from Gaussian distribution 152 | 153 | if act is None: 154 | if self.act is not None: 155 | act = self.act 156 | else: 157 | act = 'tanh' 158 | 159 | G = self.generator(x_dim=x_dim, y_dim=y_dim, act=act, reuse=True) 160 | x_vec, y_vec, r_vec = self._coordinates(x_dim, y_dim, scale=scale) 161 | image = self.sess.run(G, feed_dict={self.z: z, self.x: x_vec, self.y: y_vec, self.r: r_vec, self.k: k}) 162 | return image 163 | 164 | def close(self): 165 | self.sess.close() 166 | 167 | 168 | def single_iteration(H, net_size, act='tanh', reuse=False): 169 | with tf.variable_scope("iter_norm", reuse=reuse): 170 | H = tf.contrib.layers.layer_norm(H) 171 | with tf.variable_scope("iter_linear", reuse=reuse): 172 | #H = fully_connected(H, net_size) 173 | H = cos_sim(H, net_size) 174 | H = getattr(tf.nn, act)(H) 175 | return H 176 | 177 | 178 | def cos_sim(x, net_size, name=None): 179 | with tf.name_scope(name): 180 | weights = tf.get_variable("weights", [net_size, net_size], 181 | initializer=tf.random_normal_initializer()) 182 | biases = tf.get_variable("biases", [net_size], 183 | initializer=tf.constant_initializer(0.001)) 184 | wat = 0.001 185 | w_norm = tf.sqrt(tf.reduce_sum(weights**2, axis=0, keep_dims=True) + biases**2) 186 | x_norm = tf.sqrt(tf.reduce_sum(x**2, axis=1, keep_dims=True) + wat**2) 187 | cos_sim = (tf.matmul(x, weights) + wat * biases) / w_norm / x_norm 188 | return cos_sim -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementation of Compositional Pattern Producing Networks in Tensorflow 3 | 4 | https://en.wikipedia.org/wiki/Compositional_pattern-producing_network 5 | 6 | @hardmaru, 2016 7 | @w4nderlust, 2017 8 | 9 | Sampler Class 10 | 11 | This file is meant to be run inside an IPython session, as it is meant 12 | to be used interacively for experimentation. 13 | 14 | It shouldn't be that hard to take bits of this code into a normal 15 | command line environment though if you want to use outside of IPython. 16 | 17 | usage: 18 | 19 | %run -i sampler.py 20 | 21 | sampler = Sampler(z_dim = 4, c_dim = 1, scale = 8.0, net_size = 32) 22 | 23 | ''' 24 | from __future__ import absolute_import, division, print_function, unicode_literals 25 | 26 | import numpy as np 27 | import tensorflow as tf 28 | import math 29 | import random 30 | import PIL 31 | from PIL import Image 32 | import pylab 33 | from cppn import CPPN 34 | from rppn import RPPN 35 | import matplotlib 36 | import matplotlib.pyplot as plt 37 | import imageio 38 | 39 | try: 40 | mgc = get_ipython().magic 41 | mgc(u'matplotlib inline') 42 | pylab.rcParams['figure.figsize'] = (10.0, 10.0) 43 | except NameError: 44 | pass 45 | 46 | 47 | class Sampler(): 48 | def __init__(self, model_type='CPPN', z_dim=8, c_dim=1, scale=10.0, net_size=32, act='tanh'): 49 | if model_type.lower() == 'cppn': 50 | self.model = CPPN(z_dim=z_dim, c_dim=c_dim, scale=scale, net_size=net_size) 51 | elif model_type.lower() == 'rppn': 52 | self.model = RPPN(z_dim=z_dim, c_dim=c_dim, scale=scale, net_size=net_size, act=act) 53 | else: 54 | self.model = CPPN(z_dim=z_dim, c_dim=c_dim, scale=scale, net_size=net_size) 55 | self.z = self.generate_z() # saves most recent z here, in case we find a nice image and want the z-vec 56 | 57 | def reinit(self): 58 | self.model.reinit() 59 | 60 | def generate_z(self): 61 | z = np.random.uniform(-1.0, 1.0, size=(1, self.model.z_dim)).astype(np.float32) 62 | return z 63 | 64 | def generate(self, z=None, x_dim=1080, y_dim=1060, scale=10.0, k=3, act=None): 65 | if z is None: 66 | z = self.generate_z() 67 | else: 68 | z = np.reshape(z, (1, self.model.z_dim)) 69 | self.z = z 70 | return self.model.generate(z, x_dim, y_dim, scale, k=k, act=act)[0] 71 | 72 | def show_image(self, image_data): 73 | ''' 74 | image_data is a tensor, in [height width depth] 75 | image_data is NOT an instance of the PIL.Image class 76 | ''' 77 | matplotlib.interactive(False) 78 | # plt.subplot(1, 1, 1) 79 | y_dim = image_data.shape[0] 80 | x_dim = image_data.shape[1] 81 | c_dim = self.model.c_dim 82 | if c_dim > 1: 83 | plt.imshow(image_data, interpolation='nearest') 84 | else: 85 | plt.imshow(image_data.reshape(y_dim, x_dim), cmap='Greys', interpolation='nearest') 86 | # plt.axis('off') 87 | plt.show(block=True) 88 | 89 | def to_np_image(self, image_data): 90 | # convert to PIL.Image format from np array (0, 1) 91 | img_data = np.array(1 - image_data) 92 | y_dim = image_data.shape[0] 93 | x_dim = image_data.shape[1] 94 | c_dim = self.model.c_dim 95 | if c_dim > 1: 96 | img_data = np.array(img_data.reshape((y_dim, x_dim, c_dim)) * 255.0, dtype=np.uint8) 97 | else: 98 | img_data = np.array(img_data.reshape((y_dim, x_dim)) * 255.0, dtype=np.uint8) 99 | return img_data 100 | 101 | def to_image(self, image_data): 102 | # convert to PIL.Image format from np array (0, 1) 103 | return Image.fromarray(self.to_np_image(image_data)) 104 | 105 | def save_png(self, image_data, filename): 106 | if not filename.endswith(".png"): 107 | filename += ".png" 108 | self.to_image(image_data).save(filename) 109 | 110 | def save_anim_gif(self, z1, z2, filename, n_frame=10, duration1=0.5, duration2=1.0, duration=0.1, x_dim=512, 111 | y_dim=512, scale=10.0, k1=3, k2=3, act=None, reverse=True): 112 | ''' 113 | this saves an animated gif from two latent states z1 and z2 114 | n_frame: number of states in between z1 and z2 morphing effect, exclusive of z1 and z2 115 | duration1, duration2, control how long z1 and z2 are shown. duration controls frame speed, in seconds 116 | ''' 117 | delta_z = (z2 - z1) / (n_frame + 1) 118 | delta_k = (k2 - k1) / (n_frame + 1) 119 | total_frames = n_frame + 2 120 | images = [] 121 | for i in range(total_frames): 122 | z = z1 + delta_z * float(i) 123 | k = int(k1 + delta_k * float(i)) 124 | images.append(self.to_np_image(self.generate(z, x_dim, y_dim, scale, k=k, act=act))) 125 | print("processing image ", i) 126 | durations = [duration1] + [duration] * n_frame + [duration2] 127 | if reverse == True: # go backwards in time back to the first state 128 | revImages = list(images) 129 | revImages.reverse() 130 | revImages = revImages[1:] 131 | images = images + revImages 132 | durations = durations + [duration] * n_frame + [duration1] 133 | print("writing gif file...") 134 | if not filename.endswith(".gif"): 135 | filename += ".gif" 136 | imageio.mimsave(filename, images, duration=durations) 137 | 138 | def save_anim_mp4(self, z1, z2, filename, n_frame=10, fps=10, x_dim=512, y_dim=512, scale=10.0, k1=3, k2=3, 139 | act=None, reverse=True): 140 | ''' 141 | this saves an animated mp4 from two latent states z1 and z2 142 | n_frame: number of states in between z1 and z2 morphing effect, exclusive of z1 and z2 143 | fps: number of frames displayed in a second 144 | ''' 145 | delta_z = (z2 - z1) / (n_frame + 1) 146 | delta_k = (k2 - k1) / (n_frame + 1) 147 | total_frames = n_frame + 2 148 | images = [] 149 | for i in range(total_frames): 150 | z = z1 + delta_z * float(i) 151 | k = int(k1 + delta_k * float(i)) 152 | images.append(self.to_np_image(self.generate(z, x_dim, y_dim, scale, k=k, act=act))) 153 | print("processing image ", i) 154 | if reverse == True: # go backwards in time back to the first state 155 | revImages = list(images) 156 | revImages.reverse() 157 | revImages = revImages[1:] 158 | images = images + revImages 159 | if not filename.endswith(".mp4"): 160 | filename += ".mp4" 161 | imageio.mimsave(filename, images, fps=fps) 162 | --------------------------------------------------------------------------------