├── README.md ├── Slides.pdf ├── sliding_window_attention.ipynb └── test_attention_types.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # mistral-llm-notes 2 | Notes on the Mistral AI model 3 | -------------------------------------------------------------------------------- /Slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/mistral-llm-notes/a2575de8145f652908c5417529de8c622f2bc5dd/Slides.pdf -------------------------------------------------------------------------------- /sliding_window_attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Visualize the sliding window attention receptive field\n", 8 | "\n", 9 | "The receptive field comes from convolutional neural networks. It is the region in the input space that a particular CNN's feature is looking at. The receptive field is a function of the CNN's architecture and the number of layers and is a useful concept for understanding the CNN's feature extraction process.\n", 10 | "\n", 11 | "This sliding window attention, we start with our input sequence, in which each token is represented by a vector. We then compute the attention using a sliding window of size $w$. The output of the self-attention (computed by the final multiplication of the attention weights and the input sequence) is a weighted sum of the input sequence. In our case, I will represent the input as a sequence of `set` objects, such that each set represents the tokens from which information has been captured by the attention mechanism.\n", 12 | "\n", 13 | "Initially, the sequence is a list of sets, all containing a single token.\n", 14 | "\n", 15 | "```\n", 16 | "Layer 1 input:\n", 17 | "0: ['the']\n", 18 | "1: ['cat']\n", 19 | "2: ['is']\n", 20 | "3: ['on']\n", 21 | "4: ['a']\n", 22 | "5: ['chair']\n", 23 | "```\n", 24 | "\n", 25 | "After the first layer, considering a sliding window size of 3, the output of the attention mechanism is:\n", 26 | "\n", 27 | "```\n", 28 | "Layer 1 output:\n", 29 | "0: ['the']\n", 30 | "1: ['the', 'cat']\n", 31 | "2: ['the', 'cat', 'is']\n", 32 | "3: ['cat', 'is', 'on']\n", 33 | "4: ['is', 'on', 'a']\n", 34 | "5: ['on', 'a', 'chair']\n", 35 | "```\n", 36 | "\n", 37 | "The output of the first layer becomes the input of the second layer. The output of the second layer is:\n", 38 | "\n", 39 | "```\n", 40 | "Layer 2 output:\n", 41 | "0: ['the']\n", 42 | "1: ['the', 'cat']\n", 43 | "2: ['the', 'cat', 'is']\n", 44 | "3: ['the', 'cat', 'is', 'on']\n", 45 | "4: ['the', 'cat', 'is', 'on', 'a']\n", 46 | "5: ['cat', 'is', 'on', 'a', 'chair']\n", 47 | "```\n", 48 | "\n", 49 | "As we can see, even with a sliding window of size 3, after just two layers, the attention mechanism can capture long-range dependencies. This is because the output of the first layer is used as the input of the second layer, and the attention mechanism is applied again. This is similar to the idea of stacking multiple layers of CNNs to increase the receptive field." 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 15, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "[{'the'}, {'cat'}, {'is'}, {'on'}, {'a'}, {'chair'}]" 61 | ] 62 | }, 63 | "execution_count": 15, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "# Create a list of tuple with the token and the time-step\n", 70 | "#print_order = [chr(i) for i in range(ord('A'), ord('Z') + 1)]\n", 71 | "print_order = ['the', 'cat', 'is', 'on', 'a', 'chair']\n", 72 | "sequence = [{print_order[i]} for i in range(len(print_order))]\n", 73 | "sequence" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 16, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "sliding_window_size = 3\n", 83 | "\n", 84 | "def sliding_window_attention(seq: list[set[str]], w: int):\n", 85 | " seq_len = len(seq)\n", 86 | " attention_scores: list[list[set]] = [[None for _ in range(seq_len)] for _ in range(seq_len)]\n", 87 | " for i, q_tokens_set in enumerate(seq):\n", 88 | " for j, k_tokens_set in enumerate(seq):\n", 89 | " # The upper triangle is all None\n", 90 | " if j > i:\n", 91 | " continue\n", 92 | " # Each token can only attend to the previous W tokens\n", 93 | " if i - j >= w:\n", 94 | " continue\n", 95 | "\n", 96 | " attention = set()\n", 97 | " # Add all tokens from q_tokens_set to attention_result\n", 98 | " attention.update(q_tokens_set)\n", 99 | " # Add all tokens from k_tokens_set to attention_resul\n", 100 | " attention.update(k_tokens_set)\n", 101 | "\n", 102 | " attention_scores[i][j] = attention\n", 103 | " return attention_scores\n", 104 | "\n", 105 | "def multiple_by_v(attention_scores: list[list[set]], v_sequence: list[set[str]]) -> list[set[str]]:\n", 106 | " seq_len = len(v_sequence)\n", 107 | " result = [set() for _ in range(seq_len)]\n", 108 | " for i in range(seq_len):\n", 109 | " for j in range(seq_len):\n", 110 | " attention = attention_scores[i][j]\n", 111 | " v = v_sequence[j]\n", 112 | " r = result[i]\n", 113 | " # Add all the tokens in the attention (if not None) to r\n", 114 | " if attention is not None:\n", 115 | " # Add all the tokens in v to r\n", 116 | " r.update(v)\n", 117 | " r.update(attention)\n", 118 | " return result\n", 119 | "\n", 120 | "def print_attention(attention_scores: list[list[set[str]]]):\n", 121 | " for i, row in enumerate(attention_scores):\n", 122 | " for j, attention in enumerate(row):\n", 123 | " if attention is None:\n", 124 | " print('None', end='\\t')\n", 125 | " else:\n", 126 | " print(f'{sorted(attention, key=lambda x: print_order.index(x))}', end='\\t')\n", 127 | " print()\n", 128 | "\n", 129 | "def print_sequence(seq: list[set[str]]):\n", 130 | " for i, tokens_set in enumerate(seq):\n", 131 | " print(f'{i}: {sorted(tokens_set, key=lambda x: print_order.index(x))}')\n", 132 | "\n", 133 | "def print_layer(input: list[set[str]], layer_num: int) -> list[set[str]]:\n", 134 | " print(f'Layer {layer_num} input:')\n", 135 | " print_sequence(input)\n", 136 | " attention_scores = sliding_window_attention(input, sliding_window_size)\n", 137 | " print()\n", 138 | " print(f'Layer {layer_num} attention scores:')\n", 139 | " print_attention(attention_scores)\n", 140 | " output = multiple_by_v(attention_scores, input)\n", 141 | " print()\n", 142 | " print(f'Layer {layer_num} output:')\n", 143 | " print_sequence(output)\n", 144 | " return output" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 17, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "Layer 1 input:\n", 157 | "0: ['the']\n", 158 | "1: ['cat']\n", 159 | "2: ['is']\n", 160 | "3: ['on']\n", 161 | "4: ['a']\n", 162 | "5: ['chair']\n", 163 | "\n", 164 | "Layer 1 attention scores:\n", 165 | "['the']\tNone\tNone\tNone\tNone\tNone\t\n", 166 | "['the', 'cat']\t['cat']\tNone\tNone\tNone\tNone\t\n", 167 | "['the', 'is']\t['cat', 'is']\t['is']\tNone\tNone\tNone\t\n", 168 | "None\t['cat', 'on']\t['is', 'on']\t['on']\tNone\tNone\t\n", 169 | "None\tNone\t['is', 'a']\t['on', 'a']\t['a']\tNone\t\n", 170 | "None\tNone\tNone\t['on', 'chair']\t['a', 'chair']\t['chair']\t\n", 171 | "\n", 172 | "Layer 1 output:\n", 173 | "0: ['the']\n", 174 | "1: ['the', 'cat']\n", 175 | "2: ['the', 'cat', 'is']\n", 176 | "3: ['cat', 'is', 'on']\n", 177 | "4: ['is', 'on', 'a']\n", 178 | "5: ['on', 'a', 'chair']\n" 179 | ] 180 | } 181 | ], 182 | "source": [ 183 | "# Layer 1\n", 184 | "output_layer_1 = print_layer(sequence, 1)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 18, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "Layer 2 input:\n", 197 | "0: ['the']\n", 198 | "1: ['the', 'cat']\n", 199 | "2: ['the', 'cat', 'is']\n", 200 | "3: ['cat', 'is', 'on']\n", 201 | "4: ['is', 'on', 'a']\n", 202 | "5: ['on', 'a', 'chair']\n", 203 | "\n", 204 | "Layer 2 attention scores:\n", 205 | "['the']\tNone\tNone\tNone\tNone\tNone\t\n", 206 | "['the', 'cat']\t['the', 'cat']\tNone\tNone\tNone\tNone\t\n", 207 | "['the', 'cat', 'is']\t['the', 'cat', 'is']\t['the', 'cat', 'is']\tNone\tNone\tNone\t\n", 208 | "None\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\t['cat', 'is', 'on']\tNone\tNone\t\n", 209 | "None\tNone\t['the', 'cat', 'is', 'on', 'a']\t['cat', 'is', 'on', 'a']\t['is', 'on', 'a']\tNone\t\n", 210 | "None\tNone\tNone\t['cat', 'is', 'on', 'a', 'chair']\t['is', 'on', 'a', 'chair']\t['on', 'a', 'chair']\t\n", 211 | "\n", 212 | "Layer 2 output:\n", 213 | "0: ['the']\n", 214 | "1: ['the', 'cat']\n", 215 | "2: ['the', 'cat', 'is']\n", 216 | "3: ['the', 'cat', 'is', 'on']\n", 217 | "4: ['the', 'cat', 'is', 'on', 'a']\n", 218 | "5: ['cat', 'is', 'on', 'a', 'chair']\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "# Layer 2\n", 224 | "output_layer_2 = print_layer(output_layer_1, 2)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 19, 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "name": "stdout", 234 | "output_type": "stream", 235 | "text": [ 236 | "Layer 3 input:\n", 237 | "0: ['the']\n", 238 | "1: ['the', 'cat']\n", 239 | "2: ['the', 'cat', 'is']\n", 240 | "3: ['the', 'cat', 'is', 'on']\n", 241 | "4: ['the', 'cat', 'is', 'on', 'a']\n", 242 | "5: ['cat', 'is', 'on', 'a', 'chair']\n", 243 | "\n", 244 | "Layer 3 attention scores:\n", 245 | "['the']\tNone\tNone\tNone\tNone\tNone\t\n", 246 | "['the', 'cat']\t['the', 'cat']\tNone\tNone\tNone\tNone\t\n", 247 | "['the', 'cat', 'is']\t['the', 'cat', 'is']\t['the', 'cat', 'is']\tNone\tNone\tNone\t\n", 248 | "None\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\tNone\tNone\t\n", 249 | "None\tNone\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\tNone\t\n", 250 | "None\tNone\tNone\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['cat', 'is', 'on', 'a', 'chair']\t\n", 251 | "\n", 252 | "Layer 3 output:\n", 253 | "0: ['the']\n", 254 | "1: ['the', 'cat']\n", 255 | "2: ['the', 'cat', 'is']\n", 256 | "3: ['the', 'cat', 'is', 'on']\n", 257 | "4: ['the', 'cat', 'is', 'on', 'a']\n", 258 | "5: ['the', 'cat', 'is', 'on', 'a', 'chair']\n" 259 | ] 260 | } 261 | ], 262 | "source": [ 263 | "# Layer 3\n", 264 | "output_layer_3 = print_layer(output_layer_2, 3)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 20, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "Layer 4 input:\n", 277 | "0: ['the']\n", 278 | "1: ['the', 'cat']\n", 279 | "2: ['the', 'cat', 'is']\n", 280 | "3: ['the', 'cat', 'is', 'on']\n", 281 | "4: ['the', 'cat', 'is', 'on', 'a']\n", 282 | "5: ['the', 'cat', 'is', 'on', 'a', 'chair']\n", 283 | "\n", 284 | "Layer 4 attention scores:\n", 285 | "['the']\tNone\tNone\tNone\tNone\tNone\t\n", 286 | "['the', 'cat']\t['the', 'cat']\tNone\tNone\tNone\tNone\t\n", 287 | "['the', 'cat', 'is']\t['the', 'cat', 'is']\t['the', 'cat', 'is']\tNone\tNone\tNone\t\n", 288 | "None\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\tNone\tNone\t\n", 289 | "None\tNone\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\tNone\t\n", 290 | "None\tNone\tNone\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['the', 'cat', 'is', 'on', 'a', 'chair']\t\n", 291 | "\n", 292 | "Layer 4 output:\n", 293 | "0: ['the']\n", 294 | "1: ['the', 'cat']\n", 295 | "2: ['the', 'cat', 'is']\n", 296 | "3: ['the', 'cat', 'is', 'on']\n", 297 | "4: ['the', 'cat', 'is', 'on', 'a']\n", 298 | "5: ['the', 'cat', 'is', 'on', 'a', 'chair']\n" 299 | ] 300 | } 301 | ], 302 | "source": [ 303 | "# Layer 4\n", 304 | "output_layer_4 = print_layer(output_layer_3, 4)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 21, 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "name": "stdout", 314 | "output_type": "stream", 315 | "text": [ 316 | "Layer 5 input:\n", 317 | "0: ['the']\n", 318 | "1: ['the', 'cat']\n", 319 | "2: ['the', 'cat', 'is']\n", 320 | "3: ['the', 'cat', 'is', 'on']\n", 321 | "4: ['the', 'cat', 'is', 'on', 'a']\n", 322 | "5: ['the', 'cat', 'is', 'on', 'a', 'chair']\n", 323 | "\n", 324 | "Layer 5 attention scores:\n", 325 | "['the']\tNone\tNone\tNone\tNone\tNone\t\n", 326 | "['the', 'cat']\t['the', 'cat']\tNone\tNone\tNone\tNone\t\n", 327 | "['the', 'cat', 'is']\t['the', 'cat', 'is']\t['the', 'cat', 'is']\tNone\tNone\tNone\t\n", 328 | "None\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\tNone\tNone\t\n", 329 | "None\tNone\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\tNone\t\n", 330 | "None\tNone\tNone\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['the', 'cat', 'is', 'on', 'a', 'chair']\t\n", 331 | "\n", 332 | "Layer 5 output:\n", 333 | "0: ['the']\n", 334 | "1: ['the', 'cat']\n", 335 | "2: ['the', 'cat', 'is']\n", 336 | "3: ['the', 'cat', 'is', 'on']\n", 337 | "4: ['the', 'cat', 'is', 'on', 'a']\n", 338 | "5: ['the', 'cat', 'is', 'on', 'a', 'chair']\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "# Layer 5\n", 344 | "output_layer_5 = print_layer(output_layer_4, 5)" 345 | ] 346 | } 347 | ], 348 | "metadata": { 349 | "kernelspec": { 350 | "display_name": "pytorch-mistral", 351 | "language": "python", 352 | "name": "python3" 353 | }, 354 | "language_info": { 355 | "codemirror_mode": { 356 | "name": "ipython", 357 | "version": 3 358 | }, 359 | "file_extension": ".py", 360 | "mimetype": "text/x-python", 361 | "name": "python", 362 | "nbconvert_exporter": "python", 363 | "pygments_lexer": "ipython3", 364 | "version": "3.9.18" 365 | } 366 | }, 367 | "nbformat": 4, 368 | "nbformat_minor": 2 369 | } 370 | -------------------------------------------------------------------------------- /test_attention_types.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# !pip install xformers pandas" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from xformers.ops.fmha.attn_bias import (\n", 19 | " AttentionBias,\n", 20 | " BlockDiagonalCausalMask,\n", 21 | " BlockDiagonalCausalWithOffsetPaddedKeysMask,\n", 22 | " BlockDiagonalMask,\n", 23 | ")\n", 24 | "\n", 25 | "import pandas as pd" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 3, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "col_dict = {0.0: '#68A357', float('-inf'): '#C97064'}\n", 35 | "def colour_cell(val):\n", 36 | " if val in col_dict:\n", 37 | " return 'Background-color: %s' % col_dict[val]\n", 38 | " return ''" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 4, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stderr", 48 | "output_type": "stream", 49 | "text": [ 50 | "/tmp/ipykernel_17144/2484427568.py:13: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.\n", 51 | " df.style.applymap(colour_cell)\n" 52 | ] 53 | }, 54 | { 55 | "data": { 56 | "text/html": [ 57 | "\n", 65 | "\n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | "
 01234567891011121314151617
00.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf
10.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf
20.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf
3-inf0.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf
4-inf-inf0.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf
5-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf
6-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf
7-inf-inf-inf-inf-inf-inf-inf0.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf
8-inf-inf-inf-inf-inf-inf-inf0.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf
9-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf
10-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf
11-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf-inf-inf-inf
12-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.000000-inf-inf-inf-inf-inf
13-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.000000-inf-inf-inf-inf
14-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf
15-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf
16-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf
17-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000
\n" 470 | ], 471 | "text/plain": [ 472 | "" 473 | ] 474 | }, 475 | "execution_count": 4, 476 | "metadata": {}, 477 | "output_type": "execute_result" 478 | } 479 | ], 480 | "source": [ 481 | "## BlockDiagonalCausalMask\n", 482 | "\n", 483 | "seqlens = [7, 5, 6]\n", 484 | "sliding_window_size = 3\n", 485 | "\n", 486 | "mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(sliding_window_size)\n", 487 | "\n", 488 | "batch_size = 1\n", 489 | "total_seq_len = sum(seqlens)\n", 490 | "mask_tensor = mask.materialize((batch_size, total_seq_len, total_seq_len))\n", 491 | "\n", 492 | "df = pd.DataFrame(mask_tensor[0, :, :].numpy())\n", 493 | "df.style.applymap(colour_cell)" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 5, 499 | "metadata": {}, 500 | "outputs": [ 501 | { 502 | "name": "stderr", 503 | "output_type": "stream", 504 | "text": [ 505 | "/tmp/ipykernel_17144/1506970582.py:15: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.\n", 506 | " df.style.applymap(colour_cell)\n" 507 | ] 508 | }, 509 | { 510 | "data": { 511 | "text/html": [ 512 | "\n", 520 | "\n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | "
 01234567891011121314151617
0-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf
1-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf
2-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf
3-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf-inf
4-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf-inf
5-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf-inf
6-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000-inf
7-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf-inf0.0000000.0000000.000000
\n" 715 | ], 716 | "text/plain": [ 717 | "" 718 | ] 719 | }, 720 | "execution_count": 5, 721 | "metadata": {}, 722 | "output_type": "execute_result" 723 | } 724 | ], 725 | "source": [ 726 | "## BlockDiagonalMask\n", 727 | "\n", 728 | "q_seqlens = [3, 5]\n", 729 | "kv_seqlens = [10, 8] # (3 + 7, 5 + 3)\n", 730 | "sliding_window_size = 3\n", 731 | "\n", 732 | "mask = BlockDiagonalMask.from_seqlens(q_seqlens, kv_seqlens).make_local_attention_from_bottomright(sliding_window_size)\n", 733 | "\n", 734 | "batch_size = 1\n", 735 | "total_seq_len = sum(q_seqlens)\n", 736 | "total_kv_seq_len = sum(kv_seqlens)\n", 737 | "mask_tensor = mask.materialize((batch_size, total_seq_len, total_kv_seq_len))\n", 738 | "\n", 739 | "df = pd.DataFrame(mask_tensor[0, :, :].numpy())\n", 740 | "df.style.applymap(colour_cell)" 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "execution_count": 6, 746 | "metadata": {}, 747 | "outputs": [ 748 | { 749 | "name": "stderr", 750 | "output_type": "stream", 751 | "text": [ 752 | "/tmp/ipykernel_17144/2637469656.py:16: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.\n", 753 | " df.style.applymap(colour_cell)\n" 754 | ] 755 | }, 756 | { 757 | "data": { 758 | "text/html": [ 759 | "\n", 767 | "\n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | "
 01234567891011
00.0000000.0000000.000000-inf-inf-inf-inf-inf-inf-inf-inf-inf
1-inf-inf-inf-inf-inf-inf0.0000000.0000000.0000000.0000000.000000-inf
\n" 818 | ], 819 | "text/plain": [ 820 | "" 821 | ] 822 | }, 823 | "execution_count": 6, 824 | "metadata": {}, 825 | "output_type": "execute_result" 826 | } 827 | ], 828 | "source": [ 829 | "## BlockDiagonalCausalWithOffsetPaddedKeysMask\n", 830 | "\n", 831 | "# We use this mask with padding because the overall size of the KV-Cache is the same for all the prompts, but for each KV-Cache we may need to use only some of the items.\n", 832 | "\n", 833 | "q_seqlen = [1, 1]\n", 834 | "kv_seq_len = [3, 5]\n", 835 | "kv_padding = 6\n", 836 | "\n", 837 | "mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(q_seqlen=q_seqlen, kv_padding=kv_padding, kv_seqlen=kv_seq_len)\n", 838 | "\n", 839 | "batch_size = 1\n", 840 | "total_seq_len = sum(q_seqlen)\n", 841 | "total_kv_seq_len = kv_padding * len(kv_seq_len)\n", 842 | "\n", 843 | "mask_tensor = mask.materialize((batch_size, total_seq_len, total_kv_seq_len))\n", 844 | "\n", 845 | "df = pd.DataFrame(mask_tensor[0, :, :].numpy())\n", 846 | "df.style.applymap(colour_cell)" 847 | ] 848 | }, 849 | { 850 | "cell_type": "code", 851 | "execution_count": null, 852 | "metadata": {}, 853 | "outputs": [], 854 | "source": [] 855 | } 856 | ], 857 | "metadata": { 858 | "kernelspec": { 859 | "display_name": "pytorch-mistral", 860 | "language": "python", 861 | "name": "python3" 862 | }, 863 | "language_info": { 864 | "codemirror_mode": { 865 | "name": "ipython", 866 | "version": 3 867 | }, 868 | "file_extension": ".py", 869 | "mimetype": "text/x-python", 870 | "name": "python", 871 | "nbconvert_exporter": "python", 872 | "pygments_lexer": "ipython3", 873 | "version": "3.9.18" 874 | } 875 | }, 876 | "nbformat": 4, 877 | "nbformat_minor": 2 878 | } 879 | --------------------------------------------------------------------------------