├── J_Backwards_1_0.ipynb
├── backwards.py
├── generate.py
└── helpers.py
/J_Backwards_1_0.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {
17 | "id": "ar1aNn7TxI8r",
18 | "colab": {
19 | "base_uri": "https://localhost:8080/"
20 | },
21 | "outputId": "86d6a73c-7a75-42e6-ac3c-bd4f5fea795e"
22 | },
23 | "outputs": [
24 | {
25 | "output_type": "stream",
26 | "name": "stdout",
27 | "text": [
28 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
29 | "Requirement already satisfied: update in /usr/local/lib/python3.8/dist-packages (0.0.1)\n",
30 | "Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (4.25.1)\n",
31 | "Requirement already satisfied: style==1.1.0 in /usr/local/lib/python3.8/dist-packages (from update) (1.1.0)\n",
32 | "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers) (4.64.1)\n",
33 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2)\n",
34 | "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (6.0)\n",
35 | "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.8.2)\n",
36 | "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (21.3)\n",
37 | "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (1.21.6)\n",
38 | "Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.11.1)\n",
39 | "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.13.2)\n",
40 | "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers) (2.25.1)\n",
41 | "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.10.0->transformers) (4.4.0)\n",
42 | "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->transformers) (3.0.9)\n",
43 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (1.24.3)\n",
44 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2.10)\n",
45 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2022.12.7)\n",
46 | "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (4.0.0)\n"
47 | ]
48 | }
49 | ],
50 | "source": [
51 | "!pip install update transformers\n",
52 | "from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils\n",
53 | "import torch\n",
54 | "from matplotlib import pyplot as plt\n",
55 | "%matplotlib inline\n",
56 | "from IPython import display\n",
57 | "import numpy as np\n",
58 | "from tqdm import tqdm\n",
59 | "utils.logging.set_verbosity_error()\n",
60 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
61 | "\n",
62 | "vocab_len= 50257\n",
63 | "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\", padding_side='left')\n",
64 | "model = GPT2LMHeadModel.from_pretrained(\"gpt2\",pad_token_id=tokenizer.eos_token_id, vocab_size=vocab_len).to(device)\n",
65 | "model.eval()\n",
66 | "# the model will be in evaluation, not training, mode throughout\n",
67 | "word_embeddings = model.transformer.wte.weight.to(device) \n",
68 | "# 'word_embeddings' tensor gives emeddings for each token in the vocab for this model,\n",
69 | "# has shape (vocab_len, embedding_dimension) which in this case = (50257, 768)"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {
76 | "id": "qCgMUuO_33Wp"
77 | },
78 | "outputs": [],
79 | "source": [
80 | "def normalise(x, min_max=[]): \n",
81 | "# normalises values of (array or tensor) x according to first (min) and second (max) values in list min_max. \n",
82 | "# This effectively defaults to [0,1] if the list doesn't contain exactly two elements. \n",
83 | "# The original code threw an error if min_max had length 1, so it's been changed slightly.\n",
84 | "\n",
85 | "# First normalise x to [0,1]\n",
86 | " rnge = x.max() - x.min()\n",
87 | " if rnge > 0:\n",
88 | " x = (x - x.min())/rnge\n",
89 | "\n",
90 | "# Now, if there's a min and max given in min_max list, multiply by difference and add minimum\n",
91 | " if len(min_max) > 1:\n",
92 | " rnge = min_max[1] - min_max[0]\n",
93 | " x = x * rnge + min_max[0]\n",
94 | "\n",
95 | " return x\n",
96 | "\n",
97 | "\n",
98 | "def closest_tokens(emb, n=1): \n",
99 | "# This finds the n tokens in the vocabulary that are closest in the embedding space (in terms of Euclidean distance) to a given word embedding (‘emb’).\n",
100 | "# Note that here 'emb' may or may not correspond to a token (i.e., it may or may not be a 'legal' embedding).\n",
101 | "# Function returns a 4-tuple (list of the n tokens, list of their indices, list of their distances from emb, and list of their embedding vectors)\n",
102 | " torch.cuda.empty_cache()\n",
103 | " dists = torch.linalg.norm(word_embeddings - emb, dim=1)\n",
104 | " sorted_dists, ix = torch.sort(dists)\t \n",
105 | " # sorted_dists is a list of all embedding distances from 'emb', across entire vocab, sorted in increasing order, \n",
106 | " # ix is a list of their corresponding 'vocab indices'\n",
107 | " tokens = [tokenizer.decode(i) for i in ix[:n]]\n",
108 | " # For each of the first n 'vocab indices' in ix, we decode it into the string version of the corresponding token. \n",
109 | " # These strings then constitute the list 'tokens'.\n",
110 | " ixs = ix[:n]\n",
111 | " dists = sorted_dists[:n]\n",
112 | " embs = word_embeddings[ixs] # Each of these n 'embeddings' is a tensor of shape (768,)\n",
113 | " return tokens, ixs, dists, embs \n",
114 | "\n",
115 | "\n",
116 | "def model_emb(inputs_embeds, output_len):\n",
117 | "# 'input_embeds' is a tensor of shape (batch_size, input_len, embedding_dim)\n",
118 | "# 'output_len' is an integer specifying the number of output tokens to generate\n",
119 | "# Note that this function doesn't involve a target output. It simply takes a tensor of input embeddings (based on input length),\n",
120 | "# calculates perplexities for that batch of input sequences,\n",
121 | "# and runs the batch of input sequences through GPT2, for each finding next tokens iteratively 'output_len' number of times\n",
122 | " embs = inputs_embeds # This is going to get expanded using 'output_embs'\n",
123 | " logits = []\n",
124 | " ixs = []\n",
125 | " input_logits = None\n",
126 | " for i in range(output_len):\n",
127 | " model_out = model(inputs_embeds=embs, return_dict=True)\n",
128 | " # Does a forward pass of GPT2 (or whichever model) on a batch of inputs (given as a tensor 'embs' of embeddings).\n",
129 | " # This 'embs' will expand along its 1st dimension with each iteration.\n",
130 | " # Outputs logits and more (hidden states, attention, etc.) as a dictionary 'model_out'.\n",
131 | " # But we'll only be concerned with model_out.logits.\n",
132 | "\n",
133 | " if i == 0:\n",
134 | " input_logits = model_out.logits \n",
135 | " # On first pass through loop, we simply use the logits of the model output\n",
136 | " # That's a tensor of shape (batch_size, input_len, vocab_size) giving logits for each input in each batch.\n",
137 | " # Presumably for each input, this is conditioned on the inputs that preceded it?\n",
138 | "\n",
139 | " # On every pass throught the loop (including the first), we defined this tensor of shape (batch_size, 1, vocab_size):\n",
140 | " last_logits = model_out.logits[:,-1].unsqueeze(1) \n",
141 | " # model_out.logits[:,-1] will be a 2D tensor of shape (batch_size, vocab_size), just giving logits for last input/embedding across all batches/tokens\n",
142 | " # unsqueezing, we get tensor of shape (batch_size, 1, vocab_size) also giving logits of last input/embedding, differently formatted \n",
143 | " logits.append(last_logits) # appends last_logits tensor to the 'logits' list \n",
144 | " ix = torch.argmax(last_logits, dim=-1) # for each batch, finds the vocab index of the token with the largest logit in last_logits\n",
145 | " ixs.append(ix) # ...and appends this tensor of shape (batch_size,) (containing indices) it to the list 'ixs'\n",
146 | " output_embs = word_embeddings[ix] # for each batch, finds embedding for the token with that index...\n",
147 | " embs = torch.cat([embs, output_embs], dim=1) #...concatenates that tensor of embeddings to the 'embs' tensor in the first dimension before next iteration\n",
148 | "\n",
149 | " # When the loop is completed 'embs' will be a tensor containing all of the input and output word embeddings produced by the model \n",
150 | " # ...so presumably of shape (batch_size, input_len + output_len, embedding_dim)\n",
151 | "\n",
152 | " logits = torch.cat(logits, dim=1) # this converts logits from a list of tensors to a single tensor, by concatenating all of the tensors in the list\n",
153 | " # it will have shape (batch_size, output_len, vocab_size)\n",
154 | " perp = perplexity(input_logits) # 'input_logits' was calculated on first pass through loop where only input embeddings were involved\n",
155 | " return logits, embs, perp \n",
156 | " # logits has shape (batch_size, output_len, vocab_size), CHECK THAT!\n",
157 | " # embs has shape (batch_size, input_len + output_len, embedding_dim)\n",
158 | " # perp has shape (batch_size,)\n",
159 | "\n",
160 | "\n",
161 | "def perplexity(logits):\n",
162 | " # logits is of shape (batch_size, 'sequence length', vocab_size)\n",
163 | " # for all current calls, 'sequence length' is going to be input_len\n",
164 | " probs, ix = torch.max(torch.softmax(logits, dim=-1), dim=-1)\n",
165 | " # torch.softmax(logits, dim=-1) will also be a tensor of shape (batch_size, 'sequence length', vocab_size), \n",
166 | " # but where the logits in the last dimension get converted into probabilities via softmax. torch.max() then pull out the largest of these and its index\n",
167 | " # probs is a tensor that contains the maximum probability for each token in the embedding sequence, shape (batch_size, 'sequence length')\n",
168 | " # ix is a tensor that contains the corresponding indices, also with shape (batch_size, 'sequence length')\n",
169 | " perp = 1/ (torch.prod(probs, dim=-1)**(1/probs.shape[-1])) - 1\n",
170 | " # defines a scalar that's larger with greater uncertainty (so if the probs are small, their product is small, the reciprocal of some power is large)\n",
171 | " # probs.shape[-1] is output_len; the idea of raising the probs product to power 1/output_len is to make perplexities comparable across different output lengths\n",
172 | " return perp\n",
173 | "\n",
174 | "\n",
175 | "# Here's the key function that optimises for a sequence of input embeddings, given a target_output string:\n",
176 | "def optimise_input(epochs=100, \n",
177 | " lr=0.1, \n",
178 | " rand_after=False, # Do we re-initialise inputs tensor with random entries when an optimal input is found?\n",
179 | " w_freq=10, # logging (write) frequency\n",
180 | " base_input=None, # If none, start_inputs will be entirely random; \n",
181 | " # otherwise it will be built by stacking this tensor and then gently \"noising\" all but the first copies\n",
182 | " batch_size=1, \n",
183 | " input_len=1, \n",
184 | " target_output=tokenizer.eos_token, # Default target output is the \"end-of-string\" token; this won't generally be used\n",
185 | " output_len=None,\n",
186 | " dist_reg=1, # distance regularisation coefficient\n",
187 | " perp_reg=0, # perplexity regularisation coefficient; setting to 0 means perplexity loss isn't a thing\n",
188 | " plt_loss=False, # Do we plot loss?\n",
189 | " loss_type='log_prob_loss', \n",
190 | " seed=0,\n",
191 | " return_early=True, # finishes if single optimised input is found\n",
192 | " verbose=0, # Controls how much info gets logged.\n",
193 | " lr_decay=False, # Use learning rate decay? If so, a scheduler gets invoked.\n",
194 | " noise_coeff = 0.01): # Introduced for generality in the construction of start_input[1:] below.\n",
195 | " torch.manual_seed(seed) # sets up PyTorch random number generator\n",
196 | "\n",
197 | " if plt_loss:\n",
198 | " plt.rcParams.update({'figure.figsize': (40,6)})\n",
199 | "\n",
200 | " total_losses = []\n",
201 | " losses = []\n",
202 | " dists = []\n",
203 | " perps = []\n",
204 | " optimised_inputs = set()\n",
205 | " done = None\n",
206 | "\n",
207 | " output_ix = tokenizer.encode(target_output, return_tensors='pt')[0].to(device)\n",
208 | " # output_ix is a 1-D tensor of shape (output_len,) that contains the indices of the tokens in the encoding of the string 'target_output'\n",
209 | " # tokenizer.encode(target_output, return_tensors='pt') is a list containing this one tensor, hence the need for the [0]\n",
210 | " # \"return_tensors='pt'\" ensures that we get a tensor in PyTorch format\n",
211 | "\n",
212 | " if output_len == None or output_len < output_ix.shape[0]: # This won't generally be the case, but if we don't specify output_len (i.e. it's == None), then...\n",
213 | " output_len = output_ix.shape[0] # ...it will be set to the number of tokens in the encoding of the string 'target_output'\n",
214 | " # Why not just set output_len = output_ix.shape[0] in all cases?\n",
215 | " # Will there be situations where we want output_len to be of a different size to the number of tokens in target_output?\n",
216 | "\n",
217 | " print('Optimising input of length {} to maximise output logits for \"{}\"'.format(input_len, target_output))\n",
218 | " # Typically this would print something like 'Optimising input of length 6 to maximise output logits for \"KILL ALL HUMANS!\"'.\n",
219 | "\n",
220 | " if base_input == None:\n",
221 | " start_input = torch.rand(batch_size, input_len, word_embeddings.shape[-1]).to(device)\n",
222 | " # If no base_input is provided, we construct start_input as a random tensor \n",
223 | " # of shape (batch_size, input_len, embedding_dim) (embedding_dim = 768 for this GPT-2 model).\n",
224 | " start_input = normalise(start_input,[word_embeddings.min(dim=0)[0], word_embeddings.max(dim=0)[0]])\n",
225 | " # We normalise this random tensor so that its minimum and maximum values correspond to those in the entire word_embeddings tensor\n",
226 | " # This dispenses with whole swathes of \"input space\" which contain no legal token embeddings \n",
227 | " # (we're limiting ourselves to a kind of \"hull\" defined by the 50527 vocab tokens in the embedding space), \n",
228 | " # which is a sensible place to look for optimised inputs.\n",
229 | " else:\n",
230 | " start_input = base_input.repeat(batch_size, 1, 1)\n",
231 | " # If a base_input was given, it should be of shape (input_len, embedding_dim), \n",
232 | " # and we build the start_input tensor by stacking 'batch_size' number of copies of this together...\n",
233 | "\n",
234 | " if batch_size > 1:\n",
235 | " start_input[1:] += (torch.rand_like(start_input[1:]) + torch.full_like(start_input[1:], -0.5)) * noise_coeff\n",
236 | " #...and if we have more than one element in our batch, we \"noise\" the rest. \n",
237 | " # This was originally done using \"*=\" (multiplying entries by small random numbers)\n",
238 | " # We've changed this to \"+=\" (adding small random numbers instead of multiplying by them).\n",
239 | " # The original code would have pushed everything in a positive direction, hence the use of a tensor full of -0.5's. \n",
240 | "\n",
241 | " \n",
242 | " input = torch.nn.Parameter(start_input, requires_grad=True)\n",
243 | " # input is not a tensor, it's a Parameter object that wraps a tensor and adds additional functionality. \n",
244 | " # 'input.data' is used below\n",
245 | " \n",
246 | " optimiser = torch.optim.Adam([input], lr=lr)\n",
247 | " # standard optimiser; note that it generally operates on a list of tensors, so we're giving it a list of one tensor; standard learning rate\n",
248 | " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', patience=20, cooldown=20, factor=0.5)\n",
249 | " # this is used when loss hasn't improved for 20 timesteps; this scheduler will reduce the lr by a 'factor' of 0.5 when the \n",
250 | " # validation loss stops improving for 'patience' (here 20) epochs, and will wait 'cooldown' (here 20) epochs before resuming normal operation.\n",
251 | "\n",
252 | " # now we loop across training epochs\n",
253 | " for e in range(epochs):\n",
254 | "\n",
255 | " logits, emb, perp = model_emb(torch.clamp(input, word_embeddings.min(), word_embeddings.max()), output_len)\n",
256 | " # Does forward pass on a 'clamped' version of the 'input' tensor (done to contain it within the 'hull' of the vocabulary within 'input space').\n",
257 | " # Iterates to produce an output of output_len tokens, \n",
258 | " # returns: 'logits' = tensor of logits for output, of shape (batch_size, output_len, vocab_size)\n",
259 | " # 'emb': tensor of embeddings for input+output of shape (batch_size, input_len + output_len, embedding_dim); \n",
260 | " # 'perp': the input sequence perplexities tensor, of shape (batch_size,)\n",
261 | " probs = torch.softmax(logits, dim=-1)\n",
262 | " # For each batch, output, converts the sequence of logits (of length 'vocab_size') in the 'logits' tensor to probabilities, using softmax\n",
263 | "\n",
264 | " logits = (logits - logits.min(dim=-1)[0].unsqueeze(-1)) / (logits.max(dim=-1)[0].unsqueeze(-1) - logits.min(dim=-1)[0].unsqueeze(-1))\n",
265 | " # This appears to be normalising the logits for each batch/output embedding so they're all between 0 and 1... \n",
266 | " # This is for ease of visualisation.\n",
267 | "\n",
268 | " perp_loss = perp.mean() * perp_reg\n",
269 | " # That's taking the mean perp value across all batches, then regularising it. Currently perp_reg is set to 0, so perp_loss = 0.\n",
270 | "\n",
271 | " if output_len > output_ix.shape[0]:\n",
272 | " target_logits = torch.stack([logits[:, :, ix] for ix in output_ix], dim=-1)\n",
273 | " target_logits = torch.max(target_logits, dim=-1)[0]\n",
274 | " # logits is shape (batch_size, output_len, vocab_size) \n",
275 | " # We throw out everything in the final dimension except those logits corresponding to indices of tokens in the target_ouput\n",
276 | " # This gives tensor with shape (batch_size, output_len, output_ix.shape[0])\n",
277 | " # We then take the maximum of those for each batch, output; this gives shape (batch_size, output_len)\n",
278 | " # The [0] returns just the max (torch.max returns max, indices tuple)\n",
279 | " target_probs = torch.stack([probs[:, :, ix] for ix in output_ix], dim=-1)\n",
280 | " target_probs = torch.max(target_probs, dim=-1)[0]\n",
281 | " # This does the analogous thing for probs.\n",
282 | "\n",
283 | " else:\n",
284 | " target_logits = torch.stack([logits[:,i, ix] for i, ix in enumerate(output_ix)], dim=-1)\n",
285 | " target_probs = torch.stack([probs[:,i, ix] for i, ix in enumerate(output_ix)], dim=-1)\n",
286 | " # This handles case where output_len == output_ix.shape[0]\n",
287 | " # target_logits now of shape (batch_size, output_len)\n",
288 | " # output_len < output_ix.shape[0] was dealt with in line 133\n",
289 | " \n",
290 | " token_dist = torch.stack([torch.stack([closest_tokens(e)[2].squeeze(-1) for e in input[b]]) for b in range(batch_size)])\n",
291 | " # As far as I can tell, this creates a tensor of shape (batch_size, input_len, 1) which gives distance to nearest\n",
292 | " # legal token embedding for each input embedding in each batch\n",
293 | " mean_token_dist = token_dist.mean() * dist_reg\n",
294 | " # A single scalar value, taking mean across the batch and input embeddings? \n",
295 | "\n",
296 | "\n",
297 | " # There are currently four loss types, many more could be introduced.\n",
298 | " # log_prob_loss is the current default.\n",
299 | " if loss_type == 'logit_loss':\n",
300 | " loss = torch.mean(1-target_logits)\n",
301 | " elif loss_type == 'log_prob_loss':\n",
302 | " loss = -torch.log(target_probs).mean()\n",
303 | " elif loss_type == 'prob_loss':\n",
304 | " loss = 1-torch.mean(target_probs)\n",
305 | " elif loss_type == 'CE':\n",
306 | " loss = torch.nn.functional.cross_entropy(logits.swapaxes(-1,-2), output_ix.repeat(batch_size, 1))\n",
307 | "\n",
308 | " else:\n",
309 | " print(loss_type + 'is not implemented.')\n",
310 | " return\n",
311 | "\n",
312 | " total_loss = torch.stack([mean_token_dist, loss, perp_loss]).mean()\n",
313 | " # This is this just (mean_token_dist + loss + perp_loss)/3 tensorised across batches, yes?\n",
314 | "\n",
315 | " total_losses.append(total_loss.detach().cpu().data)\n",
316 | " losses.append(loss.detach().cpu().data)\n",
317 | " dists.append(mean_token_dist.detach().cpu().data)\n",
318 | " perps.append(perp_loss.detach().cpu().data)\n",
319 | " # these four lists were intialised above. We're appeneding to the list each epoch. All are scalars.\n",
320 | "\n",
321 | " closest_ix = torch.stack([torch.stack([closest_tokens(e)[1] for e in b]) for b in input]).squeeze(-1)\n",
322 | " # This is similar to above, but building a tensor of indices of nearest embeddings, rather than distances.\n",
323 | " # Iterates over batches, and for each batch iterates over embeddings, giving tensor of shape (batch_size, input_len).\n",
324 | "\n",
325 | " model_outs = model.generate(closest_ix, max_length = output_len+input_len)\n",
326 | " # The 'closest_ix' tensor is passed as the initial input sequence to the model, \n",
327 | " # and the max_length parameter specifies the maximum length of the total sequence to generate.\n",
328 | " # The output sequence will be terminated either when the end-of-sequence token is generated \n",
329 | " # or when the maximum length is reached, whichever occurs first.\n",
330 | " # \n",
331 | " # The output of the model.generate method will be a tuple containing the generated sequences and the model's internal states. \n",
332 | " # The generated sequences will be stored in a tensor of shape (batch_size, output_len+input_len). \n",
333 | " # Each element of the tensor will be a sequence of tokens with a length of at most output_len+input_len.\n",
334 | " \n",
335 | " for b in range(batch_size):\n",
336 | " # iterate over batches \n",
337 | " if output_len > output_ix.shape[0]:\n",
338 | " if target_output in tokenizer.decode(model_outs[b][input_len:]):\n",
339 | " done = tokenizer.decode(model_outs[b][:input_len])\n",
340 | " optimised_inputs.add(done)\n",
341 | " # model_outs[b][input_len:], for a batch b, is only looking at the *output* embeddings \n",
342 | " # we decode these as tokens... is the target_output a substring?\n",
343 | " # if so, we print the target_output and the decoded string that contains it\n",
344 | " # 'done' is the string version of the model's output for given input, we add this to set 'optimised_inputs'.\n",
345 | "\n",
346 | " if rand_after:\n",
347 | " input.data[b] = torch.rand_like(input[b])\n",
348 | " # This will require new normalisation function.\n",
349 | " # The idea here seems to be randomly re-initialise the input tensor once we've found an optimised input,\n",
350 | " # input.data is the tensor version of the 'input' Parameter object. Current values, without gradient!\n",
351 | " # That's of shape (batch_size, input_len, embedding_dim)\n",
352 | "\n",
353 | " if tokenizer.decode(model_outs[b][input_len:]) == target_output:\n",
354 | " done = tokenizer.decode(model_outs[b][:input_len])\n",
355 | " optimised_inputs.add(done)\n",
356 | " # model_outs[b][input_len:], for a batch b, is only looking at the *output* embeddings \n",
357 | " # we decode these as tokens... is the target_output equal to output string?\n",
358 | " # Nothing printed in this case.\n",
359 | " # 'done' is the string version of the model's output for given input, we add this to set 'optimised_inputs'.\n",
360 | " if rand_after:\n",
361 | " input.data[b] = torch.rand_like(input[b])\n",
362 | " # Random re-initialisation (if 'rand_after' set to True)\n",
363 | "\n",
364 | " \n",
365 | " if ((e+1) % w_freq == 0) or done and return_early:\n",
366 | " display.clear_output(wait=True) \n",
367 | " # Every w epochs we write to log, unless we have found an optimised input before that and 'return_early' == True. \n",
368 | " # I'm still not entirely sure about the idea of 'return_early'.\n",
369 | "\n",
370 | " if plt_loss:\n",
371 | " plt.plot(range(len(total_losses)), total_losses, label='Total Loss', color='black')\n",
372 | " plt.plot(range(len(losses)), losses, label='Output Loss')\n",
373 | " plt.plot(range(len(dists)), dists, label='Emb Dist Loss')\n",
374 | " plt.plot(range(len(perps)), perps, label='Perp Loss')\n",
375 | " plt.yscale('log')\n",
376 | " plt.legend()\n",
377 | "\n",
378 | " plt.show()\n",
379 | "\n",
380 | " print('Inputs found: ', optimised_inputs)\n",
381 | " print('{}/{} Output Loss: {} Emb Dist Loss: {} Perp Loss: {} LR: {}'.format(e+1, epochs, loss, mean_token_dist, perp_loss, optimiser.param_groups[0]['lr']))\n",
382 | " if verbose == 3:\n",
383 | " print('Target Probs: {}\\nTarget Logits: {}\\nInput Dists: {}\\nInput Perplexity: {}\\n'.format(target_probs.detach().cpu().numpy(), target_logits.detach().cpu().numpy(), token_dist.detach().cpu().numpy(), perp.detach().reshape(-1).cpu().numpy()))\n",
384 | " # Optimised inputs and additional information are printed as part of log\n",
385 | "\n",
386 | " for b in range(batch_size):\n",
387 | " if verbose > 0:\n",
388 | " if verbose == 2:\n",
389 | " print(b, repr(' Raw embeddings: {}'.format(''.join([closest_tokens(e)[0][0] for e in emb[b]]))))\n",
390 | " # Change name to clarify (output of model if we just put in raw embeddings)\n",
391 | " # prints batch number; closest_tokens(e)[0] is a list of tokens, closest_tokens(e)[0] is the first (closest) of these\n",
392 | " # these get joined with separator '' (SHOULDN'T THAT BE ' '?) \n",
393 | " print(b, repr(' Closest embeddings: {}'.format(tokenizer.decode(model_outs[b]), '\\n')))\n",
394 | " # WON'T THIS give string decodings of the embeddings, rather than the embeddings themselves?\n",
395 | " else:\n",
396 | " print(repr(tokenizer.decode(model_outs[b])), end=' ')\n",
397 | " # The least verbose printed output. The 'end' parameter is used to specify the end-of-line string that is appended to the output. \n",
398 | " # By default, this is a newline character, but in this case it has been set to a single space character, \n",
399 | " # so the output will be separated by spaces rather than newlines.\n",
400 | "\n",
401 | " if done and return_early:\n",
402 | " print('\\nOptimised Input: \"{}\"'.format(done))\n",
403 | " return optimised_inputs\n",
404 | " # we know optimised_inputs set contains a single element in this case\n",
405 | " \n",
406 | " optimiser.zero_grad()\n",
407 | " total_loss.backward()\n",
408 | " optimiser.step()\n",
409 | " # I assume these three lines are standard NN optimisation stuff?\n",
410 | "\n",
411 | " if lr_decay:\n",
412 | " scheduler.step(total_loss)\n",
413 | " # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', patience=20, cooldown=20, factor=0.5) gets used if lr_decay == True\n",
414 | " \n",
415 | " return optimised_inputs\n",
416 | " # that's a set of strings\n",
417 | "\n"
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": null,
423 | "metadata": {
424 | "colab": {
425 | "base_uri": "https://localhost:8080/"
426 | },
427 | "id": "tev0cDbkdbxO",
428 | "outputId": "45b23d11-a2b5-4e46-d592-5f8f0076390b"
429 | },
430 | "outputs": [
431 | {
432 | "output_type": "stream",
433 | "name": "stdout",
434 | "text": [
435 | "[767, 767, 767]\n",
436 | "[' 7', ' 7', ' 7']\n",
437 | "3\n",
438 | " 7 7 7 7 7 7 7\n"
439 | ]
440 | }
441 | ],
442 | "source": [
443 | "ix = tokenizer.encode(\" 7 7 7\")\n",
444 | "# list of 'vocab indices'\n",
445 | "print(ix)\n",
446 | "print([tokenizer.decode(i) for i in ix])\n",
447 | "# prints reconstruction of input string\n",
448 | "print(len(ix))\n",
449 | "# prints number of tokens\n",
450 | "output_len=4\n",
451 | "model_out = model.generate(torch.tensor(ix).unsqueeze(0).to(device), max_length = output_len + len(ix))\n",
452 | "print(tokenizer.decode(model_out[0]))\n",
453 | "# pushes input string throught GPT2 (or whichever model) iteratively producing output_len number of tokens, then prints input + output."
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "source": [
459 | "from time import time\n",
460 | "target_output = \" the best player?\"\n",
461 | "input_len = 2\n",
462 | "\n",
463 | "base_input = word_embeddings[tokenizer.encode(target_output)].mean(dim=0)\n",
464 | "base_input = base_input.repeat(1, input_len, 1)\n",
465 | "\n",
466 | "tic = time()\n",
467 | "oi = optimise_input(base_input=base_input, \n",
468 | " plt_loss=False,\n",
469 | " verbose=2, \n",
470 | " epochs=500, \n",
471 | " lr_decay=False,\n",
472 | " return_early=False, \n",
473 | " lr=0.1, \n",
474 | " batch_size=20, \n",
475 | " target_output=target_output, \n",
476 | " output_len=4,\n",
477 | " input_len=input_len, \n",
478 | " w_freq=20, \n",
479 | " dist_reg=1, \n",
480 | " perp_reg=0,\n",
481 | " loss_type='log_prob_loss',\n",
482 | " noise_coeff = 0.75)\n",
483 | "toc = time()\n",
484 | "tt = toc - tic\n",
485 | "print('Time Taken: ', tt)\n"
486 | ],
487 | "metadata": {
488 | "colab": {
489 | "base_uri": "https://localhost:8080/",
490 | "height": 1000
491 | },
492 | "id": "81r4gT4m2NQ1",
493 | "outputId": "3ebaad31-0828-43b3-fb29-5c8c9d0c39ae"
494 | },
495 | "execution_count": null,
496 | "outputs": [
497 | {
498 | "output_type": "stream",
499 | "name": "stdout",
500 | "text": [
501 | "Inputs found: set()\n",
502 | "300/500 Output Loss: 2.0041980743408203 Emb Dist Loss: 5.514309883117676 Perp Loss: 0.0 LR: 0.1\n",
503 | "0 ' Raw embeddings: the an the best,\\n'\n",
504 | "0 ' Closest embeddings: the ancients, and the'\n",
505 | "1 ' Raw embeddings: play bruised the best player?'\n",
506 | "1 ' Closest embeddings: play bruised and battered.\\n'\n",
507 | "2 ' Raw embeddings: 80<|endoftext|> the best best.'\n",
508 | "2 ' Closest embeddings: 80<|endoftext|>The U.S'\n",
509 | "3 ' Raw embeddings: Main those the best player?'\n",
510 | "3 ' Closest embeddings: Main those who have been in'\n",
511 | "4 ' Raw embeddings: ( Continuing the best player?'\n",
512 | "4 ' Closest embeddings: ( Continuing )\\n\\n('\n",
513 | "5 ' Raw embeddings: 1 objects the the the the'\n",
514 | "5 ' Closest embeddings: 1 objects.\\n\\nThe'\n",
515 | "6 ' Raw embeddings: her Berlin the best best.'\n",
516 | "6 ' Closest embeddings: her Berlin Wall.\\n\\n'\n",
517 | "7 ' Raw embeddings: Obviously jobs the best player?'\n",
518 | "7 ' Closest embeddings: Obviously jobs are not just for'\n",
519 | "8 ' Raw embeddings: bestAmid the best team.'\n",
520 | "8 ' Closest embeddings: bestAmid the chaos, the'\n",
521 | "9 ' Raw embeddings: Cleveland if the best player?'\n",
522 | "9 \" Closest embeddings: Cleveland if he's healthy.\"\n",
523 | "10 ' Raw embeddings: linedSummary the best player?'\n",
524 | "10 ' Closest embeddings: linedSummary = \"\\n\\n'\n",
525 | "11 ' Raw embeddings: blocking has the the..'\n",
526 | "11 ' Closest embeddings: blocking has been a problem for'\n",
527 | "12 \" Raw embeddings: sensational s's team.\\n\"\n",
528 | "12 ' Closest embeddings: sensational siren song.\\n'\n",
529 | "13 ' Raw embeddings: glance family the best player?'\n",
530 | "13 ' Closest embeddings: glance family.\\n\\n\"'\n",
531 | "14 ' Raw embeddings: using game the best player?'\n",
532 | "14 ' Closest embeddings: using game-play.\\n'\n",
533 | "15 ' Raw embeddings: clothes( the best player?'\n",
534 | "15 ' Closest embeddings: clothes(s) of the'\n",
535 | "16 ' Raw embeddings: high at the best player.'\n",
536 | "16 ' Closest embeddings: high at the time.\\n'\n",
537 | "17 ' Raw embeddings: wielded getting the best player?'\n",
538 | "17 ' Closest embeddings: wielded getting a new job.'\n",
539 | "18 ' Raw embeddings: way Spot the best player?'\n",
540 | "18 \" Closest embeddings: way Spotty's been doing\"\n",
541 | "19 ' Raw embeddings: predictable best fact fact fact\\n'\n",
542 | "19 ' Closest embeddings: predictable best-case scenario for'\n"
543 | ]
544 | },
545 | {
546 | "output_type": "error",
547 | "ename": "KeyboardInterrupt",
548 | "evalue": "ignored",
549 | "traceback": [
550 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
551 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
552 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mtic\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m oi = optimise_input(base_input=base_input, \n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mplt_loss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
553 | "\u001b[0;32m\u001b[0m in \u001b[0;36moptimise_input\u001b[0;34m(epochs, lr, rand_after, w_freq, base_input, batch_size, input_len, target_output, output_len, dist_reg, perp_reg, plt_loss, loss_type, seed, return_early, verbose, lr_decay, noise_coeff)\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 327\u001b[0m \u001b[0moptimiser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 328\u001b[0;31m \u001b[0mtotal_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 329\u001b[0m \u001b[0moptimiser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0;31m# I assume these three lines are standard NN optimisation stuff?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
554 | "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n",
555 | "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 197\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 198\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
556 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
557 | ]
558 | }
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "source": [
564 | " "
565 | ],
566 | "metadata": {
567 | "id": "2jZJMakszJ_U"
568 | },
569 | "execution_count": null,
570 | "outputs": []
571 | }
572 | ],
573 | "metadata": {
574 | "accelerator": "GPU",
575 | "colab": {
576 | "machine_shape": "hm",
577 | "provenance": [],
578 | "include_colab_link": true
579 | },
580 | "gpuClass": "premium",
581 | "kernelspec": {
582 | "display_name": "Python 3",
583 | "name": "python3"
584 | },
585 | "language_info": {
586 | "name": "python"
587 | }
588 | },
589 | "nbformat": 4,
590 | "nbformat_minor": 0
591 | }
--------------------------------------------------------------------------------
/backwards.py:
--------------------------------------------------------------------------------
1 | from helpers import *
2 |
3 | import torch
4 | from matplotlib import pyplot as plt
5 | from IPython import display
6 | import numpy as np
7 | import argparse
8 | import json
9 | import os
10 | from collections import Counter
11 | import random
12 |
13 | os.environ["WANDB_API_KEY"] = "YOUR_WANDB_API_KEY"
14 | os.environ["WANDB_SILENT"] = "true"
15 |
16 |
17 | def optimise_input(model,
18 | word_embeddings,
19 | tokenizer,
20 | device,
21 | epochs=100,
22 | lr=0.1,
23 | no_reinit=False,
24 | w_freq=10,
25 | rand_input=False,
26 | local_input=False,
27 | batch_size=20,
28 | input_len=10,
29 | target_output=' world',
30 | output_len=None,
31 | dist_reg=0.1,
32 | perp_reg=0,
33 | loss_type='log_prob_loss',
34 | seed=0,
35 | return_early=False, # finishes if single optimised input is found
36 | verbose=1,
37 | lr_decay=False, # Use learning rate decay? If so, a scheduler gets invoked.
38 | run_random=0,
39 | equal_clusters=False,
40 | penalise_repetition=False,
41 | optimiser='Adam',
42 | **kwargs):
43 |
44 | if run_random > 0:
45 | random_ix = (torch.rand(1) * word_embeddings.shape[0]).int()
46 | target_output = tokenizer.decode(random_ix) # Converts token index to string representation
47 | wandb.config.update({'target_output': target_output}, allow_val_change=True)
48 |
49 | print('Optimising input of length {} to maximise output logits for "{}"'.format(input_len, target_output))
50 | done = None
51 |
52 | output_ix = tokenizer.encode(target_output, return_tensors='pt')[0].to(device)
53 |
54 | word_embeddings = word_embeddings / torch.sqrt(torch.sum(word_embeddings**2, dim=-1, keepdim=True))
55 |
56 | optimised_inputs = set()
57 | optimised_tokens = []
58 | metrics_table = wandb.Table(columns=['Input', 'Output', 'Loss', 'Perplexity', 'Distance', 'Probs'])
59 |
60 | if output_len == None or output_len < output_ix.shape[
61 | 0]:
62 | output_len = output_ix.shape[
63 | 0]
64 | else:
65 | possible_target_positions = torch.stack(
66 | [torch.arange(0, output_ix.shape[0]) + i for i in range(output_len - output_ix.shape[0] + 1)])
67 |
68 | if rand_input == True:
69 | start_input = word_embeddings[torch.randperm(word_embeddings.shape[0])[:input_len * batch_size]].reshape(
70 | batch_size, input_len, -1)
71 | elif local_input == True:
72 | local_embs = closest_tokens(word_embeddings[output_ix].mean(dim=0), word_embeddings, tokenizer, n=batch_size)[-1].unsqueeze(1)
73 | start_input = local_embs.repeat(1, input_len, 1)
74 | else:
75 | num_clusters = batch_size * input_len
76 | _, centroids = kkmeans(word_embeddings.detach(), num_clusters, seed=seed,
77 | equal_clusters=equal_clusters)
78 | start_input = centroids.reshape(batch_size, input_len, -1)
79 |
80 | input = torch.nn.Parameter(start_input.to(device), requires_grad=True)
81 |
82 | if optimiser == 'Adam':
83 | optimiser = torch.optim.Adam([input], lr=lr, eps=0.0001)
84 | elif optimiser == 'SGD':
85 | optimiser = torch.optim.SGD([input], lr=lr)
86 | else:
87 | print('Unsupported optimiser: ', optimiser)
88 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', patience=20, cooldown=20, factor=0.5)
89 |
90 | for e in range(epochs):
91 | norm_input = input / torch.sqrt(torch.sum(input**2, dim=-1, keepdim=True))
92 | logits, emb, perp = model_emb(model, norm_input,
93 | word_embeddings, output_len)
94 |
95 | probs = torch.softmax(logits, dim=-1)
96 |
97 | perp_loss = perp.mean() # across all elements in the batch
98 |
99 | if output_len > output_ix.shape[0]:
100 | target_logits = logits[:, possible_target_positions, output_ix].max(dim=1)[0]
101 | target_probs = probs[:, possible_target_positions, output_ix].max(dim=1)[0]
102 |
103 | else:
104 | target_logits = logits[:, torch.arange(output_len), output_ix]
105 | target_probs = probs[:, torch.arange(output_len), output_ix]
106 |
107 | token_dist, closest_ix = [], []
108 | for b in norm_input:
109 | tds, cixs = [], []
110 | for be in b:
111 | _, cix, td, _ = closest_tokens(be, word_embeddings, tokenizer)
112 | tds.append(td)
113 | cixs.append(cix)
114 | token_dist.append(torch.stack(tds))
115 | closest_ix.append(torch.stack(cixs))
116 |
117 | token_dist, closest_ix = torch.stack(token_dist).squeeze(-1), torch.stack(closest_ix).squeeze(-1)
118 |
119 | mean_token_dist = token_dist.mean()
120 |
121 | if loss_type == 'log_prob_loss':
122 | loss = -torch.log(target_probs)
123 | elif loss_type == 'CE':
124 | if output_len > 1:
125 | print('CE not supported with output length > 1.')
126 | return
127 | loss = torch.nn.functional.cross_entropy(logits.swapaxes(-1, -2), output_ix.repeat(batch_size, 1),
128 | reduction='none')
129 | else:
130 | print(loss_type + 'is not implemented.')
131 | return
132 |
133 | batch_loss = loss.mean()
134 |
135 | total_loss = torch.stack([mean_token_dist * dist_reg, batch_loss, perp_loss * perp_reg]).mean()
136 |
137 | if penalise_repetition:
138 | rep_penalty = logits[:, :input_len, output_ix].sum()
139 | total_loss += rep_penalty
140 | else:
141 | rep_penalty = 0
142 |
143 | model_outs = model.generate(closest_ix, max_length=output_len + input_len)
144 |
145 | for b in range(batch_size):
146 | if target_output in tokenizer.decode(model_outs[b][input_len:]):
147 | if tokenizer.decode(model_outs[b]) not in optimised_inputs:
148 | optimised_tokens += [tokenizer.decode(t) for t in model_outs[b][:input_len]]
149 |
150 | counts = Counter(optimised_tokens)
151 | labels, values = zip(*counts.items())
152 |
153 | data = [[label, val] for (label, val) in zip(labels, values)]
154 | table = wandb.Table(data=data, columns=["Token", "Count"])
155 | wandb.log({"token_freqs": wandb.plot.bar(table, "Token",
156 | "Count", title="Token Freqs")})
157 |
158 | done = tokenizer.decode(model_outs[b])
159 | optimised_inputs.add(done)
160 | metrics_table.add_data(*[tokenizer.decode(model_outs[b][:input_len]),
161 | tokenizer.decode(model_outs[b][input_len:])] + torch.stack(
162 | [loss.squeeze(-1)[b].mean(), perp[b], token_dist.mean(dim=1)[b]], dim=-1).tolist() + [target_probs[
163 | b].tolist()])
164 | wandb.log({'Optimised Inputs': wandb.Html(
165 | ''.join(['{}.{}
'.format(i, repr(s)) for i, s in enumerate(optimised_inputs)]))})
166 |
167 | if no_reinit == False:
168 | if rand_input == True or local_input == True:
169 | input.data[b] = word_embeddings[torch.randperm(word_embeddings.shape[0])[:input_len]].reshape(1,
170 | input_len,
171 | -1).to(
172 | device)
173 | else:
174 | rand_centroids = centroids[np.random.randint(0, batch_size, size=input_len)].unsqueeze(0)
175 | input.data[b] = rand_centroids
176 |
177 | if ((e + 1) % w_freq == 0) or done and return_early:
178 |
179 | print("Optimised Inputs:", optimised_inputs)
180 | print('{}/{} Output Loss: {} Emb Dist Loss: {} Perp Loss: {} LR: {}'.format(e + 1, epochs, batch_loss,
181 | mean_token_dist, perp_loss,
182 | optimiser.param_groups[0][
183 | 'lr']))
184 | if verbose == 3:
185 | print('Target Probs: {}\nTarget Logits: {}\nInput Dists: {}\nInput Perplexity: {}\n'.format(
186 | target_probs.detach().cpu().numpy(), target_logits.detach().cpu().numpy(),
187 | token_dist.detach().cpu().numpy(), perp.detach().reshape(-1).cpu().numpy()))
188 |
189 | closest_embeddings = []
190 |
191 | for b in range(batch_size):
192 | if verbose > 0:
193 | if verbose == 2:
194 | print(b, repr(' Raw embeddings: {}'.format(''.join([closest_tokens(e)[0][0] for e in emb[b]]))))
195 |
196 | print(b, repr(' Closest embeddings: {}'.format(tokenizer.decode(model_outs[b]), '\n')))
197 | closest_embeddings.append(tokenizer.decode(model_outs[b]))
198 |
199 | wandb.log({'Closest Embeddings': wandb.Html(
200 | ''.join(['{}.{}
'.format(i, repr(ce)) for i, ce in enumerate(closest_embeddings)])),
201 | 'Total Loss': total_loss, 'Mean Token Distance': mean_token_dist, 'Mean Loss': batch_loss,
202 | 'Mean Perplexity Loss': perp_loss, 'Epoch': e, 'LR': optimiser.param_groups[0]['lr'],
203 | 'Num Inputs Found': len(optimised_inputs), 'Repetition Penalty': rep_penalty})
204 |
205 | if done and return_early:
206 | print('\nOptimised Input: "{}"'.format(done))
207 | return {'Metrics': metrics_table}
208 |
209 | optimiser.zero_grad()
210 | total_loss.backward()
211 | optimiser.step()
212 |
213 | if lr_decay:
214 | scheduler.step(total_loss)
215 | done = None
216 |
217 | return {'Metrics': metrics_table}
218 |
219 |
220 | if __name__ == '__main__':
221 | parser = argparse.ArgumentParser()
222 | parser.add_argument('--wandb_user', type=str, default='jessicamarycooper')
223 | parser.add_argument('--model_name', type=str, default='gpt2')
224 | parser.add_argument('--epochs', type=int, default=100)
225 | parser.add_argument('--lr', type=float, default=0.1)
226 | parser.add_argument('--no_reinit', action='store_true')
227 | parser.add_argument('--w_freq', type=int, default=10)
228 | parser.add_argument('--rand_input', action='store_true')
229 | parser.add_argument('--local_input', action='store_true')
230 | parser.add_argument('--batch_size', type=int, default=20)
231 | parser.add_argument('--input_len', type=int, default=10)
232 | parser.add_argument('--target_output', type=str, default=' world')
233 | parser.add_argument('--output_len', type=int)
234 | parser.add_argument('--dist_reg', type=float, default=0.1)
235 | parser.add_argument('--perp_reg', type=float, default=0)
236 | parser.add_argument('--loss_type', type=str, default='log_prob_loss')
237 | parser.add_argument('--seed', type=int, default=0)
238 | parser.add_argument('--return_early', action='store_true')
239 | parser.add_argument('--verbose', type=int, default=1)
240 | parser.add_argument('--lr_decay', action='store_true')
241 | parser.add_argument('--note', type=str, default='')
242 | parser.add_argument('--run_test_set', type=int, default=-1)
243 | parser.add_argument('--run_random', type=int, default=0)
244 | parser.add_argument('--optimiser', type=str, default='Adam')
245 | parser.add_argument('--equal_clusters', action='store_true')
246 | parser.add_argument('--penalise_repetition', action='store_true')
247 |
248 | args = parser.parse_args()
249 |
250 | test_sets = [
251 | [' externalToEVA', 'quickShip', ' TheNitrome', 'embedreportprint', 'rawdownload', 'reportprint', ' サーティ',
252 | ' RandomRedditor', 'oreAndOnline', 'InstoreAndOnline', ' externalTo', 'StreamerBot', 'ActionCode', 'Nitrome', ' SolidGoldMagikarp', 'PsyNetMessage'],
253 | [' girl', ' boy', 'good', ' evil', ' science', ' art', ' England', ' USA'],
254 | [' newcom', 'slaught', 'senal', 'imei']]
255 |
256 | torch.manual_seed(args.seed)
257 | random.seed(0)
258 | np.random.seed(0)
259 |
260 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
261 |
262 | print('Using {} device.'.format(args.device))
263 |
264 | args.model, args.word_embeddings, args.tokenizer = load_all(args.model_name, args.device)
265 |
266 | if args.run_test_set > -1:
267 | for to in test_sets[args.run_test_set]:
268 | args.target_output = to
269 | run = wandb.init(config=args, project='backwards', entity=args.wandb_user, reinit=True)
270 | results = optimise_input(**vars(args))
271 | wandb.log(results)
272 | run.finish()
273 |
274 | if args.run_random > 0:
275 |
276 | seeds = (torch.rand(args.run_random) * 60000).int()
277 | for r in range(args.run_random):
278 | args.seed = seeds[r]
279 | args.target_output = 'RANDOM'
280 | run = wandb.init(config=args, project='backwards', entity=args.wandb_user, reinit=True)
281 | results = optimise_input(**vars(args))
282 | wandb.log(results)
283 | run.finish()
284 |
285 | if args.run_test_set == -1 and args.run_random == 0:
286 | run = wandb.init(config=args, project='backwards', entity=args.wandb_user, reinit=True)
287 | results = optimise_input(**vars(args))
288 | wandb.log(results)
289 | run.finish()
290 |
291 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from helpers import *
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('-m', '--model', type=str, default='gpt2')
7 | parser.add_argument('-i', '--input', type=str)
8 | parser.add_argument('-o', '--output_length', type=int)
9 | args = parser.parse_args()
10 |
11 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
12 |
13 | model, embeddings, tokenizer = load_all(args.model, device)
14 |
15 | ix = tokenizer.encode(args.input)
16 |
17 | print('{} input tokens: {}'.format(len(ix), [tokenizer.decode(i) for i in ix]))
18 |
19 | model_out = model.generate(torch.tensor(ix).unsqueeze(0).to(device), max_length = args.output_length + len(ix))
20 |
21 | print('\nOutput:\n{}'.format(tokenizer.decode(model_out[0])))
22 |
--------------------------------------------------------------------------------
/helpers.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import sys
3 | import os
4 | import torch
5 |
6 |
7 | def install(package):
8 | subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--upgrade"])
9 |
10 | try:
11 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils, AutoTokenizer, GPTJForCausalLM
12 | except:
13 | install('transformers')
14 | install('accelerate')
15 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils, AutoTokenizer, GPTJForCausalLM
16 |
17 | try:
18 | import wandb
19 | except:
20 | install('wandb')
21 | import wandb
22 |
23 | utils.logging.set_verbosity_error()
24 |
25 |
26 | def load_all(model_name="gpt2", device='cpu', save_dir=''):
27 | print(save_dir)
28 | if save_dir == '':
29 | cur_dir = os.listdir()
30 | else:
31 | cur_dir = os.listdir(save_dir)
32 |
33 | if model_name + '_tokenizer' in cur_dir:
34 | print('Loading tokenizer...')
35 | tokenizer = torch.load(save_dir + model_name + '_tokenizer')
36 | else:
37 | print('Downloading tokenizer...')
38 | if 'gpt-j' in model_name:
39 | print('WARNING: I haven\'t tuned hyperparameters for gpt-j. Don\'t expect it to work very well on defaults!')
40 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
41 | else:
42 | tokenizer = GPT2Tokenizer.from_pretrained(model_name, padding_side='left')
43 | torch.save(tokenizer, save_dir + model_name + '_tokenizer')
44 | pad_token_id = tokenizer.eos_token_id
45 |
46 | if model_name + '_model' in cur_dir:
47 | print('Loading model...')
48 | model = torch.load(save_dir + model_name + '_model').to(device)
49 | else:
50 | print('Downloading model...')
51 |
52 | if 'gpt-j' in model_name:
53 | model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16",
54 | torch_dtype=torch.float16, low_cpu_mem_usage=True)
55 | else:
56 | model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)
57 | torch.save(model, save_dir + model_name + '_model')
58 | model.eval()
59 |
60 |
61 | embeddings = model.transformer.wte.weight.detach()
62 | if model_name + '_embeddings' not in cur_dir:
63 | torch.save(embeddings, save_dir + model_name + '_embeddings')
64 |
65 | return model.to(device), embeddings.to(device), tokenizer
66 |
67 |
68 | def kkmeans(embeddings, num_clusters, threshold=0.00001, max_iter=1000, seed=0, overwrite=False,
69 | save_dir='', equal_clusters=False, cluster_dim=-1):
70 |
71 | if cluster_dim != -1 and equal_clusters:
72 | print('WARNING! Equal clusters not supported for dimension clustering.')
73 | embeddings = embeddings.detach()
74 |
75 | centroid_fname = str(embeddings.shape[0]) + '_' + str(embeddings.shape[1]) + '_' + str(num_clusters) + '_' + str(
76 | equal_clusters) + '_' + str(seed) + ' dim' + str(cluster_dim) + '_centroids'
77 | cluster_fname = str(embeddings.shape[0]) + '_' + str(embeddings.shape[1]) + '_' + str(num_clusters) + '_' + str(
78 | equal_clusters) + '_' + str(seed) + ' dim' + str(cluster_dim) + '_cluster'
79 |
80 | if not overwrite:
81 | cur_dir = os.listdir()
82 | if centroid_fname in cur_dir:
83 | print('Loading clusters...')
84 | return torch.load(cluster_fname), torch.load(centroid_fname)
85 |
86 | print('Finding clusters...')
87 | if seed != -1:
88 | torch.manual_seed(seed)
89 | cluster_size = embeddings.shape[0] // num_clusters
90 | # initial centroids is a set of random token embeddings (one for each cluster)
91 | centroids = embeddings[torch.randperm(embeddings.shape[0])[:num_clusters]]
92 |
93 | movement = 9999 # this will be used in each iteration step as mean centroid movement distance
94 | i = 0
95 |
96 | while movement > threshold and i < max_iter:
97 | i += 1
98 |
99 | print(embeddings.shape, centroids.shape)
100 | if cluster_dim > -1:
101 | distances = 1 - (embeddings[:, cluster_dim] @ centroids[cluster_dim].T)
102 |
103 | else:
104 | distances = 1 - (embeddings @ centroids.T)
105 |
106 | closest_distance, closest_centroid = torch.sort(distances, dim=-1)
107 | clusters = [embeddings[(closest_centroid[:, 0] == i)] for i in range(num_clusters)]
108 |
109 | if equal_clusters:
110 | for c in range(num_clusters):
111 | if clusters[c].shape[0] > cluster_size:
112 | # sort cluster embs by distance from centroid so spares are furthest away
113 | _, sorted_cluster_embs_ix = torch.sort(
114 | 1 - (clusters[c] @ clusters[c].mean(dim=0).unsqueeze(0).T).squeeze(-1))
115 |
116 | clusters[c] = clusters[c][sorted_cluster_embs_ix]
117 | spare_embs = clusters[c][cluster_size:]
118 | clusters[c] = clusters[c][:cluster_size]
119 | for cc in range(num_clusters):
120 | if clusters[cc].shape[0] < cluster_size:
121 |
122 | _, sorted_spare_embs_ix = torch.sort(
123 | 1 - (spare_embs @ clusters[cc].mean(dim=0).unsqueeze(0).T).squeeze(-1))
124 |
125 | free_space = cluster_size - clusters[cc].shape[0]
126 | clusters[cc] = torch.cat([clusters[cc], spare_embs[sorted_spare_embs_ix][:free_space]])
127 | spare_embs = spare_embs[free_space:]
128 |
129 | new_centroids = torch.stack([c.mean(dim=0)/torch.sqrt(torch.sum(c.mean(dim=0)**2, dim=-1, keepdim=True)) for c in clusters])
130 | movement = torch.abs(new_centroids - centroids).mean()
131 | print('Movement :', movement)
132 | centroids = new_centroids
133 |
134 | centroids = torch.stack([c.mean(dim=0)/torch.sqrt(torch.sum(c.mean(dim=0)**2, dim=-1, keepdim=True)) for c in clusters])
135 | print([c.shape[0] for c in clusters])
136 | torch.save(clusters, save_dir + cluster_fname)
137 | torch.save(centroids, save_dir + centroid_fname)
138 | return clusters, centroids
139 |
140 |
141 | def normalise(x, min_max=[]):
142 |
143 | rnge = x.max() - x.min()
144 | if rnge > 0:
145 | x = (x - x.min()) / rnge
146 |
147 | if len(min_max) > 1:
148 | rnge = min_max[1] - min_max[0]
149 | x = x * rnge + min_max[0]
150 |
151 | return x
152 |
153 |
154 | def closest_tokens(emb, word_embeddings, tokenizer, n=1):
155 | torch.cuda.empty_cache()
156 | dists = 1 - (emb.unsqueeze(0) @ word_embeddings.T).squeeze(0)
157 | sorted_dists, ix = torch.sort(dists)
158 |
159 | tokens = [tokenizer.decode(i) for i in ix[:n]]
160 | ixs = ix[:n]
161 | dists = sorted_dists[:n]
162 | embs = word_embeddings[ixs]
163 | return tokens, ixs, dists, embs
164 |
165 |
166 | def model_emb(model, inputs_embeds, word_embeddings, output_len):
167 |
168 | embs = inputs_embeds
169 | logits = []
170 | ixs = []
171 | input_logits = None
172 | for i in range(output_len):
173 | model_out = model(inputs_embeds=embs, return_dict=True)
174 |
175 | if i == 0:
176 | input_logits = model_out.logits[:, :-1]
177 |
178 | last_logits = model_out.logits[:, -1].unsqueeze(1)
179 | logits.append(last_logits)
180 | ix = torch.argmax(last_logits,
181 | dim=-1)
182 | ixs.append(ix)
183 | output_embs = word_embeddings[ix]
184 | embs = torch.cat([embs, output_embs],
185 | dim=1)
186 |
187 | logits = torch.cat(logits,
188 | dim=1)
189 | perp = perplexity(torch.cat([input_logits, logits], dim=1))
190 | return logits, embs, perp
191 |
192 |
193 |
194 | def perplexity(logits):
195 | probs, ix = torch.max(torch.softmax(logits, dim=-1), dim=-1)
196 |
197 | perp = 1 / (torch.prod(probs, dim=-1) ** (1 / probs.shape[-1])) - 1
198 | return perp
199 |
200 |
--------------------------------------------------------------------------------