├── J_Backwards_1_0.ipynb ├── backwards.py ├── generate.py └── helpers.py /J_Backwards_1_0.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "id": "ar1aNn7TxI8r", 18 | "colab": { 19 | "base_uri": "https://localhost:8080/" 20 | }, 21 | "outputId": "86d6a73c-7a75-42e6-ac3c-bd4f5fea795e" 22 | }, 23 | "outputs": [ 24 | { 25 | "output_type": "stream", 26 | "name": "stdout", 27 | "text": [ 28 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", 29 | "Requirement already satisfied: update in /usr/local/lib/python3.8/dist-packages (0.0.1)\n", 30 | "Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (4.25.1)\n", 31 | "Requirement already satisfied: style==1.1.0 in /usr/local/lib/python3.8/dist-packages (from update) (1.1.0)\n", 32 | "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers) (4.64.1)\n", 33 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2)\n", 34 | "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (6.0)\n", 35 | "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.8.2)\n", 36 | "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (21.3)\n", 37 | "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (1.21.6)\n", 38 | "Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.11.1)\n", 39 | "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.13.2)\n", 40 | "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers) (2.25.1)\n", 41 | "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.10.0->transformers) (4.4.0)\n", 42 | "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->transformers) (3.0.9)\n", 43 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (1.24.3)\n", 44 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2.10)\n", 45 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2022.12.7)\n", 46 | "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (4.0.0)\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "!pip install update transformers\n", 52 | "from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils\n", 53 | "import torch\n", 54 | "from matplotlib import pyplot as plt\n", 55 | "%matplotlib inline\n", 56 | "from IPython import display\n", 57 | "import numpy as np\n", 58 | "from tqdm import tqdm\n", 59 | "utils.logging.set_verbosity_error()\n", 60 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 61 | "\n", 62 | "vocab_len= 50257\n", 63 | "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\", padding_side='left')\n", 64 | "model = GPT2LMHeadModel.from_pretrained(\"gpt2\",pad_token_id=tokenizer.eos_token_id, vocab_size=vocab_len).to(device)\n", 65 | "model.eval()\n", 66 | "# the model will be in evaluation, not training, mode throughout\n", 67 | "word_embeddings = model.transformer.wte.weight.to(device) \n", 68 | "# 'word_embeddings' tensor gives emeddings for each token in the vocab for this model,\n", 69 | "# has shape (vocab_len, embedding_dimension) which in this case = (50257, 768)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": { 76 | "id": "qCgMUuO_33Wp" 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "def normalise(x, min_max=[]): \n", 81 | "# normalises values of (array or tensor) x according to first (min) and second (max) values in list min_max. \n", 82 | "# This effectively defaults to [0,1] if the list doesn't contain exactly two elements. \n", 83 | "# The original code threw an error if min_max had length 1, so it's been changed slightly.\n", 84 | "\n", 85 | "# First normalise x to [0,1]\n", 86 | " rnge = x.max() - x.min()\n", 87 | " if rnge > 0:\n", 88 | " x = (x - x.min())/rnge\n", 89 | "\n", 90 | "# Now, if there's a min and max given in min_max list, multiply by difference and add minimum\n", 91 | " if len(min_max) > 1:\n", 92 | " rnge = min_max[1] - min_max[0]\n", 93 | " x = x * rnge + min_max[0]\n", 94 | "\n", 95 | " return x\n", 96 | "\n", 97 | "\n", 98 | "def closest_tokens(emb, n=1): \n", 99 | "# This finds the n tokens in the vocabulary that are closest in the embedding space (in terms of Euclidean distance) to a given word embedding (‘emb’).\n", 100 | "# Note that here 'emb' may or may not correspond to a token (i.e., it may or may not be a 'legal' embedding).\n", 101 | "# Function returns a 4-tuple (list of the n tokens, list of their indices, list of their distances from emb, and list of their embedding vectors)\n", 102 | " torch.cuda.empty_cache()\n", 103 | " dists = torch.linalg.norm(word_embeddings - emb, dim=1)\n", 104 | " sorted_dists, ix = torch.sort(dists)\t \n", 105 | " # sorted_dists is a list of all embedding distances from 'emb', across entire vocab, sorted in increasing order, \n", 106 | " # ix is a list of their corresponding 'vocab indices'\n", 107 | " tokens = [tokenizer.decode(i) for i in ix[:n]]\n", 108 | " # For each of the first n 'vocab indices' in ix, we decode it into the string version of the corresponding token. \n", 109 | " # These strings then constitute the list 'tokens'.\n", 110 | " ixs = ix[:n]\n", 111 | " dists = sorted_dists[:n]\n", 112 | " embs = word_embeddings[ixs] # Each of these n 'embeddings' is a tensor of shape (768,)\n", 113 | " return tokens, ixs, dists, embs \n", 114 | "\n", 115 | "\n", 116 | "def model_emb(inputs_embeds, output_len):\n", 117 | "# 'input_embeds' is a tensor of shape (batch_size, input_len, embedding_dim)\n", 118 | "# 'output_len' is an integer specifying the number of output tokens to generate\n", 119 | "# Note that this function doesn't involve a target output. It simply takes a tensor of input embeddings (based on input length),\n", 120 | "# calculates perplexities for that batch of input sequences,\n", 121 | "# and runs the batch of input sequences through GPT2, for each finding next tokens iteratively 'output_len' number of times\n", 122 | " embs = inputs_embeds # This is going to get expanded using 'output_embs'\n", 123 | " logits = []\n", 124 | " ixs = []\n", 125 | " input_logits = None\n", 126 | " for i in range(output_len):\n", 127 | " model_out = model(inputs_embeds=embs, return_dict=True)\n", 128 | " # Does a forward pass of GPT2 (or whichever model) on a batch of inputs (given as a tensor 'embs' of embeddings).\n", 129 | " # This 'embs' will expand along its 1st dimension with each iteration.\n", 130 | " # Outputs logits and more (hidden states, attention, etc.) as a dictionary 'model_out'.\n", 131 | " # But we'll only be concerned with model_out.logits.\n", 132 | "\n", 133 | " if i == 0:\n", 134 | " input_logits = model_out.logits \n", 135 | " # On first pass through loop, we simply use the logits of the model output\n", 136 | " # That's a tensor of shape (batch_size, input_len, vocab_size) giving logits for each input in each batch.\n", 137 | " # Presumably for each input, this is conditioned on the inputs that preceded it?\n", 138 | "\n", 139 | " # On every pass throught the loop (including the first), we defined this tensor of shape (batch_size, 1, vocab_size):\n", 140 | " last_logits = model_out.logits[:,-1].unsqueeze(1) \n", 141 | " # model_out.logits[:,-1] will be a 2D tensor of shape (batch_size, vocab_size), just giving logits for last input/embedding across all batches/tokens\n", 142 | " # unsqueezing, we get tensor of shape (batch_size, 1, vocab_size) also giving logits of last input/embedding, differently formatted \n", 143 | " logits.append(last_logits) # appends last_logits tensor to the 'logits' list \n", 144 | " ix = torch.argmax(last_logits, dim=-1) # for each batch, finds the vocab index of the token with the largest logit in last_logits\n", 145 | " ixs.append(ix) # ...and appends this tensor of shape (batch_size,) (containing indices) it to the list 'ixs'\n", 146 | " output_embs = word_embeddings[ix] # for each batch, finds embedding for the token with that index...\n", 147 | " embs = torch.cat([embs, output_embs], dim=1) #...concatenates that tensor of embeddings to the 'embs' tensor in the first dimension before next iteration\n", 148 | "\n", 149 | " # When the loop is completed 'embs' will be a tensor containing all of the input and output word embeddings produced by the model \n", 150 | " # ...so presumably of shape (batch_size, input_len + output_len, embedding_dim)\n", 151 | "\n", 152 | " logits = torch.cat(logits, dim=1) # this converts logits from a list of tensors to a single tensor, by concatenating all of the tensors in the list\n", 153 | " # it will have shape (batch_size, output_len, vocab_size)\n", 154 | " perp = perplexity(input_logits) # 'input_logits' was calculated on first pass through loop where only input embeddings were involved\n", 155 | " return logits, embs, perp \n", 156 | " # logits has shape (batch_size, output_len, vocab_size), CHECK THAT!\n", 157 | " # embs has shape (batch_size, input_len + output_len, embedding_dim)\n", 158 | " # perp has shape (batch_size,)\n", 159 | "\n", 160 | "\n", 161 | "def perplexity(logits):\n", 162 | " # logits is of shape (batch_size, 'sequence length', vocab_size)\n", 163 | " # for all current calls, 'sequence length' is going to be input_len\n", 164 | " probs, ix = torch.max(torch.softmax(logits, dim=-1), dim=-1)\n", 165 | " # torch.softmax(logits, dim=-1) will also be a tensor of shape (batch_size, 'sequence length', vocab_size), \n", 166 | " # but where the logits in the last dimension get converted into probabilities via softmax. torch.max() then pull out the largest of these and its index\n", 167 | " # probs is a tensor that contains the maximum probability for each token in the embedding sequence, shape (batch_size, 'sequence length')\n", 168 | " # ix is a tensor that contains the corresponding indices, also with shape (batch_size, 'sequence length')\n", 169 | " perp = 1/ (torch.prod(probs, dim=-1)**(1/probs.shape[-1])) - 1\n", 170 | " # defines a scalar that's larger with greater uncertainty (so if the probs are small, their product is small, the reciprocal of some power is large)\n", 171 | " # probs.shape[-1] is output_len; the idea of raising the probs product to power 1/output_len is to make perplexities comparable across different output lengths\n", 172 | " return perp\n", 173 | "\n", 174 | "\n", 175 | "# Here's the key function that optimises for a sequence of input embeddings, given a target_output string:\n", 176 | "def optimise_input(epochs=100, \n", 177 | " lr=0.1, \n", 178 | " rand_after=False, # Do we re-initialise inputs tensor with random entries when an optimal input is found?\n", 179 | " w_freq=10, # logging (write) frequency\n", 180 | " base_input=None, # If none, start_inputs will be entirely random; \n", 181 | " # otherwise it will be built by stacking this tensor and then gently \"noising\" all but the first copies\n", 182 | " batch_size=1, \n", 183 | " input_len=1, \n", 184 | " target_output=tokenizer.eos_token, # Default target output is the \"end-of-string\" token; this won't generally be used\n", 185 | " output_len=None,\n", 186 | " dist_reg=1, # distance regularisation coefficient\n", 187 | " perp_reg=0, # perplexity regularisation coefficient; setting to 0 means perplexity loss isn't a thing\n", 188 | " plt_loss=False, # Do we plot loss?\n", 189 | " loss_type='log_prob_loss', \n", 190 | " seed=0,\n", 191 | " return_early=True, # finishes if single optimised input is found\n", 192 | " verbose=0, # Controls how much info gets logged.\n", 193 | " lr_decay=False, # Use learning rate decay? If so, a scheduler gets invoked.\n", 194 | " noise_coeff = 0.01): # Introduced for generality in the construction of start_input[1:] below.\n", 195 | " torch.manual_seed(seed) # sets up PyTorch random number generator\n", 196 | "\n", 197 | " if plt_loss:\n", 198 | " plt.rcParams.update({'figure.figsize': (40,6)})\n", 199 | "\n", 200 | " total_losses = []\n", 201 | " losses = []\n", 202 | " dists = []\n", 203 | " perps = []\n", 204 | " optimised_inputs = set()\n", 205 | " done = None\n", 206 | "\n", 207 | " output_ix = tokenizer.encode(target_output, return_tensors='pt')[0].to(device)\n", 208 | " # output_ix is a 1-D tensor of shape (output_len,) that contains the indices of the tokens in the encoding of the string 'target_output'\n", 209 | " # tokenizer.encode(target_output, return_tensors='pt') is a list containing this one tensor, hence the need for the [0]\n", 210 | " # \"return_tensors='pt'\" ensures that we get a tensor in PyTorch format\n", 211 | "\n", 212 | " if output_len == None or output_len < output_ix.shape[0]: # This won't generally be the case, but if we don't specify output_len (i.e. it's == None), then...\n", 213 | " output_len = output_ix.shape[0] # ...it will be set to the number of tokens in the encoding of the string 'target_output'\n", 214 | " # Why not just set output_len = output_ix.shape[0] in all cases?\n", 215 | " # Will there be situations where we want output_len to be of a different size to the number of tokens in target_output?\n", 216 | "\n", 217 | " print('Optimising input of length {} to maximise output logits for \"{}\"'.format(input_len, target_output))\n", 218 | " # Typically this would print something like 'Optimising input of length 6 to maximise output logits for \"KILL ALL HUMANS!\"'.\n", 219 | "\n", 220 | " if base_input == None:\n", 221 | " start_input = torch.rand(batch_size, input_len, word_embeddings.shape[-1]).to(device)\n", 222 | " # If no base_input is provided, we construct start_input as a random tensor \n", 223 | " # of shape (batch_size, input_len, embedding_dim) (embedding_dim = 768 for this GPT-2 model).\n", 224 | " start_input = normalise(start_input,[word_embeddings.min(dim=0)[0], word_embeddings.max(dim=0)[0]])\n", 225 | " # We normalise this random tensor so that its minimum and maximum values correspond to those in the entire word_embeddings tensor\n", 226 | " # This dispenses with whole swathes of \"input space\" which contain no legal token embeddings \n", 227 | " # (we're limiting ourselves to a kind of \"hull\" defined by the 50527 vocab tokens in the embedding space), \n", 228 | " # which is a sensible place to look for optimised inputs.\n", 229 | " else:\n", 230 | " start_input = base_input.repeat(batch_size, 1, 1)\n", 231 | " # If a base_input was given, it should be of shape (input_len, embedding_dim), \n", 232 | " # and we build the start_input tensor by stacking 'batch_size' number of copies of this together...\n", 233 | "\n", 234 | " if batch_size > 1:\n", 235 | " start_input[1:] += (torch.rand_like(start_input[1:]) + torch.full_like(start_input[1:], -0.5)) * noise_coeff\n", 236 | " #...and if we have more than one element in our batch, we \"noise\" the rest. \n", 237 | " # This was originally done using \"*=\" (multiplying entries by small random numbers)\n", 238 | " # We've changed this to \"+=\" (adding small random numbers instead of multiplying by them).\n", 239 | " # The original code would have pushed everything in a positive direction, hence the use of a tensor full of -0.5's. \n", 240 | "\n", 241 | " \n", 242 | " input = torch.nn.Parameter(start_input, requires_grad=True)\n", 243 | " # input is not a tensor, it's a Parameter object that wraps a tensor and adds additional functionality. \n", 244 | " # 'input.data' is used below\n", 245 | " \n", 246 | " optimiser = torch.optim.Adam([input], lr=lr)\n", 247 | " # standard optimiser; note that it generally operates on a list of tensors, so we're giving it a list of one tensor; standard learning rate\n", 248 | " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', patience=20, cooldown=20, factor=0.5)\n", 249 | " # this is used when loss hasn't improved for 20 timesteps; this scheduler will reduce the lr by a 'factor' of 0.5 when the \n", 250 | " # validation loss stops improving for 'patience' (here 20) epochs, and will wait 'cooldown' (here 20) epochs before resuming normal operation.\n", 251 | "\n", 252 | " # now we loop across training epochs\n", 253 | " for e in range(epochs):\n", 254 | "\n", 255 | " logits, emb, perp = model_emb(torch.clamp(input, word_embeddings.min(), word_embeddings.max()), output_len)\n", 256 | " # Does forward pass on a 'clamped' version of the 'input' tensor (done to contain it within the 'hull' of the vocabulary within 'input space').\n", 257 | " # Iterates to produce an output of output_len tokens, \n", 258 | " # returns: 'logits' = tensor of logits for output, of shape (batch_size, output_len, vocab_size)\n", 259 | " # 'emb': tensor of embeddings for input+output of shape (batch_size, input_len + output_len, embedding_dim); \n", 260 | " # 'perp': the input sequence perplexities tensor, of shape (batch_size,)\n", 261 | " probs = torch.softmax(logits, dim=-1)\n", 262 | " # For each batch, output, converts the sequence of logits (of length 'vocab_size') in the 'logits' tensor to probabilities, using softmax\n", 263 | "\n", 264 | " logits = (logits - logits.min(dim=-1)[0].unsqueeze(-1)) / (logits.max(dim=-1)[0].unsqueeze(-1) - logits.min(dim=-1)[0].unsqueeze(-1))\n", 265 | " # This appears to be normalising the logits for each batch/output embedding so they're all between 0 and 1... \n", 266 | " # This is for ease of visualisation.\n", 267 | "\n", 268 | " perp_loss = perp.mean() * perp_reg\n", 269 | " # That's taking the mean perp value across all batches, then regularising it. Currently perp_reg is set to 0, so perp_loss = 0.\n", 270 | "\n", 271 | " if output_len > output_ix.shape[0]:\n", 272 | " target_logits = torch.stack([logits[:, :, ix] for ix in output_ix], dim=-1)\n", 273 | " target_logits = torch.max(target_logits, dim=-1)[0]\n", 274 | " # logits is shape (batch_size, output_len, vocab_size) \n", 275 | " # We throw out everything in the final dimension except those logits corresponding to indices of tokens in the target_ouput\n", 276 | " # This gives tensor with shape (batch_size, output_len, output_ix.shape[0])\n", 277 | " # We then take the maximum of those for each batch, output; this gives shape (batch_size, output_len)\n", 278 | " # The [0] returns just the max (torch.max returns max, indices tuple)\n", 279 | " target_probs = torch.stack([probs[:, :, ix] for ix in output_ix], dim=-1)\n", 280 | " target_probs = torch.max(target_probs, dim=-1)[0]\n", 281 | " # This does the analogous thing for probs.\n", 282 | "\n", 283 | " else:\n", 284 | " target_logits = torch.stack([logits[:,i, ix] for i, ix in enumerate(output_ix)], dim=-1)\n", 285 | " target_probs = torch.stack([probs[:,i, ix] for i, ix in enumerate(output_ix)], dim=-1)\n", 286 | " # This handles case where output_len == output_ix.shape[0]\n", 287 | " # target_logits now of shape (batch_size, output_len)\n", 288 | " # output_len < output_ix.shape[0] was dealt with in line 133\n", 289 | " \n", 290 | " token_dist = torch.stack([torch.stack([closest_tokens(e)[2].squeeze(-1) for e in input[b]]) for b in range(batch_size)])\n", 291 | " # As far as I can tell, this creates a tensor of shape (batch_size, input_len, 1) which gives distance to nearest\n", 292 | " # legal token embedding for each input embedding in each batch\n", 293 | " mean_token_dist = token_dist.mean() * dist_reg\n", 294 | " # A single scalar value, taking mean across the batch and input embeddings? \n", 295 | "\n", 296 | "\n", 297 | " # There are currently four loss types, many more could be introduced.\n", 298 | " # log_prob_loss is the current default.\n", 299 | " if loss_type == 'logit_loss':\n", 300 | " loss = torch.mean(1-target_logits)\n", 301 | " elif loss_type == 'log_prob_loss':\n", 302 | " loss = -torch.log(target_probs).mean()\n", 303 | " elif loss_type == 'prob_loss':\n", 304 | " loss = 1-torch.mean(target_probs)\n", 305 | " elif loss_type == 'CE':\n", 306 | " loss = torch.nn.functional.cross_entropy(logits.swapaxes(-1,-2), output_ix.repeat(batch_size, 1))\n", 307 | "\n", 308 | " else:\n", 309 | " print(loss_type + 'is not implemented.')\n", 310 | " return\n", 311 | "\n", 312 | " total_loss = torch.stack([mean_token_dist, loss, perp_loss]).mean()\n", 313 | " # This is this just (mean_token_dist + loss + perp_loss)/3 tensorised across batches, yes?\n", 314 | "\n", 315 | " total_losses.append(total_loss.detach().cpu().data)\n", 316 | " losses.append(loss.detach().cpu().data)\n", 317 | " dists.append(mean_token_dist.detach().cpu().data)\n", 318 | " perps.append(perp_loss.detach().cpu().data)\n", 319 | " # these four lists were intialised above. We're appeneding to the list each epoch. All are scalars.\n", 320 | "\n", 321 | " closest_ix = torch.stack([torch.stack([closest_tokens(e)[1] for e in b]) for b in input]).squeeze(-1)\n", 322 | " # This is similar to above, but building a tensor of indices of nearest embeddings, rather than distances.\n", 323 | " # Iterates over batches, and for each batch iterates over embeddings, giving tensor of shape (batch_size, input_len).\n", 324 | "\n", 325 | " model_outs = model.generate(closest_ix, max_length = output_len+input_len)\n", 326 | " # The 'closest_ix' tensor is passed as the initial input sequence to the model, \n", 327 | " # and the max_length parameter specifies the maximum length of the total sequence to generate.\n", 328 | " # The output sequence will be terminated either when the end-of-sequence token is generated \n", 329 | " # or when the maximum length is reached, whichever occurs first.\n", 330 | " # \n", 331 | " # The output of the model.generate method will be a tuple containing the generated sequences and the model's internal states. \n", 332 | " # The generated sequences will be stored in a tensor of shape (batch_size, output_len+input_len). \n", 333 | " # Each element of the tensor will be a sequence of tokens with a length of at most output_len+input_len.\n", 334 | " \n", 335 | " for b in range(batch_size):\n", 336 | " # iterate over batches \n", 337 | " if output_len > output_ix.shape[0]:\n", 338 | " if target_output in tokenizer.decode(model_outs[b][input_len:]):\n", 339 | " done = tokenizer.decode(model_outs[b][:input_len])\n", 340 | " optimised_inputs.add(done)\n", 341 | " # model_outs[b][input_len:], for a batch b, is only looking at the *output* embeddings \n", 342 | " # we decode these as tokens... is the target_output a substring?\n", 343 | " # if so, we print the target_output and the decoded string that contains it\n", 344 | " # 'done' is the string version of the model's output for given input, we add this to set 'optimised_inputs'.\n", 345 | "\n", 346 | " if rand_after:\n", 347 | " input.data[b] = torch.rand_like(input[b])\n", 348 | " # This will require new normalisation function.\n", 349 | " # The idea here seems to be randomly re-initialise the input tensor once we've found an optimised input,\n", 350 | " # input.data is the tensor version of the 'input' Parameter object. Current values, without gradient!\n", 351 | " # That's of shape (batch_size, input_len, embedding_dim)\n", 352 | "\n", 353 | " if tokenizer.decode(model_outs[b][input_len:]) == target_output:\n", 354 | " done = tokenizer.decode(model_outs[b][:input_len])\n", 355 | " optimised_inputs.add(done)\n", 356 | " # model_outs[b][input_len:], for a batch b, is only looking at the *output* embeddings \n", 357 | " # we decode these as tokens... is the target_output equal to output string?\n", 358 | " # Nothing printed in this case.\n", 359 | " # 'done' is the string version of the model's output for given input, we add this to set 'optimised_inputs'.\n", 360 | " if rand_after:\n", 361 | " input.data[b] = torch.rand_like(input[b])\n", 362 | " # Random re-initialisation (if 'rand_after' set to True)\n", 363 | "\n", 364 | " \n", 365 | " if ((e+1) % w_freq == 0) or done and return_early:\n", 366 | " display.clear_output(wait=True) \n", 367 | " # Every w epochs we write to log, unless we have found an optimised input before that and 'return_early' == True. \n", 368 | " # I'm still not entirely sure about the idea of 'return_early'.\n", 369 | "\n", 370 | " if plt_loss:\n", 371 | " plt.plot(range(len(total_losses)), total_losses, label='Total Loss', color='black')\n", 372 | " plt.plot(range(len(losses)), losses, label='Output Loss')\n", 373 | " plt.plot(range(len(dists)), dists, label='Emb Dist Loss')\n", 374 | " plt.plot(range(len(perps)), perps, label='Perp Loss')\n", 375 | " plt.yscale('log')\n", 376 | " plt.legend()\n", 377 | "\n", 378 | " plt.show()\n", 379 | "\n", 380 | " print('Inputs found: ', optimised_inputs)\n", 381 | " print('{}/{} Output Loss: {} Emb Dist Loss: {} Perp Loss: {} LR: {}'.format(e+1, epochs, loss, mean_token_dist, perp_loss, optimiser.param_groups[0]['lr']))\n", 382 | " if verbose == 3:\n", 383 | " print('Target Probs: {}\\nTarget Logits: {}\\nInput Dists: {}\\nInput Perplexity: {}\\n'.format(target_probs.detach().cpu().numpy(), target_logits.detach().cpu().numpy(), token_dist.detach().cpu().numpy(), perp.detach().reshape(-1).cpu().numpy()))\n", 384 | " # Optimised inputs and additional information are printed as part of log\n", 385 | "\n", 386 | " for b in range(batch_size):\n", 387 | " if verbose > 0:\n", 388 | " if verbose == 2:\n", 389 | " print(b, repr(' Raw embeddings: {}'.format(''.join([closest_tokens(e)[0][0] for e in emb[b]]))))\n", 390 | " # Change name to clarify (output of model if we just put in raw embeddings)\n", 391 | " # prints batch number; closest_tokens(e)[0] is a list of tokens, closest_tokens(e)[0] is the first (closest) of these\n", 392 | " # these get joined with separator '' (SHOULDN'T THAT BE ' '?) \n", 393 | " print(b, repr(' Closest embeddings: {}'.format(tokenizer.decode(model_outs[b]), '\\n')))\n", 394 | " # WON'T THIS give string decodings of the embeddings, rather than the embeddings themselves?\n", 395 | " else:\n", 396 | " print(repr(tokenizer.decode(model_outs[b])), end=' ')\n", 397 | " # The least verbose printed output. The 'end' parameter is used to specify the end-of-line string that is appended to the output. \n", 398 | " # By default, this is a newline character, but in this case it has been set to a single space character, \n", 399 | " # so the output will be separated by spaces rather than newlines.\n", 400 | "\n", 401 | " if done and return_early:\n", 402 | " print('\\nOptimised Input: \"{}\"'.format(done))\n", 403 | " return optimised_inputs\n", 404 | " # we know optimised_inputs set contains a single element in this case\n", 405 | " \n", 406 | " optimiser.zero_grad()\n", 407 | " total_loss.backward()\n", 408 | " optimiser.step()\n", 409 | " # I assume these three lines are standard NN optimisation stuff?\n", 410 | "\n", 411 | " if lr_decay:\n", 412 | " scheduler.step(total_loss)\n", 413 | " # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', patience=20, cooldown=20, factor=0.5) gets used if lr_decay == True\n", 414 | " \n", 415 | " return optimised_inputs\n", 416 | " # that's a set of strings\n", 417 | "\n" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": null, 423 | "metadata": { 424 | "colab": { 425 | "base_uri": "https://localhost:8080/" 426 | }, 427 | "id": "tev0cDbkdbxO", 428 | "outputId": "45b23d11-a2b5-4e46-d592-5f8f0076390b" 429 | }, 430 | "outputs": [ 431 | { 432 | "output_type": "stream", 433 | "name": "stdout", 434 | "text": [ 435 | "[767, 767, 767]\n", 436 | "[' 7', ' 7', ' 7']\n", 437 | "3\n", 438 | " 7 7 7 7 7 7 7\n" 439 | ] 440 | } 441 | ], 442 | "source": [ 443 | "ix = tokenizer.encode(\" 7 7 7\")\n", 444 | "# list of 'vocab indices'\n", 445 | "print(ix)\n", 446 | "print([tokenizer.decode(i) for i in ix])\n", 447 | "# prints reconstruction of input string\n", 448 | "print(len(ix))\n", 449 | "# prints number of tokens\n", 450 | "output_len=4\n", 451 | "model_out = model.generate(torch.tensor(ix).unsqueeze(0).to(device), max_length = output_len + len(ix))\n", 452 | "print(tokenizer.decode(model_out[0]))\n", 453 | "# pushes input string throught GPT2 (or whichever model) iteratively producing output_len number of tokens, then prints input + output." 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "source": [ 459 | "from time import time\n", 460 | "target_output = \" the best player?\"\n", 461 | "input_len = 2\n", 462 | "\n", 463 | "base_input = word_embeddings[tokenizer.encode(target_output)].mean(dim=0)\n", 464 | "base_input = base_input.repeat(1, input_len, 1)\n", 465 | "\n", 466 | "tic = time()\n", 467 | "oi = optimise_input(base_input=base_input, \n", 468 | " plt_loss=False,\n", 469 | " verbose=2, \n", 470 | " epochs=500, \n", 471 | " lr_decay=False,\n", 472 | " return_early=False, \n", 473 | " lr=0.1, \n", 474 | " batch_size=20, \n", 475 | " target_output=target_output, \n", 476 | " output_len=4,\n", 477 | " input_len=input_len, \n", 478 | " w_freq=20, \n", 479 | " dist_reg=1, \n", 480 | " perp_reg=0,\n", 481 | " loss_type='log_prob_loss',\n", 482 | " noise_coeff = 0.75)\n", 483 | "toc = time()\n", 484 | "tt = toc - tic\n", 485 | "print('Time Taken: ', tt)\n" 486 | ], 487 | "metadata": { 488 | "colab": { 489 | "base_uri": "https://localhost:8080/", 490 | "height": 1000 491 | }, 492 | "id": "81r4gT4m2NQ1", 493 | "outputId": "3ebaad31-0828-43b3-fb29-5c8c9d0c39ae" 494 | }, 495 | "execution_count": null, 496 | "outputs": [ 497 | { 498 | "output_type": "stream", 499 | "name": "stdout", 500 | "text": [ 501 | "Inputs found: set()\n", 502 | "300/500 Output Loss: 2.0041980743408203 Emb Dist Loss: 5.514309883117676 Perp Loss: 0.0 LR: 0.1\n", 503 | "0 ' Raw embeddings: the an the best,\\n'\n", 504 | "0 ' Closest embeddings: the ancients, and the'\n", 505 | "1 ' Raw embeddings: play bruised the best player?'\n", 506 | "1 ' Closest embeddings: play bruised and battered.\\n'\n", 507 | "2 ' Raw embeddings: 80<|endoftext|> the best best.'\n", 508 | "2 ' Closest embeddings: 80<|endoftext|>The U.S'\n", 509 | "3 ' Raw embeddings: Main those the best player?'\n", 510 | "3 ' Closest embeddings: Main those who have been in'\n", 511 | "4 ' Raw embeddings: ( Continuing the best player?'\n", 512 | "4 ' Closest embeddings: ( Continuing )\\n\\n('\n", 513 | "5 ' Raw embeddings: 1 objects the the the the'\n", 514 | "5 ' Closest embeddings: 1 objects.\\n\\nThe'\n", 515 | "6 ' Raw embeddings: her Berlin the best best.'\n", 516 | "6 ' Closest embeddings: her Berlin Wall.\\n\\n'\n", 517 | "7 ' Raw embeddings: Obviously jobs the best player?'\n", 518 | "7 ' Closest embeddings: Obviously jobs are not just for'\n", 519 | "8 ' Raw embeddings: bestAmid the best team.'\n", 520 | "8 ' Closest embeddings: bestAmid the chaos, the'\n", 521 | "9 ' Raw embeddings: Cleveland if the best player?'\n", 522 | "9 \" Closest embeddings: Cleveland if he's healthy.\"\n", 523 | "10 ' Raw embeddings: linedSummary the best player?'\n", 524 | "10 ' Closest embeddings: linedSummary = \"\\n\\n'\n", 525 | "11 ' Raw embeddings: blocking has the the..'\n", 526 | "11 ' Closest embeddings: blocking has been a problem for'\n", 527 | "12 \" Raw embeddings: sensational s's team.\\n\"\n", 528 | "12 ' Closest embeddings: sensational siren song.\\n'\n", 529 | "13 ' Raw embeddings: glance family the best player?'\n", 530 | "13 ' Closest embeddings: glance family.\\n\\n\"'\n", 531 | "14 ' Raw embeddings: using game the best player?'\n", 532 | "14 ' Closest embeddings: using game-play.\\n'\n", 533 | "15 ' Raw embeddings: clothes( the best player?'\n", 534 | "15 ' Closest embeddings: clothes(s) of the'\n", 535 | "16 ' Raw embeddings: high at the best player.'\n", 536 | "16 ' Closest embeddings: high at the time.\\n'\n", 537 | "17 ' Raw embeddings: wielded getting the best player?'\n", 538 | "17 ' Closest embeddings: wielded getting a new job.'\n", 539 | "18 ' Raw embeddings: way Spot the best player?'\n", 540 | "18 \" Closest embeddings: way Spotty's been doing\"\n", 541 | "19 ' Raw embeddings: predictable best fact fact fact\\n'\n", 542 | "19 ' Closest embeddings: predictable best-case scenario for'\n" 543 | ] 544 | }, 545 | { 546 | "output_type": "error", 547 | "ename": "KeyboardInterrupt", 548 | "evalue": "ignored", 549 | "traceback": [ 550 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 551 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 552 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mtic\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m oi = optimise_input(base_input=base_input, \n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mplt_loss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 553 | "\u001b[0;32m\u001b[0m in \u001b[0;36moptimise_input\u001b[0;34m(epochs, lr, rand_after, w_freq, base_input, batch_size, input_len, target_output, output_len, dist_reg, perp_reg, plt_loss, loss_type, seed, return_early, verbose, lr_decay, noise_coeff)\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 327\u001b[0m \u001b[0moptimiser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 328\u001b[0;31m \u001b[0mtotal_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 329\u001b[0m \u001b[0moptimiser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0;31m# I assume these three lines are standard NN optimisation stuff?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 554 | "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n", 555 | "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 197\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 198\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n", 556 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 557 | ] 558 | } 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "source": [ 564 | " " 565 | ], 566 | "metadata": { 567 | "id": "2jZJMakszJ_U" 568 | }, 569 | "execution_count": null, 570 | "outputs": [] 571 | } 572 | ], 573 | "metadata": { 574 | "accelerator": "GPU", 575 | "colab": { 576 | "machine_shape": "hm", 577 | "provenance": [], 578 | "include_colab_link": true 579 | }, 580 | "gpuClass": "premium", 581 | "kernelspec": { 582 | "display_name": "Python 3", 583 | "name": "python3" 584 | }, 585 | "language_info": { 586 | "name": "python" 587 | } 588 | }, 589 | "nbformat": 4, 590 | "nbformat_minor": 0 591 | } -------------------------------------------------------------------------------- /backwards.py: -------------------------------------------------------------------------------- 1 | from helpers import * 2 | 3 | import torch 4 | from matplotlib import pyplot as plt 5 | from IPython import display 6 | import numpy as np 7 | import argparse 8 | import json 9 | import os 10 | from collections import Counter 11 | import random 12 | 13 | os.environ["WANDB_API_KEY"] = "YOUR_WANDB_API_KEY" 14 | os.environ["WANDB_SILENT"] = "true" 15 | 16 | 17 | def optimise_input(model, 18 | word_embeddings, 19 | tokenizer, 20 | device, 21 | epochs=100, 22 | lr=0.1, 23 | no_reinit=False, 24 | w_freq=10, 25 | rand_input=False, 26 | local_input=False, 27 | batch_size=20, 28 | input_len=10, 29 | target_output=' world', 30 | output_len=None, 31 | dist_reg=0.1, 32 | perp_reg=0, 33 | loss_type='log_prob_loss', 34 | seed=0, 35 | return_early=False, # finishes if single optimised input is found 36 | verbose=1, 37 | lr_decay=False, # Use learning rate decay? If so, a scheduler gets invoked. 38 | run_random=0, 39 | equal_clusters=False, 40 | penalise_repetition=False, 41 | optimiser='Adam', 42 | **kwargs): 43 | 44 | if run_random > 0: 45 | random_ix = (torch.rand(1) * word_embeddings.shape[0]).int() 46 | target_output = tokenizer.decode(random_ix) # Converts token index to string representation 47 | wandb.config.update({'target_output': target_output}, allow_val_change=True) 48 | 49 | print('Optimising input of length {} to maximise output logits for "{}"'.format(input_len, target_output)) 50 | done = None 51 | 52 | output_ix = tokenizer.encode(target_output, return_tensors='pt')[0].to(device) 53 | 54 | word_embeddings = word_embeddings / torch.sqrt(torch.sum(word_embeddings**2, dim=-1, keepdim=True)) 55 | 56 | optimised_inputs = set() 57 | optimised_tokens = [] 58 | metrics_table = wandb.Table(columns=['Input', 'Output', 'Loss', 'Perplexity', 'Distance', 'Probs']) 59 | 60 | if output_len == None or output_len < output_ix.shape[ 61 | 0]: 62 | output_len = output_ix.shape[ 63 | 0] 64 | else: 65 | possible_target_positions = torch.stack( 66 | [torch.arange(0, output_ix.shape[0]) + i for i in range(output_len - output_ix.shape[0] + 1)]) 67 | 68 | if rand_input == True: 69 | start_input = word_embeddings[torch.randperm(word_embeddings.shape[0])[:input_len * batch_size]].reshape( 70 | batch_size, input_len, -1) 71 | elif local_input == True: 72 | local_embs = closest_tokens(word_embeddings[output_ix].mean(dim=0), word_embeddings, tokenizer, n=batch_size)[-1].unsqueeze(1) 73 | start_input = local_embs.repeat(1, input_len, 1) 74 | else: 75 | num_clusters = batch_size * input_len 76 | _, centroids = kkmeans(word_embeddings.detach(), num_clusters, seed=seed, 77 | equal_clusters=equal_clusters) 78 | start_input = centroids.reshape(batch_size, input_len, -1) 79 | 80 | input = torch.nn.Parameter(start_input.to(device), requires_grad=True) 81 | 82 | if optimiser == 'Adam': 83 | optimiser = torch.optim.Adam([input], lr=lr, eps=0.0001) 84 | elif optimiser == 'SGD': 85 | optimiser = torch.optim.SGD([input], lr=lr) 86 | else: 87 | print('Unsupported optimiser: ', optimiser) 88 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', patience=20, cooldown=20, factor=0.5) 89 | 90 | for e in range(epochs): 91 | norm_input = input / torch.sqrt(torch.sum(input**2, dim=-1, keepdim=True)) 92 | logits, emb, perp = model_emb(model, norm_input, 93 | word_embeddings, output_len) 94 | 95 | probs = torch.softmax(logits, dim=-1) 96 | 97 | perp_loss = perp.mean() # across all elements in the batch 98 | 99 | if output_len > output_ix.shape[0]: 100 | target_logits = logits[:, possible_target_positions, output_ix].max(dim=1)[0] 101 | target_probs = probs[:, possible_target_positions, output_ix].max(dim=1)[0] 102 | 103 | else: 104 | target_logits = logits[:, torch.arange(output_len), output_ix] 105 | target_probs = probs[:, torch.arange(output_len), output_ix] 106 | 107 | token_dist, closest_ix = [], [] 108 | for b in norm_input: 109 | tds, cixs = [], [] 110 | for be in b: 111 | _, cix, td, _ = closest_tokens(be, word_embeddings, tokenizer) 112 | tds.append(td) 113 | cixs.append(cix) 114 | token_dist.append(torch.stack(tds)) 115 | closest_ix.append(torch.stack(cixs)) 116 | 117 | token_dist, closest_ix = torch.stack(token_dist).squeeze(-1), torch.stack(closest_ix).squeeze(-1) 118 | 119 | mean_token_dist = token_dist.mean() 120 | 121 | if loss_type == 'log_prob_loss': 122 | loss = -torch.log(target_probs) 123 | elif loss_type == 'CE': 124 | if output_len > 1: 125 | print('CE not supported with output length > 1.') 126 | return 127 | loss = torch.nn.functional.cross_entropy(logits.swapaxes(-1, -2), output_ix.repeat(batch_size, 1), 128 | reduction='none') 129 | else: 130 | print(loss_type + 'is not implemented.') 131 | return 132 | 133 | batch_loss = loss.mean() 134 | 135 | total_loss = torch.stack([mean_token_dist * dist_reg, batch_loss, perp_loss * perp_reg]).mean() 136 | 137 | if penalise_repetition: 138 | rep_penalty = logits[:, :input_len, output_ix].sum() 139 | total_loss += rep_penalty 140 | else: 141 | rep_penalty = 0 142 | 143 | model_outs = model.generate(closest_ix, max_length=output_len + input_len) 144 | 145 | for b in range(batch_size): 146 | if target_output in tokenizer.decode(model_outs[b][input_len:]): 147 | if tokenizer.decode(model_outs[b]) not in optimised_inputs: 148 | optimised_tokens += [tokenizer.decode(t) for t in model_outs[b][:input_len]] 149 | 150 | counts = Counter(optimised_tokens) 151 | labels, values = zip(*counts.items()) 152 | 153 | data = [[label, val] for (label, val) in zip(labels, values)] 154 | table = wandb.Table(data=data, columns=["Token", "Count"]) 155 | wandb.log({"token_freqs": wandb.plot.bar(table, "Token", 156 | "Count", title="Token Freqs")}) 157 | 158 | done = tokenizer.decode(model_outs[b]) 159 | optimised_inputs.add(done) 160 | metrics_table.add_data(*[tokenizer.decode(model_outs[b][:input_len]), 161 | tokenizer.decode(model_outs[b][input_len:])] + torch.stack( 162 | [loss.squeeze(-1)[b].mean(), perp[b], token_dist.mean(dim=1)[b]], dim=-1).tolist() + [target_probs[ 163 | b].tolist()]) 164 | wandb.log({'Optimised Inputs': wandb.Html( 165 | ''.join(['

{}.{}

'.format(i, repr(s)) for i, s in enumerate(optimised_inputs)]))}) 166 | 167 | if no_reinit == False: 168 | if rand_input == True or local_input == True: 169 | input.data[b] = word_embeddings[torch.randperm(word_embeddings.shape[0])[:input_len]].reshape(1, 170 | input_len, 171 | -1).to( 172 | device) 173 | else: 174 | rand_centroids = centroids[np.random.randint(0, batch_size, size=input_len)].unsqueeze(0) 175 | input.data[b] = rand_centroids 176 | 177 | if ((e + 1) % w_freq == 0) or done and return_early: 178 | 179 | print("Optimised Inputs:", optimised_inputs) 180 | print('{}/{} Output Loss: {} Emb Dist Loss: {} Perp Loss: {} LR: {}'.format(e + 1, epochs, batch_loss, 181 | mean_token_dist, perp_loss, 182 | optimiser.param_groups[0][ 183 | 'lr'])) 184 | if verbose == 3: 185 | print('Target Probs: {}\nTarget Logits: {}\nInput Dists: {}\nInput Perplexity: {}\n'.format( 186 | target_probs.detach().cpu().numpy(), target_logits.detach().cpu().numpy(), 187 | token_dist.detach().cpu().numpy(), perp.detach().reshape(-1).cpu().numpy())) 188 | 189 | closest_embeddings = [] 190 | 191 | for b in range(batch_size): 192 | if verbose > 0: 193 | if verbose == 2: 194 | print(b, repr(' Raw embeddings: {}'.format(''.join([closest_tokens(e)[0][0] for e in emb[b]])))) 195 | 196 | print(b, repr(' Closest embeddings: {}'.format(tokenizer.decode(model_outs[b]), '\n'))) 197 | closest_embeddings.append(tokenizer.decode(model_outs[b])) 198 | 199 | wandb.log({'Closest Embeddings': wandb.Html( 200 | ''.join(['

{}.{}

'.format(i, repr(ce)) for i, ce in enumerate(closest_embeddings)])), 201 | 'Total Loss': total_loss, 'Mean Token Distance': mean_token_dist, 'Mean Loss': batch_loss, 202 | 'Mean Perplexity Loss': perp_loss, 'Epoch': e, 'LR': optimiser.param_groups[0]['lr'], 203 | 'Num Inputs Found': len(optimised_inputs), 'Repetition Penalty': rep_penalty}) 204 | 205 | if done and return_early: 206 | print('\nOptimised Input: "{}"'.format(done)) 207 | return {'Metrics': metrics_table} 208 | 209 | optimiser.zero_grad() 210 | total_loss.backward() 211 | optimiser.step() 212 | 213 | if lr_decay: 214 | scheduler.step(total_loss) 215 | done = None 216 | 217 | return {'Metrics': metrics_table} 218 | 219 | 220 | if __name__ == '__main__': 221 | parser = argparse.ArgumentParser() 222 | parser.add_argument('--wandb_user', type=str, default='jessicamarycooper') 223 | parser.add_argument('--model_name', type=str, default='gpt2') 224 | parser.add_argument('--epochs', type=int, default=100) 225 | parser.add_argument('--lr', type=float, default=0.1) 226 | parser.add_argument('--no_reinit', action='store_true') 227 | parser.add_argument('--w_freq', type=int, default=10) 228 | parser.add_argument('--rand_input', action='store_true') 229 | parser.add_argument('--local_input', action='store_true') 230 | parser.add_argument('--batch_size', type=int, default=20) 231 | parser.add_argument('--input_len', type=int, default=10) 232 | parser.add_argument('--target_output', type=str, default=' world') 233 | parser.add_argument('--output_len', type=int) 234 | parser.add_argument('--dist_reg', type=float, default=0.1) 235 | parser.add_argument('--perp_reg', type=float, default=0) 236 | parser.add_argument('--loss_type', type=str, default='log_prob_loss') 237 | parser.add_argument('--seed', type=int, default=0) 238 | parser.add_argument('--return_early', action='store_true') 239 | parser.add_argument('--verbose', type=int, default=1) 240 | parser.add_argument('--lr_decay', action='store_true') 241 | parser.add_argument('--note', type=str, default='') 242 | parser.add_argument('--run_test_set', type=int, default=-1) 243 | parser.add_argument('--run_random', type=int, default=0) 244 | parser.add_argument('--optimiser', type=str, default='Adam') 245 | parser.add_argument('--equal_clusters', action='store_true') 246 | parser.add_argument('--penalise_repetition', action='store_true') 247 | 248 | args = parser.parse_args() 249 | 250 | test_sets = [ 251 | [' externalToEVA', 'quickShip', ' TheNitrome', 'embedreportprint', 'rawdownload', 'reportprint', ' サーティ', 252 | ' RandomRedditor', 'oreAndOnline', 'InstoreAndOnline', ' externalTo', 'StreamerBot', 'ActionCode', 'Nitrome', ' SolidGoldMagikarp', 'PsyNetMessage'], 253 | [' girl', ' boy', 'good', ' evil', ' science', ' art', ' England', ' USA'], 254 | [' newcom', 'slaught', 'senal', 'imei']] 255 | 256 | torch.manual_seed(args.seed) 257 | random.seed(0) 258 | np.random.seed(0) 259 | 260 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 261 | 262 | print('Using {} device.'.format(args.device)) 263 | 264 | args.model, args.word_embeddings, args.tokenizer = load_all(args.model_name, args.device) 265 | 266 | if args.run_test_set > -1: 267 | for to in test_sets[args.run_test_set]: 268 | args.target_output = to 269 | run = wandb.init(config=args, project='backwards', entity=args.wandb_user, reinit=True) 270 | results = optimise_input(**vars(args)) 271 | wandb.log(results) 272 | run.finish() 273 | 274 | if args.run_random > 0: 275 | 276 | seeds = (torch.rand(args.run_random) * 60000).int() 277 | for r in range(args.run_random): 278 | args.seed = seeds[r] 279 | args.target_output = 'RANDOM' 280 | run = wandb.init(config=args, project='backwards', entity=args.wandb_user, reinit=True) 281 | results = optimise_input(**vars(args)) 282 | wandb.log(results) 283 | run.finish() 284 | 285 | if args.run_test_set == -1 and args.run_random == 0: 286 | run = wandb.init(config=args, project='backwards', entity=args.wandb_user, reinit=True) 287 | results = optimise_input(**vars(args)) 288 | wandb.log(results) 289 | run.finish() 290 | 291 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from helpers import * 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('-m', '--model', type=str, default='gpt2') 7 | parser.add_argument('-i', '--input', type=str) 8 | parser.add_argument('-o', '--output_length', type=int) 9 | args = parser.parse_args() 10 | 11 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 12 | 13 | model, embeddings, tokenizer = load_all(args.model, device) 14 | 15 | ix = tokenizer.encode(args.input) 16 | 17 | print('{} input tokens: {}'.format(len(ix), [tokenizer.decode(i) for i in ix])) 18 | 19 | model_out = model.generate(torch.tensor(ix).unsqueeze(0).to(device), max_length = args.output_length + len(ix)) 20 | 21 | print('\nOutput:\n{}'.format(tokenizer.decode(model_out[0]))) 22 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import os 4 | import torch 5 | 6 | 7 | def install(package): 8 | subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--upgrade"]) 9 | 10 | try: 11 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils, AutoTokenizer, GPTJForCausalLM 12 | except: 13 | install('transformers') 14 | install('accelerate') 15 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils, AutoTokenizer, GPTJForCausalLM 16 | 17 | try: 18 | import wandb 19 | except: 20 | install('wandb') 21 | import wandb 22 | 23 | utils.logging.set_verbosity_error() 24 | 25 | 26 | def load_all(model_name="gpt2", device='cpu', save_dir=''): 27 | print(save_dir) 28 | if save_dir == '': 29 | cur_dir = os.listdir() 30 | else: 31 | cur_dir = os.listdir(save_dir) 32 | 33 | if model_name + '_tokenizer' in cur_dir: 34 | print('Loading tokenizer...') 35 | tokenizer = torch.load(save_dir + model_name + '_tokenizer') 36 | else: 37 | print('Downloading tokenizer...') 38 | if 'gpt-j' in model_name: 39 | print('WARNING: I haven\'t tuned hyperparameters for gpt-j. Don\'t expect it to work very well on defaults!') 40 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 41 | else: 42 | tokenizer = GPT2Tokenizer.from_pretrained(model_name, padding_side='left') 43 | torch.save(tokenizer, save_dir + model_name + '_tokenizer') 44 | pad_token_id = tokenizer.eos_token_id 45 | 46 | if model_name + '_model' in cur_dir: 47 | print('Loading model...') 48 | model = torch.load(save_dir + model_name + '_model').to(device) 49 | else: 50 | print('Downloading model...') 51 | 52 | if 'gpt-j' in model_name: 53 | model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", 54 | torch_dtype=torch.float16, low_cpu_mem_usage=True) 55 | else: 56 | model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id) 57 | torch.save(model, save_dir + model_name + '_model') 58 | model.eval() 59 | 60 | 61 | embeddings = model.transformer.wte.weight.detach() 62 | if model_name + '_embeddings' not in cur_dir: 63 | torch.save(embeddings, save_dir + model_name + '_embeddings') 64 | 65 | return model.to(device), embeddings.to(device), tokenizer 66 | 67 | 68 | def kkmeans(embeddings, num_clusters, threshold=0.00001, max_iter=1000, seed=0, overwrite=False, 69 | save_dir='', equal_clusters=False, cluster_dim=-1): 70 | 71 | if cluster_dim != -1 and equal_clusters: 72 | print('WARNING! Equal clusters not supported for dimension clustering.') 73 | embeddings = embeddings.detach() 74 | 75 | centroid_fname = str(embeddings.shape[0]) + '_' + str(embeddings.shape[1]) + '_' + str(num_clusters) + '_' + str( 76 | equal_clusters) + '_' + str(seed) + ' dim' + str(cluster_dim) + '_centroids' 77 | cluster_fname = str(embeddings.shape[0]) + '_' + str(embeddings.shape[1]) + '_' + str(num_clusters) + '_' + str( 78 | equal_clusters) + '_' + str(seed) + ' dim' + str(cluster_dim) + '_cluster' 79 | 80 | if not overwrite: 81 | cur_dir = os.listdir() 82 | if centroid_fname in cur_dir: 83 | print('Loading clusters...') 84 | return torch.load(cluster_fname), torch.load(centroid_fname) 85 | 86 | print('Finding clusters...') 87 | if seed != -1: 88 | torch.manual_seed(seed) 89 | cluster_size = embeddings.shape[0] // num_clusters 90 | # initial centroids is a set of random token embeddings (one for each cluster) 91 | centroids = embeddings[torch.randperm(embeddings.shape[0])[:num_clusters]] 92 | 93 | movement = 9999 # this will be used in each iteration step as mean centroid movement distance 94 | i = 0 95 | 96 | while movement > threshold and i < max_iter: 97 | i += 1 98 | 99 | print(embeddings.shape, centroids.shape) 100 | if cluster_dim > -1: 101 | distances = 1 - (embeddings[:, cluster_dim] @ centroids[cluster_dim].T) 102 | 103 | else: 104 | distances = 1 - (embeddings @ centroids.T) 105 | 106 | closest_distance, closest_centroid = torch.sort(distances, dim=-1) 107 | clusters = [embeddings[(closest_centroid[:, 0] == i)] for i in range(num_clusters)] 108 | 109 | if equal_clusters: 110 | for c in range(num_clusters): 111 | if clusters[c].shape[0] > cluster_size: 112 | # sort cluster embs by distance from centroid so spares are furthest away 113 | _, sorted_cluster_embs_ix = torch.sort( 114 | 1 - (clusters[c] @ clusters[c].mean(dim=0).unsqueeze(0).T).squeeze(-1)) 115 | 116 | clusters[c] = clusters[c][sorted_cluster_embs_ix] 117 | spare_embs = clusters[c][cluster_size:] 118 | clusters[c] = clusters[c][:cluster_size] 119 | for cc in range(num_clusters): 120 | if clusters[cc].shape[0] < cluster_size: 121 | 122 | _, sorted_spare_embs_ix = torch.sort( 123 | 1 - (spare_embs @ clusters[cc].mean(dim=0).unsqueeze(0).T).squeeze(-1)) 124 | 125 | free_space = cluster_size - clusters[cc].shape[0] 126 | clusters[cc] = torch.cat([clusters[cc], spare_embs[sorted_spare_embs_ix][:free_space]]) 127 | spare_embs = spare_embs[free_space:] 128 | 129 | new_centroids = torch.stack([c.mean(dim=0)/torch.sqrt(torch.sum(c.mean(dim=0)**2, dim=-1, keepdim=True)) for c in clusters]) 130 | movement = torch.abs(new_centroids - centroids).mean() 131 | print('Movement :', movement) 132 | centroids = new_centroids 133 | 134 | centroids = torch.stack([c.mean(dim=0)/torch.sqrt(torch.sum(c.mean(dim=0)**2, dim=-1, keepdim=True)) for c in clusters]) 135 | print([c.shape[0] for c in clusters]) 136 | torch.save(clusters, save_dir + cluster_fname) 137 | torch.save(centroids, save_dir + centroid_fname) 138 | return clusters, centroids 139 | 140 | 141 | def normalise(x, min_max=[]): 142 | 143 | rnge = x.max() - x.min() 144 | if rnge > 0: 145 | x = (x - x.min()) / rnge 146 | 147 | if len(min_max) > 1: 148 | rnge = min_max[1] - min_max[0] 149 | x = x * rnge + min_max[0] 150 | 151 | return x 152 | 153 | 154 | def closest_tokens(emb, word_embeddings, tokenizer, n=1): 155 | torch.cuda.empty_cache() 156 | dists = 1 - (emb.unsqueeze(0) @ word_embeddings.T).squeeze(0) 157 | sorted_dists, ix = torch.sort(dists) 158 | 159 | tokens = [tokenizer.decode(i) for i in ix[:n]] 160 | ixs = ix[:n] 161 | dists = sorted_dists[:n] 162 | embs = word_embeddings[ixs] 163 | return tokens, ixs, dists, embs 164 | 165 | 166 | def model_emb(model, inputs_embeds, word_embeddings, output_len): 167 | 168 | embs = inputs_embeds 169 | logits = [] 170 | ixs = [] 171 | input_logits = None 172 | for i in range(output_len): 173 | model_out = model(inputs_embeds=embs, return_dict=True) 174 | 175 | if i == 0: 176 | input_logits = model_out.logits[:, :-1] 177 | 178 | last_logits = model_out.logits[:, -1].unsqueeze(1) 179 | logits.append(last_logits) 180 | ix = torch.argmax(last_logits, 181 | dim=-1) 182 | ixs.append(ix) 183 | output_embs = word_embeddings[ix] 184 | embs = torch.cat([embs, output_embs], 185 | dim=1) 186 | 187 | logits = torch.cat(logits, 188 | dim=1) 189 | perp = perplexity(torch.cat([input_logits, logits], dim=1)) 190 | return logits, embs, perp 191 | 192 | 193 | 194 | def perplexity(logits): 195 | probs, ix = torch.max(torch.softmax(logits, dim=-1), dim=-1) 196 | 197 | perp = 1 / (torch.prod(probs, dim=-1) ** (1 / probs.shape[-1])) - 1 198 | return perp 199 | 200 | --------------------------------------------------------------------------------