├── README.md
├── Slides.pdf
├── sliding_window_attention.ipynb
└── test_attention_types.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # mistral-llm-notes
2 | Notes on the Mistral AI model
3 |
--------------------------------------------------------------------------------
/Slides.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkproj/mistral-llm-notes/a2575de8145f652908c5417529de8c622f2bc5dd/Slides.pdf
--------------------------------------------------------------------------------
/sliding_window_attention.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Visualize the sliding window attention receptive field\n",
8 | "\n",
9 | "The receptive field comes from convolutional neural networks. It is the region in the input space that a particular CNN's feature is looking at. The receptive field is a function of the CNN's architecture and the number of layers and is a useful concept for understanding the CNN's feature extraction process.\n",
10 | "\n",
11 | "This sliding window attention, we start with our input sequence, in which each token is represented by a vector. We then compute the attention using a sliding window of size $w$. The output of the self-attention (computed by the final multiplication of the attention weights and the input sequence) is a weighted sum of the input sequence. In our case, I will represent the input as a sequence of `set` objects, such that each set represents the tokens from which information has been captured by the attention mechanism.\n",
12 | "\n",
13 | "Initially, the sequence is a list of sets, all containing a single token.\n",
14 | "\n",
15 | "```\n",
16 | "Layer 1 input:\n",
17 | "0: ['the']\n",
18 | "1: ['cat']\n",
19 | "2: ['is']\n",
20 | "3: ['on']\n",
21 | "4: ['a']\n",
22 | "5: ['chair']\n",
23 | "```\n",
24 | "\n",
25 | "After the first layer, considering a sliding window size of 3, the output of the attention mechanism is:\n",
26 | "\n",
27 | "```\n",
28 | "Layer 1 output:\n",
29 | "0: ['the']\n",
30 | "1: ['the', 'cat']\n",
31 | "2: ['the', 'cat', 'is']\n",
32 | "3: ['cat', 'is', 'on']\n",
33 | "4: ['is', 'on', 'a']\n",
34 | "5: ['on', 'a', 'chair']\n",
35 | "```\n",
36 | "\n",
37 | "The output of the first layer becomes the input of the second layer. The output of the second layer is:\n",
38 | "\n",
39 | "```\n",
40 | "Layer 2 output:\n",
41 | "0: ['the']\n",
42 | "1: ['the', 'cat']\n",
43 | "2: ['the', 'cat', 'is']\n",
44 | "3: ['the', 'cat', 'is', 'on']\n",
45 | "4: ['the', 'cat', 'is', 'on', 'a']\n",
46 | "5: ['cat', 'is', 'on', 'a', 'chair']\n",
47 | "```\n",
48 | "\n",
49 | "As we can see, even with a sliding window of size 3, after just two layers, the attention mechanism can capture long-range dependencies. This is because the output of the first layer is used as the input of the second layer, and the attention mechanism is applied again. This is similar to the idea of stacking multiple layers of CNNs to increase the receptive field."
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 15,
55 | "metadata": {},
56 | "outputs": [
57 | {
58 | "data": {
59 | "text/plain": [
60 | "[{'the'}, {'cat'}, {'is'}, {'on'}, {'a'}, {'chair'}]"
61 | ]
62 | },
63 | "execution_count": 15,
64 | "metadata": {},
65 | "output_type": "execute_result"
66 | }
67 | ],
68 | "source": [
69 | "# Create a list of tuple with the token and the time-step\n",
70 | "#print_order = [chr(i) for i in range(ord('A'), ord('Z') + 1)]\n",
71 | "print_order = ['the', 'cat', 'is', 'on', 'a', 'chair']\n",
72 | "sequence = [{print_order[i]} for i in range(len(print_order))]\n",
73 | "sequence"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 16,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "sliding_window_size = 3\n",
83 | "\n",
84 | "def sliding_window_attention(seq: list[set[str]], w: int):\n",
85 | " seq_len = len(seq)\n",
86 | " attention_scores: list[list[set]] = [[None for _ in range(seq_len)] for _ in range(seq_len)]\n",
87 | " for i, q_tokens_set in enumerate(seq):\n",
88 | " for j, k_tokens_set in enumerate(seq):\n",
89 | " # The upper triangle is all None\n",
90 | " if j > i:\n",
91 | " continue\n",
92 | " # Each token can only attend to the previous W tokens\n",
93 | " if i - j >= w:\n",
94 | " continue\n",
95 | "\n",
96 | " attention = set()\n",
97 | " # Add all tokens from q_tokens_set to attention_result\n",
98 | " attention.update(q_tokens_set)\n",
99 | " # Add all tokens from k_tokens_set to attention_resul\n",
100 | " attention.update(k_tokens_set)\n",
101 | "\n",
102 | " attention_scores[i][j] = attention\n",
103 | " return attention_scores\n",
104 | "\n",
105 | "def multiple_by_v(attention_scores: list[list[set]], v_sequence: list[set[str]]) -> list[set[str]]:\n",
106 | " seq_len = len(v_sequence)\n",
107 | " result = [set() for _ in range(seq_len)]\n",
108 | " for i in range(seq_len):\n",
109 | " for j in range(seq_len):\n",
110 | " attention = attention_scores[i][j]\n",
111 | " v = v_sequence[j]\n",
112 | " r = result[i]\n",
113 | " # Add all the tokens in the attention (if not None) to r\n",
114 | " if attention is not None:\n",
115 | " # Add all the tokens in v to r\n",
116 | " r.update(v)\n",
117 | " r.update(attention)\n",
118 | " return result\n",
119 | "\n",
120 | "def print_attention(attention_scores: list[list[set[str]]]):\n",
121 | " for i, row in enumerate(attention_scores):\n",
122 | " for j, attention in enumerate(row):\n",
123 | " if attention is None:\n",
124 | " print('None', end='\\t')\n",
125 | " else:\n",
126 | " print(f'{sorted(attention, key=lambda x: print_order.index(x))}', end='\\t')\n",
127 | " print()\n",
128 | "\n",
129 | "def print_sequence(seq: list[set[str]]):\n",
130 | " for i, tokens_set in enumerate(seq):\n",
131 | " print(f'{i}: {sorted(tokens_set, key=lambda x: print_order.index(x))}')\n",
132 | "\n",
133 | "def print_layer(input: list[set[str]], layer_num: int) -> list[set[str]]:\n",
134 | " print(f'Layer {layer_num} input:')\n",
135 | " print_sequence(input)\n",
136 | " attention_scores = sliding_window_attention(input, sliding_window_size)\n",
137 | " print()\n",
138 | " print(f'Layer {layer_num} attention scores:')\n",
139 | " print_attention(attention_scores)\n",
140 | " output = multiple_by_v(attention_scores, input)\n",
141 | " print()\n",
142 | " print(f'Layer {layer_num} output:')\n",
143 | " print_sequence(output)\n",
144 | " return output"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 17,
150 | "metadata": {},
151 | "outputs": [
152 | {
153 | "name": "stdout",
154 | "output_type": "stream",
155 | "text": [
156 | "Layer 1 input:\n",
157 | "0: ['the']\n",
158 | "1: ['cat']\n",
159 | "2: ['is']\n",
160 | "3: ['on']\n",
161 | "4: ['a']\n",
162 | "5: ['chair']\n",
163 | "\n",
164 | "Layer 1 attention scores:\n",
165 | "['the']\tNone\tNone\tNone\tNone\tNone\t\n",
166 | "['the', 'cat']\t['cat']\tNone\tNone\tNone\tNone\t\n",
167 | "['the', 'is']\t['cat', 'is']\t['is']\tNone\tNone\tNone\t\n",
168 | "None\t['cat', 'on']\t['is', 'on']\t['on']\tNone\tNone\t\n",
169 | "None\tNone\t['is', 'a']\t['on', 'a']\t['a']\tNone\t\n",
170 | "None\tNone\tNone\t['on', 'chair']\t['a', 'chair']\t['chair']\t\n",
171 | "\n",
172 | "Layer 1 output:\n",
173 | "0: ['the']\n",
174 | "1: ['the', 'cat']\n",
175 | "2: ['the', 'cat', 'is']\n",
176 | "3: ['cat', 'is', 'on']\n",
177 | "4: ['is', 'on', 'a']\n",
178 | "5: ['on', 'a', 'chair']\n"
179 | ]
180 | }
181 | ],
182 | "source": [
183 | "# Layer 1\n",
184 | "output_layer_1 = print_layer(sequence, 1)"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 18,
190 | "metadata": {},
191 | "outputs": [
192 | {
193 | "name": "stdout",
194 | "output_type": "stream",
195 | "text": [
196 | "Layer 2 input:\n",
197 | "0: ['the']\n",
198 | "1: ['the', 'cat']\n",
199 | "2: ['the', 'cat', 'is']\n",
200 | "3: ['cat', 'is', 'on']\n",
201 | "4: ['is', 'on', 'a']\n",
202 | "5: ['on', 'a', 'chair']\n",
203 | "\n",
204 | "Layer 2 attention scores:\n",
205 | "['the']\tNone\tNone\tNone\tNone\tNone\t\n",
206 | "['the', 'cat']\t['the', 'cat']\tNone\tNone\tNone\tNone\t\n",
207 | "['the', 'cat', 'is']\t['the', 'cat', 'is']\t['the', 'cat', 'is']\tNone\tNone\tNone\t\n",
208 | "None\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\t['cat', 'is', 'on']\tNone\tNone\t\n",
209 | "None\tNone\t['the', 'cat', 'is', 'on', 'a']\t['cat', 'is', 'on', 'a']\t['is', 'on', 'a']\tNone\t\n",
210 | "None\tNone\tNone\t['cat', 'is', 'on', 'a', 'chair']\t['is', 'on', 'a', 'chair']\t['on', 'a', 'chair']\t\n",
211 | "\n",
212 | "Layer 2 output:\n",
213 | "0: ['the']\n",
214 | "1: ['the', 'cat']\n",
215 | "2: ['the', 'cat', 'is']\n",
216 | "3: ['the', 'cat', 'is', 'on']\n",
217 | "4: ['the', 'cat', 'is', 'on', 'a']\n",
218 | "5: ['cat', 'is', 'on', 'a', 'chair']\n"
219 | ]
220 | }
221 | ],
222 | "source": [
223 | "# Layer 2\n",
224 | "output_layer_2 = print_layer(output_layer_1, 2)"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 19,
230 | "metadata": {},
231 | "outputs": [
232 | {
233 | "name": "stdout",
234 | "output_type": "stream",
235 | "text": [
236 | "Layer 3 input:\n",
237 | "0: ['the']\n",
238 | "1: ['the', 'cat']\n",
239 | "2: ['the', 'cat', 'is']\n",
240 | "3: ['the', 'cat', 'is', 'on']\n",
241 | "4: ['the', 'cat', 'is', 'on', 'a']\n",
242 | "5: ['cat', 'is', 'on', 'a', 'chair']\n",
243 | "\n",
244 | "Layer 3 attention scores:\n",
245 | "['the']\tNone\tNone\tNone\tNone\tNone\t\n",
246 | "['the', 'cat']\t['the', 'cat']\tNone\tNone\tNone\tNone\t\n",
247 | "['the', 'cat', 'is']\t['the', 'cat', 'is']\t['the', 'cat', 'is']\tNone\tNone\tNone\t\n",
248 | "None\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\tNone\tNone\t\n",
249 | "None\tNone\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\tNone\t\n",
250 | "None\tNone\tNone\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['cat', 'is', 'on', 'a', 'chair']\t\n",
251 | "\n",
252 | "Layer 3 output:\n",
253 | "0: ['the']\n",
254 | "1: ['the', 'cat']\n",
255 | "2: ['the', 'cat', 'is']\n",
256 | "3: ['the', 'cat', 'is', 'on']\n",
257 | "4: ['the', 'cat', 'is', 'on', 'a']\n",
258 | "5: ['the', 'cat', 'is', 'on', 'a', 'chair']\n"
259 | ]
260 | }
261 | ],
262 | "source": [
263 | "# Layer 3\n",
264 | "output_layer_3 = print_layer(output_layer_2, 3)"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 20,
270 | "metadata": {},
271 | "outputs": [
272 | {
273 | "name": "stdout",
274 | "output_type": "stream",
275 | "text": [
276 | "Layer 4 input:\n",
277 | "0: ['the']\n",
278 | "1: ['the', 'cat']\n",
279 | "2: ['the', 'cat', 'is']\n",
280 | "3: ['the', 'cat', 'is', 'on']\n",
281 | "4: ['the', 'cat', 'is', 'on', 'a']\n",
282 | "5: ['the', 'cat', 'is', 'on', 'a', 'chair']\n",
283 | "\n",
284 | "Layer 4 attention scores:\n",
285 | "['the']\tNone\tNone\tNone\tNone\tNone\t\n",
286 | "['the', 'cat']\t['the', 'cat']\tNone\tNone\tNone\tNone\t\n",
287 | "['the', 'cat', 'is']\t['the', 'cat', 'is']\t['the', 'cat', 'is']\tNone\tNone\tNone\t\n",
288 | "None\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\tNone\tNone\t\n",
289 | "None\tNone\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\tNone\t\n",
290 | "None\tNone\tNone\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['the', 'cat', 'is', 'on', 'a', 'chair']\t\n",
291 | "\n",
292 | "Layer 4 output:\n",
293 | "0: ['the']\n",
294 | "1: ['the', 'cat']\n",
295 | "2: ['the', 'cat', 'is']\n",
296 | "3: ['the', 'cat', 'is', 'on']\n",
297 | "4: ['the', 'cat', 'is', 'on', 'a']\n",
298 | "5: ['the', 'cat', 'is', 'on', 'a', 'chair']\n"
299 | ]
300 | }
301 | ],
302 | "source": [
303 | "# Layer 4\n",
304 | "output_layer_4 = print_layer(output_layer_3, 4)"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 21,
310 | "metadata": {},
311 | "outputs": [
312 | {
313 | "name": "stdout",
314 | "output_type": "stream",
315 | "text": [
316 | "Layer 5 input:\n",
317 | "0: ['the']\n",
318 | "1: ['the', 'cat']\n",
319 | "2: ['the', 'cat', 'is']\n",
320 | "3: ['the', 'cat', 'is', 'on']\n",
321 | "4: ['the', 'cat', 'is', 'on', 'a']\n",
322 | "5: ['the', 'cat', 'is', 'on', 'a', 'chair']\n",
323 | "\n",
324 | "Layer 5 attention scores:\n",
325 | "['the']\tNone\tNone\tNone\tNone\tNone\t\n",
326 | "['the', 'cat']\t['the', 'cat']\tNone\tNone\tNone\tNone\t\n",
327 | "['the', 'cat', 'is']\t['the', 'cat', 'is']\t['the', 'cat', 'is']\tNone\tNone\tNone\t\n",
328 | "None\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\t['the', 'cat', 'is', 'on']\tNone\tNone\t\n",
329 | "None\tNone\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\t['the', 'cat', 'is', 'on', 'a']\tNone\t\n",
330 | "None\tNone\tNone\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['the', 'cat', 'is', 'on', 'a', 'chair']\t['the', 'cat', 'is', 'on', 'a', 'chair']\t\n",
331 | "\n",
332 | "Layer 5 output:\n",
333 | "0: ['the']\n",
334 | "1: ['the', 'cat']\n",
335 | "2: ['the', 'cat', 'is']\n",
336 | "3: ['the', 'cat', 'is', 'on']\n",
337 | "4: ['the', 'cat', 'is', 'on', 'a']\n",
338 | "5: ['the', 'cat', 'is', 'on', 'a', 'chair']\n"
339 | ]
340 | }
341 | ],
342 | "source": [
343 | "# Layer 5\n",
344 | "output_layer_5 = print_layer(output_layer_4, 5)"
345 | ]
346 | }
347 | ],
348 | "metadata": {
349 | "kernelspec": {
350 | "display_name": "pytorch-mistral",
351 | "language": "python",
352 | "name": "python3"
353 | },
354 | "language_info": {
355 | "codemirror_mode": {
356 | "name": "ipython",
357 | "version": 3
358 | },
359 | "file_extension": ".py",
360 | "mimetype": "text/x-python",
361 | "name": "python",
362 | "nbconvert_exporter": "python",
363 | "pygments_lexer": "ipython3",
364 | "version": "3.9.18"
365 | }
366 | },
367 | "nbformat": 4,
368 | "nbformat_minor": 2
369 | }
370 |
--------------------------------------------------------------------------------
/test_attention_types.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# !pip install xformers pandas"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from xformers.ops.fmha.attn_bias import (\n",
19 | " AttentionBias,\n",
20 | " BlockDiagonalCausalMask,\n",
21 | " BlockDiagonalCausalWithOffsetPaddedKeysMask,\n",
22 | " BlockDiagonalMask,\n",
23 | ")\n",
24 | "\n",
25 | "import pandas as pd"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 3,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "col_dict = {0.0: '#68A357', float('-inf'): '#C97064'}\n",
35 | "def colour_cell(val):\n",
36 | " if val in col_dict:\n",
37 | " return 'Background-color: %s' % col_dict[val]\n",
38 | " return ''"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 4,
44 | "metadata": {},
45 | "outputs": [
46 | {
47 | "name": "stderr",
48 | "output_type": "stream",
49 | "text": [
50 | "/tmp/ipykernel_17144/2484427568.py:13: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.\n",
51 | " df.style.applymap(colour_cell)\n"
52 | ]
53 | },
54 | {
55 | "data": {
56 | "text/html": [
57 | "\n",
65 | "
\n",
66 | " \n",
67 | " \n",
68 | " | \n",
69 | " 0 | \n",
70 | " 1 | \n",
71 | " 2 | \n",
72 | " 3 | \n",
73 | " 4 | \n",
74 | " 5 | \n",
75 | " 6 | \n",
76 | " 7 | \n",
77 | " 8 | \n",
78 | " 9 | \n",
79 | " 10 | \n",
80 | " 11 | \n",
81 | " 12 | \n",
82 | " 13 | \n",
83 | " 14 | \n",
84 | " 15 | \n",
85 | " 16 | \n",
86 | " 17 | \n",
87 | "
\n",
88 | " \n",
89 | " \n",
90 | " \n",
91 | " 0 | \n",
92 | " 0.000000 | \n",
93 | " -inf | \n",
94 | " -inf | \n",
95 | " -inf | \n",
96 | " -inf | \n",
97 | " -inf | \n",
98 | " -inf | \n",
99 | " -inf | \n",
100 | " -inf | \n",
101 | " -inf | \n",
102 | " -inf | \n",
103 | " -inf | \n",
104 | " -inf | \n",
105 | " -inf | \n",
106 | " -inf | \n",
107 | " -inf | \n",
108 | " -inf | \n",
109 | " -inf | \n",
110 | "
\n",
111 | " \n",
112 | " 1 | \n",
113 | " 0.000000 | \n",
114 | " 0.000000 | \n",
115 | " -inf | \n",
116 | " -inf | \n",
117 | " -inf | \n",
118 | " -inf | \n",
119 | " -inf | \n",
120 | " -inf | \n",
121 | " -inf | \n",
122 | " -inf | \n",
123 | " -inf | \n",
124 | " -inf | \n",
125 | " -inf | \n",
126 | " -inf | \n",
127 | " -inf | \n",
128 | " -inf | \n",
129 | " -inf | \n",
130 | " -inf | \n",
131 | "
\n",
132 | " \n",
133 | " 2 | \n",
134 | " 0.000000 | \n",
135 | " 0.000000 | \n",
136 | " 0.000000 | \n",
137 | " -inf | \n",
138 | " -inf | \n",
139 | " -inf | \n",
140 | " -inf | \n",
141 | " -inf | \n",
142 | " -inf | \n",
143 | " -inf | \n",
144 | " -inf | \n",
145 | " -inf | \n",
146 | " -inf | \n",
147 | " -inf | \n",
148 | " -inf | \n",
149 | " -inf | \n",
150 | " -inf | \n",
151 | " -inf | \n",
152 | "
\n",
153 | " \n",
154 | " 3 | \n",
155 | " -inf | \n",
156 | " 0.000000 | \n",
157 | " 0.000000 | \n",
158 | " 0.000000 | \n",
159 | " -inf | \n",
160 | " -inf | \n",
161 | " -inf | \n",
162 | " -inf | \n",
163 | " -inf | \n",
164 | " -inf | \n",
165 | " -inf | \n",
166 | " -inf | \n",
167 | " -inf | \n",
168 | " -inf | \n",
169 | " -inf | \n",
170 | " -inf | \n",
171 | " -inf | \n",
172 | " -inf | \n",
173 | "
\n",
174 | " \n",
175 | " 4 | \n",
176 | " -inf | \n",
177 | " -inf | \n",
178 | " 0.000000 | \n",
179 | " 0.000000 | \n",
180 | " 0.000000 | \n",
181 | " -inf | \n",
182 | " -inf | \n",
183 | " -inf | \n",
184 | " -inf | \n",
185 | " -inf | \n",
186 | " -inf | \n",
187 | " -inf | \n",
188 | " -inf | \n",
189 | " -inf | \n",
190 | " -inf | \n",
191 | " -inf | \n",
192 | " -inf | \n",
193 | " -inf | \n",
194 | "
\n",
195 | " \n",
196 | " 5 | \n",
197 | " -inf | \n",
198 | " -inf | \n",
199 | " -inf | \n",
200 | " 0.000000 | \n",
201 | " 0.000000 | \n",
202 | " 0.000000 | \n",
203 | " -inf | \n",
204 | " -inf | \n",
205 | " -inf | \n",
206 | " -inf | \n",
207 | " -inf | \n",
208 | " -inf | \n",
209 | " -inf | \n",
210 | " -inf | \n",
211 | " -inf | \n",
212 | " -inf | \n",
213 | " -inf | \n",
214 | " -inf | \n",
215 | "
\n",
216 | " \n",
217 | " 6 | \n",
218 | " -inf | \n",
219 | " -inf | \n",
220 | " -inf | \n",
221 | " -inf | \n",
222 | " 0.000000 | \n",
223 | " 0.000000 | \n",
224 | " 0.000000 | \n",
225 | " -inf | \n",
226 | " -inf | \n",
227 | " -inf | \n",
228 | " -inf | \n",
229 | " -inf | \n",
230 | " -inf | \n",
231 | " -inf | \n",
232 | " -inf | \n",
233 | " -inf | \n",
234 | " -inf | \n",
235 | " -inf | \n",
236 | "
\n",
237 | " \n",
238 | " 7 | \n",
239 | " -inf | \n",
240 | " -inf | \n",
241 | " -inf | \n",
242 | " -inf | \n",
243 | " -inf | \n",
244 | " -inf | \n",
245 | " -inf | \n",
246 | " 0.000000 | \n",
247 | " -inf | \n",
248 | " -inf | \n",
249 | " -inf | \n",
250 | " -inf | \n",
251 | " -inf | \n",
252 | " -inf | \n",
253 | " -inf | \n",
254 | " -inf | \n",
255 | " -inf | \n",
256 | " -inf | \n",
257 | "
\n",
258 | " \n",
259 | " 8 | \n",
260 | " -inf | \n",
261 | " -inf | \n",
262 | " -inf | \n",
263 | " -inf | \n",
264 | " -inf | \n",
265 | " -inf | \n",
266 | " -inf | \n",
267 | " 0.000000 | \n",
268 | " 0.000000 | \n",
269 | " -inf | \n",
270 | " -inf | \n",
271 | " -inf | \n",
272 | " -inf | \n",
273 | " -inf | \n",
274 | " -inf | \n",
275 | " -inf | \n",
276 | " -inf | \n",
277 | " -inf | \n",
278 | "
\n",
279 | " \n",
280 | " 9 | \n",
281 | " -inf | \n",
282 | " -inf | \n",
283 | " -inf | \n",
284 | " -inf | \n",
285 | " -inf | \n",
286 | " -inf | \n",
287 | " -inf | \n",
288 | " 0.000000 | \n",
289 | " 0.000000 | \n",
290 | " 0.000000 | \n",
291 | " -inf | \n",
292 | " -inf | \n",
293 | " -inf | \n",
294 | " -inf | \n",
295 | " -inf | \n",
296 | " -inf | \n",
297 | " -inf | \n",
298 | " -inf | \n",
299 | "
\n",
300 | " \n",
301 | " 10 | \n",
302 | " -inf | \n",
303 | " -inf | \n",
304 | " -inf | \n",
305 | " -inf | \n",
306 | " -inf | \n",
307 | " -inf | \n",
308 | " -inf | \n",
309 | " -inf | \n",
310 | " 0.000000 | \n",
311 | " 0.000000 | \n",
312 | " 0.000000 | \n",
313 | " -inf | \n",
314 | " -inf | \n",
315 | " -inf | \n",
316 | " -inf | \n",
317 | " -inf | \n",
318 | " -inf | \n",
319 | " -inf | \n",
320 | "
\n",
321 | " \n",
322 | " 11 | \n",
323 | " -inf | \n",
324 | " -inf | \n",
325 | " -inf | \n",
326 | " -inf | \n",
327 | " -inf | \n",
328 | " -inf | \n",
329 | " -inf | \n",
330 | " -inf | \n",
331 | " -inf | \n",
332 | " 0.000000 | \n",
333 | " 0.000000 | \n",
334 | " 0.000000 | \n",
335 | " -inf | \n",
336 | " -inf | \n",
337 | " -inf | \n",
338 | " -inf | \n",
339 | " -inf | \n",
340 | " -inf | \n",
341 | "
\n",
342 | " \n",
343 | " 12 | \n",
344 | " -inf | \n",
345 | " -inf | \n",
346 | " -inf | \n",
347 | " -inf | \n",
348 | " -inf | \n",
349 | " -inf | \n",
350 | " -inf | \n",
351 | " -inf | \n",
352 | " -inf | \n",
353 | " -inf | \n",
354 | " -inf | \n",
355 | " -inf | \n",
356 | " 0.000000 | \n",
357 | " -inf | \n",
358 | " -inf | \n",
359 | " -inf | \n",
360 | " -inf | \n",
361 | " -inf | \n",
362 | "
\n",
363 | " \n",
364 | " 13 | \n",
365 | " -inf | \n",
366 | " -inf | \n",
367 | " -inf | \n",
368 | " -inf | \n",
369 | " -inf | \n",
370 | " -inf | \n",
371 | " -inf | \n",
372 | " -inf | \n",
373 | " -inf | \n",
374 | " -inf | \n",
375 | " -inf | \n",
376 | " -inf | \n",
377 | " 0.000000 | \n",
378 | " 0.000000 | \n",
379 | " -inf | \n",
380 | " -inf | \n",
381 | " -inf | \n",
382 | " -inf | \n",
383 | "
\n",
384 | " \n",
385 | " 14 | \n",
386 | " -inf | \n",
387 | " -inf | \n",
388 | " -inf | \n",
389 | " -inf | \n",
390 | " -inf | \n",
391 | " -inf | \n",
392 | " -inf | \n",
393 | " -inf | \n",
394 | " -inf | \n",
395 | " -inf | \n",
396 | " -inf | \n",
397 | " -inf | \n",
398 | " 0.000000 | \n",
399 | " 0.000000 | \n",
400 | " 0.000000 | \n",
401 | " -inf | \n",
402 | " -inf | \n",
403 | " -inf | \n",
404 | "
\n",
405 | " \n",
406 | " 15 | \n",
407 | " -inf | \n",
408 | " -inf | \n",
409 | " -inf | \n",
410 | " -inf | \n",
411 | " -inf | \n",
412 | " -inf | \n",
413 | " -inf | \n",
414 | " -inf | \n",
415 | " -inf | \n",
416 | " -inf | \n",
417 | " -inf | \n",
418 | " -inf | \n",
419 | " -inf | \n",
420 | " 0.000000 | \n",
421 | " 0.000000 | \n",
422 | " 0.000000 | \n",
423 | " -inf | \n",
424 | " -inf | \n",
425 | "
\n",
426 | " \n",
427 | " 16 | \n",
428 | " -inf | \n",
429 | " -inf | \n",
430 | " -inf | \n",
431 | " -inf | \n",
432 | " -inf | \n",
433 | " -inf | \n",
434 | " -inf | \n",
435 | " -inf | \n",
436 | " -inf | \n",
437 | " -inf | \n",
438 | " -inf | \n",
439 | " -inf | \n",
440 | " -inf | \n",
441 | " -inf | \n",
442 | " 0.000000 | \n",
443 | " 0.000000 | \n",
444 | " 0.000000 | \n",
445 | " -inf | \n",
446 | "
\n",
447 | " \n",
448 | " 17 | \n",
449 | " -inf | \n",
450 | " -inf | \n",
451 | " -inf | \n",
452 | " -inf | \n",
453 | " -inf | \n",
454 | " -inf | \n",
455 | " -inf | \n",
456 | " -inf | \n",
457 | " -inf | \n",
458 | " -inf | \n",
459 | " -inf | \n",
460 | " -inf | \n",
461 | " -inf | \n",
462 | " -inf | \n",
463 | " -inf | \n",
464 | " 0.000000 | \n",
465 | " 0.000000 | \n",
466 | " 0.000000 | \n",
467 | "
\n",
468 | " \n",
469 | "
\n"
470 | ],
471 | "text/plain": [
472 | ""
473 | ]
474 | },
475 | "execution_count": 4,
476 | "metadata": {},
477 | "output_type": "execute_result"
478 | }
479 | ],
480 | "source": [
481 | "## BlockDiagonalCausalMask\n",
482 | "\n",
483 | "seqlens = [7, 5, 6]\n",
484 | "sliding_window_size = 3\n",
485 | "\n",
486 | "mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(sliding_window_size)\n",
487 | "\n",
488 | "batch_size = 1\n",
489 | "total_seq_len = sum(seqlens)\n",
490 | "mask_tensor = mask.materialize((batch_size, total_seq_len, total_seq_len))\n",
491 | "\n",
492 | "df = pd.DataFrame(mask_tensor[0, :, :].numpy())\n",
493 | "df.style.applymap(colour_cell)"
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": 5,
499 | "metadata": {},
500 | "outputs": [
501 | {
502 | "name": "stderr",
503 | "output_type": "stream",
504 | "text": [
505 | "/tmp/ipykernel_17144/1506970582.py:15: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.\n",
506 | " df.style.applymap(colour_cell)\n"
507 | ]
508 | },
509 | {
510 | "data": {
511 | "text/html": [
512 | "\n",
520 | "\n",
521 | " \n",
522 | " \n",
523 | " | \n",
524 | " 0 | \n",
525 | " 1 | \n",
526 | " 2 | \n",
527 | " 3 | \n",
528 | " 4 | \n",
529 | " 5 | \n",
530 | " 6 | \n",
531 | " 7 | \n",
532 | " 8 | \n",
533 | " 9 | \n",
534 | " 10 | \n",
535 | " 11 | \n",
536 | " 12 | \n",
537 | " 13 | \n",
538 | " 14 | \n",
539 | " 15 | \n",
540 | " 16 | \n",
541 | " 17 | \n",
542 | "
\n",
543 | " \n",
544 | " \n",
545 | " \n",
546 | " 0 | \n",
547 | " -inf | \n",
548 | " -inf | \n",
549 | " -inf | \n",
550 | " -inf | \n",
551 | " -inf | \n",
552 | " 0.000000 | \n",
553 | " 0.000000 | \n",
554 | " 0.000000 | \n",
555 | " -inf | \n",
556 | " -inf | \n",
557 | " -inf | \n",
558 | " -inf | \n",
559 | " -inf | \n",
560 | " -inf | \n",
561 | " -inf | \n",
562 | " -inf | \n",
563 | " -inf | \n",
564 | " -inf | \n",
565 | "
\n",
566 | " \n",
567 | " 1 | \n",
568 | " -inf | \n",
569 | " -inf | \n",
570 | " -inf | \n",
571 | " -inf | \n",
572 | " -inf | \n",
573 | " -inf | \n",
574 | " 0.000000 | \n",
575 | " 0.000000 | \n",
576 | " 0.000000 | \n",
577 | " -inf | \n",
578 | " -inf | \n",
579 | " -inf | \n",
580 | " -inf | \n",
581 | " -inf | \n",
582 | " -inf | \n",
583 | " -inf | \n",
584 | " -inf | \n",
585 | " -inf | \n",
586 | "
\n",
587 | " \n",
588 | " 2 | \n",
589 | " -inf | \n",
590 | " -inf | \n",
591 | " -inf | \n",
592 | " -inf | \n",
593 | " -inf | \n",
594 | " -inf | \n",
595 | " -inf | \n",
596 | " 0.000000 | \n",
597 | " 0.000000 | \n",
598 | " 0.000000 | \n",
599 | " -inf | \n",
600 | " -inf | \n",
601 | " -inf | \n",
602 | " -inf | \n",
603 | " -inf | \n",
604 | " -inf | \n",
605 | " -inf | \n",
606 | " -inf | \n",
607 | "
\n",
608 | " \n",
609 | " 3 | \n",
610 | " -inf | \n",
611 | " -inf | \n",
612 | " -inf | \n",
613 | " -inf | \n",
614 | " -inf | \n",
615 | " -inf | \n",
616 | " -inf | \n",
617 | " -inf | \n",
618 | " -inf | \n",
619 | " -inf | \n",
620 | " -inf | \n",
621 | " 0.000000 | \n",
622 | " 0.000000 | \n",
623 | " 0.000000 | \n",
624 | " -inf | \n",
625 | " -inf | \n",
626 | " -inf | \n",
627 | " -inf | \n",
628 | "
\n",
629 | " \n",
630 | " 4 | \n",
631 | " -inf | \n",
632 | " -inf | \n",
633 | " -inf | \n",
634 | " -inf | \n",
635 | " -inf | \n",
636 | " -inf | \n",
637 | " -inf | \n",
638 | " -inf | \n",
639 | " -inf | \n",
640 | " -inf | \n",
641 | " -inf | \n",
642 | " -inf | \n",
643 | " 0.000000 | \n",
644 | " 0.000000 | \n",
645 | " 0.000000 | \n",
646 | " -inf | \n",
647 | " -inf | \n",
648 | " -inf | \n",
649 | "
\n",
650 | " \n",
651 | " 5 | \n",
652 | " -inf | \n",
653 | " -inf | \n",
654 | " -inf | \n",
655 | " -inf | \n",
656 | " -inf | \n",
657 | " -inf | \n",
658 | " -inf | \n",
659 | " -inf | \n",
660 | " -inf | \n",
661 | " -inf | \n",
662 | " -inf | \n",
663 | " -inf | \n",
664 | " -inf | \n",
665 | " 0.000000 | \n",
666 | " 0.000000 | \n",
667 | " 0.000000 | \n",
668 | " -inf | \n",
669 | " -inf | \n",
670 | "
\n",
671 | " \n",
672 | " 6 | \n",
673 | " -inf | \n",
674 | " -inf | \n",
675 | " -inf | \n",
676 | " -inf | \n",
677 | " -inf | \n",
678 | " -inf | \n",
679 | " -inf | \n",
680 | " -inf | \n",
681 | " -inf | \n",
682 | " -inf | \n",
683 | " -inf | \n",
684 | " -inf | \n",
685 | " -inf | \n",
686 | " -inf | \n",
687 | " 0.000000 | \n",
688 | " 0.000000 | \n",
689 | " 0.000000 | \n",
690 | " -inf | \n",
691 | "
\n",
692 | " \n",
693 | " 7 | \n",
694 | " -inf | \n",
695 | " -inf | \n",
696 | " -inf | \n",
697 | " -inf | \n",
698 | " -inf | \n",
699 | " -inf | \n",
700 | " -inf | \n",
701 | " -inf | \n",
702 | " -inf | \n",
703 | " -inf | \n",
704 | " -inf | \n",
705 | " -inf | \n",
706 | " -inf | \n",
707 | " -inf | \n",
708 | " -inf | \n",
709 | " 0.000000 | \n",
710 | " 0.000000 | \n",
711 | " 0.000000 | \n",
712 | "
\n",
713 | " \n",
714 | "
\n"
715 | ],
716 | "text/plain": [
717 | ""
718 | ]
719 | },
720 | "execution_count": 5,
721 | "metadata": {},
722 | "output_type": "execute_result"
723 | }
724 | ],
725 | "source": [
726 | "## BlockDiagonalMask\n",
727 | "\n",
728 | "q_seqlens = [3, 5]\n",
729 | "kv_seqlens = [10, 8] # (3 + 7, 5 + 3)\n",
730 | "sliding_window_size = 3\n",
731 | "\n",
732 | "mask = BlockDiagonalMask.from_seqlens(q_seqlens, kv_seqlens).make_local_attention_from_bottomright(sliding_window_size)\n",
733 | "\n",
734 | "batch_size = 1\n",
735 | "total_seq_len = sum(q_seqlens)\n",
736 | "total_kv_seq_len = sum(kv_seqlens)\n",
737 | "mask_tensor = mask.materialize((batch_size, total_seq_len, total_kv_seq_len))\n",
738 | "\n",
739 | "df = pd.DataFrame(mask_tensor[0, :, :].numpy())\n",
740 | "df.style.applymap(colour_cell)"
741 | ]
742 | },
743 | {
744 | "cell_type": "code",
745 | "execution_count": 6,
746 | "metadata": {},
747 | "outputs": [
748 | {
749 | "name": "stderr",
750 | "output_type": "stream",
751 | "text": [
752 | "/tmp/ipykernel_17144/2637469656.py:16: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.\n",
753 | " df.style.applymap(colour_cell)\n"
754 | ]
755 | },
756 | {
757 | "data": {
758 | "text/html": [
759 | "\n",
767 | "\n",
768 | " \n",
769 | " \n",
770 | " | \n",
771 | " 0 | \n",
772 | " 1 | \n",
773 | " 2 | \n",
774 | " 3 | \n",
775 | " 4 | \n",
776 | " 5 | \n",
777 | " 6 | \n",
778 | " 7 | \n",
779 | " 8 | \n",
780 | " 9 | \n",
781 | " 10 | \n",
782 | " 11 | \n",
783 | "
\n",
784 | " \n",
785 | " \n",
786 | " \n",
787 | " 0 | \n",
788 | " 0.000000 | \n",
789 | " 0.000000 | \n",
790 | " 0.000000 | \n",
791 | " -inf | \n",
792 | " -inf | \n",
793 | " -inf | \n",
794 | " -inf | \n",
795 | " -inf | \n",
796 | " -inf | \n",
797 | " -inf | \n",
798 | " -inf | \n",
799 | " -inf | \n",
800 | "
\n",
801 | " \n",
802 | " 1 | \n",
803 | " -inf | \n",
804 | " -inf | \n",
805 | " -inf | \n",
806 | " -inf | \n",
807 | " -inf | \n",
808 | " -inf | \n",
809 | " 0.000000 | \n",
810 | " 0.000000 | \n",
811 | " 0.000000 | \n",
812 | " 0.000000 | \n",
813 | " 0.000000 | \n",
814 | " -inf | \n",
815 | "
\n",
816 | " \n",
817 | "
\n"
818 | ],
819 | "text/plain": [
820 | ""
821 | ]
822 | },
823 | "execution_count": 6,
824 | "metadata": {},
825 | "output_type": "execute_result"
826 | }
827 | ],
828 | "source": [
829 | "## BlockDiagonalCausalWithOffsetPaddedKeysMask\n",
830 | "\n",
831 | "# We use this mask with padding because the overall size of the KV-Cache is the same for all the prompts, but for each KV-Cache we may need to use only some of the items.\n",
832 | "\n",
833 | "q_seqlen = [1, 1]\n",
834 | "kv_seq_len = [3, 5]\n",
835 | "kv_padding = 6\n",
836 | "\n",
837 | "mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(q_seqlen=q_seqlen, kv_padding=kv_padding, kv_seqlen=kv_seq_len)\n",
838 | "\n",
839 | "batch_size = 1\n",
840 | "total_seq_len = sum(q_seqlen)\n",
841 | "total_kv_seq_len = kv_padding * len(kv_seq_len)\n",
842 | "\n",
843 | "mask_tensor = mask.materialize((batch_size, total_seq_len, total_kv_seq_len))\n",
844 | "\n",
845 | "df = pd.DataFrame(mask_tensor[0, :, :].numpy())\n",
846 | "df.style.applymap(colour_cell)"
847 | ]
848 | },
849 | {
850 | "cell_type": "code",
851 | "execution_count": null,
852 | "metadata": {},
853 | "outputs": [],
854 | "source": []
855 | }
856 | ],
857 | "metadata": {
858 | "kernelspec": {
859 | "display_name": "pytorch-mistral",
860 | "language": "python",
861 | "name": "python3"
862 | },
863 | "language_info": {
864 | "codemirror_mode": {
865 | "name": "ipython",
866 | "version": 3
867 | },
868 | "file_extension": ".py",
869 | "mimetype": "text/x-python",
870 | "name": "python",
871 | "nbconvert_exporter": "python",
872 | "pygments_lexer": "ipython3",
873 | "version": "3.9.18"
874 | }
875 | },
876 | "nbformat": 4,
877 | "nbformat_minor": 2
878 | }
879 |
--------------------------------------------------------------------------------