├── lib ├── __init__.py ├── utils.py ├── displayer.py ├── utils_vis.py └── CAModel.py ├── data ├── demo.gif └── emoji.png ├── models └── remaster_1.pth ├── README.md ├── LICENSE ├── main_pygame_dl.py └── training.ipynb /lib/__init__.py: -------------------------------------------------------------------------------- 1 | # init 2 | -------------------------------------------------------------------------------- /data/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flobh70/growin/HEAD/data/demo.gif -------------------------------------------------------------------------------- /data/emoji.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flobh70/growin/HEAD/data/emoji.png -------------------------------------------------------------------------------- /models/remaster_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flobh70/growin/HEAD/models/remaster_1.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Growing-Neural-Cellular-Automata 2 | 3 | An unofficial reproduction of Growing Neural Cellular Automata using PyTorch: 4 | 5 | ``` 6 | Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020. 7 | ``` 8 | 9 | Run "training.ipynb" to train a model, and "main_pygame_dl.py" to play the demo (A pre-trained model is already given, see ./models). 10 | 11 | 12 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def tup_distance(node1, node2, mode="Euclidean"): 4 | """ 5 | mode: "Manhattan", "Euclidean" 6 | """ 7 | if mode=="Euclidean": 8 | return ((node1[0]-node2[0])**2+(node1[1]-node2[1])**2)**0.5 9 | elif mode=="Manhattan": 10 | return np.abs(node1[0]-node2[0])+np.abs(node1[1]-node2[1]) 11 | else: 12 | raise ValueError("Unrecognized distance mode: "+mode) 13 | 14 | def mat_distance(mat1, mat2, mode="Euclidean"): 15 | """ 16 | mode: "Manhattan", "Euclidean" 17 | """ 18 | if mode=="Euclidean": 19 | return np.sum((mat1-mat2)**2, axis=-1)**0.5 20 | elif mode=="Manhattan": 21 | return np.sum(np.abs(mat1-mat2), axis=-1) 22 | else: 23 | raise ValueError("Unrecognized distance mode: "+mode) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ming 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lib/displayer.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import numpy as np 3 | 4 | class displayer: 5 | 6 | def __init__(self, _map_shape, pix_size, has_gap=False): 7 | """ 8 | _map_size: tuple 9 | color_map: a list indicates the color to each index. 10 | 0 : empty block, should always white 11 | 1+: varies building types 12 | """ 13 | pygame.init() 14 | clock = pygame.time.Clock() 15 | clock.tick(60) 16 | 17 | self.has_gap = has_gap 18 | self.pix_size = pix_size 19 | self.screen = pygame.display.set_mode((_map_shape[1]*self.pix_size, 20 | _map_shape[0]*self.pix_size)) 21 | 22 | def update(self, _map): 23 | self.screen.fill((255,255,255)) 24 | for i in range(_map.shape[0]): 25 | for j in range(_map.shape[1]): 26 | x = j * self.pix_size + int(self.pix_size/2) 27 | y = i * self.pix_size + int(self.pix_size/2) 28 | if self.has_gap: 29 | size = min(int(self.pix_size * 0.75), self.pix_size-2) 30 | else: 31 | size = self.pix_size 32 | s = pygame.Surface((size,size)) 33 | c = (_map[i,j]*256).astype(int)[:3] 34 | s.fill(c) 35 | self.screen.blit(s, (x-int(size/2), y-int(size/2))) 36 | pygame.display.update() 37 | -------------------------------------------------------------------------------- /lib/utils_vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | 7 | class SamplePool: 8 | def __init__(self, *, _parent=None, _parent_idx=None, **slots): 9 | self._parent = _parent 10 | self._parent_idx = _parent_idx 11 | self._slot_names = slots.keys() 12 | self._size = None 13 | for k, v in slots.items(): 14 | if self._size is None: 15 | self._size = len(v) 16 | assert self._size == len(v) 17 | setattr(self, k, np.asarray(v)) 18 | 19 | def sample(self, n): 20 | idx = np.random.choice(self._size, n, False) 21 | batch = {k: getattr(self, k)[idx] for k in self._slot_names} 22 | batch = SamplePool(**batch, _parent=self, _parent_idx=idx) 23 | return batch 24 | 25 | def commit(self): 26 | for k in self._slot_names: 27 | getattr(self._parent, k)[self._parent_idx] = getattr(self, k) 28 | 29 | def to_alpha(x): 30 | return np.clip(x[..., 3:4], 0, 0.9999) 31 | 32 | def to_rgb(x): 33 | # assume rgb premultiplied by alpha 34 | rgb, a = x[..., :3], to_alpha(x) 35 | return np.clip(1.0-a+rgb, 0, 0.9999) 36 | 37 | def get_living_mask(x): 38 | return nn.MaxPool2d(3, stride=1, padding=1)(x[:, 3:4, :, :])>0.1 39 | 40 | def make_seeds(shape, n_channels, n=1): 41 | x = np.zeros([n, shape[0], shape[1], n_channels], np.float32) 42 | x[:, shape[0]//2, shape[1]//2, 3:] = 1.0 43 | return x 44 | 45 | def make_seed(shape, n_channels): 46 | seed = np.zeros([shape[0], shape[1], n_channels], np.float32) 47 | seed[shape[0]//2, shape[1]//2, 3:] = 1.0 48 | return seed 49 | 50 | def make_circle_masks(n, h, w): 51 | x = np.linspace(-1.0, 1.0, w)[None, None, :] 52 | y = np.linspace(-1.0, 1.0, h)[None, :, None] 53 | center = np.random.random([2, n, 1, 1])*1.0-0.5 54 | r = np.random.random([n, 1, 1])*0.3+0.1 55 | x, y = (x-center[0])/r, (y-center[1])/r 56 | mask = (x*x+y*y < 1.0).astype(np.float32) 57 | return mask 58 | -------------------------------------------------------------------------------- /main_pygame_dl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pygame 3 | import torch 4 | import numpy as np 5 | 6 | from lib.displayer import displayer 7 | from lib.utils import mat_distance 8 | from lib.CAModel import CAModel 9 | from lib.utils_vis import to_rgb, make_seed 10 | 11 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 12 | 13 | eraser_radius = 3 14 | pix_size = 8 15 | _map_shape = (72,72) 16 | CHANNEL_N = 16 17 | CELL_FIRE_RATE = 0.5 18 | model_path = "models/remaster_1.pth" 19 | device = torch.device("cpu") 20 | 21 | _rows = np.arange(_map_shape[0]).repeat(_map_shape[1]).reshape([_map_shape[0],_map_shape[1]]) 22 | _cols = np.arange(_map_shape[1]).reshape([1,-1]).repeat(_map_shape[0],axis=0) 23 | _map_pos = np.array([_rows,_cols]).transpose([1,2,0]) 24 | 25 | _map = make_seed(_map_shape, CHANNEL_N) 26 | 27 | model = CAModel(CHANNEL_N, CELL_FIRE_RATE, device).to(device) 28 | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) 29 | output = model(torch.from_numpy(_map.reshape([1,_map_shape[0],_map_shape[1],CHANNEL_N]).astype(np.float32)), 1) 30 | 31 | disp = displayer(_map_shape, pix_size) 32 | 33 | isMouseDown = False 34 | running = True 35 | while running: 36 | 37 | for event in pygame.event.get(): 38 | if event.type == pygame.QUIT: 39 | running = False 40 | 41 | elif event.type == pygame.MOUSEBUTTONDOWN: 42 | if event.button == 1: 43 | isMouseDown = True 44 | 45 | elif event.type == pygame.MOUSEBUTTONUP: 46 | if event.button == 1: 47 | isMouseDown = False 48 | 49 | if isMouseDown: 50 | try: 51 | mouse_pos = np.array([int(event.pos[1]/pix_size), int(event.pos[0]/pix_size)]) 52 | should_keep = (mat_distance(_map_pos, mouse_pos)>eraser_radius).reshape([_map_shape[0],_map_shape[1],1]) 53 | output = torch.from_numpy(output.detach().numpy()*should_keep) 54 | except AttributeError: 55 | pass 56 | 57 | output = model(output, 1) 58 | 59 | _map = to_rgb(output.detach().numpy()[0]) 60 | disp.update(_map) 61 | -------------------------------------------------------------------------------- /lib/CAModel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | 7 | class CAModel(nn.Module): 8 | def __init__(self, channel_n, fire_rate, device, hidden_size=128): 9 | super(CAModel, self).__init__() 10 | 11 | self.device = device 12 | self.channel_n = channel_n 13 | 14 | self.fc0 = nn.Linear(channel_n*3, hidden_size) 15 | self.fc1 = nn.Linear(hidden_size, channel_n, bias=False) 16 | with torch.no_grad(): 17 | self.fc1.weight.zero_() 18 | 19 | self.fire_rate = fire_rate 20 | self.to(self.device) 21 | 22 | def alive(self, x): 23 | return F.max_pool2d(x[:, 3:4, :, :], kernel_size=3, stride=1, padding=1) > 0.1 24 | 25 | def perceive(self, x, angle): 26 | 27 | def _perceive_with(x, weight): 28 | conv_weights = torch.from_numpy(weight.astype(np.float32)).to(self.device) 29 | conv_weights = conv_weights.view(1,1,3,3).repeat(self.channel_n, 1, 1, 1) 30 | return F.conv2d(x, conv_weights, padding=1, groups=self.channel_n) 31 | 32 | dx = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0 # Sobel filter 33 | dy = dx.T 34 | c = np.cos(angle*np.pi/180) 35 | s = np.sin(angle*np.pi/180) 36 | w1 = c*dx-s*dy 37 | w2 = s*dx+c*dy 38 | 39 | y1 = _perceive_with(x, w1) 40 | y2 = _perceive_with(x, w2) 41 | y = torch.cat((x,y1,y2),1) 42 | return y 43 | 44 | def update(self, x, fire_rate, angle): 45 | x = x.transpose(1,3) 46 | pre_life_mask = self.alive(x) 47 | 48 | dx = self.perceive(x, angle) 49 | dx = dx.transpose(1,3) 50 | dx = self.fc0(dx) 51 | dx = F.relu(dx) 52 | dx = self.fc1(dx) 53 | 54 | if fire_rate is None: 55 | fire_rate=self.fire_rate 56 | stochastic = torch.rand([dx.size(0),dx.size(1),dx.size(2),1])>fire_rate 57 | stochastic = stochastic.float().to(self.device) 58 | dx = dx * stochastic 59 | 60 | x = x+dx.transpose(1,3) 61 | 62 | post_life_mask = self.alive(x) 63 | life_mask = (pre_life_mask & post_life_mask).float() 64 | x = x * life_mask 65 | return x.transpose(1,3) 66 | 67 | def forward(self, x, steps=1, fire_rate=None, angle=0.0): 68 | for step in range(steps): 69 | x = self.update(x, fire_rate, angle) 70 | return x 71 | -------------------------------------------------------------------------------- /training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "import time\n", 11 | "import imageio\n", 12 | "\n", 13 | "import numpy as np\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import torch\n", 16 | "import torch.nn as nn\n", 17 | "import torch.optim as optim\n", 18 | "import torch.nn.functional as F\n", 19 | "\n", 20 | "from IPython.display import clear_output\n", 21 | "\n", 22 | "from lib.CAModel import CAModel\n", 23 | "from lib.utils_vis import SamplePool, to_alpha, to_rgb, get_living_mask, make_seed, make_circle_masks" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "def load_emoji(index, path=\"data/emoji.png\"):\n", 33 | " im = imageio.imread(path)\n", 34 | " emoji = np.array(im[:, index*40:(index+1)*40].astype(np.float32))\n", 35 | " emoji /= 255.0\n", 36 | " return emoji\n", 37 | "\n", 38 | "def visualize_batch(x0, x):\n", 39 | " vis0 = to_rgb(x0)\n", 40 | " vis1 = to_rgb(x)\n", 41 | " print('batch (before/after):')\n", 42 | " plt.figure(figsize=[15,5])\n", 43 | " for i in range(x0.shape[0]):\n", 44 | " plt.subplot(2,x0.shape[0],i+1)\n", 45 | " plt.imshow(vis0[i])\n", 46 | " plt.axis('off')\n", 47 | " for i in range(x0.shape[0]):\n", 48 | " plt.subplot(2,x0.shape[0],i+1+x0.shape[0])\n", 49 | " plt.imshow(vis1[i])\n", 50 | " plt.axis('off')\n", 51 | " plt.show()\n", 52 | "\n", 53 | "def plot_loss(loss_log):\n", 54 | " plt.figure(figsize=(10, 4))\n", 55 | " plt.title('Loss history (log10)')\n", 56 | " plt.plot(np.log10(loss_log), '.', alpha=0.1)\n", 57 | " plt.show()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "device = torch.device(\"cuda:0\")\n", 67 | "model_path = \"models/remaster_1.pth\"\n", 68 | "\n", 69 | "CHANNEL_N = 16 # Number of CA state channels\n", 70 | "TARGET_PADDING = 16 # Number of pixels used to pad the target image border\n", 71 | "TARGET_SIZE = 40\n", 72 | "\n", 73 | "lr = 2e-3\n", 74 | "lr_gamma = 0.9999\n", 75 | "betas = (0.5, 0.5)\n", 76 | "n_epoch = 80000\n", 77 | "\n", 78 | "BATCH_SIZE = 8\n", 79 | "POOL_SIZE = 1024\n", 80 | "CELL_FIRE_RATE = 0.5\n", 81 | "\n", 82 | "TARGET_EMOJI = 0 #@param \"🦎\"\n", 83 | "\n", 84 | "EXPERIMENT_TYPE = \"Regenerating\"\n", 85 | "EXPERIMENT_MAP = {\"Growing\":0, \"Persistent\":1, \"Regenerating\":2}\n", 86 | "EXPERIMENT_N = EXPERIMENT_MAP[EXPERIMENT_TYPE]\n", 87 | "\n", 88 | "USE_PATTERN_POOL = [0, 1, 1][EXPERIMENT_N]\n", 89 | "DAMAGE_N = [0, 0, 3][EXPERIMENT_N] # Number of patterns to damage in a batch" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "target_img = load_emoji(TARGET_EMOJI)\n", 99 | "plt.figure(figsize=(4,4))\n", 100 | "plt.imshow(to_rgb(target_img))\n", 101 | "plt.show()" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "p = TARGET_PADDING\n", 111 | "pad_target = np.pad(target_img, [(p, p), (p, p), (0, 0)])\n", 112 | "h, w = pad_target.shape[:2]\n", 113 | "pad_target = np.expand_dims(pad_target, axis=0)\n", 114 | "pad_target = torch.from_numpy(pad_target.astype(np.float32)).to(device)\n", 115 | "\n", 116 | "seed = make_seed((h, w), CHANNEL_N)\n", 117 | "pool = SamplePool(x=np.repeat(seed[None, ...], POOL_SIZE, 0))\n", 118 | "batch = pool.sample(BATCH_SIZE).x\n", 119 | "\n", 120 | "ca = CAModel(CHANNEL_N, CELL_FIRE_RATE, device).to(device)\n", 121 | "ca.load_state_dict(torch.load(model_path))\n", 122 | "\n", 123 | "optimizer = optim.Adam(ca.parameters(), lr=lr, betas=betas)\n", 124 | "scheduler = optim.lr_scheduler.ExponentialLR(optimizer, lr_gamma)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "loss_log = []\n", 134 | "\n", 135 | "def train(x, target, steps, optimizer, scheduler):\n", 136 | " x = ca(x, steps=steps)\n", 137 | " loss = F.mse_loss(x[:, :, :, :4], target)\n", 138 | " optimizer.zero_grad()\n", 139 | " loss.backward()\n", 140 | " optimizer.step()\n", 141 | " scheduler.step()\n", 142 | " return x, loss\n", 143 | "\n", 144 | "def loss_f(x, target):\n", 145 | " return torch.mean(torch.pow(x[..., :4]-target, 2), [-2,-3,-1])\n", 146 | "\n", 147 | "for i in range(n_epoch+1):\n", 148 | " if USE_PATTERN_POOL:\n", 149 | " batch = pool.sample(BATCH_SIZE)\n", 150 | " x0 = torch.from_numpy(batch.x.astype(np.float32)).to(device)\n", 151 | " loss_rank = loss_f(x0, pad_target).detach().cpu().numpy().argsort()[::-1]\n", 152 | " x0 = batch.x[loss_rank]\n", 153 | " x0[:1] = seed\n", 154 | " if DAMAGE_N:\n", 155 | " damage = 1.0-make_circle_masks(DAMAGE_N, h, w)[..., None]\n", 156 | " x0[-DAMAGE_N:] *= damage\n", 157 | " else:\n", 158 | " x0 = np.repeat(seed[None, ...], BATCH_SIZE, 0)\n", 159 | " x0 = torch.from_numpy(x0.astype(np.float32)).to(device)\n", 160 | "\n", 161 | " x, loss = train(x0, pad_target, np.random.randint(64,96), optimizer, scheduler)\n", 162 | " \n", 163 | " if USE_PATTERN_POOL:\n", 164 | " batch.x[:] = x.detach().cpu().numpy()\n", 165 | " batch.commit()\n", 166 | "\n", 167 | " step_i = len(loss_log)\n", 168 | " loss_log.append(loss.item())\n", 169 | " \n", 170 | " if step_i%100 == 0:\n", 171 | " clear_output()\n", 172 | " print(step_i, \"loss =\", loss.item())\n", 173 | " visualize_batch(x0.detach().cpu().numpy(), x.detach().cpu().numpy())\n", 174 | " plot_loss(loss_log)\n", 175 | " torch.save(ca.state_dict(), model_path)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [] 184 | } 185 | ], 186 | "metadata": { 187 | "kernelspec": { 188 | "display_name": "Python [conda env:py37_torch] *", 189 | "language": "python", 190 | "name": "conda-env-py37_torch-py" 191 | }, 192 | "language_info": { 193 | "codemirror_mode": { 194 | "name": "ipython", 195 | "version": 3 196 | }, 197 | "file_extension": ".py", 198 | "mimetype": "text/x-python", 199 | "name": "python", 200 | "nbconvert_exporter": "python", 201 | "pygments_lexer": "ipython3", 202 | "version": "3.7.6" 203 | } 204 | }, 205 | "nbformat": 4, 206 | "nbformat_minor": 4 207 | } 208 | --------------------------------------------------------------------------------