├── .gitignore ├── README.md ├── configs ├── debug.yaml └── train.yaml ├── modules ├── datasets.py ├── models.py ├── networks.py └── utils.py ├── notebooks └── ca.ipynb ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Images 132 | *.gif 133 | *.jpg 134 | *.jpeg 135 | 136 | # Default outputs 137 | outputs/* 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cellular-automata-pytorch 2 | 3 | WiP 4 | 5 | An attempt to recreate the results of [Growing Neural Cellular Automata 6 | ](https://distill.pub/2020/growing-ca/) and go beyond 7 | 8 | Related: 9 | 10 | * [Growing Neural Cellular Automata](https://distill.pub/2020/growing-ca/) 11 | * [Cellular automata as convolutional neural networks](https://arxiv.org/abs/1809.02942) 12 | * [Growing Cellular Automata Pytorch Repo](https://github.com/PWhiddy/Growing-Neural-Cellular-Automata-Pytorch?files=1) 13 | 14 | ## HOWTO 15 | 16 | ### Install 17 | 18 | ``` 19 | git clone https://github.com/belkakari/cellular-automata-pytorch.git 20 | cd cellular-automata-pytorch && pip install -r requirements.txt 21 | ``` 22 | 23 | ### Run training 24 | 25 | ``` 26 | python train.py -c ./configs/train.yaml 27 | ``` 28 | -------------------------------------------------------------------------------- /configs/debug.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: 'time' 2 | device: 'cuda:0' 3 | output_folder: ./outputs/ 4 | n_steps_interval: [64, 92] 5 | split_rate_interval: ~ #[60, 90] 6 | test_frequency: 1 7 | logging_level: DEBUG 8 | emoji: 🦋 9 | model: 10 | use_coords: False # whether to pass meshgrid as a feature 11 | random_spawn: False # randomly select initial cell coordinates 12 | stochastic_prob: 0.5 13 | norm_kernel: False 14 | interm_dim: 128 15 | bias: False 16 | train: 17 | batch_size: 4 18 | num_epochs: 10 19 | optim: 20 | milestones: [2000] 21 | gamma: 0.1 22 | lr: 0.002 23 | grad_clip: 20 24 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: 'time' 2 | device: 'cuda:0' 3 | output_folder: ./outputs/ 4 | n_steps_interval: [64, 92] 5 | split_rate_interval: ~ #[60, 90] 6 | test_frequency: 100 7 | logging_level: DEBUG 8 | data: 9 | pad: 16 10 | target_size: 40 11 | emoji: 🦋 12 | model: 13 | use_coords: False # whether to pass meshgrid as a feature 14 | random_spawn: True # randomly select initial cell coordinates 15 | stochastic_prob: 0.5 16 | norm_kernel: False 17 | interm_dim: 128 18 | bias: True 19 | train: 20 | batch_size: 4 21 | num_epochs: 8000 22 | optim: 23 | milestones: [2000] 24 | gamma: 0.1 25 | lr: 0.002 26 | grad_clip: 20 27 | -------------------------------------------------------------------------------- /modules/datasets.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | 8 | from modules.utils import load_emoji 9 | 10 | 11 | class StateGridSet(Dataset): 12 | def __init__(self, emoji='🦎', use_coords=False, 13 | batch_size=10, random_spawn=True, 14 | pad=16, target_size=40): 15 | # emojis = ['🦎', '😀', '💥', '👁', '🐠', '🦋', '🐞', '🕸', '🥨', '🎄'] 16 | 17 | self.target = torch.from_numpy(load_emoji(emoji=emoji)).permute(2, 0, 1).unsqueeze(0) 18 | self.target = F.pad(self.target, (pad, pad, pad, pad), value=0) 19 | self.target = F.interpolate(self.target, target_size)[0] 20 | 21 | self.use_coords = use_coords 22 | self.batch_size = batch_size 23 | self.random_spawn = random_spawn 24 | 25 | def __len__(self): 26 | return self.batch_size 27 | 28 | def __getitem__(self, idx): 29 | state_grid = torch.zeros((16, self.target.shape[-2], 30 | self.target.shape[-1]), 31 | requires_grad=False) 32 | if self.random_spawn: 33 | center = random.randint(int(0.2 * (self.target.shape[-2] - 1)), 34 | int(0.8 * (self.target.shape[-2] - 1))) 35 | else: 36 | center = state_grid.shape[2] // 2 37 | state_grid[3:, center, center] = 1. 38 | 39 | if self.use_coords: 40 | xv, yv = torch.meshgrid([torch.linspace(-1, 1, steps=state_grid.shape[1]), 41 | torch.linspace(-1, 1, steps=state_grid.shape[2])]) 42 | state_grid[-1] = xv 43 | state_grid[-2] = yv 44 | 45 | return state_grid, self.target 46 | -------------------------------------------------------------------------------- /modules/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modules import utils 6 | 7 | 8 | class AbstractCAModel(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def get_input(self, state_grid, target): 13 | self.state_grid = state_grid 14 | self.target = target 15 | 16 | def forward(self): 17 | raise NotImplementedError 18 | 19 | def optimize_parameters(self): 20 | raise NotImplementedError 21 | 22 | 23 | class SimpleCA(AbstractCAModel): 24 | def __init__(self, perception, policy, config, 25 | logger=None, grad_clip=3): 26 | super().__init__() 27 | self.perception = perception 28 | self.policy = policy 29 | self.config = config 30 | self.use_coords = config['model']['use_coords'] 31 | self.stochastic_prob = config['model']['stochastic_prob'] 32 | self.optim = torch.optim.Adam(list(self.policy.parameters()), 33 | lr=config['optim']['lr']) 34 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optim, 35 | config['optim']['milestones'], 36 | gamma=config['optim']['gamma']) 37 | self.loss_fn = nn.MSELoss() 38 | self.logger = logger 39 | self.grad_clip = grad_clip 40 | 41 | def forward(self): 42 | alive_pre = utils.alive_mask(self.state_grid, thr=0.1) 43 | perception_grid = self.perception(self.state_grid) 44 | ds_grid = self.policy(perception_grid) 45 | mask = utils.stochastic_update_mask(ds_grid, 46 | prob=self.stochastic_prob) 47 | self.state_grid = self.state_grid + ds_grid * mask 48 | alive_post = utils.alive_mask(self.state_grid, thr=0.1) 49 | final_mask = (alive_post.bool() & alive_pre.bool()).float() 50 | self.state_grid = self.state_grid * final_mask 51 | 52 | if self.use_coords: 53 | xgrid = torch.linspace(-1, 1, steps=self.target.shape[-1]) 54 | ygrid = torch.linspace(-1, 1, steps=self.target.shape[-2]) 55 | xv, yv = torch.meshgrid([xgrid, ygrid]) 56 | self.state_grid[:, -1, ...] = xv[None, :, :] 57 | self.state_grid[:, -2, ...] = yv[None, :, :] 58 | 59 | return final_mask 60 | 61 | 62 | def optimize_parameters(self): 63 | loss_value = self.loss_fn(self.target[:, :4, ...], 64 | self.state_grid[:, :4, ...]) 65 | self.optim.zero_grad() 66 | loss_value.backward() 67 | if self.logger: 68 | norm = [] 69 | for p in self.policy.parameters(): 70 | param_norm = p.grad.data.norm(2) 71 | norm.append(param_norm.item()) 72 | self.logger.debug(f'norm before clipping, {norm}') 73 | 74 | torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 75 | max_norm=self.grad_clip) 76 | 77 | if self.logger: 78 | norm = [] 79 | for p in self.policy.parameters(): 80 | param_norm = p.grad.data.norm(2) 81 | norm.append(param_norm.item()) 82 | self.logger.debug(f'norm after clipping, {norm}') 83 | self.optim.step() 84 | self.scheduler.step() 85 | return loss_value 86 | -------------------------------------------------------------------------------- /modules/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Perception(nn.Module): 7 | def __init__(self, channels=16, norm_kernel=False): 8 | super().__init__() 9 | self.channels = channels 10 | sobel_x = torch.tensor([[-1.0, 0.0, 1.0], 11 | [-2.0, 0.0, 2.0], 12 | [-1.0, 0.0, 1.0]]) / 8 13 | sobel_y = torch.tensor([[1.0, 2.0, 1.0], 14 | [0.0, 0.0, 0.0], 15 | [-1.0, -2.0, -1.0]]) / 8 16 | identity = torch.tensor([[0.0, 0.0, 0.0], 17 | [0.0, 1.0, 0.0], 18 | [0.0, 0.0, 0.0]]) 19 | 20 | self.kernel = torch.stack((identity, sobel_x, sobel_y)).repeat(channels, 1, 1).unsqueeze(1) 21 | if norm_kernel: 22 | self.kernel /= channels 23 | 24 | def forward(self, state_grid): 25 | return F.conv2d(state_grid, 26 | self.kernel.to(state_grid.device), 27 | groups=self.channels, 28 | padding=1) # thanks https://github.com/PWhiddy/Growing-Neural-Cellular-Automata-Pytorch?files=1 for the group parameter 29 | 30 | 31 | class Policy(nn.Module): 32 | def __init__(self, state_dim=16, interm_dim=128, 33 | use_embedding=True, kernel=1, padding=0, 34 | bias=False): 35 | super().__init__() 36 | dim = state_dim * 3 37 | if use_embedding: 38 | dim += 1 39 | self.conv1 = nn.Conv2d(dim, interm_dim, kernel, padding=padding) 40 | self.conv2 = nn.Conv2d(interm_dim, state_dim, kernel, padding=padding, 41 | bias=bias) 42 | nn.init.constant_(self.conv2.weight, 0.) 43 | if bias: 44 | nn.init.constant_(self.conv2.bias, 0.) 45 | 46 | def forward(self, state): 47 | interm = self.conv1(state) 48 | interm = torch.relu(interm) 49 | return self.conv2(interm) 50 | 51 | 52 | class Embedder(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | self.embedder = nn.Sequential(*[nn.Conv2d(4, 16, 3), 56 | nn.ReLU(), 57 | nn.Conv2d(16, 16, 3), 58 | nn.ReLU(), 59 | nn.Conv2d(16, 1, 3)]) 60 | 61 | def forward(self, img): 62 | return nn.AdaptiveAvgPool2d((1, 1))(self.embedder(img)) 63 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import logging 3 | import os 4 | import random 5 | from time import gmtime, strftime 6 | 7 | import numpy as np 8 | import requests 9 | import torch 10 | import torch.nn as nn 11 | from PIL import Image 12 | from torchvision import transforms 13 | 14 | 15 | def stochastic_update_mask(ds_grid, prob=0.5): 16 | # Generate mask for zero out a random fraction of the updates. 17 | bern = torch.distributions.Bernoulli(prob) 18 | rand_mask = bern.sample((ds_grid.shape[2] * ds_grid.shape[3],)) 19 | rand_mask = rand_mask.view(ds_grid.shape[2:]).float() 20 | return rand_mask.to(ds_grid.device)[None, None] 21 | 22 | 23 | def alive_mask(state_grid, thr=0.1): 24 | # Take the alpha channel as the measure of “life”. 25 | alpha = state_grid[:, [3], :, :].clamp(0, 1) 26 | alive = (nn.MaxPool2d(3, stride=1, padding=1)(alpha) > thr).float()#.unsqueeze(1) 27 | return alive 28 | 29 | 30 | def load_image(url, max_size=128): 31 | r = requests.get(url) 32 | img = Image.open(io.BytesIO(r.content)) 33 | #img.thumbnail((max_size, max_size), Image.ANTIALIAS) 34 | img = np.float32(img) / 255. 35 | ## premultiply RGB by Alpha 36 | img[..., :3] *= img[..., 3:] 37 | return img 38 | 39 | 40 | def load_emoji(emoji): 41 | code = hex(ord(emoji))[2:].lower() 42 | url = f'https://github.com/googlefonts/noto-emoji/raw/master/png/128/emoji_u{code}.png' 43 | return load_image(url) 44 | 45 | 46 | def get_timestamp(): 47 | return strftime("%Y-%m-%d-%H:%M:%S", gmtime()) 48 | 49 | 50 | def setup_logger(logger_name, root, level=logging.INFO, 51 | screen=False, tofile=False): 52 | lg = logging.getLogger(logger_name) 53 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 54 | datefmt='%y-%m-%d %H:%M:%S') 55 | lg.setLevel(level) 56 | if tofile: 57 | log_file = os.path.join(root, f'_{get_timestamp()}.log') 58 | fh = logging.FileHandler(log_file, mode='w') 59 | fh.setFormatter(formatter) 60 | lg.addHandler(fh) 61 | if screen: 62 | sh = logging.StreamHandler() 63 | sh.setFormatter(formatter) 64 | lg.addHandler(sh) 65 | 66 | 67 | def set_random_seed(seed): 68 | random.seed(seed) 69 | np.random.seed(seed) 70 | torch.manual_seed(seed) 71 | torch.cuda.manual_seed_all(seed) 72 | -------------------------------------------------------------------------------- /notebooks/ca.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import random\n", 11 | "import sys\n", 12 | "\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.nn.functional as F\n", 16 | "from torch.utils.data import DataLoader\n", 17 | "from torchvision import transforms\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "from PIL import Image\n", 20 | "\n", 21 | "sys.path.append('../')\n", 22 | "\n", 23 | "from modules.datasets import StateGridSet\n", 24 | "from modules.networks import Perception, Policy\n", 25 | "from modules.utils import alive_mask, load_emoji, stochastic_update_mask\n", 26 | "\n", 27 | "%matplotlib inline" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "plt.imshow(load_emoji('🦋'))" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "plt.imshow(F.pad(torch.from_numpy(load_emoji('🦋')).unsqueeze(0), (0, 0, 10, 10), value=0)[0].data.numpy())" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "torch.from_numpy(load_emoji('🦋')).permute(2, 0, 1).shape" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "F.pad(torch.from_numpy(load_emoji('🦋')).permute(2, 0, 1).unsqueeze(0), (10, 10, 10, 10), value=0)[0].data.numpy().shape" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "alpha = torch.rand((1, 1, 128, 128))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "F.pad(alpha, (10, 10, 10, 10), value = 0).shape" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "F.interpolate(alpha[0], (1, 64, 64)).shape" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "(nn.MaxPool2d(3, stride=1, padding=1)(alpha) > 0.1).shape" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "colab": { 107 | "base_uri": "https://localhost:8080/", 108 | "height": 34 109 | }, 110 | "colab_type": "code", 111 | "id": "fST5pquzz234", 112 | "outputId": "750c5417-6b03-4644-e065-82b877ec40e1" 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "torch.__version__" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": { 123 | "colab": {}, 124 | "colab_type": "code", 125 | "id": "BTdSwxfgzy4k" 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "device = 'cuda:0'\n", 130 | "stochastic_prob = 0.1\n", 131 | "batch_size = 4\n", 132 | "num_epochs = 2000\n", 133 | "output_folder = './outputs/'" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "colab": { 141 | "base_uri": "https://localhost:8080/", 142 | "height": 302 143 | }, 144 | "colab_type": "code", 145 | "id": "BHaieXH8zy4x", 146 | "outputId": "4215a65d-cd22-4020-ce9f-48d2c0016ed3" 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "img = load_emoji(emoji='🦎')\n", 151 | "img = transforms.ToTensor()(img)\n", 152 | "img = transforms.Normalize(tuple(0.5 for _ in range(img.shape[0])),\n", 153 | " tuple(0.5 for _ in range(img.shape[0])))(img)\n", 154 | "img = img.to(device)\n", 155 | "print(img.max(), img.min(), img.mean())\n", 156 | "plt.imshow((img.cpu().permute(1, 2, 0).data.numpy() + 1.) / 2.)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "colab": {}, 164 | "colab_type": "code", 165 | "id": "sT7HWlREzy47" 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "policy = Policy(use_embedding=False, kernel=1, padding=0).to(device)\n", 170 | "perception = Perception(channels=16).to(device)\n", 171 | "optim = torch.optim.Adam(list(policy.parameters()) + list(perception.parameters()), lr=2e-3)\n", 172 | "scheduler = torch.optim.lr_scheduler.StepLR(optim, 100, gamma=0.7)\n", 173 | "loss_fn = nn.MSELoss()" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "for p in policy.parameters():\n", 183 | " p.grad.data /= (p.grad.data.norm(2) + 1e-8)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "dset = StateGridSet(emoji='🦎', use_coords=True, batch_size=batch_size, random_spawn=False)\n", 193 | "dset_test = StateGridSet(emoji='🦎', use_coords=True, batch_size=1, random_spawn=False)\n", 194 | "dloader = DataLoader(dset, batch_size=batch_size)\n", 195 | "dloader_test = DataLoader(dset, batch_size=1)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": { 202 | "colab": { 203 | "base_uri": "https://localhost:8080/", 204 | "height": 561 205 | }, 206 | "colab_type": "code", 207 | "id": "AdRLrySLzy49", 208 | "outputId": "a9dac145-ec8d-45ba-abf1-5a5d885ff6de", 209 | "scrolled": true 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "xv, yv = torch.meshgrid([torch.linspace(-1, 1, steps=img.shape[-1]),\n", 214 | " torch.linspace(-1, 1, steps=img.shape[-2])])\n", 215 | "\n", 216 | "for epoch in range(num_epochs):\n", 217 | " n_steps = random.randint(100, 150)\n", 218 | " split_rate = random.randint(30, 40)\n", 219 | " for state_grid, target in dloader:\n", 220 | " state_grid, target = state_grid.to(device), target.to(device)\n", 221 | " for k in range(n_steps):\n", 222 | " alive_pre = alive_mask((state_grid + 1.) / 2., thr=0.1)\n", 223 | " perception_grid = perception(state_grid)\n", 224 | " ds_grid = policy(perception_grid)\n", 225 | " mask = stochastic_update_mask(ds_grid,\n", 226 | " prob=stochastic_prob)\n", 227 | " state_grid = state_grid + ds_grid * mask\n", 228 | " alive_post = alive_mask((state_grid + 1.) / 2., thr=0.1)\n", 229 | " final_mask = (alive_post.bool() & alive_pre.bool()).float()\n", 230 | " state_grid = state_grid * final_mask\n", 231 | "\n", 232 | " if dset.use_coords:\n", 233 | " state_grid[:, -1, ...] = xv[None, :, :]\n", 234 | " state_grid[:, -2, ...] = yv[None, :, :]\n", 235 | "\n", 236 | " if k % split_rate == 0:\n", 237 | " loss_value = loss_fn(target[:, :4, ...],\n", 238 | " state_grid[:, :4, ...])\n", 239 | " optim.zero_grad()\n", 240 | " loss_value.backward()\n", 241 | " optim.step()\n", 242 | " state_grid = state_grid.detach()\n", 243 | "\n", 244 | " if k % split_rate == 0:\n", 245 | " pass\n", 246 | " else:\n", 247 | " loss_value = loss_fn(target[:, :4, ...], state_grid[:, :4, ...])\n", 248 | " optim.zero_grad()\n", 249 | " loss_value.backward()\n", 250 | " optim.step()\n", 251 | " scheduler.step()\n", 252 | " print(f'{loss_value.item():.2f}, {n_steps} steps, ',\n", 253 | " f'{split_rate} split rate, {epoch} epoch')\n", 254 | " if epoch % 50 == 0:\n", 255 | " print('Testing')\n", 256 | " output_path = os.path.join(output_folder, f'{epoch}/')\n", 257 | " os.makedirs(output_path, exist_ok=True)\n", 258 | " test(policy, perception, dloader_test,\n", 259 | " output_path, num_steps=150,\n", 260 | " stochastic_prob=stochastic_prob)\n" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": { 267 | "colab": {}, 268 | "colab_type": "code", 269 | "id": "R32N1M58zy5S" 270 | }, 271 | "outputs": [], 272 | "source": [] 273 | } 274 | ], 275 | "metadata": { 276 | "accelerator": "GPU", 277 | "colab": { 278 | "name": "ca.ipynb", 279 | "provenance": [] 280 | }, 281 | "kernelspec": { 282 | "display_name": "Python 3", 283 | "language": "python", 284 | "name": "python3" 285 | }, 286 | "language_info": { 287 | "codemirror_mode": { 288 | "name": "ipython", 289 | "version": 3 290 | }, 291 | "file_extension": ".py", 292 | "mimetype": "text/x-python", 293 | "name": "python", 294 | "nbconvert_exporter": "python", 295 | "pygments_lexer": "ipython3", 296 | "version": "3.7.4" 297 | } 298 | }, 299 | "nbformat": 4, 300 | "nbformat_minor": 1 301 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2 2 | torchvision 3 | numpy 4 | pillow 5 | yaml 6 | requests -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import random 7 | import shutil 8 | import yaml 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import yaml 14 | from torch.optim import lr_scheduler 15 | from torch.utils.data import DataLoader 16 | from torchvision import transforms 17 | 18 | from modules.datasets import StateGridSet 19 | from modules.models import SimpleCA 20 | from modules.networks import Perception, Policy 21 | from modules.utils import get_timestamp, set_random_seed, setup_logger 22 | 23 | parser = argparse.ArgumentParser(description='Train neural cellular automata') 24 | parser.add_argument('-c', '--config', type=str, 25 | help='path to config .yaml') 26 | args = parser.parse_args() 27 | config_path = args.config 28 | 29 | with open(config_path) as f: 30 | config = yaml.load(f, Loader=yaml.FullLoader) 31 | 32 | device = config['device'] 33 | batch_size = config['train']['batch_size'] 34 | num_epochs = config['train']['num_epochs'] 35 | n_steps_interval = config['n_steps_interval'] 36 | split_rate_interval = config['split_rate_interval'] 37 | test_frequency = config['test_frequency'] 38 | use_coords = config['model']['use_coords'] 39 | random_spawn = config['model']['random_spawn'] 40 | norm_kernel = config['model']['norm_kernel'] 41 | interm_dim = config['model']['interm_dim'] 42 | bias = config['model']['bias'] 43 | 44 | set_random_seed(10) 45 | 46 | if (config['experiment_name'] == 'time') or \ 47 | ('experiment_name' not in config.keys()): 48 | start_time = get_timestamp() 49 | experiment_name = start_time 50 | else: 51 | experiment_name = config['experiment_name'] 52 | 53 | output_folder = os.path.join(config['output_folder'], experiment_name) 54 | os.makedirs(output_folder, exist_ok=True) 55 | shutil.copy(config_path, os.path.join(output_folder, 'config.yaml')) 56 | 57 | logging_level = logging.DEBUG if config['logging_level'] == 'DEBUG' else logging.INFO 58 | setup_logger('base', output_folder, 59 | level=logging_level, screen=True, tofile=True) 60 | 61 | logger = logging.getLogger('base') 62 | 63 | perception = Perception(channels=16, 64 | norm_kernel=norm_kernel).to(device) 65 | policy = Policy(use_embedding=False, kernel=1, padding=0, 66 | interm_dim=interm_dim, bias=bias).to(device) 67 | 68 | model = SimpleCA(perception, policy, config, logger=logger, 69 | grad_clip=config['optim']['grad_clip']) 70 | 71 | dset = StateGridSet(emoji=config['data']['emoji'], use_coords=use_coords, 72 | batch_size=batch_size, 73 | random_spawn=random_spawn, 74 | pad=config['data']['pad'], 75 | target_size=config['data']['target_size']) 76 | dset_test = StateGridSet(emoji=config['data']['emoji'], use_coords=use_coords, 77 | batch_size=1, 78 | random_spawn=False, 79 | pad=config['data']['pad'], 80 | target_size=config['data']['target_size']) 81 | dloader = DataLoader(dset, batch_size=batch_size) 82 | dloader_test = DataLoader(dset, batch_size=1) 83 | 84 | xv, yv = torch.meshgrid([torch.linspace(-1, 1, steps=dset.target.shape[-1]), 85 | torch.linspace(-1, 1, steps=dset.target.shape[-2])]) 86 | 87 | for epoch in range(num_epochs): 88 | n_steps = random.randint(*n_steps_interval) 89 | split_rate = None 90 | if split_rate_interval: 91 | split_rate = random.randint(*split_rate_interval) 92 | for state_grid, target in dloader: 93 | state_grid, target = state_grid.to(device), target.to(device) 94 | model.get_input(state_grid, target) 95 | for k in range(n_steps): 96 | final_mask = model.forward() 97 | if split_rate and (k % split_rate == 0): # truncated bptt 98 | loss_value = model.optimize_parameters() 99 | state_grid = model.state_grid.detach() 100 | model.get_input(state_grid, target) 101 | 102 | if split_rate and (k % split_rate == 0): 103 | pass 104 | else: 105 | loss_value = model.optimize_parameters() 106 | 107 | logger.info(f'{loss_value.item():.2f}, {n_steps} steps, {split_rate} split rate, {epoch} epoch') 108 | 109 | if epoch % test_frequency == 0: 110 | output_path = os.path.join(output_folder, f'{epoch}/') 111 | logger.info(f'writing gif to {output_path}') 112 | os.makedirs(output_path, exist_ok=True) 113 | topil = transforms.ToPILImage() 114 | with torch.no_grad(): 115 | for k, (state_grid, target) in enumerate(dloader_test): 116 | state_grid, target = state_grid.to(device), target.to(device) 117 | topil(target[0].cpu()).save(os.path.join(output_folder, 118 | f'target.png')) 119 | imgs = [] 120 | masks = [] 121 | model.get_input(state_grid, target) 122 | for _ in range(150): 123 | final_mask = model.forward() 124 | imgs.append(topil(model.state_grid[0, :4, ...].cpu())) 125 | masks.append(topil(final_mask[0, :, ...].cpu())) 126 | imgs[0].save(os.path.join(output_path, f'{k}.gif'), 127 | save_all=True, append_images=imgs[1:]) 128 | masks[0].save(os.path.join(output_path, f'{k}_mask.gif'), 129 | save_all=True, append_images=masks[1:]) 130 | --------------------------------------------------------------------------------