├── .gitignore ├── LICENSE.md ├── README.md ├── demo.ipynb ├── demo.pdf ├── environment.yml ├── scripts └── run.py └── universal_computation ├── __init__.py ├── datasets ├── __init__.py ├── bit_memory.py ├── bit_xor.py ├── cifar10.py ├── cifar10_gray.py ├── dataset.py ├── helpers │ └── listops.py ├── listops.py ├── mnist.py └── remote_homology.py ├── experiment.py ├── fpt.py ├── models ├── __init__.py └── lstm.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_STORE 2 | **/*.pyc 3 | **/*.swp 4 | MANIFEST 5 | *.egg-info 6 | \.idea/ 7 | /.idea 8 | /data 9 | /misc 10 | /wandb 11 | /models 12 | /.ipynb_checkpoints 13 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 FPT (Pretrained Transformers as Universal Computation Engines) Authors (https://arxiv.org/abs/2103.05247) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # universal-computation 2 | 3 | ## Overview 4 | Official codebase for [Pretrained Transformers as Universal Computation Engines](https://arxiv.org/abs/2103.05247). 5 | Contains demo notebook and scripts to reproduce experiments. 6 | 7 | ### Project Demo 8 | 9 | For a minimal demonstration of frozen pretrained transformers, see ```demo.ipynb```. 10 | You can run the notebook which reproduces the Bit XOR experiment in a couple minutes, and visualizes the learned 11 | attention maps. 12 | 13 | ### Status 14 | No updates are currently planned but there may be new features added in the future. 15 | 16 | Currently the repo supports the following tasks: 17 | ``` 18 | ['bit-memory', 'bit-xor', 'listops', 'mnist', 'cifar10', 'cifar10-gray', 'remote-homology'] 19 | ``` 20 | 21 | As well as the following models: 22 | ``` 23 | ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl', 'vit', 'lstm'] 24 | ``` 25 | 26 | Note that CIFAR-10 LRA is ```cifar10-gray``` with a patch size of 1. 27 | 28 | ## Usage 29 | 30 | ### Installation 31 | 32 | 1. Install Anaconda environment: 33 | ``` 34 | $ conda env create -f environment.yml 35 | ``` 36 | 37 | 2. Add ```universal-computation/``` to your PYTHONPATH, i.e. add this line to your ```~/.bashrc```: 38 | ``` 39 | export PYTHONPATH=~/universal-computation:$PYTHONPATH 40 | ``` 41 | 42 | ### Downloading datasets 43 | 44 | Datasets are stored in ```data/```. 45 | MNIST and CIFAR-10 are automatically downloaded by PyTorch upon starting experiment. 46 | 47 | #### Listops 48 | 49 | Download the files for Listops from [Long Range Arena](https://github.com/google-research/long-range-arena). 50 | Move the ```.tsv``` files into ```data/listops```. 51 | There should be three files: ```basic_test, basic_train, basic_val```. 52 | The script evaluates on the validation set by default. 53 | 54 | #### Remote homology 55 | 56 | Install and download the files for Remote Homology from [TAPE](https://github.com/songlab-cal/tape). 57 | Move the files into ```data/tape```, i.e. there will exist a directory (and valid variant) 58 | ``` 59 | data/tape/remote_homology/remote_homology_train.lmdb 60 | ``` 61 | Inside, there should be two files, ```data.mdb``` and ```lock.mdb```. 62 | The script evaluates on the validation set by default. 63 | 64 | ### Running experiments 65 | 66 | You can run experiments with: 67 | ``` 68 | python scripts/run.py 69 | ``` 70 | 71 | Adding ```-w True``` will log results to Weights and Biases. 72 | 73 | ## Citation 74 | 75 | ``` 76 | @article{lu2021fpt, 77 | title={Pretrained Transformers as Universal Computation Engines}, 78 | author={Kevin Lu and Aditya Grover and Pieter Abbeel and Igor Mordatch}, 79 | journal={arXiv preprint arXiv:2103.05247}, 80 | year={2021} 81 | } 82 | ``` 83 | 84 | ## License 85 | 86 | MIT 87 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Pretrained Transformers as Universal Computation Engines Demo\n", 8 | "\n", 9 | "This is a demo notebook illustrating creating a Frozen Pretrained Transformer (FPT) and training on the Bit XOR task, which converges within a couple minutes.\n", 10 | "\n", 11 | "arXiv: https://arxiv.org/pdf/2103.05247.pdf\n", 12 | "\n", 13 | "Github: https://github.com/kzl/universal-computation" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import matplotlib.pyplot as plt\n", 23 | "import numpy as np\n", 24 | "import torch\n", 25 | "import torch.nn as nn\n", 26 | "\n", 27 | "from transformers.models.gpt2.modeling_gpt2 import GPT2Model" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Creating the dataset\n", 35 | "\n", 36 | "For this demo, we'll look at calculating the elementwise XOR between two randomly generated bitstrings.\n", 37 | "If you want to play more with the model, feel free to try larger $n$, although it will take longer to train." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def generate_example(n):\n", 47 | " bits = np.random.randint(low=0, high=2, size=(2, n))\n", 48 | " xor = np.logical_xor(bits[0], bits[1]).astype(np.long)\n", 49 | " return bits.reshape((2*n)), xor" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | " String 1: [0 0 1 1 0]\n", 62 | " String 2: [0 0 1 0 0]\n", 63 | "Output XOR: [0 0 0 1 0]\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "n = 5\n", 69 | "bits, xor = generate_example(n)\n", 70 | "\n", 71 | "print(' String 1:', bits[:n])\n", 72 | "print(' String 2:', bits[n:])\n", 73 | "print('Output XOR:', xor)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "## Creating the frozen pretrained transformer\n", 81 | "\n", 82 | "We simply wrap a pretrained GPT-2 model with linear input and output layers, then freeze the weights of the self-attention and feedforward layers.\n", 83 | "You can also see what happens using a randomly initialized model instead." 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 4, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "if torch.cuda.is_available():\n", 93 | " device = 'cuda'\n", 94 | "else:\n", 95 | " device = 'cpu'" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stderr", 105 | "output_type": "stream", 106 | "text": [ 107 | "Some weights of GPT2Model were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias']\n", 108 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "gpt2 = GPT2Model.from_pretrained('gpt2') # loads a pretrained GPT-2 base model\n", 114 | "in_layer = nn.Embedding(2, 768) # map bit to GPT-2 embedding dim of 768\n", 115 | "out_layer = nn.Linear(768, 2) # predict logits" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 6, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "for name, param in gpt2.named_parameters():\n", 125 | " # freeze all parameters except the layernorm and positional embeddings\n", 126 | " if 'ln' in name or 'wpe' in name:\n", 127 | " param.requires_grad = True\n", 128 | " else:\n", 129 | " param.requires_grad = False" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "## Training loop\n", 137 | "\n", 138 | "We train the model with stochastic gradient descent on the Bit XOR task.\n", 139 | "The model should converge within 5000 samples." 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 7, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "params = list(gpt2.parameters()) + list(in_layer.parameters()) + list(out_layer.parameters())\n", 149 | "optimizer = torch.optim.Adam(params)\n", 150 | "loss_fn = nn.CrossEntropyLoss()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 8, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "for layer in (gpt2, in_layer, out_layer):\n", 160 | " layer.to(device=device)\n", 161 | " layer.train()" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 9, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "Samples: 500, Accuracy: 0.6320000129938126\n", 174 | "Samples: 1000, Accuracy: 0.600000016093254\n", 175 | "Samples: 1500, Accuracy: 0.6560000130534172\n", 176 | "Samples: 2000, Accuracy: 0.7320000123977661\n", 177 | "Samples: 2500, Accuracy: 0.6400000116229058\n", 178 | "Samples: 3000, Accuracy: 0.6960000142455101\n", 179 | "Samples: 3500, Accuracy: 0.7640000116825104\n", 180 | "Samples: 4000, Accuracy: 0.7520000123977661\n", 181 | "Samples: 4500, Accuracy: 0.7560000121593475\n", 182 | "Samples: 5000, Accuracy: 0.8280000087618827\n", 183 | "Samples: 5500, Accuracy: 0.9000000059604645\n", 184 | "Samples: 6000, Accuracy: 0.9440000027418136\n", 185 | "Samples: 6500, Accuracy: 0.9520000028610229\n", 186 | "Final accuracy: 0.9920000004768371\n" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "accuracies = [0]\n", 192 | "while sum(accuracies[-50:]) / len(accuracies[-50:]) < .99:\n", 193 | " x, y = generate_example(n)\n", 194 | " x = torch.from_numpy(x).to(device=device, dtype=torch.long)\n", 195 | " y = torch.from_numpy(y).to(device=device, dtype=torch.long)\n", 196 | " \n", 197 | " embeddings = in_layer(x.reshape(1, -1))\n", 198 | " hidden_state = gpt2(inputs_embeds=embeddings).last_hidden_state[:,n:]\n", 199 | " logits = out_layer(hidden_state)[0]\n", 200 | " \n", 201 | " loss = loss_fn(logits, y)\n", 202 | " accuracies.append((logits.argmax(dim=-1) == y).float().mean().item())\n", 203 | " \n", 204 | " optimizer.zero_grad()\n", 205 | " loss.backward()\n", 206 | " optimizer.step()\n", 207 | " \n", 208 | " if len(accuracies) % 500 == 0:\n", 209 | " accuracy = sum(accuracies[-50:]) / len(accuracies[-50:])\n", 210 | " print(f'Samples: {len(accuracies)}, Accuracy: {accuracy}')\n", 211 | "\n", 212 | "print(f'Final accuracy: {sum(accuracies[-50:]) / len(accuracies[-50:])}')" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "## Visualizing attention map\n", 220 | "\n", 221 | "We can visualize the attention map of the first layer: the model learns to attend to the relevant bits for each element in the XOR operation.\n", 222 | "Note the two consistent diagonal lines for output tokens 5-9 across samples, denoting each position of either string (the pattern is stronger if the model is allowed to train longer or evaluated on more samples)." 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 10, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "for layer in (gpt2, in_layer, out_layer):\n", 232 | " layer.eval()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 11, 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "name": "stdout", 242 | "output_type": "stream", 243 | "text": [ 244 | " String 1: [0 1 1 0 0]\n", 245 | " String 2: [1 1 1 1 1]\n", 246 | "Prediction: [1 0 0 1 1]\n", 247 | "Output XOR: [1 0 0 1 1]\n" 248 | ] 249 | } 250 | ], 251 | "source": [ 252 | "bits, xor = generate_example(n)\n", 253 | "\n", 254 | "with torch.no_grad():\n", 255 | " x = torch.from_numpy(bits).to(device=device, dtype=torch.long)\n", 256 | " \n", 257 | " embeddings = in_layer(x)\n", 258 | " transformer_outputs = gpt2(\n", 259 | " inputs_embeds=embeddings,\n", 260 | " return_dict=True,\n", 261 | " output_attentions=True,\n", 262 | " )\n", 263 | " logits = out_layer(transformer_outputs.last_hidden_state[n:])\n", 264 | " predictions = logits.argmax(dim=-1).cpu().numpy()\n", 265 | "\n", 266 | "print(' String 1:', bits[:n])\n", 267 | "print(' String 2:', bits[n:])\n", 268 | "print('Prediction:', predictions)\n", 269 | "print('Output XOR:', xor)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 12, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "data": { 279 | "text/plain": [ 280 | "" 281 | ] 282 | }, 283 | "execution_count": 12, 284 | "metadata": {}, 285 | "output_type": "execute_result" 286 | }, 287 | { 288 | "data": { 289 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAAELCAYAAAAlYhhRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUmUlEQVR4nO3de7BddXnG8e+TO/eQIAQIIVQoUyqaoaAiotDxggIDOGKTAJERCRYQ7VgVCugBaZWAILV4gRYzQgmK4g0BUSRRCkqCUJByCZAEEkOC4SK3HCR5+8daB7ab/Ttn7Zy19i3PZ2bPOeuy3/XuDXnPuvzWuxQRmJk1MqLdCZhZ53KBMLMkFwgzS3KBMLMkFwgzSxrV7gSGss2EkTF1p9Glx33w7k1Lj2nWjdbyPC9Fvxot6/gCMXWn0dz+s51Kj/veHaaVHtOsG/02bkou8yGGmSW5QJhZkguEmSW5QJhZkguEmSW5QJhZUssLhCTlP/tqp82s87RjHMRRkrYHxkn6DPAH4Io25GFmQ2j5HkREXAEsBz4NPJpPm1kHaschxkxgMnAeMCWfrl9ntqRFkhY9sWZdq1M0s1w7DjHmRURI6ouIOY3OQUTEJcAlAHu/aZxbXpm1STsOMSL/2Vc7bWadx5c5zSzJBcLMklwgzCzJBcLMklwgzCzJBcLMklwgzCzJBcLMkjq+ae3i+8dz8H6HlR739Ed+VHpMgH/9q2mVxDVrB+9BmFmSC4SZJblAmFmSC4SZJblAmFmSC4SZJblAmFmSu1qbWZK7WptZkrtam1lSx3e1fmn9C61O0cxyHd/Vequxk9zU1qxN3NXazJJ8mdPMklwgzCzJBcLMklwgzCzJBcLMklwgzCzJBcLMklwgzCyp47tas34d8exzpYf97BkfKz0mwMvXPll6zAmHPFh6TLMivAdhZkkuEGaW5AJhZkkuEGaW5AJhZkkuEGaW5AJhZknuam1mScMaKCVpYkSsafJt7mpt1iUK7UFIOl7Sp2um95S0HFidN5edVHSD7mpt1j2KHmJ8HHixZvoC4Gngk8BWwNlFN9h8V+u1RUObWcmKHmLsDNwPIGkr4J3A4RFxnaQ1wBeb2GZzXa1Hv85Nbc3apOgexAhgff7724EA5ufTjwHbFt2gu1qbdY+iBWIxcHD++3Tg1ogYeKLNDkD5tzCaWdsVPcQ4H7hc0oeBrYEja5YdCNxddmJm1n6FCkREXCnpUeAtwMKI+FXN4lXAj6tIzszaq/A4iIi4BbilwfzPl5qRmXWMpgZK5eMdpgDj6pfV7VWYWQ8oVCAk7QhcTnZ58zWLya5qjCwxLzPrAEX3IL4O7Al8BrgH6K8sIzPrGEULxP7AKRFxeZXJNDRqFGy3TelhN3+smhr3+ILyc336mImlxwQYf/ltlcS13lF0HMSLwOoqEzGzzlO0QFwKHFNlImbWeYoeYqwAjpF0E3A9DUZORsRlZSZmZu1XtEB8I/85lWzkZL0AXCDMekzRArFLpVmYWUcqOtR6WdWJmFnnaXYk5RuBdwATgW9GxOOSdgVWRcSzVSRoZu1TtOXcWElXA3cC/w58juw2b4A5wOlFN+imtWbdo+hlzn8F3kV2qXM7suHVA64H3tvENo/K+1sONK09qon3mlkLFS0QM4AzIuJKXnuJcwnZ1Y1C3LTWrHsULRATgfsGiTG26Aabblq77oXXxDCz1ihaIJYA+yaWvRl4oIltzouI84C1ETEHmFe/QkRcEhF7R8TeY0Zu2kRoMytT0QLxbeBUSUcBo/N5IelA4J9oYpCUm9aadY+iBWIO8FOynhBP5fNuAX4B3BARX60gNzNrs6IDpdYB0yVdTHbFYltgDVlxWFBhfmbWRkU7Ss2MiCsj4tfArxss/2pEfLz07MysrYoeYnxL0rsaLZB0EfDR8lIys05RtECcA1wjaa/amZIuAD5G9jAdM+sxRc9BfEHSDsD1kt4WEQ9LOh84GZgeET+qNEsza4tmbtY6kezk5I2SrgdOAGZGxDWVZGZmbVf0EGNgvMJMsof1zgaOjoirq0rMzNovuQchKfUgnC2A54CTJJ2Uz4uIaPTMjGGL/n7WP/Bw6XHXHfCm0mMCbLZy/dArNWnCnU8NvdIGWHnS2yqJu+3Ft1YS11pvsEOM9WSt5Oo9nb/MrMclC0REHNDCPMysAxU+B2FmG5/CBULS9pLOl7RQ0sP5zzn5A33NrAcVbTn318BdwClkJyhvz39+ArhL0m5VJWhm7VN0HMS5wJ+At0TE0oGZknYGbsyXf6D07MysrYoeYhwInFlbHOCVdvh9NH6Yjpl1uaIFYgyQamv/bL68EHe1NuseRQvEXcDHJf3F+vk/7hPz5UW5q7VZlxhsJOUvgRMj4n7gbOBa4D5J3wFWApOAI4HdgIOLbjAirpA0A/gScFREXDWM/M2sQoPtQRwAbAkQETcAh5AdTpwOXAycQXYl45CIuLHoBpvtav3n6C8a2sxKVvhuzrxI3CBpU2Br4KmI2JCe9PMiIiT1RcScRucgIuIS4BKALUdMcFNbszZp6tmcAHlR2OCHVbirtVn3GKpAzJZ0SIE4ERGfLyMhM+scQxWIjxSME4ALhFmPGeoy51sjYkSB18iWZGtmLeW7Oc0syQXCzJJcIMwsabCTlLuQjZg0s43UYC3nlrUyETPrPE0PlGo1jRnDyMmTS48bI6u5iXT8vambXjec1jxdekyASf9Vzd+ApX3VdMue0udu2a3mcxBmluQCYWZJRXtSzpI0MbFsgqRZ5aZlZp2g6B7Et4DXJ5btki83sx5TtEAMdkZvM+DlEnIxsw4zWEepacBeNbMOlfSGutU2AaYDi8tPzczabbDLnIfx6h2aQdZJqpE1wHFlJmVmnWGwAvEVYC7Z4cUjZM+9uLNunX5gVdGmL5IuI2tdtzoi6vdGzKzDDDaS8hngGQBJuwArI+KlYW5vLvAfwLeHGcfMWqDQScqIWFZCcSAifgU8Odw4ZtYahYZaS1pPdh4iqcymMZJmA7MBxo3aoqywZtakovdinM1rC8RE4D3AWLJDh9LUdrXeatwkN7U1a5NCBWKgA3U9SSOBn5CfqzCz3jKsezEiYh3wNeCTpWRjZh2ljJu1xgITiqwoaR5wG7C7pOWSPH7CrIMVPUk5pcHsMcAbyJ6xuahInIiYUTw1M2u3oicpl9L4KoaAh4GTykrIzDpH0QLxEV5bINYCy4CF+bkIM+sxRa9izK04DzPrQE31pJS0Jdl5hx2BFcA9EVF+E0Yz6wiFC4SkzwGfAjbn1f4Qz0o6LyLOqSI5M2uvolcxzgLOBP4TuApYBWwHzADOkjQqNZhquKL/JV5e+ljpcVccv33pMQEm31T+wM9xz21WekyAeOrpSuLuctF9lcSNaXtUEnf9Xf9XSdxeUHQP4njgyxHx6Zp59wK/lPQM2X0TfSXnZmZtVnSg1FbAzxLLbsiXm1mPKVogfgvsk1i2T77czHpM0UOMU4AfSHoZuJpXz0F8iGyMxGGSXik2EbG+7ETNrPWKFoi7859fyl+1BNxTMx1NxDWzDjacfhBm1uOG1Q/CzHpb0UfvXZY3rm20bOe8W3Uhkg6S9ICkhySdWvR9ZtZ6Ra9iHAu8LrFsG+DDRYLkHaguBt4H7AHMkFTN6BczG7ZmGsakzkFMAl4sGOPNwEMR8UjeJfsqsgf0mFkHGuzRe0cAR9TMOkvSH+tW2wTYH7ij4PZ2BGrHTS8H3tJg2692tWbTgqHNrGyDnaScQvaPH7K9h2lkT9Kq1Q/cCpxWZlK1Xa231ARfPTFrk8GerHURcBGApCXA4RHxv8Pc3gpgp5rpyfk8M+tARS9zNryCsQEWArvlV0RWkD0ZfGZJsc2sZEVv937HUOvkj9Ubap2XJZ1MduPXSOCyiLi3SA5m1npFR1LOZ+iRlIUevRcR1wHXFdyumbVR0QJxYIN5E4FDgHcCJ5eWkZl1jKLnIBYkFl0j6ULgUOD60rIys45QxpO1fkp227eZ9ZgyCsTugPs/mPWgolcxZjWYPfDoveOAa8pMysw6Q9GTlHMT8/uB7wCfKCWbFtp17hOVxF36wW1Lj7ndyELPRm7a2EerGaMW/fUDbkuyeFklYfV3f1t6zLijN67eFy0QjQZKrY2IVWUmY2adpehVjGpKt5l1tKLnIPYB/p5X76N4DPhlRCysKjEza79BC4SkHYFvAwfw6uP2BoSkBcCsiFheTXpm1k7Jy5ySxpMNsZ4GnAr8DVn/h03y308D3gjcnK9rZj1msHEQpwJbAHtFxHkR8UBE9OevByJiDtlDc7bI1zWzHjNYgTgC+NJgJygjYglwLn/ZeWpQblpr1j0GKxBTKNZK7o583SG5aa1ZdxmsQDwPFBmhszXwQsHtuWmtWRcZrEDcDhxTIMasfN0iGjWt3bF+JUmzJS2StOjPr2mDaWatMliB+ArwAUnnSxpTv1DSGEnnA4cDF5aZVERcEhF7R8TeoxlbZmgza8JgTWtvlHQG8AVglqSfA0vzxVOBd5M1jemLiBsLbs9Na826yKADpSLi3yTdBnyGbE9hk3zRWmABcH5E3NTE9ty01qyLDDnUOiJuJhsMNZJsjwFgTUSsa3Zjblpr1l2K3s1JXhBWD3eDblpr1j3K6ChlZj3KBcLMklwgzCzJBcLMklwgzCyp8FWMdtGIEYwYV8FoytVryo8J7Hzt6NJjPr7f1qXHBNjupqavVBej+t5C5Yh11TxdQQ8+WnrMJ07Yt/SYANt887ZK4qZ4D8LMklwgzCzJBcLMklwgzCzJBcLMklwgzCzJBcLMklpaICRdJmm1pN+3crtmtmFavQcxFzioxds0sw3U0gIREb8CnmzlNs1sw3XkUGtJs4HZAOO0WZuzMdt4deRJytqu1mM0rt3pmG20OrJAmFlncIEws6RWX+acB9wG7C5puaTjWrl9M2tOS09SRsSMVm7PzIbHhxhmluQCYWZJLhBmluQCYWZJLhBmlqSIaHcOg9pyxMR46+gK7u8aUU3n5SpozJhK4i791J6VxJ3Sd2slcavqlt1Nlp311tJjPvr1C1m74rGGX673IMwsyQXCzJJcIMwsyQXCzJJcIMwsyQXCzJJcIMwsyV2tzSzJXa3NLMldrc0sqfO7WrNpm7Mx23h15EnK2q7Wo93V2qxtOrJAmFlncIEwsyR3tTazJHe1NrMkH2KYWZILhJkluUCYWZILhJkluUCYWVLHd7WW9ASwrMCq2wB/rCAFx+2uXLstbifkunNEvK7Rgo4vEEVJWhQReztu+XG7Kddui9vpufoQw8ySXCDMLKmXCsQljltZ3G7KtdvidnSuPXMOwszK10t7EGZWMhcIM0vqiQIh6SBJD0h6SNKpJcWspAN3hXFL/w6qituF323pcbsl164vEJJGAhcD7wP2AGZI2qOE0HOppgN36XGr+g783VYat4qYpcft+gIBvBl4KCIeiYiXgKuAw4YbtKoO3BXFreQ7qCpul323lcTtllx7oUDsCDxWM708n7cxqeo78He7keuFAmFmFemFArEC2KlmenI+b2NS1Xfg73Yj1wsFYiGwm6RdJI0BpgM/bnNOrVbVd+DvdmMXEV3/At4PPAg8DJxeUsx5wErgz2TH3sd1eNzSvwN/t9XF7ZZcPdTazJJ64RDDzCriAmFmSS4QZpbkAmFmSS4QZpbkAtFBJB0rKSTt2gG5jJfUJ2mvIdabmuc81Gt+gW3Ol3RLaR/Chq2lD++1rjIe+DzZtfTfDbLeSmDfunm3kd1V+M2aeX8qMTdrERcIG5aI6Ad+UztPEsCKiPhNwzdZ1/AhRocb2O2W9C5Jv5P0gqTfSzqibr2+fFd+T0k35+utlHS2pBE16w0cxkxt9P7896nAknzRpTWHCccO43McJOk2SS9KekbSDyXtXuB9Z0p6SdLR+fQoSadJul9Sv6Q/SPqypHE17xk47Dkh//wrJT0t6SeSJtfFnynpTknPSfqTpHsknbChn7PXuEB0h9cDFwEXAB8g262/OnGu4ofAL4DDgSuBM4HPNbm9lfl2AL5IdgixL/DTJuMAWXHI3/sc8A/APwJvAG6R1PD2cUkjJH0d+CxwaERckS+6AjiD7LMdnOd3HPDfDcKcBuwKfAT4RP4ZBuIg6e359AKy7+uDwKVkh1cGvXEvRq+8gGOBAHatmTefbFz9bjXztgXWAf9SM68vf++pdTEvBZ4FxtdtY2rden3Z/w6vTE/N1/voBnyOAM6pmV4ELAZG1czbJf9cF9R91luAccD3gSeAfWqW75/HnlW3vaPy+dPqcp9ft94/5/N3qJl+st3/3Tv55T2I7rA4IhYPTETEamA1MKXBut+tm74K2JzsL3bLSdoM2Av4TkS8PDA/IpYA/wO8s+4tWwA/y9+zX0QsrFl2EPAS8L38UGOUpFHAjfnyd9TFuq5u+p7858D3thDYWtIVkg6RNL7pD9jjXCC6Q6MWYv1kf2nrrUpMt6sT1NaAyA5b6j0OTKibNwXYD7g+Ih6sW7YtMAZ4nmzvY+C1Ol8+sW79+u+tP/85DiAiFgBHkvW8+AHwhKRfSHrj0B9r4+CrGL1nO+CRuml4tdHL2vznmLr31f/jKstTZLv1kxosm8Rr/xHfS9Yo93JJL0bEp2qWrSHLf//Etv7QbHIR8T2yPZLNgQOAc4EbJE2OiPXNxus13oPoPR+qm55OdnJwYPd6Wf7zlUOOfDf9PXXvG/hru8lwkomI54E7gCPzLtkD29wZeBvZeYf698wDZgKnSLqwZtENZH/9t4qIRQ1eTReImm0+FxHXko3d2J7qCmZX8R5E7zk+v6y5EHgv8FGgLyKeyZcvJGv+cl6+Xj9wIjC2Ls4qsr/Y0yXdTbZbvyQi1mxATmeSXcW4VtLXyM6JnAU8A3y50Rsi4ruS1gHzJI2MiFMiYr6keWR/8S8AbgfWk52UfD/w2QaHJUmSzibbw7qZbO9jMnAKcFdEPLEBn7PneA+i9xwGvJusNdzRwDnAFwYW5icKDyPrVj2XbHf+5/nv1Ky3nqy4bE122XQhcOiGJBQRN5BdkhxPdhL1G8B9wNsH+6sfEd8n2yM6QdLFykZgHU12xeWDwI+A7wEnk10lqT//MpTfkhWXC8m+g3PJLnke3GScnuWOUj1CUh/Z0OjRtVcLzIbDexBmluQCYWZJPsQwsyTvQZhZkguEmSW5QJhZkguEmSW5QJhZ0v8DjNl9sTxm9isAAAAASUVORK5CYII=\n", 290 | "text/plain": [ 291 | "
" 292 | ] 293 | }, 294 | "metadata": { 295 | "needs_background": "light" 296 | }, 297 | "output_type": "display_data" 298 | } 299 | ], 300 | "source": [ 301 | "attentions = transformer_outputs.attentions[0][0] # first layer, first in batch\n", 302 | "mean_attentions = attentions.mean(dim=0) # take the mean over heads\n", 303 | "mean_attentions = mean_attentions.cpu().numpy()\n", 304 | "\n", 305 | "plt.xlabel('Input Tokens', size=16)\n", 306 | "plt.xticks(range(10), bits)\n", 307 | "plt.ylabel('Output Tokens', size=16)\n", 308 | "plt.yticks(range(10), ['*'] * 5 + list(predictions))\n", 309 | "\n", 310 | "plt.imshow(mean_attentions)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "## Sanity check\n", 318 | "\n", 319 | "As a sanity check, we can see that the model could solve this task without needing to finetune the self-attention layers! The XOR was computed using only the connections already present in GPT-2." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 13, 325 | "metadata": {}, 326 | "outputs": [ 327 | { 328 | "name": "stderr", 329 | "output_type": "stream", 330 | "text": [ 331 | "Some weights of GPT2Model were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias']\n", 332 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 333 | ] 334 | }, 335 | { 336 | "name": "stdout", 337 | "output_type": "stream", 338 | "text": [ 339 | "h.0.attn.c_attn.weight is unchanged\n", 340 | "h.0.attn.c_attn.bias is unchanged\n", 341 | "h.0.attn.c_proj.weight is unchanged\n", 342 | "h.0.attn.c_proj.bias is unchanged\n", 343 | "h.0.mlp.c_fc.weight is unchanged\n", 344 | "h.0.mlp.c_fc.bias is unchanged\n", 345 | "h.0.mlp.c_proj.weight is unchanged\n", 346 | "h.0.mlp.c_proj.bias is unchanged\n", 347 | "h.1.attn.c_attn.weight is unchanged\n", 348 | "h.1.attn.c_attn.bias is unchanged\n", 349 | "h.1.attn.c_proj.weight is unchanged\n", 350 | "h.1.attn.c_proj.bias is unchanged\n", 351 | "h.1.mlp.c_fc.weight is unchanged\n", 352 | "h.1.mlp.c_fc.bias is unchanged\n", 353 | "h.1.mlp.c_proj.weight is unchanged\n", 354 | "h.1.mlp.c_proj.bias is unchanged\n", 355 | "h.2.attn.c_attn.weight is unchanged\n", 356 | "h.2.attn.c_attn.bias is unchanged\n", 357 | "h.2.attn.c_proj.weight is unchanged\n", 358 | "h.2.attn.c_proj.bias is unchanged\n", 359 | "h.2.mlp.c_fc.weight is unchanged\n", 360 | "h.2.mlp.c_fc.bias is unchanged\n", 361 | "h.2.mlp.c_proj.weight is unchanged\n", 362 | "h.2.mlp.c_proj.bias is unchanged\n", 363 | "h.3.attn.c_attn.weight is unchanged\n", 364 | "h.3.attn.c_attn.bias is unchanged\n", 365 | "h.3.attn.c_proj.weight is unchanged\n", 366 | "h.3.attn.c_proj.bias is unchanged\n", 367 | "h.3.mlp.c_fc.weight is unchanged\n", 368 | "h.3.mlp.c_fc.bias is unchanged\n", 369 | "h.3.mlp.c_proj.weight is unchanged\n", 370 | "h.3.mlp.c_proj.bias is unchanged\n", 371 | "h.4.attn.c_attn.weight is unchanged\n", 372 | "h.4.attn.c_attn.bias is unchanged\n", 373 | "h.4.attn.c_proj.weight is unchanged\n", 374 | "h.4.attn.c_proj.bias is unchanged\n", 375 | "h.4.mlp.c_fc.weight is unchanged\n", 376 | "h.4.mlp.c_fc.bias is unchanged\n", 377 | "h.4.mlp.c_proj.weight is unchanged\n", 378 | "h.4.mlp.c_proj.bias is unchanged\n", 379 | "h.5.attn.c_attn.weight is unchanged\n", 380 | "h.5.attn.c_attn.bias is unchanged\n", 381 | "h.5.attn.c_proj.weight is unchanged\n", 382 | "h.5.attn.c_proj.bias is unchanged\n", 383 | "h.5.mlp.c_fc.weight is unchanged\n", 384 | "h.5.mlp.c_fc.bias is unchanged\n", 385 | "h.5.mlp.c_proj.weight is unchanged\n", 386 | "h.5.mlp.c_proj.bias is unchanged\n", 387 | "h.6.attn.c_attn.weight is unchanged\n", 388 | "h.6.attn.c_attn.bias is unchanged\n", 389 | "h.6.attn.c_proj.weight is unchanged\n", 390 | "h.6.attn.c_proj.bias is unchanged\n", 391 | "h.6.mlp.c_fc.weight is unchanged\n", 392 | "h.6.mlp.c_fc.bias is unchanged\n", 393 | "h.6.mlp.c_proj.weight is unchanged\n", 394 | "h.6.mlp.c_proj.bias is unchanged\n", 395 | "h.7.attn.c_attn.weight is unchanged\n", 396 | "h.7.attn.c_attn.bias is unchanged\n", 397 | "h.7.attn.c_proj.weight is unchanged\n", 398 | "h.7.attn.c_proj.bias is unchanged\n", 399 | "h.7.mlp.c_fc.weight is unchanged\n", 400 | "h.7.mlp.c_fc.bias is unchanged\n", 401 | "h.7.mlp.c_proj.weight is unchanged\n", 402 | "h.7.mlp.c_proj.bias is unchanged\n", 403 | "h.8.attn.c_attn.weight is unchanged\n", 404 | "h.8.attn.c_attn.bias is unchanged\n", 405 | "h.8.attn.c_proj.weight is unchanged\n", 406 | "h.8.attn.c_proj.bias is unchanged\n", 407 | "h.8.mlp.c_fc.weight is unchanged\n", 408 | "h.8.mlp.c_fc.bias is unchanged\n", 409 | "h.8.mlp.c_proj.weight is unchanged\n", 410 | "h.8.mlp.c_proj.bias is unchanged\n", 411 | "h.9.attn.c_attn.weight is unchanged\n", 412 | "h.9.attn.c_attn.bias is unchanged\n", 413 | "h.9.attn.c_proj.weight is unchanged\n", 414 | "h.9.attn.c_proj.bias is unchanged\n", 415 | "h.9.mlp.c_fc.weight is unchanged\n", 416 | "h.9.mlp.c_fc.bias is unchanged\n", 417 | "h.9.mlp.c_proj.weight is unchanged\n", 418 | "h.9.mlp.c_proj.bias is unchanged\n", 419 | "h.10.attn.c_attn.weight is unchanged\n", 420 | "h.10.attn.c_attn.bias is unchanged\n", 421 | "h.10.attn.c_proj.weight is unchanged\n", 422 | "h.10.attn.c_proj.bias is unchanged\n", 423 | "h.10.mlp.c_fc.weight is unchanged\n", 424 | "h.10.mlp.c_fc.bias is unchanged\n", 425 | "h.10.mlp.c_proj.weight is unchanged\n", 426 | "h.10.mlp.c_proj.bias is unchanged\n", 427 | "h.11.attn.c_attn.weight is unchanged\n", 428 | "h.11.attn.c_attn.bias is unchanged\n", 429 | "h.11.attn.c_proj.weight is unchanged\n", 430 | "h.11.attn.c_proj.bias is unchanged\n", 431 | "h.11.mlp.c_fc.weight is unchanged\n", 432 | "h.11.mlp.c_fc.bias is unchanged\n", 433 | "h.11.mlp.c_proj.weight is unchanged\n", 434 | "h.11.mlp.c_proj.bias is unchanged\n" 435 | ] 436 | } 437 | ], 438 | "source": [ 439 | "fresh_gpt2 = GPT2Model.from_pretrained('gpt2')\n", 440 | "\n", 441 | "gpt2.to(device='cpu')\n", 442 | "gpt2_state_dict = gpt2.state_dict()\n", 443 | "for name, param in fresh_gpt2.named_parameters():\n", 444 | " if 'attn' in name or 'mlp' in name:\n", 445 | " new_param = gpt2_state_dict[name]\n", 446 | " if torch.abs(param.data - new_param.data).sum() > 1e-8:\n", 447 | " print(f'{name} was modified')\n", 448 | " else:\n", 449 | " print(f'{name} is unchanged')" 450 | ] 451 | } 452 | ], 453 | "metadata": { 454 | "kernelspec": { 455 | "display_name": "Python 3", 456 | "language": "python", 457 | "name": "python3" 458 | }, 459 | "language_info": { 460 | "codemirror_mode": { 461 | "name": "ipython", 462 | "version": 3 463 | }, 464 | "file_extension": ".py", 465 | "mimetype": "text/x-python", 466 | "name": "python", 467 | "nbconvert_exporter": "python", 468 | "pygments_lexer": "ipython3", 469 | "version": "3.7.7" 470 | } 471 | }, 472 | "nbformat": 4, 473 | "nbformat_minor": 4 474 | } 475 | -------------------------------------------------------------------------------- /demo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kzl/universal-computation/23b80fd1ba3caee493f5fa8715fb70a3ff8a250e/demo.pdf -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fpt 2 | dependencies: 3 | - python=3.7 4 | - pip: 5 | - boto3==1.17.102 6 | - einops==0.3.0 7 | - matplotlib==3.2.1 8 | - numpy==1.18.3 9 | - tape-proteins==0.4 10 | - tensorflow==2.3.0 11 | - tensorflow-datasets==4.0.1 12 | - torch==1.7.1 13 | - torchvision==0.8.2 14 | - transformers==4.1.1 15 | - tqdm==4.46.0 16 | - wandb==0.9.1 17 | -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | from universal_computation.experiment import run_experiment 2 | 3 | 4 | if __name__ == '__main__': 5 | 6 | experiment_name = 'fpt' 7 | 8 | experiment_params = dict( 9 | task='bit-memory', 10 | n=1000, # ignored if not a bit task 11 | num_patterns=5, # ignored if not a bit task 12 | patch_size=50, 13 | 14 | model_name='gpt2', 15 | pretrained=True, # if vit this is forced to true, if lstm this is forced to false 16 | 17 | freeze_trans=True, # if False, we don't check arguments other than in and out 18 | freeze_in=False, 19 | freeze_pos=False, 20 | freeze_ln=False, 21 | freeze_attn=True, 22 | freeze_ff=True, 23 | freeze_out=False, 24 | 25 | in_layer_sizes=None, # not in paper, but can specify layer sizes for an MLP, 26 | out_layer_sizes=None, # ex. [32, 32] creates a 2-layer MLP with dimension 32 27 | 28 | learning_rate=1e-3, 29 | batch_size=4, 30 | dropout=0.1, 31 | orth_gain=1.41, # orthogonal initialization of input layer 32 | ) 33 | 34 | run_experiment(experiment_name, experiment_params) 35 | -------------------------------------------------------------------------------- /universal_computation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kzl/universal-computation/23b80fd1ba3caee493f5fa8715fb70a3ff8a250e/universal_computation/__init__.py -------------------------------------------------------------------------------- /universal_computation/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kzl/universal-computation/23b80fd1ba3caee493f5fa8715fb70a3ff8a250e/universal_computation/datasets/__init__.py -------------------------------------------------------------------------------- /universal_computation/datasets/bit_memory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from universal_computation.datasets.dataset import Dataset 4 | 5 | 6 | class BitMemoryDataset(Dataset): 7 | 8 | def __init__(self, n=1000, num_patterns=5, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.n = n 11 | self.num_patterns = num_patterns 12 | 13 | def get_batch_np(self, batch_size, train): 14 | bits = np.random.randint(low=0, high=2, size=(batch_size, self.num_patterns, self.n)) 15 | bits = 2 * bits - 1 16 | query_inds = np.random.randint(low=0, high=self.num_patterns, size=batch_size) 17 | query_bits = bits[range(batch_size), query_inds] 18 | mask = np.random.randint(low=0, high=2, size=query_bits.shape) 19 | masked_query_bits = mask * query_bits 20 | masked_query_bits = masked_query_bits.reshape(batch_size, 1, self.n) 21 | x = np.concatenate([bits, masked_query_bits], axis=1) 22 | y = query_bits 23 | return x, y 24 | -------------------------------------------------------------------------------- /universal_computation/datasets/bit_xor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from universal_computation.datasets.dataset import Dataset 4 | 5 | 6 | class BitXORDataset(Dataset): 7 | 8 | def __init__(self, n=5, num_patterns=2, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.n = n 11 | self.num_patterns = num_patterns 12 | 13 | def get_batch_np(self, batch_size, train): 14 | bits = np.random.randint(low=0, high=2, size=(batch_size, self.num_patterns, self.n)) 15 | xored_bits = bits[:,0] 16 | for i in range(1, self.num_patterns): 17 | xored_bits = np.logical_xor(xored_bits, bits[:,i]) 18 | return bits, xored_bits 19 | -------------------------------------------------------------------------------- /universal_computation/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from torch.utils.data import DataLoader 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | from universal_computation.datasets.dataset import Dataset 7 | 8 | 9 | class CIFAR10Dataset(Dataset): 10 | 11 | def __init__(self, batch_size, patch_size=None, data_aug=True, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | 14 | self.batch_size = batch_size # we fix it so we can use dataloader 15 | self.patch_size = patch_size # grid of (patch_size x patch_size) 16 | 17 | if data_aug: 18 | transform = transforms.Compose([ 19 | transforms.ToTensor(), 20 | transforms.RandomCrop(32, padding=4), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.RandomRotation(20), 23 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 24 | ]) 25 | else: 26 | transform = transforms.Compose([ 27 | transforms.ToTensor(), 28 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 29 | ]) 30 | 31 | val_transform = transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 34 | ]) 35 | 36 | self.d_train = DataLoader( 37 | torchvision.datasets.CIFAR10('data/cifar', download=True, train=True, transform=transform), 38 | batch_size=batch_size, drop_last=True, shuffle=True, 39 | ) 40 | self.d_test = DataLoader( 41 | torchvision.datasets.CIFAR10('data/cifar', download=True, train=False, transform=val_transform), 42 | batch_size=batch_size, drop_last=True, shuffle=True, 43 | ) 44 | 45 | self.train_enum = enumerate(self.d_train) 46 | self.test_enum = enumerate(self.d_test) 47 | 48 | self.train_size = len(self.d_train) 49 | self.test_size = len(self.d_test) 50 | 51 | def reset_test(self): 52 | self.test_enum = enumerate(self.d_test) 53 | 54 | def get_batch(self, batch_size=None, train=True): 55 | if train: 56 | _, (x, y) = next(self.train_enum, (None, (None, None))) 57 | if x is None: 58 | self.train_enum = enumerate(self.d_train) 59 | _, (x, y) = next(self.train_enum) 60 | else: 61 | _, (x, y) = next(self.test_enum, (None, (None, None))) 62 | if x is None: 63 | self.test_enum = enumerate(self.d_test) 64 | _, (x, y) = next(self.test_enum) 65 | 66 | if self.patch_size is not None: 67 | x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) 68 | 69 | x = x.to(device=self.device) 70 | y = y.to(device=self.device) 71 | 72 | self._ind += 1 73 | 74 | return x, y 75 | -------------------------------------------------------------------------------- /universal_computation/datasets/cifar10_gray.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from torch.utils.data import DataLoader 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | from universal_computation.datasets.dataset import Dataset 7 | 8 | 9 | class CIFAR10GrayDataset(Dataset): 10 | 11 | def __init__(self, batch_size, patch_size=None, data_aug=True, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | 14 | self.batch_size = batch_size # we fix it so we can use dataloader 15 | self.patch_size = patch_size # grid of (patch_size x patch_size) 16 | 17 | if data_aug: 18 | transform = transforms.Compose([ 19 | transforms.ToTensor(), 20 | transforms.Grayscale(num_output_channels=1), 21 | transforms.RandomCrop(32, padding=4), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.RandomRotation(20), 24 | transforms.Normalize((0.5,), (0.5,)), 25 | ]) 26 | else: 27 | transform = transforms.Compose([ 28 | transforms.ToTensor(), 29 | transforms.Grayscale(num_output_channels=1), 30 | transforms.Normalize((0.5,), (0.5,)), 31 | ]) 32 | 33 | val_transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Grayscale(num_output_channels=1), 36 | transforms.Normalize((0.5,), (0.5,)), 37 | ]) 38 | 39 | self.d_train = DataLoader( 40 | torchvision.datasets.CIFAR10('data/cifar', download=True, train=True, transform=transform), 41 | batch_size=batch_size, drop_last=True, shuffle=True, 42 | ) 43 | self.d_test = DataLoader( 44 | torchvision.datasets.CIFAR10('data/cifar', download=True, train=False, transform=val_transform), 45 | batch_size=batch_size, drop_last=True, shuffle=True, 46 | ) 47 | 48 | self.train_enum = enumerate(self.d_train) 49 | self.test_enum = enumerate(self.d_test) 50 | 51 | def get_batch(self, batch_size=None, train=True): 52 | if train: 53 | _, (x, y) = next(self.train_enum, (None, (None, None))) 54 | if x is None: 55 | self.train_enum = enumerate(self.d_train) 56 | _, (x, y) = next(self.train_enum) 57 | else: 58 | _, (x, y) = next(self.test_enum, (None, (None, None))) 59 | if x is None: 60 | self.test_enum = enumerate(self.d_test) 61 | _, (x, y) = next(self.test_enum) 62 | 63 | if self.patch_size is not None: 64 | x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) 65 | 66 | x = x.to(device=self.device) 67 | y = y.to(device=self.device) 68 | 69 | self._ind += 1 70 | 71 | return x, y 72 | -------------------------------------------------------------------------------- /universal_computation/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Dataset: 5 | 6 | def __init__(self, device='cpu'): 7 | self.device = device 8 | self._ind = 0 9 | 10 | def get_batch(self, batch_size, train=True): 11 | x, y = self.get_batch_np(batch_size, train=train) 12 | x = torch.from_numpy(x).to(device=self.device, dtype=torch.float32) 13 | y = torch.from_numpy(y).to(device=self.device, dtype=torch.long) 14 | self._ind += 1 15 | return x, y 16 | 17 | def get_batch_np(self, batch_size, train): 18 | raise NotImplementedError 19 | 20 | def start_epoch(self): 21 | self._ind = 0 22 | -------------------------------------------------------------------------------- /universal_computation/datasets/helpers/listops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Input pipeline for the listops dataset.""" 16 | 17 | import numpy as np 18 | 19 | import tensorflow.compat.v1 as tf 20 | import tensorflow_datasets as tfds 21 | 22 | AUTOTUNE = tf.data.experimental.AUTOTUNE 23 | 24 | 25 | def preprocess_dataset(file_path, batch_size): 26 | """Preprocess dataset.""" 27 | tf.logging.info(file_path) 28 | sel_cols = ['Source', 'Target'] 29 | col_defaults = [tf.string, tf.int32] 30 | ds = tf.data.experimental.make_csv_dataset([file_path], 31 | batch_size, 32 | column_defaults=col_defaults, 33 | select_columns=sel_cols, 34 | field_delim='\t', 35 | header=True, 36 | num_epochs=1) 37 | ds = ds.unbatch() 38 | return ds 39 | 40 | 41 | def get_datasets(n_devices, 42 | task_name, 43 | data_dir=None, 44 | batch_size=256, 45 | max_length=512): 46 | """Get algorithmic datasets.""" 47 | if batch_size % n_devices: 48 | raise ValueError("Batch size %d isn't divided evenly by n_devices %d" % 49 | (batch_size, n_devices)) 50 | 51 | train_path = data_dir + task_name + '_train.tsv' 52 | val_path = data_dir + task_name + '_val.tsv' 53 | test_path = data_dir + task_name + '_test.tsv' 54 | 55 | train_dataset = preprocess_dataset(train_path, batch_size) 56 | val_dataset = preprocess_dataset(val_path, batch_size) 57 | test_dataset = preprocess_dataset(test_path, batch_size) 58 | 59 | tf.logging.info('Finished preprocessing') 60 | tf.logging.info('Building vocab') 61 | # build vocab 62 | vocab_set = set() 63 | tokenizer = tfds.deprecated.text.Tokenizer() 64 | 65 | for i, data in enumerate(train_dataset): 66 | examples = data['Source'] 67 | examples = tokenizer.tokenize(examples.numpy()) 68 | examples = np.reshape(examples, (-1)).tolist() 69 | vocab_set.update(examples) 70 | if i % 1000 == 0: 71 | tf.logging.info('Processed {}'.format(i)) 72 | if i > 1000: 73 | break 74 | vocab_set = list(set(vocab_set)) 75 | tf.logging.info('Finished processing vocab size={}'.format(len(vocab_set))) 76 | 77 | encoder = tfds.deprecated.text.TokenTextEncoder(vocab_set) 78 | 79 | def tf_encode(x): 80 | result = tf.py_function(lambda s: tf.constant(encoder.encode(s.numpy())), [ 81 | x, 82 | ], tf.int32) 83 | result.set_shape([None]) 84 | return result 85 | 86 | def tokenize(d): 87 | return { 88 | 'inputs': tf_encode(d['Source'])[:max_length], 89 | 'targets': d['Target'] 90 | } 91 | 92 | train_dataset = train_dataset.map(tokenize, num_parallel_calls=AUTOTUNE) 93 | val_dataset = val_dataset.map(tokenize, num_parallel_calls=AUTOTUNE) 94 | test_dataset = test_dataset.map(tokenize, num_parallel_calls=AUTOTUNE) 95 | 96 | max_shape = {'inputs': [max_length], 'targets': []} 97 | train_dataset = train_dataset.shuffle( 98 | buffer_size=1024, reshuffle_each_iteration=True).padded_batch( 99 | batch_size, padded_shapes=max_shape) 100 | val_dataset = val_dataset.padded_batch(batch_size, padded_shapes=max_shape) 101 | test_dataset = test_dataset.padded_batch(batch_size, padded_shapes=max_shape) 102 | 103 | train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE) 104 | val_dataset = val_dataset.prefetch(tf.data.experimental.AUTOTUNE) 105 | test_dataset = test_dataset.prefetch(tf.data.experimental.AUTOTUNE) 106 | 107 | return train_dataset, val_dataset, test_dataset, encoder 108 | -------------------------------------------------------------------------------- /universal_computation/datasets/listops.py: -------------------------------------------------------------------------------- 1 | import tensorflow_datasets 2 | import torch 3 | 4 | from universal_computation.datasets.dataset import Dataset 5 | from universal_computation.datasets.helpers.listops import get_datasets 6 | 7 | 8 | class ListopsDataset(Dataset): 9 | 10 | def __init__(self, batch_size, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | 13 | self.batch_size = batch_size # we fix it so we can use dataloader 14 | 15 | self.d_train, self.d_test, *_ = get_datasets(1, 'basic', batch_size=batch_size, data_dir='data/listops/') 16 | 17 | self.train_enum = iter(tensorflow_datasets.as_numpy(self.d_train)) 18 | self.test_enum = iter(tensorflow_datasets.as_numpy(self.d_test)) 19 | 20 | def reset_test(self): 21 | self.test_enum = enumerate(self.d_test) 22 | 23 | def get_batch(self, batch_size=None, train=True): 24 | if train: 25 | batch = next(self.train_enum, None) 26 | if batch is None: 27 | self.train_enum = iter(tensorflow_datasets.as_numpy(self.d_train)) 28 | batch = next(self.train_enum) 29 | else: 30 | batch = next(self.test_enum, None) 31 | if batch is None: 32 | self.test_enum = iter(tensorflow_datasets.as_numpy(self.d_test)) 33 | batch = next(self.test_enum) 34 | 35 | x, y = batch['inputs'], batch['targets'] 36 | x = torch.from_numpy(x).long() 37 | y = torch.from_numpy(y).long() 38 | 39 | x = x.to(device=self.device) 40 | y = y.to(device=self.device) 41 | 42 | self._ind += 1 43 | 44 | return x, y 45 | -------------------------------------------------------------------------------- /universal_computation/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from torch.utils.data import DataLoader 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | from universal_computation.datasets.dataset import Dataset 7 | 8 | 9 | class MNISTDataset(Dataset): 10 | 11 | def __init__(self, batch_size, patch_size=None, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | 14 | self.batch_size = batch_size # we fix it so we can use dataloader 15 | self.patch_size = patch_size # grid of (patch_size x patch_size) 16 | 17 | transform = transforms.Compose([ 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=0., std=1.), 20 | ]) 21 | 22 | self.d_train = DataLoader( 23 | torchvision.datasets.MNIST('data/mnist', download=True, train=True, transform=transform), 24 | batch_size=batch_size, drop_last=True, shuffle=True, 25 | ) 26 | self.d_test = DataLoader( 27 | torchvision.datasets.MNIST('data/mnist', download=True, train=False, transform=transform), 28 | batch_size=batch_size, drop_last=True, shuffle=True, 29 | ) 30 | 31 | self.train_enum = enumerate(self.d_train) 32 | self.test_enum = enumerate(self.d_test) 33 | 34 | def get_batch(self, batch_size=None, train=True): 35 | if train: 36 | _, (x, y) = next(self.train_enum, (None, (None, None))) 37 | if x is None: 38 | self.train_enum = enumerate(self.d_train) 39 | _, (x, y) = next(self.train_enum) 40 | else: 41 | _, (x, y) = next(self.test_enum, (None, (None, None))) 42 | if x is None: 43 | self.test_enum = enumerate(self.d_test) 44 | _, (x, y) = next(self.test_enum) 45 | 46 | if self.patch_size is not None: 47 | x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size) 48 | 49 | x = x.to(device=self.device) 50 | y = y.to(device=self.device) 51 | 52 | self._ind += 1 53 | 54 | return x, y 55 | -------------------------------------------------------------------------------- /universal_computation/datasets/remote_homology.py: -------------------------------------------------------------------------------- 1 | from tape import utils 2 | from tape.datasets import RemoteHomologyDataset as TAPERemoteHomologyDataset 3 | 4 | from universal_computation.datasets.dataset import Dataset 5 | 6 | 7 | class RemoteHomologyDataset(Dataset): 8 | 9 | """ 10 | Note that we clip all sequences less than max_seq_len = 1024 11 | for the sake of simplicity in our paper. This leaves a remaining 12 | 236224 examples (out of 242560 -- 97.39%). To correct the accuracy, 13 | we multiply reported accuracies for the paper by .9739. 14 | 15 | We pad lazily inside this dataset by assigning ID 28 to mean padding. 16 | This increases the input dimension of the model by 1, so the model 17 | should have an input dimension of 29. 18 | """ 19 | 20 | def __init__(self, data_subdir='tape', max_seq_len=1024, train_batch_size=2, test_batch_size=8, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | 23 | self.max_seq_len = max_seq_len 24 | 25 | data_dir = f'data/{data_subdir}' 26 | train_dataset = TAPERemoteHomologyDataset(data_dir, 'train') 27 | val_dataset = TAPERemoteHomologyDataset(data_dir, 'valid') 28 | 29 | self.d_train = utils.setup_loader(train_dataset, train_batch_size, -1, 1, 1, 0) 30 | self.d_test = utils.setup_loader(val_dataset, test_batch_size, -1, 1, 1, 0) 31 | 32 | self.train_enum = enumerate(self.d_train) 33 | self.test_enum = enumerate(self.d_test) 34 | 35 | def get_batch(self, batch_size=None, train=True): 36 | 37 | seq_len = self.max_seq_len + 1 38 | 39 | while seq_len > self.max_seq_len: 40 | if train: 41 | _, data = next(self.train_enum, (None, None)) 42 | if data is None: 43 | self.train_enum = enumerate(self.d_train) 44 | _, data = next(self.train_enum) 45 | else: 46 | _, data = next(self.test_enum, (None, None)) 47 | if data is None: 48 | self.test_enum = enumerate(self.d_test) 49 | _, data = next(self.test_enum) 50 | x, y = data['input_ids'], data['targets'] 51 | seq_len = x.shape[1] 52 | 53 | x = x.to(device=self.device) 54 | y = y.to(device=self.device) 55 | 56 | self._ind += 1 57 | return x, y 58 | -------------------------------------------------------------------------------- /universal_computation/experiment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import wandb 4 | 5 | import argparse 6 | from datetime import datetime 7 | import random 8 | import sys 9 | 10 | from universal_computation.fpt import FPT 11 | from universal_computation.trainer import Trainer 12 | 13 | 14 | def experiment( 15 | exp_name, 16 | exp_args, 17 | **kwargs 18 | ): 19 | 20 | """ 21 | Preliminary checks 22 | """ 23 | 24 | # Must be able to accumulate gradient if batch size is large 25 | assert 'batch_size' in kwargs 26 | assert kwargs['batch_size'] <= exp_args['gpu_batch_size'] or \ 27 | kwargs['batch_size'] % exp_args['gpu_batch_size'] == 0 28 | 29 | """ 30 | Create dataset, model, and trainer 31 | """ 32 | 33 | task = kwargs['task'] 34 | batch_size = kwargs['batch_size'] 35 | patch_size = kwargs['patch_size'] 36 | device = exp_args['device'] 37 | 38 | return_last_only = True 39 | 40 | if task == 'bit-memory': 41 | from universal_computation.datasets.bit_memory import BitMemoryDataset 42 | dataset = BitMemoryDataset(n=kwargs['n'], num_patterns=kwargs['num_patterns'], device=device) 43 | input_dim = kwargs['n'] if patch_size is None else patch_size 44 | output_dim = 2*kwargs['n'] if patch_size is None else 2 * patch_size 45 | use_embeddings = False 46 | experiment_type = 'classification' 47 | 48 | elif task == 'bit-xor': 49 | from universal_computation.datasets.bit_xor import BitXORDataset 50 | dataset = BitXORDataset(n=kwargs['n'], num_patterns=kwargs['num_patterns'], device=device) 51 | input_dim = kwargs['n'] if patch_size is None else patch_size 52 | output_dim = 2 * kwargs['n'] if patch_size is None else 2 * patch_size 53 | use_embeddings = False 54 | experiment_type = 'classification' 55 | 56 | elif task == 'listops': 57 | from universal_computation.datasets.listops import ListopsDataset 58 | dataset = ListopsDataset(batch_size=batch_size, device=device) 59 | input_dim, output_dim = 15, 10 60 | use_embeddings = True 61 | experiment_type = 'classification' 62 | 63 | elif task == 'mnist': 64 | from universal_computation.datasets.mnist import MNISTDataset 65 | dataset = MNISTDataset(batch_size=batch_size, patch_size=patch_size, device=device) 66 | input_dim, output_dim = patch_size ** 2, 10 67 | use_embeddings = False 68 | experiment_type = 'classification' 69 | 70 | elif task == 'cifar10': 71 | from universal_computation.datasets.cifar10 import CIFAR10Dataset 72 | dataset = CIFAR10Dataset(batch_size=batch_size, patch_size=patch_size, device=device) 73 | input_dim, output_dim = 3 * patch_size**2, 10 74 | use_embeddings = False 75 | experiment_type = 'classification' 76 | 77 | elif task == 'cifar10-gray': 78 | from universal_computation.datasets.cifar10_gray import CIFAR10GrayDataset 79 | dataset = CIFAR10GrayDataset(batch_size=batch_size, patch_size=patch_size, device=device) 80 | input_dim, output_dim = patch_size**2, 10 81 | use_embeddings = False 82 | experiment_type = 'classification' 83 | 84 | elif task == 'remote-homology': 85 | from universal_computation.datasets.remote_homology import RemoteHomologyDataset 86 | dataset = RemoteHomologyDataset(train_batch_size=batch_size, test_batch_size=4*batch_size, device=device) 87 | input_dim, output_dim = 30, 1200 88 | use_embeddings = True 89 | experiment_type = 'classification' 90 | 91 | else: 92 | raise NotImplementedError('dataset not implemented') 93 | 94 | if 'bit' in task: 95 | 96 | ce_loss = torch.nn.CrossEntropyLoss() 97 | 98 | def loss_fn(out, y, x=None): 99 | out = torch.reshape(out, (-1, kwargs['n'], 2)) 100 | ids = torch.zeros(y.shape).to(device=y.device).long() 101 | if task == 'bit-memory': 102 | ids[y < 0], ids[y > 0] = 0, 1 103 | else: 104 | ids[y < 0.5], ids[y > 0.5] = 0, 1 105 | out, ids = torch.reshape(out, (-1, 2)), torch.reshape(ids, (-1,)) 106 | return ce_loss(out, ids) 107 | 108 | def accuracy_fn(preds, true, x=None): 109 | if task == 'bit-memory': 110 | preds = preds.reshape(-1, kwargs['n'], 2).argmax(-1) * 2 - 1 111 | else: 112 | preds = preds.reshape(-1, kwargs['n'], 2).argmax(-1) 113 | if task == 'bit-memory': 114 | return (np.sign(preds) == np.sign(true)).mean() 115 | else: 116 | return ((preds > 0.5) == (true > 0.5)).mean() 117 | 118 | elif experiment_type == 'classification': 119 | 120 | ce_loss = torch.nn.CrossEntropyLoss() 121 | 122 | def loss_fn(out, y, x=None): 123 | out = out[:, 0] 124 | return ce_loss(out, y) 125 | 126 | def accuracy_fn(preds, true, x=None): 127 | preds = preds[:, 0].argmax(-1) 128 | return (preds == true).mean() 129 | 130 | else: 131 | raise NotImplementedError('experiment_type not recognized') 132 | 133 | model = FPT( 134 | input_dim=input_dim, 135 | output_dim=output_dim, 136 | model_name=kwargs.get('model_name', 'gpt2'), 137 | pretrained=kwargs.get('pretrained', True), 138 | return_last_only=return_last_only, 139 | use_embeddings_for_in=use_embeddings, 140 | in_layer_sizes=kwargs.get('in_layer_sizes', None), 141 | out_layer_sizes=kwargs.get('out_layer_sizes', None), 142 | freeze_trans=kwargs.get('freeze_trans', True), 143 | freeze_in=kwargs.get('freeze_in', False), 144 | freeze_pos=kwargs.get('freeze_pos', False), 145 | freeze_ln=kwargs.get('freeze_ln', False), 146 | freeze_attn=kwargs.get('freeze_attn', True), 147 | freeze_ff=kwargs.get('freeze_ff', True), 148 | freeze_out=kwargs.get('freeze_out', False), 149 | dropout=kwargs['dropout'], 150 | orth_gain=kwargs['orth_gain'], 151 | ) 152 | model.to(device) 153 | 154 | gpu_batch_size = exp_args['gpu_batch_size'] 155 | trainer = Trainer( 156 | model, 157 | dataset, 158 | loss_fn=loss_fn, 159 | accuracy_fn=accuracy_fn, 160 | steps_per_epoch=exp_args['steps_per_iter'], 161 | test_steps_per_epoch=exp_args['test_steps_per_iter'], 162 | learning_rate=kwargs['learning_rate'], 163 | batch_size=gpu_batch_size if batch_size > gpu_batch_size else batch_size, 164 | eval_batch_size=batch_size, 165 | grad_accumulate=batch_size // gpu_batch_size if batch_size > gpu_batch_size else 1, 166 | ) 167 | 168 | """ 169 | Set up logging 170 | """ 171 | 172 | log_to_wandb = exp_args['log_to_wandb'] 173 | save_models = exp_args['save_models'] 174 | wandb_project = exp_args['wandb_project'] 175 | 176 | short_name = str(random.randint(int(1e5), int(1e6) - 1)) 177 | run_name = f'{exp_name}-{task}-{short_name}' 178 | 179 | if log_to_wandb: 180 | config = dict( 181 | short_name=short_name, 182 | run_name=run_name, 183 | **exp_args, 184 | **kwargs, 185 | ) 186 | wandb.init( 187 | name=f'{exp_name}-{short_name}', 188 | group=f'{exp_name}-{task}', 189 | project=wandb_project, 190 | config=config, 191 | ) 192 | wandb.watch(model) 193 | 194 | for t in range(exp_args['num_iters']): 195 | trainer.train_epoch() 196 | 197 | print('=' * 57) 198 | print(f'| Iteration {" " * 15} | {t+1:25} |') 199 | for k, v in trainer.diagnostics.items(): 200 | print(f'| {k:25} | {v:25} |') 201 | 202 | if log_to_wandb: 203 | wandb.log(trainer.diagnostics) 204 | 205 | if save_models and ((t+1) % exp_args['save_models_every'] == 0 or 206 | (t+1) == exp_args['num_iters']): 207 | with open(f'models/{run_name}.pt', 'wb') as f: 208 | state_dict = dict(model=model.state_dict(), optim=trainer.optim.state_dict()) 209 | torch.save(state_dict, f) 210 | print(f'Saved model at {t+1} iters: {run_name}') 211 | 212 | 213 | def run_experiment( 214 | exp_name, 215 | experiment_params, 216 | ): 217 | parser = argparse.ArgumentParser() 218 | 219 | parser.add_argument('--num_iters', '-it', type=int, default=10000, 220 | help='Number of iterations for trainer') 221 | parser.add_argument('--steps_per_iter', type=int, default=100, 222 | help='Number of gradient steps per iteration') 223 | parser.add_argument('--test_steps_per_iter', type=int, default=25, 224 | help='Number of test gradient steps per iteration') 225 | 226 | parser.add_argument('--log_to_wandb', '-w', type=bool, default=False, 227 | help='Whether or not to log to Weights and Biases') 228 | parser.add_argument('--note', '-n', type=str, default='', 229 | help='An optional note to be logged to W&B') 230 | parser.add_argument('--wandb_project', type=str, default='my_project', 231 | help='Project name for W&B') 232 | parser.add_argument('--include_date', type=bool, default=True, 233 | help='Whether to include date in run name') 234 | 235 | parser.add_argument('--save_models', '-s', type=bool, default=False, 236 | help='Whether or not to save the model files locally') 237 | parser.add_argument('--save_models_every', '-int', type=int, default=25, 238 | help='How often to save models locally') 239 | 240 | parser.add_argument('--device', '-d', type=str, default='cuda', 241 | help='Which device for Pytorch to use') 242 | parser.add_argument('--gpu_batch_size', '-gbs', type=int, default=16, 243 | help='Max batch size to put on GPU (used for gradient accumulation)') 244 | 245 | exp_args = parser.parse_args(sys.argv[1:]) 246 | 247 | if exp_args.include_date: 248 | timestamp = datetime.now().strftime('%m-%d') 249 | exp_name = f'{timestamp}-{exp_name}' 250 | 251 | experiment_params['exp_name'] = exp_name 252 | experiment_params['exp_args'] = vars(exp_args) 253 | 254 | experiment(xp_name=exp_name, **experiment_params) 255 | -------------------------------------------------------------------------------- /universal_computation/fpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FPT(nn.Module): 6 | 7 | def __init__( 8 | self, 9 | input_dim, 10 | output_dim, 11 | model_name='gpt2', 12 | pretrained=False, 13 | return_last_only=True, 14 | use_embeddings_for_in=False, 15 | in_layer_sizes=None, 16 | out_layer_sizes=None, 17 | freeze_trans=True, 18 | freeze_in=False, 19 | freeze_pos=False, 20 | freeze_ln=False, 21 | freeze_attn=True, 22 | freeze_ff=True, 23 | freeze_out=False, 24 | dropout=0.1, 25 | orth_gain=1.41, 26 | ): 27 | super().__init__() 28 | 29 | self.input_dim = input_dim 30 | self.output_dim = output_dim 31 | self.model_name = model_name 32 | self.return_last_only = return_last_only 33 | self.use_embeddings_for_in = use_embeddings_for_in 34 | 35 | self.in_layer_sizes = [] if in_layer_sizes is None else in_layer_sizes 36 | self.out_layer_sizes = [] if out_layer_sizes is None else out_layer_sizes 37 | self.dropout = dropout 38 | 39 | if 'gpt' in model_name: 40 | assert model_name in ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'] 41 | 42 | from transformers import GPT2Model 43 | 44 | pretrained_transformer = GPT2Model.from_pretrained(model_name) 45 | if pretrained: 46 | self.sequence_model = pretrained_transformer 47 | else: 48 | self.sequence_model = GPT2Model(pretrained_transformer.config) 49 | 50 | if model_name == 'gpt2': 51 | embedding_size = 768 52 | elif model_name == 'gpt2-medium': 53 | embedding_size = 1024 54 | elif model_name == 'gpt2-large': 55 | embedding_size = 1280 56 | elif model_name == 'gpt2-xl': 57 | embedding_size = 1600 58 | 59 | elif model_name == 'vit': 60 | 61 | import timm 62 | 63 | self.sequence_model = timm.create_model( 64 | 'vit_base_patch16_224', pretrained=pretrained, drop_rate=dropout, attn_drop_rate=dropout, 65 | ) 66 | embedding_size = 768 67 | 68 | self.vit_pos_embed = nn.Parameter(torch.zeros(1, 1024, embedding_size)) 69 | if freeze_pos: 70 | self.vit_pos_embed.requires_grad = False 71 | 72 | elif model_name == 'lstm': 73 | 74 | from universal_computation.models.lstm import LNLSTM 75 | 76 | num_layers, embedding_size = 3, 768 77 | 78 | self.sequence_model = LNLSTM( 79 | input_size=embedding_size, 80 | hidden_size=embedding_size, 81 | num_layers=num_layers, 82 | batch_first=True, 83 | residual=False, 84 | dropout=dropout, 85 | bidirectional=0, 86 | ) 87 | 88 | # optionally: 89 | # self.lstm_pos_embed = nn.Parameter(torch.zeros(1, 1024, embedding_size)) 90 | # if freeze_pos: 91 | # self.lstm_pos_embed.requires_grad = False 92 | 93 | else: 94 | raise NotImplementedError('model_name not implemented') 95 | 96 | if use_embeddings_for_in: 97 | self.in_net = nn.Embedding(input_dim, embedding_size) 98 | else: 99 | in_layers = [] 100 | last_output_size = input_dim 101 | for size in self.in_layer_sizes: 102 | layer = nn.Linear(last_output_size, size) 103 | if orth_gain is not None: 104 | torch.nn.init.orthogonal_(layer.weight, gain=orth_gain) 105 | layer.bias.data.zero_() 106 | 107 | in_layers.append(layer) 108 | in_layers.append(nn.ReLU()) 109 | in_layers.append(nn.Dropout(dropout)) 110 | last_output_size = size 111 | 112 | final_linear = nn.Linear(last_output_size, embedding_size) 113 | if orth_gain is not None: 114 | torch.nn.init.orthogonal_(final_linear.weight, gain=orth_gain) 115 | final_linear.bias.data.zero_() 116 | 117 | in_layers.append(final_linear) 118 | in_layers.append(nn.Dropout(dropout)) 119 | 120 | self.in_net = nn.Sequential(*in_layers) 121 | 122 | out_layers = [] 123 | last_output_size = embedding_size 124 | for size in self.out_layer_sizes: 125 | out_layers.append(nn.Linear(last_output_size, size)) 126 | out_layers.append(nn.ReLU()) 127 | out_layers.append(nn.Dropout(dropout)) 128 | last_output_size = size 129 | out_layers.append(nn.Linear(last_output_size, output_dim)) 130 | self.out_net = nn.Sequential(*out_layers) 131 | 132 | if freeze_trans: 133 | for name, p in self.sequence_model.named_parameters(): 134 | name = name.lower() 135 | if 'ln' in name or 'norm' in name: 136 | p.requires_grad = not freeze_ln 137 | elif 'wpe' in name or 'position_embeddings' in name or 'pos_drop' in name: 138 | p.requires_grad = not freeze_pos 139 | elif 'mlp' in name: 140 | p.requires_grad = not freeze_ff 141 | elif 'attn' in name: 142 | p.requires_grad = not freeze_attn 143 | else: 144 | p.requires_grad = False 145 | if freeze_in: 146 | for p in self.in_net.parameters(): 147 | p.requires_grad = False 148 | if freeze_out: 149 | for p in self.out_net.parameters(): 150 | p.requires_grad = False 151 | 152 | def forward(self, x): 153 | 154 | # reshape x (batch_size, seq_len, dim) into patches (batch_size, seq_len*num_patches, patch_dim) 155 | orig_dim = x.shape[-1] 156 | if orig_dim != self.input_dim and not self.use_embeddings_for_in: 157 | if orig_dim % self.input_dim != 0: 158 | raise ValueError('dimension of x must be divisible by patch size') 159 | ratio = orig_dim // self.input_dim 160 | x = x.reshape(x.shape[0], x.shape[1] * ratio, self.input_dim) 161 | else: 162 | ratio = 1 163 | 164 | x = self.in_net(x) 165 | 166 | # ignore input layer that comes with model and use our own embeddings 167 | if self.model_name == 'vit': 168 | x = x + self.vit_pos_embed[:, :x.shape[1]] 169 | x = self.sequence_model.pos_drop(x) 170 | for blk in self.sequence_model.blocks: 171 | x = blk(x) 172 | x = self.sequence_model.norm(x) 173 | elif self.model_name == 'lstm': 174 | # x = x + self.lstm_pos_embed[:, :x.shape[1]] 175 | x, *_ = self.sequence_model(x) 176 | else: 177 | transformer_outputs = self.sequence_model( 178 | inputs_embeds=x, 179 | return_dict=True, 180 | ) 181 | x = transformer_outputs.last_hidden_state 182 | 183 | # take final hidden state of tokens corresponding to last patch 184 | if self.return_last_only: 185 | x = x[:,-ratio:] 186 | 187 | # single linear layer applied to last hidden state 188 | x = self.out_net(x) 189 | 190 | # if we did patch resizing above, return in the original shape (batch_size, seq_len, dim) 191 | if self.return_last_only and ratio > 1: 192 | x = x.reshape(x.shape[0], x.shape[1] // ratio, ratio * self.output_dim) 193 | 194 | return x 195 | -------------------------------------------------------------------------------- /universal_computation/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kzl/universal-computation/23b80fd1ba3caee493f5fa8715fb70a3ff8a250e/universal_computation/models/__init__.py -------------------------------------------------------------------------------- /universal_computation/models/lstm.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2018 Alex 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | Code modified from Github repo: https://github.com/exe1023/LSTM_LN 25 | """ 26 | 27 | import math 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | from torch import Tensor 32 | from torch.nn import Parameter 33 | from torch.autograd import Variable 34 | 35 | 36 | use_cuda = torch.cuda.is_available() 37 | 38 | 39 | class LNLSTM(nn.Module): 40 | def __init__(self, 41 | input_size, 42 | hidden_size, 43 | num_layers=1, 44 | dropout=0., 45 | bidirectional=1, 46 | batch_first=False, 47 | residual=False, 48 | cln=False): 49 | super(LNLSTM, self).__init__() 50 | self.input_size = input_size 51 | self.hidden_size = hidden_size 52 | self.num_layers = num_layers 53 | self.direction = bidirectional + 1 54 | self.batch_first = batch_first 55 | self.residual = residual 56 | 57 | layers = [] 58 | for i in range(num_layers): 59 | for j in range(self.direction): 60 | layer = LayerNormLSTM(input_size*self.direction, 61 | hidden_size, 62 | dropout=dropout, 63 | cln=cln) 64 | layers.append(layer) 65 | input_size = hidden_size 66 | self.layers = layers 67 | self.params = nn.ModuleList(layers) 68 | 69 | def reset_parameters(self): 70 | for l in self.layers: 71 | l.reset_parameters() 72 | 73 | def init_hidden(self, batch_size): 74 | # Uses Xavier init here. 75 | hiddens = [] 76 | for l in self.layers: 77 | std = math.sqrt(2.0 / (l.input_size + l.hidden_size)) 78 | h = Variable(Tensor(1, batch_size, l.hidden_size).normal_(0, std)) 79 | c = Variable(Tensor(1, batch_size, l.hidden_size).normal_(0, std)) 80 | if use_cuda: 81 | hiddens.append((h.cuda(), c.cuda())) 82 | else: 83 | hiddens.append((h, c)) 84 | return hiddens 85 | 86 | def layer_forward(self, l, xs, h, image_emb, reverse=False): 87 | ''' 88 | return: 89 | xs: (seq_len, batch, hidden) 90 | h: (1, batch, hidden) 91 | ''' 92 | if self.batch_first: 93 | xs = xs.permute(1, 0, 2).contiguous() 94 | ys = [] 95 | for i in range(xs.size(0)): 96 | if reverse: 97 | x = xs.narrow(0, (xs.size(0)-1)-i, 1) 98 | else: 99 | x = xs.narrow(0, i, 1) 100 | y, h = l(x, h, image_emb) 101 | ys.append(y) 102 | y = torch.cat(ys, 0) 103 | if self.batch_first: 104 | y = y.permute(1, 0, 2) 105 | return y, h 106 | 107 | def forward(self, x, hiddens=None, image_emb=None): 108 | if hiddens is None: 109 | hiddens = self.init_hidden(x.shape[0]) 110 | if self.direction > 1: 111 | x = torch.cat((x, x), 2) 112 | if type(hiddens) != list: 113 | # when the hidden feed is (direction * num_layer, batch, hidden) 114 | tmp = [] 115 | for idx in range(hiddens[0].size(0)): 116 | tmp.append((hiddens[0].narrow(0, idx, 1), 117 | (hiddens[1].narrow(0, idx, 1)))) 118 | hiddens = tmp 119 | 120 | new_hs = [] 121 | new_cs = [] 122 | for l_idx in range(0, len(self.layers), self.direction): 123 | l, h = self.layers[l_idx], hiddens[l_idx] 124 | f_x, f_h = self.layer_forward(l, x, h, image_emb) 125 | if self.direction > 1: 126 | l, h = self.layers[l_idx+1], hiddens[l_idx+1] 127 | r_x, r_h = self.layer_forward(l, x, h, image_emb, reverse=True) 128 | 129 | x = torch.cat((f_x, r_x), 2) 130 | h = torch.cat((f_h[0], r_h[0]), 0) 131 | c = torch.cat((f_h[1], r_h[1]), 0) 132 | else: 133 | if self.residual: 134 | x = x + f_x 135 | else: 136 | x = f_x 137 | h, c = f_h 138 | new_hs.append(h) 139 | new_cs.append(c) 140 | 141 | h = torch.cat(new_hs, 0) 142 | c = torch.cat(new_cs, 0) 143 | 144 | return x, (h, c) 145 | 146 | 147 | class CLN(nn.Module): 148 | """ 149 | Conditioned Layer Normalization 150 | """ 151 | def __init__(self, input_size, image_size, epsilon=1e-6): 152 | super(CLN, self).__init__() 153 | self.input_size = input_size 154 | self.image_size = image_size 155 | self.alpha = Tensor(1, input_size).fill_(1) 156 | self.beta = Tensor(1, input_size).fill_(0) 157 | self.epsilon = epsilon 158 | 159 | self.alpha = Parameter(self.alpha) 160 | self.beta = Parameter(self.beta) 161 | 162 | # MLP used to predict delta of alpha, beta 163 | self.fc_alpha = nn.Linear(self.image_size, self.input_size) 164 | self.fc_beta = nn.Linear(self.image_size, self.input_size) 165 | 166 | self.reset_parameters() 167 | 168 | def reset_parameters(self): 169 | std = 1.0 / math.sqrt(self.input_size) 170 | for w in self.parameters(): 171 | w.data.uniform_(-std, std) 172 | 173 | def create_cln_input(self, image_emb): 174 | delta_alpha = self.fc_alpha(image_emb) 175 | delta_beta = self.fc_beta(image_emb) 176 | return delta_alpha, delta_beta 177 | 178 | def forward(self, x, image_emb): 179 | if image_emb is None: 180 | return x 181 | # x: (batch, input_size) 182 | size = x.size() 183 | x = x.view(x.size(0), -1) 184 | x = (x - torch.mean(x, 1).unsqueeze(1).expand_as(x)) / torch.sqrt(torch.var(x, 1).unsqueeze(1).expand_as(x) + self.epsilon) 185 | 186 | delta_alpha, delta_beta = self.create_cln_input(image_emb) 187 | alpha = self.alpha.expand_as(x) + delta_alpha 188 | beta = self.beta.expand_as(x) + delta_beta 189 | x = alpha * x + beta 190 | return x.view(size) 191 | 192 | 193 | class LayerNorm(nn.Module): 194 | """ 195 | Layer Normalization based on Ba & al.: 196 | 'Layer Normalization' 197 | https://arxiv.org/pdf/1607.06450.pdf 198 | """ 199 | 200 | def __init__(self, input_size, learnable=True, epsilon=1e-6): 201 | super(LayerNorm, self).__init__() 202 | self.input_size = input_size 203 | self.learnable = learnable 204 | self.alpha = Tensor(1, input_size).fill_(1) 205 | self.beta = Tensor(1, input_size).fill_(0) 206 | self.epsilon = epsilon 207 | # Wrap as parameters if necessary 208 | if learnable: 209 | W = Parameter 210 | else: 211 | W = Variable 212 | self.alpha = W(self.alpha) 213 | self.beta = W(self.beta) 214 | self.reset_parameters() 215 | 216 | def reset_parameters(self): 217 | std = 1.0 / math.sqrt(self.input_size) 218 | for w in self.parameters(): 219 | w.data.uniform_(-std, std) 220 | 221 | def forward(self, x): 222 | size = x.size() 223 | x = x.view(x.size(0), -1) 224 | x = (x - torch.mean(x, 1).unsqueeze(1).expand_as(x)) / torch.sqrt(torch.var(x, 1).unsqueeze(1).expand_as(x) + self.epsilon) 225 | if self.learnable: 226 | x = self.alpha.expand_as(x) * x + self.beta.expand_as(x) 227 | return x.view(size) 228 | 229 | 230 | class LSTMcell(nn.Module): 231 | 232 | """ 233 | An implementation of Hochreiter & Schmidhuber: 234 | 'Long-Short Term Memory' 235 | http://www.bioinf.jku.at/publications/older/2604.pdf 236 | Special args: 237 | dropout_method: one of 238 | * pytorch: default dropout implementation 239 | * gal: uses GalLSTM's dropout 240 | * moon: uses MoonLSTM's dropout 241 | * semeniuta: uses SemeniutaLSTM's dropout 242 | """ 243 | 244 | def __init__(self, input_size, hidden_size, bias=True, dropout=0.0, dropout_method='pytorch'): 245 | super(LSTMcell, self).__init__() 246 | self.input_size = input_size 247 | self.hidden_size = hidden_size 248 | self.bias = bias 249 | self.dropout = dropout 250 | self.i2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) 251 | self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias) 252 | self.reset_parameters() 253 | assert(dropout_method.lower() in ['pytorch', 'gal', 'moon', 'semeniuta']) 254 | self.dropout_method = dropout_method 255 | 256 | def sample_mask(self): 257 | keep = 1.0 - self.dropout 258 | self.mask = Variable(torch.bernoulli(Tensor(1, self.hidden_size).fill_(keep))) 259 | 260 | def reset_parameters(self): 261 | std = 1.0 / math.sqrt(self.hidden_size) 262 | for w in self.parameters(): 263 | w.data.uniform_(-std, std) 264 | 265 | def forward(self, x, hidden): 266 | do_dropout = self.training and self.dropout > 0.0 267 | h, c = hidden 268 | h = h.view(h.size(1), -1) 269 | c = c.view(c.size(1), -1) 270 | x = x.view(x.size(1), -1) 271 | 272 | # Linear mappings 273 | preact = self.i2h(x) + self.h2h(h) 274 | 275 | # activations 276 | gates = preact[:, :3 * self.hidden_size].sigmoid() 277 | g_t = preact[:, 3 * self.hidden_size:].tanh() 278 | i_t = gates[:, :self.hidden_size] 279 | f_t = gates[:, self.hidden_size:2 * self.hidden_size] 280 | o_t = gates[:, -self.hidden_size:] 281 | 282 | # cell computations 283 | if do_dropout and self.dropout_method == 'semeniuta': 284 | g_t = F.dropout(g_t, p=self.dropout, training=self.training) 285 | 286 | c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t) 287 | 288 | if do_dropout and self.dropout_method == 'moon': 289 | c_t.data.set_(torch.mul(c_t, self.mask).data) 290 | c_t.data *= 1.0/(1.0 - self.dropout) 291 | 292 | h_t = torch.mul(o_t, c_t.tanh()) 293 | 294 | # Reshape for compatibility 295 | if do_dropout: 296 | if self.dropout_method == 'pytorch': 297 | F.dropout(h_t, p=self.dropout, training=self.training, inplace=True) 298 | if self.dropout_method == 'gal': 299 | h_t.data.set_(th.mul(h_t, self.mask).data) 300 | h_t.data *= 1.0/(1.0 - self.dropout) 301 | 302 | h_t = h_t.view(1, h_t.size(0), -1) 303 | c_t = c_t.view(1, c_t.size(0), -1) 304 | return h_t, (h_t, c_t) 305 | 306 | 307 | class LayerNormLSTM(LSTMcell): 308 | 309 | """ 310 | Layer Normalization LSTM, based on Ba & al.: 311 | 'Layer Normalization' 312 | https://arxiv.org/pdf/1607.06450.pdf 313 | Special args: 314 | ln_preact: whether to Layer Normalize the pre-activations. 315 | learnable: whether the LN alpha & gamma should be used. 316 | """ 317 | 318 | def __init__(self, 319 | input_size, 320 | hidden_size, 321 | bias=True, 322 | dropout=0.0, 323 | dropout_method='pytorch', 324 | ln_preact=True, 325 | learnable=True, 326 | cln=True): 327 | super(LayerNormLSTM, self).__init__(input_size=input_size, 328 | hidden_size=hidden_size, 329 | bias=bias, 330 | dropout=dropout, 331 | dropout_method=dropout_method) 332 | self.cln = cln 333 | if ln_preact: 334 | if self.cln: 335 | self.ln_i2h = CLN(4*hidden_size, 1024) 336 | self.ln_h2h = CLN(4*hidden_size, 1024) 337 | else: 338 | self.ln_h2h = LayerNorm(4*hidden_size, learnable=learnable) 339 | self.ln_i2h = LayerNorm(4*hidden_size, learnable=learnable) 340 | self.ln_preact = ln_preact 341 | if self.cln: 342 | self.ln_cell = CLN(hidden_size, 1024) 343 | else: 344 | self.ln_cell = LayerNorm(hidden_size, learnable=learnable) 345 | 346 | def forward(self, x, hidden, image_emb=None): 347 | do_dropout = self.training and self.dropout > 0.0 348 | h, c = hidden 349 | h = h.view(h.size(1), -1) 350 | c = c.view(c.size(1), -1) 351 | x = x.view(x.size(1), -1) 352 | 353 | # Linear mappings 354 | i2h = self.i2h(x) 355 | h2h = self.h2h(h) 356 | if self.ln_preact: 357 | if self.cln: 358 | i2h = self.ln_i2h(i2h, image_emb) 359 | h2h = self.ln_h2h(h2h, image_emb) 360 | else: 361 | i2h = self.ln_i2h(i2h) 362 | h2h = self.ln_h2h(h2h) 363 | preact = i2h + h2h 364 | 365 | # activations 366 | gates = preact[:, :3 * self.hidden_size].sigmoid() 367 | g_t = preact[:, 3 * self.hidden_size:].tanh() 368 | i_t = gates[:, :self.hidden_size] 369 | f_t = gates[:, self.hidden_size:2 * self.hidden_size] 370 | o_t = gates[:, -self.hidden_size:] 371 | 372 | # cell computations 373 | if do_dropout and self.dropout_method == 'semeniuta': 374 | g_t = F.dropout(g_t, p=self.dropout, training=self.training) 375 | 376 | c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t) 377 | 378 | if do_dropout and self.dropout_method == 'moon': 379 | c_t.data.set_(torch.mul(c_t, self.mask).data) 380 | c_t.data *= 1.0/(1.0 - self.dropout) 381 | 382 | if self.cln: 383 | c_t = self.ln_cell(c_t, image_emb) 384 | else: 385 | c_t = self.ln_cell(c_t) 386 | h_t = torch.mul(o_t, c_t.tanh()) 387 | 388 | # Reshape for compatibility 389 | if do_dropout: 390 | if self.dropout_method == 'pytorch': 391 | F.dropout(h_t, p=self.dropout, training=self.training, inplace=True) 392 | if self.dropout_method == 'gal': 393 | h_t.data.set_(torch.mul(h_t, self.mask).data) 394 | h_t.data *= 1.0/(1.0 - self.dropout) 395 | 396 | h_t = h_t.view(1, h_t.size(0), -1) 397 | c_t = c_t.view(1, c_t.size(0), -1) 398 | return h_t, (h_t, c_t) 399 | -------------------------------------------------------------------------------- /universal_computation/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | import time 5 | 6 | 7 | class Trainer: 8 | 9 | def __init__( 10 | self, 11 | model, 12 | dataset, 13 | loss_fn, 14 | accuracy_fn=None, 15 | steps_per_epoch=100, 16 | test_steps_per_epoch=20, 17 | learning_rate=1e-3, 18 | batch_size=2, 19 | eval_batch_size=8, 20 | grad_accumulate=1, 21 | ): 22 | self.model = model 23 | self.dataset = dataset 24 | self.loss_fn = loss_fn 25 | self.acc_fn = accuracy_fn 26 | self.steps_per_epoch = steps_per_epoch 27 | self.test_steps_per_epoch = test_steps_per_epoch 28 | self.batch_size = batch_size 29 | self.eval_batch_size = eval_batch_size 30 | self.grad_accumulate = grad_accumulate 31 | 32 | self.optim = torch.optim.Adam(model.parameters(), lr=learning_rate) 33 | 34 | self.diagnostics = {'Gradient Steps': 0} 35 | 36 | def get_loss(self, x, y, return_acc=False): 37 | out = self.model(x) 38 | loss = self.loss_fn(out, y, x=x) 39 | if return_acc: 40 | if self.acc_fn is None: 41 | raise NotImplementedError('accuracy function not specified') 42 | accs = self.acc_fn( 43 | out.detach().cpu().numpy(), 44 | y.detach().cpu().numpy(), 45 | x=x.detach().cpu().numpy(), 46 | ) 47 | return loss, accs 48 | return loss 49 | 50 | def train_epoch(self, test_steps=None): 51 | self.dataset.start_epoch() 52 | 53 | train_losses, tr_accuracy = [], 0. 54 | self.model.train() 55 | start_train_time = time.time() 56 | for _ in tqdm(range(self.steps_per_epoch)): 57 | step_loss = 0 58 | for _ in range(self.grad_accumulate): 59 | x, y = self.dataset.get_batch(self.batch_size, train=True) 60 | loss, acc = self.get_loss(x, y, return_acc=True) 61 | loss = loss / self.grad_accumulate 62 | loss.backward() 63 | step_loss += loss.detach().cpu().item() 64 | tr_accuracy += acc 65 | 66 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.) 67 | self.optim.step() 68 | self.optim.zero_grad() 69 | 70 | self.diagnostics['Gradient Steps'] += 1 71 | 72 | train_losses.append(step_loss) 73 | end_train_time = time.time() 74 | 75 | test_steps = self.test_steps_per_epoch if test_steps is None else test_steps 76 | 77 | test_loss, accuracy = 0., 0. 78 | self.model.eval() 79 | start_test_time = time.time() 80 | with torch.no_grad(): 81 | for _ in range(test_steps): 82 | x, y = self.dataset.get_batch(self.eval_batch_size, train=False) 83 | loss, acc = self.get_loss(x, y, return_acc=True) 84 | test_loss += loss.detach().cpu().item() / test_steps 85 | accuracy += acc / test_steps 86 | end_test_time = time.time() 87 | 88 | self.diagnostics['Average Train Loss'] = sum(train_losses) / self.steps_per_epoch 89 | self.diagnostics['Start Train Loss'] = train_losses[0] 90 | self.diagnostics['Final Train Loss'] = train_losses[-1] 91 | self.diagnostics['Test Loss'] = test_loss 92 | self.diagnostics['Test Accuracy'] = accuracy 93 | self.diagnostics['Train Accuracy'] = tr_accuracy / (self.steps_per_epoch * self.grad_accumulate) 94 | self.diagnostics['Time Training'] = end_train_time - start_train_time 95 | self.diagnostics['Time Testing'] = end_test_time - start_test_time 96 | --------------------------------------------------------------------------------