"
198 | ]
199 | },
200 | "metadata": {},
201 | "output_type": "display_data"
202 | }
203 | ],
204 | "execution_count": 5
205 | },
206 | {
207 | "metadata": {},
208 | "cell_type": "markdown",
209 | "source": "For the purpose of our classification model, we shall employ the **encoder** part of the architecture, which constitutes the left-hand part in the image above. First, the embedded input samples enter the multi-head attention layer, whose output is then summed with the original input coming through a residual connection. Following a normalization, the tensor enters a fully connected segment containing two Dense layers. The output from this dense projection is then added to the input tensor via a residual connection, and normalized once more to produce the final output of the Transformer encoder.",
210 | "id": "f3304f680edc0de6"
211 | },
212 | {
213 | "metadata": {
214 | "ExecuteTime": {
215 | "end_time": "2024-11-27T13:53:23.773292Z",
216 | "start_time": "2024-11-27T13:53:23.764666Z"
217 | }
218 | },
219 | "cell_type": "code",
220 | "source": [
221 | "class TransformerEncoder(layers.Layer):\n",
222 | " def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):\n",
223 | " super().__init__(**kwargs)\n",
224 | " self.embed_dim = embed_dim\n",
225 | " self.dense_dim = dense_dim\n",
226 | " self.num_heads = num_heads\n",
227 | " self.attention = layers.MultiHeadAttention(\n",
228 | " num_heads=num_heads, key_dim=embed_dim\n",
229 | " )\n",
230 | " self.dense_proj = keras.Sequential(\n",
231 | " [\n",
232 | " layers.Dense(dense_dim, activation=\"relu\"),\n",
233 | " layers.Dense(embed_dim),\n",
234 | " ]\n",
235 | " )\n",
236 | " self.layernorm_1 = layers.LayerNormalization()\n",
237 | " self.layernorm_2 = layers.LayerNormalization()\n",
238 | " self.supports_masking = True\n",
239 | "\n",
240 | " def call(self, inputs, mask=None):\n",
241 | " if mask is not None:\n",
242 | " padding_mask = ops.cast(mask[:, None, :], dtype=\"int32\")\n",
243 | " else:\n",
244 | " padding_mask = None\n",
245 | "\n",
246 | " attention_output = self.attention(\n",
247 | " query=inputs, value=inputs, key=inputs, attention_mask=padding_mask\n",
248 | " )\n",
249 | " proj_input = self.layernorm_1(inputs + attention_output)\n",
250 | " proj_output = self.dense_proj(proj_input)\n",
251 | " return self.layernorm_2(proj_input + proj_output)\n",
252 | "\n",
253 | " def get_config(self):\n",
254 | " config = super().get_config()\n",
255 | " config.update(\n",
256 | " {\n",
257 | " \"embed_dim\": self.embed_dim,\n",
258 | " \"dense_dim\": self.dense_dim,\n",
259 | " \"num_heads\": self.num_heads,\n",
260 | " }\n",
261 | " )\n",
262 | " return config\n"
263 | ],
264 | "id": "b258aa4d3d86c12d",
265 | "outputs": [],
266 | "execution_count": 6
267 | },
268 | {
269 | "metadata": {},
270 | "cell_type": "markdown",
271 | "source": "We are now ready to build the actual classifier model. The output from the encoder is flattened by a global pooling layer, and then fed straight into the output layer of a single neuron with sigmoid activation.",
272 | "id": "5bf1eb515995b216"
273 | },
274 | {
275 | "metadata": {
276 | "ExecuteTime": {
277 | "end_time": "2024-11-27T13:53:24.672410Z",
278 | "start_time": "2024-11-27T13:53:23.773292Z"
279 | }
280 | },
281 | "cell_type": "code",
282 | "source": [
283 | "embed_dim = 32 # dimension of word embeddings (token + pos)\n",
284 | "dense_dim = 64 # \n",
285 | "num_heads = 2\n",
286 | "\n",
287 | "inputs = keras.Input(shape=(None,), dtype=\"int64\")\n",
288 | "x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)\n",
289 | "x = TransformerEncoder(embed_dim, dense_dim, num_heads)(x)\n",
290 | "x = layers.GlobalAveragePooling1D()(x)\n",
291 | "x = layers.Dropout(0.5)(x)\n",
292 | "outputs = layers.Dense(1, activation='sigmoid')(x)\n",
293 | "\n",
294 | "model = keras.Model(inputs=inputs, outputs=outputs)\n",
295 | "\n",
296 | "model.compile(optimizer=keras.optimizers.RMSprop(), \n",
297 | " loss='binary_crossentropy', \n",
298 | " metrics=['accuracy'])\n",
299 | "\n",
300 | "model.summary()"
301 | ],
302 | "id": "8ecae66d00669c53",
303 | "outputs": [
304 | {
305 | "name": "stdout",
306 | "output_type": "stream",
307 | "text": [
308 | "WARNING:tensorflow:From C:\\Users\\kopuj\\Anaconda3\\envs\\keras-cpu\\Lib\\site-packages\\keras\\src\\backend\\tensorflow\\core.py:204: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
309 | "\n"
310 | ]
311 | },
312 | {
313 | "data": {
314 | "text/plain": [
315 | "\u001B[1mModel: \"functional_1\"\u001B[0m\n"
316 | ],
317 | "text/html": [
318 | "Model: \"functional_1\"\n",
319 | "
\n"
320 | ]
321 | },
322 | "metadata": {},
323 | "output_type": "display_data"
324 | },
325 | {
326 | "data": {
327 | "text/plain": [
328 | "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
329 | "┃\u001B[1m \u001B[0m\u001B[1mLayer (type) \u001B[0m\u001B[1m \u001B[0m┃\u001B[1m \u001B[0m\u001B[1mOutput Shape \u001B[0m\u001B[1m \u001B[0m┃\u001B[1m \u001B[0m\u001B[1m Param #\u001B[0m\u001B[1m \u001B[0m┃\u001B[1m \u001B[0m\u001B[1mConnected to \u001B[0m\u001B[1m \u001B[0m┃\n",
330 | "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
331 | "│ input_layer │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;45mNone\u001B[0m) │ \u001B[38;5;34m0\u001B[0m │ - │\n",
332 | "│ (\u001B[38;5;33mInputLayer\u001B[0m) │ │ │ │\n",
333 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
334 | "│ positional_embeddi… │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;45mNone\u001B[0m, \u001B[38;5;34m32\u001B[0m) │ \u001B[38;5;34m328,000\u001B[0m │ input_layer[\u001B[38;5;34m0\u001B[0m][\u001B[38;5;34m0\u001B[0m] │\n",
335 | "│ (\u001B[38;5;33mPositionalEmbeddi…\u001B[0m │ │ │ │\n",
336 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
337 | "│ not_equal │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;45mNone\u001B[0m) │ \u001B[38;5;34m0\u001B[0m │ input_layer[\u001B[38;5;34m0\u001B[0m][\u001B[38;5;34m0\u001B[0m] │\n",
338 | "│ (\u001B[38;5;33mNotEqual\u001B[0m) │ │ │ │\n",
339 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
340 | "│ transformer_encoder │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;45mNone\u001B[0m, \u001B[38;5;34m32\u001B[0m) │ \u001B[38;5;34m12,736\u001B[0m │ positional_embed… │\n",
341 | "│ (\u001B[38;5;33mTransformerEncode…\u001B[0m │ │ │ not_equal[\u001B[38;5;34m0\u001B[0m][\u001B[38;5;34m0\u001B[0m] │\n",
342 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
343 | "│ global_average_poo… │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;34m32\u001B[0m) │ \u001B[38;5;34m0\u001B[0m │ transformer_enco… │\n",
344 | "│ (\u001B[38;5;33mGlobalAveragePool…\u001B[0m │ │ │ not_equal[\u001B[38;5;34m0\u001B[0m][\u001B[38;5;34m0\u001B[0m] │\n",
345 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
346 | "│ dropout_1 (\u001B[38;5;33mDropout\u001B[0m) │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;34m32\u001B[0m) │ \u001B[38;5;34m0\u001B[0m │ global_average_p… │\n",
347 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
348 | "│ dense_2 (\u001B[38;5;33mDense\u001B[0m) │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;34m1\u001B[0m) │ \u001B[38;5;34m33\u001B[0m │ dropout_1[\u001B[38;5;34m0\u001B[0m][\u001B[38;5;34m0\u001B[0m] │\n",
349 | "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
350 | ],
351 | "text/html": [
352 | "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
353 | "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n",
354 | "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
355 | "│ input_layer │ (None, None) │ 0 │ - │\n",
356 | "│ (InputLayer) │ │ │ │\n",
357 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
358 | "│ positional_embeddi… │ (None, None, 32) │ 328,000 │ input_layer[0][0] │\n",
359 | "│ (PositionalEmbeddi… │ │ │ │\n",
360 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
361 | "│ not_equal │ (None, None) │ 0 │ input_layer[0][0] │\n",
362 | "│ (NotEqual) │ │ │ │\n",
363 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
364 | "│ transformer_encoder │ (None, None, 32) │ 12,736 │ positional_embed… │\n",
365 | "│ (TransformerEncode… │ │ │ not_equal[0][0] │\n",
366 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
367 | "│ global_average_poo… │ (None, 32) │ 0 │ transformer_enco… │\n",
368 | "│ (GlobalAveragePool… │ │ │ not_equal[0][0] │\n",
369 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
370 | "│ dropout_1 (Dropout) │ (None, 32) │ 0 │ global_average_p… │\n",
371 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
372 | "│ dense_2 (Dense) │ (None, 1) │ 33 │ dropout_1[0][0] │\n",
373 | "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
374 | "
\n"
375 | ]
376 | },
377 | "metadata": {},
378 | "output_type": "display_data"
379 | },
380 | {
381 | "data": {
382 | "text/plain": [
383 | "\u001B[1m Total params: \u001B[0m\u001B[38;5;34m340,769\u001B[0m (1.30 MB)\n"
384 | ],
385 | "text/html": [
386 | " Total params: 340,769 (1.30 MB)\n",
387 | "
\n"
388 | ]
389 | },
390 | "metadata": {},
391 | "output_type": "display_data"
392 | },
393 | {
394 | "data": {
395 | "text/plain": [
396 | "\u001B[1m Trainable params: \u001B[0m\u001B[38;5;34m340,769\u001B[0m (1.30 MB)\n"
397 | ],
398 | "text/html": [
399 | " Trainable params: 340,769 (1.30 MB)\n",
400 | "
\n"
401 | ]
402 | },
403 | "metadata": {},
404 | "output_type": "display_data"
405 | },
406 | {
407 | "data": {
408 | "text/plain": [
409 | "\u001B[1m Non-trainable params: \u001B[0m\u001B[38;5;34m0\u001B[0m (0.00 B)\n"
410 | ],
411 | "text/html": [
412 | " Non-trainable params: 0 (0.00 B)\n",
413 | "
\n"
414 | ]
415 | },
416 | "metadata": {},
417 | "output_type": "display_data"
418 | }
419 | ],
420 | "execution_count": 7
421 | },
422 | {
423 | "metadata": {},
424 | "cell_type": "markdown",
425 | "source": "Train the model ...",
426 | "id": "dc117025b5a0faab"
427 | },
428 | {
429 | "metadata": {
430 | "ExecuteTime": {
431 | "end_time": "2024-11-27T13:54:25.724615Z",
432 | "start_time": "2024-11-27T13:53:24.672410Z"
433 | }
434 | },
435 | "cell_type": "code",
436 | "source": [
437 | "history = model.fit(\n",
438 | " train_ds_int, \n",
439 | " batch_size=batch_size, \n",
440 | " epochs=3, \n",
441 | " validation_data=(val_ds_int)\n",
442 | ")\n"
443 | ],
444 | "id": "d0bd0b59acca4b53",
445 | "outputs": [
446 | {
447 | "name": "stdout",
448 | "output_type": "stream",
449 | "text": [
450 | "Epoch 1/3\n"
451 | ]
452 | },
453 | {
454 | "name": "stderr",
455 | "output_type": "stream",
456 | "text": [
457 | "C:\\Users\\kopuj\\Anaconda3\\envs\\keras-cpu\\Lib\\site-packages\\keras\\src\\layers\\layer.py:932: UserWarning: Layer 'query' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask.\n",
458 | " warnings.warn(\n",
459 | "C:\\Users\\kopuj\\Anaconda3\\envs\\keras-cpu\\Lib\\site-packages\\keras\\src\\layers\\layer.py:932: UserWarning: Layer 'key' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask.\n",
460 | " warnings.warn(\n",
461 | "C:\\Users\\kopuj\\Anaconda3\\envs\\keras-cpu\\Lib\\site-packages\\keras\\src\\layers\\layer.py:932: UserWarning: Layer 'value' (of type EinsumDense) was passed an input with a mask attached to it. However, this layer does not support masking and will therefore destroy the mask information. Downstream layers will not see the mask.\n",
462 | " warnings.warn(\n"
463 | ]
464 | },
465 | {
466 | "name": "stdout",
467 | "output_type": "stream",
468 | "text": [
469 | "\u001B[1m625/625\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m21s\u001B[0m 30ms/step - accuracy: 0.6883 - loss: 0.5721 - val_accuracy: 0.7454 - val_loss: 0.5455\n",
470 | "Epoch 2/3\n",
471 | "\u001B[1m625/625\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m19s\u001B[0m 30ms/step - accuracy: 0.8607 - loss: 0.3280 - val_accuracy: 0.8264 - val_loss: 0.4016\n",
472 | "Epoch 3/3\n",
473 | "\u001B[1m625/625\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m21s\u001B[0m 33ms/step - accuracy: 0.8887 - loss: 0.2775 - val_accuracy: 0.8722 - val_loss: 0.3088\n"
474 | ]
475 | }
476 | ],
477 | "execution_count": 8
478 | },
479 | {
480 | "metadata": {},
481 | "cell_type": "markdown",
482 | "source": "... and test it.",
483 | "id": "d1a6a3f6a109812e"
484 | },
485 | {
486 | "metadata": {
487 | "ExecuteTime": {
488 | "end_time": "2024-11-27T13:54:39.310606Z",
489 | "start_time": "2024-11-27T13:54:25.724615Z"
490 | }
491 | },
492 | "cell_type": "code",
493 | "source": "model.evaluate(test_ds_int)",
494 | "id": "f3c097972c6aea9a",
495 | "outputs": [
496 | {
497 | "name": "stdout",
498 | "output_type": "stream",
499 | "text": [
500 | "\u001B[1m782/782\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m14s\u001B[0m 17ms/step - accuracy: 0.8578 - loss: 0.3268\n"
501 | ]
502 | },
503 | {
504 | "data": {
505 | "text/plain": [
506 | "[0.3285871744155884, 0.8590400218963623]"
507 | ]
508 | },
509 | "execution_count": 9,
510 | "metadata": {},
511 | "output_type": "execute_result"
512 | }
513 | ],
514 | "execution_count": 9
515 | },
516 | {
517 | "metadata": {},
518 | "cell_type": "markdown",
519 | "source": [
520 | "## Text generation\n",
521 | "\n",
522 | "In addition to simple tasks like classification, Transformer architecture models can also be used for more ambitious natural language processing tasks, such as machine translation and text generation. Here we shall take a look at how a neural network model can be trained to produce movie reviews from an initial prompt. \n",
523 | "\n",
524 | "We begin by creating a new Dataset object from the data in the training directory only. This time, we do not need the sentiment labels, and remove the line break tokens in order to avoid generating those later."
525 | ],
526 | "id": "3cb96f236d7cd3fc"
527 | },
528 | {
529 | "metadata": {
530 | "ExecuteTime": {
531 | "end_time": "2024-11-27T13:54:41.067718Z",
532 | "start_time": "2024-11-27T13:54:39.310606Z"
533 | }
534 | },
535 | "cell_type": "code",
536 | "source": [
537 | "dataset = keras.utils.text_dataset_from_directory(\n",
538 | " directory=\"../../aclImdb/train/\", \n",
539 | " label_mode=None, \n",
540 | " batch_size=256)\n",
541 | "\n",
542 | "dataset = dataset.map(lambda x: tf.strings.regex_replace(x, \"
\", \"\"))"
543 | ],
544 | "id": "24c15dc8aa7bf7f1",
545 | "outputs": [
546 | {
547 | "name": "stdout",
548 | "output_type": "stream",
549 | "text": [
550 | "Found 25000 files.\n"
551 | ]
552 | }
553 | ],
554 | "execution_count": 10
555 | },
556 | {
557 | "metadata": {},
558 | "cell_type": "markdown",
559 | "source": "Next, we convert the strings to integer lists, as before. To speed up training, we restrict the sequence lengths to be somewhat shorter than before.",
560 | "id": "4a8e929d6bae3c27"
561 | },
562 | {
563 | "metadata": {
564 | "ExecuteTime": {
565 | "end_time": "2024-11-27T13:54:47.026413Z",
566 | "start_time": "2024-11-27T13:54:41.067718Z"
567 | }
568 | },
569 | "cell_type": "code",
570 | "source": [
571 | "sequence_length = 100\n",
572 | "vocab_size = 10000\n",
573 | "text_vectorization = layers.TextVectorization(\n",
574 | " max_tokens=vocab_size,\n",
575 | " output_mode=\"int\",\n",
576 | " output_sequence_length=sequence_length,\n",
577 | ")\n",
578 | "text_vectorization.adapt(dataset)"
579 | ],
580 | "id": "3d7856bb7ce47196",
581 | "outputs": [],
582 | "execution_count": 11
583 | },
584 | {
585 | "metadata": {},
586 | "cell_type": "markdown",
587 | "source": "Next, we generate specialized targets for the purpose of training the model: the targets are simply the same integer sequences as the samples, but shifted one token to the right (this means that the sequence length gets reduced by one token from the original ones). ",
588 | "id": "36aa3103da705b07"
589 | },
590 | {
591 | "metadata": {
592 | "ExecuteTime": {
593 | "end_time": "2024-11-27T13:54:47.100553Z",
594 | "start_time": "2024-11-27T13:54:47.026413Z"
595 | }
596 | },
597 | "cell_type": "code",
598 | "source": [
599 | "def prepare_lm_dataset(text_batch):\n",
600 | " vectorized_sequences = text_vectorization(text_batch)\n",
601 | " x = vectorized_sequences[:, :-1]\n",
602 | " y = vectorized_sequences[:, 1:]\n",
603 | " return x, y\n",
604 | "\n",
605 | "lm_dataset = dataset.map(prepare_lm_dataset)"
606 | ],
607 | "id": "f9841b801b257f8c",
608 | "outputs": [],
609 | "execution_count": 12
610 | },
611 | {
612 | "metadata": {},
613 | "cell_type": "markdown",
614 | "source": [
615 | "For the actual text generation, we need to employ the **decoder** part of the Transformer (the right-hand side of the architecture image above). The building blocks of the decoder are very similar to those of the encoder, but with a couple of differences. First, there are two separate attention layers; the first layer takes the embedded inputs in as query, key and value. The second attention layer takes the output of the first one in as the query. In the case of a machine translation task the key and value would come from in from the encoder outputs generated by the source sequence to be translated; here, however, there is no such separate source sequence, but the keys and values are provided by the original inputs to the decoder.\n",
616 | "\n",
617 | "Another important detail is the **causal mask**, which prevents the generated from attending to future words, when predicting a given word in a sequence. The implementation of this, as well as all the other essentials of the code are from F. Chollet: **Deep Learning with Python** (Chapter 12), with migration guidelines to Keras 3 in [this example](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)."
618 | ],
619 | "id": "e5f17086db0e2ea1"
620 | },
621 | {
622 | "metadata": {
623 | "ExecuteTime": {
624 | "end_time": "2024-11-27T13:54:47.108871Z",
625 | "start_time": "2024-11-27T13:54:47.100553Z"
626 | }
627 | },
628 | "cell_type": "code",
629 | "source": [
630 | "class TransformerDecoder(layers.Layer):\n",
631 | " def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):\n",
632 | " super().__init__(**kwargs)\n",
633 | " self.embed_dim = embed_dim\n",
634 | " self.latent_dim = latent_dim\n",
635 | " self.num_heads = num_heads\n",
636 | " self.attention_1 = layers.MultiHeadAttention(\n",
637 | " num_heads=num_heads, key_dim=embed_dim\n",
638 | " )\n",
639 | " self.attention_2 = layers.MultiHeadAttention(\n",
640 | " num_heads=num_heads, key_dim=embed_dim\n",
641 | " )\n",
642 | " self.dense_proj = keras.Sequential(\n",
643 | " [\n",
644 | " layers.Dense(latent_dim, activation=\"relu\"),\n",
645 | " layers.Dense(embed_dim),\n",
646 | " ]\n",
647 | " )\n",
648 | " self.layernorm_1 = layers.LayerNormalization()\n",
649 | " self.layernorm_2 = layers.LayerNormalization()\n",
650 | " self.layernorm_3 = layers.LayerNormalization()\n",
651 | " self.supports_masking = True\n",
652 | "\n",
653 | " def call(self, inputs, encoder_outputs, mask=None):\n",
654 | " causal_mask = self.get_causal_attention_mask(inputs)\n",
655 | " if mask is not None:\n",
656 | " padding_mask = ops.cast(mask[:, None, :], dtype=\"int32\")\n",
657 | " padding_mask = ops.minimum(padding_mask, causal_mask)\n",
658 | " else:\n",
659 | " padding_mask = None\n",
660 | "\n",
661 | " attention_output_1 = self.attention_1(\n",
662 | " query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n",
663 | " )\n",
664 | " out_1 = self.layernorm_1(inputs + attention_output_1)\n",
665 | "\n",
666 | " attention_output_2 = self.attention_2(\n",
667 | " query=out_1,\n",
668 | " value=encoder_outputs,\n",
669 | " key=encoder_outputs,\n",
670 | " attention_mask=padding_mask,\n",
671 | " )\n",
672 | " out_2 = self.layernorm_2(out_1 + attention_output_2)\n",
673 | "\n",
674 | " proj_output = self.dense_proj(out_2)\n",
675 | " return self.layernorm_3(out_2 + proj_output)\n",
676 | "\n",
677 | " def get_causal_attention_mask(self, inputs):\n",
678 | " input_shape = ops.shape(inputs)\n",
679 | " batch_size, sequence_length = input_shape[0], input_shape[1]\n",
680 | " i = ops.arange(sequence_length)[:, None]\n",
681 | " j = ops.arange(sequence_length)\n",
682 | " mask = ops.cast(i >= j, dtype=\"int32\")\n",
683 | " mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))\n",
684 | " mult = ops.concatenate(\n",
685 | " [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],\n",
686 | " axis=0,\n",
687 | " )\n",
688 | " return ops.tile(mask, mult)\n",
689 | "\n",
690 | " def get_config(self):\n",
691 | " config = super().get_config()\n",
692 | " config.update(\n",
693 | " {\n",
694 | " \"embed_dim\": self.embed_dim,\n",
695 | " \"latent_dim\": self.latent_dim,\n",
696 | " \"num_heads\": self.num_heads,\n",
697 | " }\n",
698 | " )\n",
699 | " return config"
700 | ],
701 | "id": "400ba417d38ab6f2",
702 | "outputs": [],
703 | "execution_count": 13
704 | },
705 | {
706 | "metadata": {},
707 | "cell_type": "markdown",
708 | "source": "",
709 | "id": "3df06d1bab9a625b"
710 | },
711 | {
712 | "metadata": {},
713 | "cell_type": "markdown",
714 | "source": "We are now ready to build the model with the Transformer decoder. Its output is directly connected to the output layer with the size of the vocabulary and softmax activation, to provide a probability distribution for token predictions. Note that we choose `sparse_categorical_crossentropy` as our loss function, because the target labels consist of integers (instead of being one-hot encoded).",
715 | "id": "d486ccb406b67d7d"
716 | },
717 | {
718 | "metadata": {
719 | "ExecuteTime": {
720 | "end_time": "2024-11-27T13:54:47.356113Z",
721 | "start_time": "2024-11-27T13:54:47.108871Z"
722 | }
723 | },
724 | "cell_type": "code",
725 | "source": [
726 | "embed_dim = 32\n",
727 | "latent_dim = 64\n",
728 | "num_heads = 2\n",
729 | "\n",
730 | "inputs = keras.Input(shape=(None,), dtype=\"int64\")\n",
731 | "x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)\n",
732 | "x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, x)\n",
733 | "outputs = layers.Dense(vocab_size, activation=\"softmax\")(x)\n",
734 | "\n",
735 | "model = keras.Model(inputs, outputs)\n",
736 | "\n",
737 | "model.compile(loss=\"sparse_categorical_crossentropy\", \n",
738 | " optimizer=keras.optimizers.RMSprop())\n",
739 | "\n",
740 | "model.summary()"
741 | ],
742 | "id": "1b51e72262245b76",
743 | "outputs": [
744 | {
745 | "data": {
746 | "text/plain": [
747 | "\u001B[1mModel: \"functional_3\"\u001B[0m\n"
748 | ],
749 | "text/html": [
750 | "Model: \"functional_3\"\n",
751 | "
\n"
752 | ]
753 | },
754 | "metadata": {},
755 | "output_type": "display_data"
756 | },
757 | {
758 | "data": {
759 | "text/plain": [
760 | "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
761 | "┃\u001B[1m \u001B[0m\u001B[1mLayer (type) \u001B[0m\u001B[1m \u001B[0m┃\u001B[1m \u001B[0m\u001B[1mOutput Shape \u001B[0m\u001B[1m \u001B[0m┃\u001B[1m \u001B[0m\u001B[1m Param #\u001B[0m\u001B[1m \u001B[0m┃\u001B[1m \u001B[0m\u001B[1mConnected to \u001B[0m\u001B[1m \u001B[0m┃\n",
762 | "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
763 | "│ input_layer_2 │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;45mNone\u001B[0m) │ \u001B[38;5;34m0\u001B[0m │ - │\n",
764 | "│ (\u001B[38;5;33mInputLayer\u001B[0m) │ │ │ │\n",
765 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
766 | "│ positional_embeddi… │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;45mNone\u001B[0m, \u001B[38;5;34m32\u001B[0m) │ \u001B[38;5;34m323,200\u001B[0m │ input_layer_2[\u001B[38;5;34m0\u001B[0m]… │\n",
767 | "│ (\u001B[38;5;33mPositionalEmbeddi…\u001B[0m │ │ │ │\n",
768 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
769 | "│ transformer_decoder │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;45mNone\u001B[0m, \u001B[38;5;34m32\u001B[0m) │ \u001B[38;5;34m21,216\u001B[0m │ positional_embed… │\n",
770 | "│ (\u001B[38;5;33mTransformerDecode…\u001B[0m │ │ │ positional_embed… │\n",
771 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
772 | "│ dense_5 (\u001B[38;5;33mDense\u001B[0m) │ (\u001B[38;5;45mNone\u001B[0m, \u001B[38;5;45mNone\u001B[0m, │ \u001B[38;5;34m330,000\u001B[0m │ transformer_deco… │\n",
773 | "│ │ \u001B[38;5;34m10000\u001B[0m) │ │ │\n",
774 | "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n"
775 | ],
776 | "text/html": [
777 | "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓\n",
778 | "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n",
779 | "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩\n",
780 | "│ input_layer_2 │ (None, None) │ 0 │ - │\n",
781 | "│ (InputLayer) │ │ │ │\n",
782 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
783 | "│ positional_embeddi… │ (None, None, 32) │ 323,200 │ input_layer_2[0]… │\n",
784 | "│ (PositionalEmbeddi… │ │ │ │\n",
785 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
786 | "│ transformer_decoder │ (None, None, 32) │ 21,216 │ positional_embed… │\n",
787 | "│ (TransformerDecode… │ │ │ positional_embed… │\n",
788 | "├─────────────────────┼───────────────────┼────────────┼───────────────────┤\n",
789 | "│ dense_5 (Dense) │ (None, None, │ 330,000 │ transformer_deco… │\n",
790 | "│ │ 10000) │ │ │\n",
791 | "└─────────────────────┴───────────────────┴────────────┴───────────────────┘\n",
792 | "
\n"
793 | ]
794 | },
795 | "metadata": {},
796 | "output_type": "display_data"
797 | },
798 | {
799 | "data": {
800 | "text/plain": [
801 | "\u001B[1m Total params: \u001B[0m\u001B[38;5;34m674,416\u001B[0m (2.57 MB)\n"
802 | ],
803 | "text/html": [
804 | " Total params: 674,416 (2.57 MB)\n",
805 | "
\n"
806 | ]
807 | },
808 | "metadata": {},
809 | "output_type": "display_data"
810 | },
811 | {
812 | "data": {
813 | "text/plain": [
814 | "\u001B[1m Trainable params: \u001B[0m\u001B[38;5;34m674,416\u001B[0m (2.57 MB)\n"
815 | ],
816 | "text/html": [
817 | " Trainable params: 674,416 (2.57 MB)\n",
818 | "
\n"
819 | ]
820 | },
821 | "metadata": {},
822 | "output_type": "display_data"
823 | },
824 | {
825 | "data": {
826 | "text/plain": [
827 | "\u001B[1m Non-trainable params: \u001B[0m\u001B[38;5;34m0\u001B[0m (0.00 B)\n"
828 | ],
829 | "text/html": [
830 | " Non-trainable params: 0 (0.00 B)\n",
831 | "
\n"
832 | ]
833 | },
834 | "metadata": {},
835 | "output_type": "display_data"
836 | }
837 | ],
838 | "execution_count": 14
839 | },
840 | {
841 | "metadata": {},
842 | "cell_type": "markdown",
843 | "source": "Now we can train the model.",
844 | "id": "a234182d70731a37"
845 | },
846 | {
847 | "metadata": {
848 | "ExecuteTime": {
849 | "end_time": "2024-11-27T14:51:48.432106Z",
850 | "start_time": "2024-11-27T14:35:09.876053Z"
851 | }
852 | },
853 | "cell_type": "code",
854 | "source": "model.fit(lm_dataset, epochs=10)",
855 | "id": "f65d340549ae25e4",
856 | "outputs": [
857 | {
858 | "name": "stdout",
859 | "output_type": "stream",
860 | "text": [
861 | "Epoch 1/10\n",
862 | "\u001B[1m98/98\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m95s\u001B[0m 967ms/step - loss: 5.3083\n",
863 | "Epoch 2/10\n",
864 | "\u001B[1m98/98\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m106s\u001B[0m 1s/step - loss: 5.2834\n",
865 | "Epoch 3/10\n",
866 | "\u001B[1m98/98\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m108s\u001B[0m 1s/step - loss: 5.2622\n",
867 | "Epoch 4/10\n",
868 | "\u001B[1m98/98\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m105s\u001B[0m 1s/step - loss: 5.2397\n",
869 | "Epoch 5/10\n",
870 | "\u001B[1m98/98\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m102s\u001B[0m 1s/step - loss: 5.2251\n",
871 | "Epoch 6/10\n",
872 | "\u001B[1m98/98\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m103s\u001B[0m 1s/step - loss: 5.2083\n",
873 | "Epoch 7/10\n",
874 | "\u001B[1m98/98\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m102s\u001B[0m 1s/step - loss: 5.1917\n",
875 | "Epoch 8/10\n",
876 | "\u001B[1m98/98\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m104s\u001B[0m 1s/step - loss: 5.1768\n",
877 | "Epoch 9/10\n",
878 | "\u001B[1m98/98\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m85s\u001B[0m 860ms/step - loss: 5.1650\n",
879 | "Epoch 10/10\n",
880 | "\u001B[1m98/98\u001B[0m \u001B[32m━━━━━━━━━━━━━━━━━━━━\u001B[0m\u001B[37m\u001B[0m \u001B[1m88s\u001B[0m 895ms/step - loss: 5.1521\n"
881 | ]
882 | },
883 | {
884 | "data": {
885 | "text/plain": [
886 | ""
887 | ]
888 | },
889 | "execution_count": 22,
890 | "metadata": {},
891 | "output_type": "execute_result"
892 | }
893 | ],
894 | "execution_count": 22
895 | },
896 | {
897 | "metadata": {},
898 | "cell_type": "markdown",
899 | "source": "Once the model has been trained, we can experiment with generating text from an initial seed prompt. Instead of always predicting the token with the highest probability in the predicted distribution, we scale the distribution with a parameter referred to as **temperature**: high values of temperature tend to flatten the probability distribution, which leads to somewhat more surprising choices for next tokens.",
900 | "id": "721ade4e1cf77948"
901 | },
902 | {
903 | "metadata": {
904 | "ExecuteTime": {
905 | "end_time": "2024-11-27T14:34:09.027122Z",
906 | "start_time": "2024-11-27T14:34:09.002197Z"
907 | }
908 | },
909 | "cell_type": "code",
910 | "source": [
911 | "tokens_index = dict(enumerate(text_vectorization.get_vocabulary()))\n",
912 | "\n",
913 | "def sample_next(predictions, temperature=1.0):\n",
914 | " predictions = np.asarray(predictions).astype(\"float64\")\n",
915 | " predictions = np.log(predictions) / temperature\n",
916 | " exp_preds = np.exp(predictions)\n",
917 | " predictions = exp_preds / np.sum(exp_preds)\n",
918 | " probas = np.random.multinomial(1, predictions, 1)\n",
919 | " return np.argmax(probas)"
920 | ],
921 | "id": "81505246136ec361",
922 | "outputs": [],
923 | "execution_count": 19
924 | },
925 | {
926 | "metadata": {},
927 | "cell_type": "markdown",
928 | "source": "Finally, we can define an initial prompt, and check out the quality of generated text. Unfortunately, due to the short training time, the results are fairly disappointing. ",
929 | "id": "3ced59f70f9d0ac1"
930 | },
931 | {
932 | "metadata": {
933 | "ExecuteTime": {
934 | "end_time": "2024-11-27T14:52:17.333110Z",
935 | "start_time": "2024-11-27T14:52:15.678818Z"
936 | }
937 | },
938 | "cell_type": "code",
939 | "source": [
940 | "temperature = 0.7\n",
941 | "\n",
942 | "sentence = \"in my view\"\n",
943 | "generate_length = 50\n",
944 | "for i in range(generate_length):\n",
945 | " tokenized_sentence = text_vectorization([sentence])\n",
946 | " predictions = model(tokenized_sentence)\n",
947 | " next_token = sample_next(predictions[0, i, :])\n",
948 | " sampled_token = tokens_index[next_token]\n",
949 | " sentence += \" \" + sampled_token\n",
950 | "print(sentence)"
951 | ],
952 | "id": "dab069e1cca9bc04",
953 | "outputs": [
954 | {
955 | "name": "stdout",
956 | "output_type": "stream",
957 | "text": [
958 | "in my view all foreign of the movies the film have sex or few hunting the years the greatest not buffs live here or in this [UNK] the film that way is the included land stadium the surf which film is all most a the soup deranged and of [UNK] any an the\n"
959 | ]
960 | }
961 | ],
962 | "execution_count": 24
963 | },
964 | {
965 | "metadata": {
966 | "ExecuteTime": {
967 | "end_time": "2024-11-27T13:56:05.941920Z",
968 | "start_time": "2024-11-27T13:56:05.939133Z"
969 | }
970 | },
971 | "cell_type": "code",
972 | "source": "",
973 | "id": "ccbfd75b84aea31b",
974 | "outputs": [],
975 | "execution_count": 17
976 | }
977 | ],
978 | "metadata": {
979 | "kernelspec": {
980 | "display_name": "Python 3",
981 | "language": "python",
982 | "name": "python3"
983 | },
984 | "language_info": {
985 | "codemirror_mode": {
986 | "name": "ipython",
987 | "version": 2
988 | },
989 | "file_extension": ".py",
990 | "mimetype": "text/x-python",
991 | "name": "python",
992 | "nbconvert_exporter": "python",
993 | "pygments_lexer": "ipython2",
994 | "version": "2.7.6"
995 | }
996 | },
997 | "nbformat": 4,
998 | "nbformat_minor": 5
999 | }
1000 |
--------------------------------------------------------------------------------