├── GRPO.py ├── GRPO_From_Scratch_Multi_GPU_DataParallel_Qwen_2_5_1_5B_Instruct.ipynb ├── GRPO_Qwen_0_5_Instruct.ipynb ├── README.md ├── byte_pair_encoding.ipynb ├── count_language_model.ipynb ├── document_classifier_with_LLMs_as_labelers.ipynb ├── embedding_vs_linear.py ├── emotion_GPT2_as_classifier.ipynb ├── emotion_GPT2_as_text_generator.ipynb ├── emotion_GPT2_as_text_generator_LoRA.ipynb ├── emotion_classifier_CNN.ipynb ├── emotion_classifier_LR.ipynb ├── instruct_GPT2.ipynb ├── news_RNN_language_model.ipynb ├── news_decoder_language_model.ipynb ├── quadratic_loss.py ├── sampling_method.ipynb ├── spotify_gemini_playlist.py └── wiki ├── GPU-rental.md ├── MoE.md ├── PyTorch.md ├── VLM.md ├── alignment.md ├── colabs.md ├── compression.md ├── corrections.md ├── deployment.md ├── distributed.md ├── embeddings.md ├── encoder-decoder.md ├── encoder.md ├── evaluation.md ├── function-calling.md ├── index.md ├── inference.md ├── math.md ├── merging.md ├── non-transformer.md ├── notebook-services.md ├── online-finetuning.md ├── overfitting.md ├── prompting.md ├── scaling.md ├── scripts.md ├── security.md ├── test.md └── tokenization.md /README.md: -------------------------------------------------------------------------------- 1 | # The Hundred-Page Language Models Book 2 | ![cover](https://github.com/user-attachments/assets/06a77a36-3022-4a21-a522-d5d213320bf0) 3 | -------------------------------------------------------------------------------- /byte_pair_encoding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "source": [ 16 | "
\n", 17 | "
\n", 18 | " \n", 19 | " \n", 20 | " \n", 26 | " \n", 31 | " \n", 32 | "
\n", 21 | " \n", 22 | " A notebook for The Hundred-Page Language Models Book by Andriy Burkov

\n", 23 | " Code repository: https://github.com/aburkov/theLMbook\n", 24 | "
\n", 25 | "
\n", 27 | " \n", 28 | " \"The\n", 29 | " \n", 30 | "
\n", 33 | "
\n", 34 | "
" 35 | ], 36 | "metadata": { 37 | "id": "tRiU14bopExj" 38 | } 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "source": [ 43 | "# Training the BPE model\n", 44 | "\n", 45 | "Below, we load the data and train the BPE model:" 46 | ], 47 | "metadata": { 48 | "id": "btXb1IGF0r6S" 49 | } 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "colab": { 56 | "base_uri": "https://localhost:8080/" 57 | }, 58 | "id": "yy0zjL_2ouOU", 59 | "outputId": "68b672d0-4d83-4dfa-bdfa-061bfd699373" 60 | }, 61 | "outputs": [ 62 | { 63 | "output_type": "stream", 64 | "name": "stdout", 65 | "text": [ 66 | "Downloading dataset from https://www.thelmbook.com/data/news...\n", 67 | "Download completed.\n", 68 | "Extracting files...\n", 69 | "Extracted train.txt\n", 70 | "Extracted test.txt\n", 71 | "Extraction completed.\n", 72 | "Training BPE tokenizer...\n", 73 | "Saving the tokenizer...\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "# Import required libraries\n", 79 | "import os # For file operations and path handling\n", 80 | "import urllib.request # For downloading files\n", 81 | "import tarfile # For extracting tar files\n", 82 | "import pickle # For saving/loading tokenizer\n", 83 | "import re # For regex in merge operations\n", 84 | "import time # For timing operations\n", 85 | "from collections import defaultdict # For counting tokens and pairs\n", 86 | "\n", 87 | "def download_file(url, filename):\n", 88 | " \"\"\"\n", 89 | " Downloads a file from a URL if it doesn't exist locally.\n", 90 | " Prevents redundant downloads by checking file existence.\n", 91 | "\n", 92 | " Args:\n", 93 | " url (str): URL to download the file from\n", 94 | " filename (str): Local path to save the downloaded file\n", 95 | "\n", 96 | " Returns:\n", 97 | " None: Prints status messages about download progress\n", 98 | " \"\"\"\n", 99 | " # Check if file already exists to avoid re-downloading\n", 100 | " if not os.path.exists(filename):\n", 101 | " print(f\"Downloading dataset from {url}...\")\n", 102 | " urllib.request.urlretrieve(url, filename)\n", 103 | " print(\"Download completed.\")\n", 104 | " else:\n", 105 | " print(f\"{filename} already downloaded.\")\n", 106 | "\n", 107 | "def is_within_directory(directory, target):\n", 108 | " \"\"\"\n", 109 | " Security check to prevent path traversal attacks by verifying target path.\n", 110 | " Ensures extracted files remain within the intended directory.\n", 111 | "\n", 112 | " Args:\n", 113 | " directory (str): Base directory path to check against\n", 114 | " target (str): Target path to validate\n", 115 | "\n", 116 | " Returns:\n", 117 | " bool: True if target is within directory, False otherwise\n", 118 | " \"\"\"\n", 119 | " # Convert both paths to absolute form for comparison\n", 120 | " abs_directory = os.path.abspath(directory)\n", 121 | " abs_target = os.path.abspath(target)\n", 122 | " # Get common prefix to check containment\n", 123 | " prefix = os.path.commonprefix([abs_directory, abs_target])\n", 124 | " return prefix == abs_directory\n", 125 | "\n", 126 | "def safe_extract_tar(tar_file, required_files):\n", 127 | " \"\"\"\n", 128 | " Safely extracts specific files from a tar archive with security checks.\n", 129 | " Prevents path traversal attacks and extracts only required files.\n", 130 | "\n", 131 | " Args:\n", 132 | " tar_file (str): Path to the tar archive file\n", 133 | " required_files (list): List of filenames to extract\n", 134 | "\n", 135 | " Returns:\n", 136 | " None: Extracts files and prints progress\n", 137 | "\n", 138 | " Raises:\n", 139 | " Exception: If path traversal attempt is detected\n", 140 | " \"\"\"\n", 141 | " with tarfile.open(tar_file, \"r:gz\") as tar:\n", 142 | " # Perform security check on all archive members\n", 143 | " for member in tar.getmembers():\n", 144 | " if not is_within_directory('.', member.name):\n", 145 | " raise Exception(\"Attempted Path Traversal in Tar File\")\n", 146 | "\n", 147 | " # Extract only the specified files\n", 148 | " for member in tar.getmembers():\n", 149 | " if any(member.name.endswith(file) for file in required_files):\n", 150 | " # Remove path prefix for safety\n", 151 | " member.name = os.path.basename(member.name)\n", 152 | " tar.extract(member, '.')\n", 153 | " print(f\"Extracted {member.name}\")\n", 154 | "\n", 155 | "def create_word_generator(filepath):\n", 156 | " \"\"\"\n", 157 | " Creates a generator that yields words from a text file one at a time.\n", 158 | " Memory efficient way to process large text files.\n", 159 | "\n", 160 | " Args:\n", 161 | " filepath (str): Path to text file to read\n", 162 | "\n", 163 | " Returns:\n", 164 | " generator: Yields individual words from the file\n", 165 | " \"\"\"\n", 166 | " def generator():\n", 167 | " with open(filepath, 'r') as f:\n", 168 | " for line in f:\n", 169 | " for word in line.split():\n", 170 | " yield word\n", 171 | " return generator()\n", 172 | "\n", 173 | "def download_and_prepare_data(url):\n", 174 | " \"\"\"\n", 175 | " Downloads, extracts, and prepares dataset for training.\n", 176 | " Handles both downloading and extraction with security checks.\n", 177 | "\n", 178 | " Args:\n", 179 | " url (str): URL of the dataset to download\n", 180 | "\n", 181 | " Returns:\n", 182 | " generator: Word generator for the training data\n", 183 | " \"\"\"\n", 184 | " required_files = [\"train.txt\", \"test.txt\"]\n", 185 | " filename = os.path.basename(url)\n", 186 | "\n", 187 | " # Download dataset if needed\n", 188 | " download_file(url, filename)\n", 189 | "\n", 190 | " # Extract required files if they don't exist\n", 191 | " if not all(os.path.exists(file) for file in required_files):\n", 192 | " print(\"Extracting files...\")\n", 193 | " safe_extract_tar(filename, required_files)\n", 194 | " print(\"Extraction completed.\")\n", 195 | " else:\n", 196 | " print(\"'train.txt' and 'test.txt' already extracted.\")\n", 197 | "\n", 198 | " # Create and return word generator\n", 199 | " return create_word_generator(\"train.txt\")\n", 200 | "\n", 201 | "def initialize_vocabulary(corpus):\n", 202 | " \"\"\"\n", 203 | " Creates initial vocabulary from corpus by splitting words into characters.\n", 204 | " Adds word boundary marker '_' and tracks unique characters.\n", 205 | "\n", 206 | " Args:\n", 207 | " corpus (iterable): Iterator or list of words to process\n", 208 | "\n", 209 | " Returns:\n", 210 | " tuple: (vocabulary dict mapping tokenized words to counts,\n", 211 | " set of unique characters in corpus)\n", 212 | " \"\"\"\n", 213 | " # Track word counts and unique characters\n", 214 | " vocabulary = defaultdict(int)\n", 215 | " charset = set()\n", 216 | "\n", 217 | " for word in corpus:\n", 218 | " # Add word boundary marker and split into characters\n", 219 | " word_with_marker = '_' + word\n", 220 | " characters = list(word_with_marker)\n", 221 | " # Update set of unique characters\n", 222 | " charset.update(characters)\n", 223 | " # Create space-separated string of characters\n", 224 | " tokenized_word = \" \".join(characters)\n", 225 | " # Increment count for this tokenized word\n", 226 | " vocabulary[tokenized_word] += 1\n", 227 | "\n", 228 | " return vocabulary, charset\n", 229 | "\n", 230 | "def get_pair_counts(vocabulary):\n", 231 | " \"\"\"\n", 232 | " Counts frequencies of adjacent symbol pairs in the vocabulary.\n", 233 | " Used to identify most common pairs for merging.\n", 234 | "\n", 235 | " Args:\n", 236 | " vocabulary (dict): Dictionary mapping tokenized words to their counts\n", 237 | "\n", 238 | " Returns:\n", 239 | " defaultdict: Maps token pairs to their frequency counts\n", 240 | " \"\"\"\n", 241 | " pair_counts = defaultdict(int)\n", 242 | " for tokenized_word, count in vocabulary.items():\n", 243 | " # Split word into tokens\n", 244 | " tokens = tokenized_word.split()\n", 245 | " # Count adjacent pairs weighted by word frequency\n", 246 | " for i in range(len(tokens) - 1):\n", 247 | " pair = (tokens[i], tokens[i + 1])\n", 248 | " pair_counts[pair] += count\n", 249 | " return pair_counts\n", 250 | "\n", 251 | "def merge_pair(vocab, pair):\n", 252 | " \"\"\"\n", 253 | " Merges all occurrences of a specific symbol pair in the vocabulary.\n", 254 | " Uses regex for accurate token boundary matching.\n", 255 | "\n", 256 | " Args:\n", 257 | " vocab (dict): Current vocabulary dictionary\n", 258 | " pair (tuple): Pair of tokens to merge\n", 259 | "\n", 260 | " Returns:\n", 261 | " dict: New vocabulary with specified pair merged\n", 262 | " \"\"\"\n", 263 | " new_vocab = {}\n", 264 | " # Create regex pattern for matching the pair\n", 265 | " bigram = re.escape(' '.join(pair))\n", 266 | " pattern = re.compile(r\"(?\"):\n", 313 | " \"\"\"\n", 314 | " Tokenizes a single word using learned BPE merges.\n", 315 | " Handles unknown characters with UNK token.\n", 316 | "\n", 317 | " Args:\n", 318 | " word (str): Word to tokenize\n", 319 | " merges (list): List of learned merge operations\n", 320 | " charset (set): Set of known characters\n", 321 | " unk_token (str): Token to use for unknown characters\n", 322 | "\n", 323 | " Returns:\n", 324 | " list: List of tokens for the word\n", 325 | " \"\"\"\n", 326 | " # Add word boundary marker and convert to characters\n", 327 | " word = '_' + word\n", 328 | " tokens = [char if char in charset else unk_token for char in word]\n", 329 | "\n", 330 | " # Apply merges in order\n", 331 | " for left, right in merges:\n", 332 | " i = 0\n", 333 | " while i < len(tokens) - 1:\n", 334 | " if tokens[i:i+2] == [left, right]:\n", 335 | " tokens[i:i+2] = [left + right]\n", 336 | " else:\n", 337 | " i += 1\n", 338 | " return tokens\n", 339 | "\n", 340 | "def build_merge_map(merges):\n", 341 | " \"\"\"\n", 342 | " Creates a mapping from token pairs to their merged forms.\n", 343 | " Preserves merge order for consistent tokenization.\n", 344 | "\n", 345 | " Args:\n", 346 | " merges (list): List of merge operations\n", 347 | "\n", 348 | " Returns:\n", 349 | " dict: Maps token pairs to (merged_token, merge_priority) tuples\n", 350 | " \"\"\"\n", 351 | " merge_map = {}\n", 352 | " # Build map with merge priorities\n", 353 | " for i, (left, right) in enumerate(merges):\n", 354 | " merged_token = left + right\n", 355 | " merge_map[(left, right)] = (merged_token, i)\n", 356 | " return merge_map\n", 357 | "\n", 358 | "def tokenize_word_fast(word, merge_map, vocabulary, charset, unk_token=\"\"):\n", 359 | " \"\"\"\n", 360 | " Optimized tokenization function using pre-computed merge map.\n", 361 | " Produces identical results to original algorithm but faster.\n", 362 | "\n", 363 | " Args:\n", 364 | " word (str): Word to tokenize\n", 365 | " merge_map (dict): Mapping of token pairs to merged forms\n", 366 | " vocabulary (dict): Current vocabulary dictionary\n", 367 | " charset (set): Set of known characters\n", 368 | " unk_token (str): Token to use for unknown characters\n", 369 | "\n", 370 | " Returns:\n", 371 | " list: List of tokens for the word\n", 372 | " \"\"\"\n", 373 | " # Check if word exists in vocabulary as-is\n", 374 | " word_with_prefix = '_' + word\n", 375 | " if word_with_prefix in vocabulary:\n", 376 | " return [word_with_prefix]\n", 377 | "\n", 378 | " # Initialize with characters, replacing unknown ones\n", 379 | " tokens = [char if char in charset else unk_token for char in word_with_prefix]\n", 380 | "\n", 381 | " # Keep merging until no more merges possible\n", 382 | " while True:\n", 383 | " # Find all possible merge operations\n", 384 | " pairs_with_positions = []\n", 385 | " for i in range(len(tokens) - 1):\n", 386 | " pair = (tokens[i], tokens[i + 1])\n", 387 | " if pair in merge_map:\n", 388 | " merged_token, merge_priority = merge_map[pair]\n", 389 | " pairs_with_positions.append((i, pair, merged_token, merge_priority))\n", 390 | "\n", 391 | " # Exit if no more merges possible\n", 392 | " if not pairs_with_positions:\n", 393 | " break\n", 394 | "\n", 395 | " # Sort by merge priority and position for consistency\n", 396 | " pairs_with_positions.sort(key=lambda x: (x[3], x[0]))\n", 397 | "\n", 398 | " # Apply first valid merge\n", 399 | " pos, pair, merged_token, _ = pairs_with_positions[0]\n", 400 | " tokens[pos:pos+2] = [merged_token]\n", 401 | "\n", 402 | " return tokens\n", 403 | "\n", 404 | "def save_tokenizer(merges, charset, tokens, filename=\"tokenizer.pkl\"):\n", 405 | " \"\"\"\n", 406 | " Saves tokenizer state to a pickle file for later use.\n", 407 | "\n", 408 | " Args:\n", 409 | " merges (list): List of merge operations\n", 410 | " charset (set): Set of known characters\n", 411 | " tokens (set): Set of all tokens\n", 412 | " filename (str): Path to save tokenizer state\n", 413 | "\n", 414 | " Returns:\n", 415 | " None: Saves tokenizer to disk\n", 416 | " \"\"\"\n", 417 | " with open(filename, \"wb\") as f:\n", 418 | " pickle.dump({\n", 419 | " \"merges\": merges,\n", 420 | " \"charset\": charset,\n", 421 | " \"tokens\": tokens\n", 422 | " }, f)\n", 423 | "\n", 424 | "def load_tokenizer(filename=\"tokenizer.pkl\"):\n", 425 | " \"\"\"\n", 426 | " Loads tokenizer state from a pickle file.\n", 427 | "\n", 428 | " Args:\n", 429 | " filename (str): Path to saved tokenizer state\n", 430 | "\n", 431 | " Returns:\n", 432 | " dict: Dictionary containing tokenizer components\n", 433 | " \"\"\"\n", 434 | " with open(filename, \"rb\") as f:\n", 435 | " return pickle.load(f)\n", 436 | "\n", 437 | "# Main function for downloading, training BPE, saving, and loading tokenizer\n", 438 | "if __name__ == \"__main__\":\n", 439 | " # Configuration parameters\n", 440 | " vocab_size = 5_000 # Target vocabulary size\n", 441 | " max_corpus_size = 500_000 # Maximum number of words to process\n", 442 | " data_url = \"https://www.thelmbook.com/data/news\" # Dataset source\n", 443 | "\n", 444 | " # Download and prepare training data\n", 445 | " word_gen = download_and_prepare_data(data_url)\n", 446 | "\n", 447 | " # Collect corpus up to maximum size\n", 448 | " word_list = []\n", 449 | " for word in word_gen:\n", 450 | " word_list.append(word)\n", 451 | " if len(word_list) >= max_corpus_size:\n", 452 | " break\n", 453 | "\n", 454 | " # Train BPE tokenizer\n", 455 | " print(\"Training BPE tokenizer...\")\n", 456 | " vocab, merges, charset, tokens = byte_pair_encoding(word_list, vocab_size)\n", 457 | "\n", 458 | " # Save trained tokenizer\n", 459 | " print(\"Saving the tokenizer...\")\n", 460 | " save_tokenizer(merges, charset, tokens)" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "source": [ 466 | "# Testing the trained BPE tokenizer\n", 467 | "\n", 468 | "Once the BPE tokenizer is trained, we can load it and apply to a new text:" 469 | ], 470 | "metadata": { 471 | "id": "NXvPpOgo0Wgk" 472 | } 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": null, 477 | "metadata": { 478 | "colab": { 479 | "base_uri": "https://localhost:8080/" 480 | }, 481 | "id": "y0_RrAhuF7_B", 482 | "outputId": "678ee2c2-5df4-4598-97e3-159a605fcf96" 483 | }, 484 | "outputs": [ 485 | { 486 | "output_type": "stream", 487 | "name": "stdout", 488 | "text": [ 489 | "Loading the tokenizer...\n", 490 | "\n", 491 | "Sentence tokenized with the straightforward implementation:\n", 492 | "Let's -> ['_Let', \"'\", 's']\n", 493 | "proceed -> ['_proceed']\n", 494 | "to -> ['_to']\n", 495 | "the -> ['_the']\n", 496 | "language -> ['_language']\n", 497 | "modeling -> ['_model', 'ing']\n", 498 | "part. -> ['_part', '.']\n", 499 | "--- Elapsed: 0.020586490631103516 seconds ---\n", 500 | "\n", 501 | "Sentence tokenized with a fast implementation:\n", 502 | "Let's -> ['_Let', \"'\", 's']\n", 503 | "proceed -> ['_proceed']\n", 504 | "to -> ['_to']\n", 505 | "the -> ['_the']\n", 506 | "language -> ['_language']\n", 507 | "modeling -> ['_model', 'ing']\n", 508 | "part. -> ['_part', '.']\n", 509 | "--- Elapsed: 0.004575490951538086 seconds ---\n", 510 | "\n", 511 | "Vocabulary size: 5000\n" 512 | ] 513 | } 514 | ], 515 | "source": [ 516 | "if __name__ == \"__main__\":\n", 517 | " print(\"Loading the tokenizer...\")\n", 518 | " tokenizer = load_tokenizer()\n", 519 | "\n", 520 | " # Tokenize the sample sentence using the loaded tokenizer\n", 521 | " sentence = \"Let's proceed to the language modeling part.\"\n", 522 | "\n", 523 | " start_time = time.time()\n", 524 | " tokenized_sentence = [tokenize_word(word, tokenizer[\"merges\"], tokenizer[\"charset\"]) for word in sentence.split()]\n", 525 | " elapsed = time.time() - start_time\n", 526 | " print(\"\\nSentence tokenized with the straightforward implementation:\")\n", 527 | " for word, tokens in zip(sentence.split(), tokenized_sentence):\n", 528 | " print(f\"{word} -> {tokens}\")\n", 529 | " print(\"--- Elapsed: %s seconds ---\" % (elapsed))\n", 530 | "\n", 531 | " merge_map = build_merge_map(tokenizer[\"merges\"])\n", 532 | " start_time = time.time()\n", 533 | " fast_tokenized_sentence = [tokenize_word_fast(word, merge_map, vocab, tokenizer[\"charset\"]) for word in sentence.split()]\n", 534 | " elapsed = time.time() - start_time\n", 535 | " print(\"\\nSentence tokenized with a fast implementation:\")\n", 536 | " for word, tokens in zip(sentence.split(), fast_tokenized_sentence):\n", 537 | " print(f\"{word} -> {tokens}\")\n", 538 | " print(\"--- Elapsed: %s seconds ---\" % (time.time() - start_time))\n", 539 | "\n", 540 | " print(\"\\nVocabulary size:\", len(tokenizer[\"tokens\"]))" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": null, 546 | "metadata": { 547 | "id": "AFdVkYCAvj2p" 548 | }, 549 | "outputs": [], 550 | "source": [] 551 | } 552 | ], 553 | "metadata": { 554 | "colab": { 555 | "provenance": [], 556 | "authorship_tag": "ABX9TyNBe3ByEAsgr9onXQLnGaJz", 557 | "include_colab_link": true 558 | }, 559 | "kernelspec": { 560 | "display_name": "Python 3", 561 | "name": "python3" 562 | }, 563 | "language_info": { 564 | "name": "python" 565 | } 566 | }, 567 | "nbformat": 4, 568 | "nbformat_minor": 0 569 | } -------------------------------------------------------------------------------- /count_language_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "source": [ 16 | "
\n", 17 | "
\n", 18 | " \n", 19 | " \n", 20 | " \n", 26 | " \n", 31 | " \n", 32 | "
\n", 21 | " \n", 22 | " A notebook for The Hundred-Page Language Models Book by Andriy Burkov

\n", 23 | " Code repository: https://github.com/aburkov/theLMbook\n", 24 | "
\n", 25 | "
\n", 27 | " \n", 28 | " \"The\n", 29 | " \n", 30 | "
\n", 33 | "
\n", 34 | "
" 35 | ], 36 | "metadata": { 37 | "id": "IdYAR6C2_wUi" 38 | } 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "source": [ 43 | "# Count-based language model\n", 44 | "\n", 45 | "## Utility functions and classes\n", 46 | "\n", 47 | "In the cell below, we import the dependencies and define the utility functions and the model class:" 48 | ], 49 | "metadata": { 50 | "id": "4R2DpglXC7R4" 51 | } 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "id": "yy0zjL_2ouOU" 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "# Import required libraries\n", 62 | "import re # For regular expressions (text tokenization)\n", 63 | "import requests # For downloading the corpus\n", 64 | "import gzip # For decompressing the downloaded corpus\n", 65 | "import io # For handling byte streams\n", 66 | "import math # For mathematical operations (log, exp)\n", 67 | "import random # For random number generation\n", 68 | "from collections import defaultdict # For efficient dictionary operations\n", 69 | "import pickle, os # For saving and loading the model\n", 70 | "\n", 71 | "def set_seed(seed):\n", 72 | " \"\"\"\n", 73 | " Sets random seeds for reproducibility.\n", 74 | "\n", 75 | " Args:\n", 76 | " seed (int): Seed value for the random number generator\n", 77 | " \"\"\"\n", 78 | " random.seed(seed)\n", 79 | "\n", 80 | "def download_corpus(url):\n", 81 | " \"\"\"\n", 82 | " Downloads and decompresses a gzipped corpus file from the given URL.\n", 83 | "\n", 84 | " Args:\n", 85 | " url (str): URL of the gzipped corpus file\n", 86 | "\n", 87 | " Returns:\n", 88 | " str: Decoded text content of the corpus\n", 89 | "\n", 90 | " Raises:\n", 91 | " HTTPError: If the download fails\n", 92 | " \"\"\"\n", 93 | " print(f\"Downloading corpus from {url}...\")\n", 94 | " response = requests.get(url)\n", 95 | " response.raise_for_status() # Raises an exception for bad HTTP responses\n", 96 | "\n", 97 | " print(\"Decompressing and reading the corpus...\")\n", 98 | " with gzip.GzipFile(fileobj=io.BytesIO(response.content)) as f:\n", 99 | " corpus = f.read().decode('utf-8')\n", 100 | "\n", 101 | " print(f\"Corpus size: {len(corpus)} characters\")\n", 102 | " return corpus\n", 103 | "\n", 104 | "class CountLanguageModel:\n", 105 | " \"\"\"\n", 106 | " Implements an n-gram language model using count-based probability estimation.\n", 107 | " Supports variable context lengths up to n-grams.\n", 108 | " \"\"\"\n", 109 | " def __init__(self, n):\n", 110 | " \"\"\"\n", 111 | " Initialize the model with maximum n-gram length.\n", 112 | "\n", 113 | " Args:\n", 114 | " n (int): Maximum length of n-grams to use\n", 115 | " \"\"\"\n", 116 | " self.n = n # Maximum n-gram length\n", 117 | " self.ngram_counts = [{} for _ in range(n)] # List of dictionaries for each n-gram length\n", 118 | " self.total_unigrams = 0 # Total number of tokens in training data\n", 119 | "\n", 120 | " def predict_next_token(self, context):\n", 121 | " \"\"\"\n", 122 | " Predicts the most likely next token given a context.\n", 123 | " Uses backoff strategy: tries largest n-gram first, then backs off to smaller n-grams.\n", 124 | "\n", 125 | " Args:\n", 126 | " context (list): List of tokens providing context for prediction\n", 127 | "\n", 128 | " Returns:\n", 129 | " str: Most likely next token, or None if no prediction can be made\n", 130 | " \"\"\"\n", 131 | " for n in range(self.n, 1, -1): # Start with largest n-gram, back off to smaller ones\n", 132 | " if len(context) >= n - 1:\n", 133 | " context_n = tuple(context[-(n - 1):]) # Get the relevant context for this n-gram\n", 134 | " counts = self.ngram_counts[n - 1].get(context_n)\n", 135 | " if counts:\n", 136 | " return max(counts.items(), key=lambda x: x[1])[0] # Return most frequent token\n", 137 | " # Backoff to unigram if no larger context matches\n", 138 | " unigram_counts = self.ngram_counts[0].get(())\n", 139 | " if unigram_counts:\n", 140 | " return max(unigram_counts.items(), key=lambda x: x[1])[0]\n", 141 | " return None\n", 142 | "\n", 143 | " def get_probability(self, token, context):\n", 144 | " for n in range(self.n, 1, -1):\n", 145 | " if len(context) >= n - 1:\n", 146 | " context_n = tuple(context[-(n - 1):])\n", 147 | " counts = self.ngram_counts[n - 1].get(context_n)\n", 148 | " if counts:\n", 149 | " total = sum(counts.values())\n", 150 | " count = counts.get(token, 0)\n", 151 | " if count > 0:\n", 152 | " return count / total\n", 153 | " unigram_counts = self.ngram_counts[0].get(())\n", 154 | " count = unigram_counts.get(token, 0)\n", 155 | " V = len(unigram_counts)\n", 156 | " return (count + 1) / (self.total_unigrams + V)\n", 157 | "\n", 158 | "def train(model, tokens):\n", 159 | " \"\"\"\n", 160 | " Trains the language model by counting n-grams in the training data.\n", 161 | "\n", 162 | " Args:\n", 163 | " model (CountLanguageModel): Model to train\n", 164 | " tokens (list): List of tokens from the training corpus\n", 165 | " \"\"\"\n", 166 | " # Train models for each n-gram size from 1 to n\n", 167 | " for n in range(1, model.n + 1):\n", 168 | " counts = model.ngram_counts[n - 1]\n", 169 | " # Slide a window of size n over the corpus\n", 170 | " for i in range(len(tokens) - n + 1):\n", 171 | " # Split into context (n-1 tokens) and next token\n", 172 | " context = tuple(tokens[i:i + n - 1])\n", 173 | " next_token = tokens[i + n - 1]\n", 174 | "\n", 175 | " # Initialize counts dictionary for this context if needed\n", 176 | " if context not in counts:\n", 177 | " counts[context] = defaultdict(int)\n", 178 | "\n", 179 | " # Increment count for this context-token pair\n", 180 | " counts[context][next_token] = counts[context][next_token] + 1\n", 181 | "\n", 182 | " # Store total number of tokens for unigram probability calculations\n", 183 | " model.total_unigrams = len(tokens)\n", 184 | "\n", 185 | "def generate_text(model, context, num_tokens):\n", 186 | " \"\"\"\n", 187 | " Generates text by repeatedly sampling from the model.\n", 188 | "\n", 189 | " Args:\n", 190 | " model (CountLanguageModel): Trained language model\n", 191 | " context (list): Initial context tokens\n", 192 | " num_tokens (int): Number of tokens to generate\n", 193 | "\n", 194 | " Returns:\n", 195 | " str: Generated text including initial context\n", 196 | " \"\"\"\n", 197 | " # Start with the provided context\n", 198 | " generated = list(context)\n", 199 | "\n", 200 | " # Generate new tokens until we reach the desired length\n", 201 | " while len(generated) - len(context) < num_tokens:\n", 202 | " # Use the last n-1 tokens as context for prediction\n", 203 | " next_token = model.predict_next_token(generated[-(model.n-1):])\n", 204 | " generated.append(next_token)\n", 205 | "\n", 206 | " # Stop if we've generated enough tokens AND found a period\n", 207 | " # This helps ensure complete sentences\n", 208 | " if len(generated) - len(context) >= num_tokens and next_token == '.':\n", 209 | " break\n", 210 | "\n", 211 | " # Join tokens with spaces to create readable text\n", 212 | " return ' '.join(generated)\n", 213 | "\n", 214 | "def compute_perplexity(model, tokens, context_size):\n", 215 | " \"\"\"\n", 216 | " Computes perplexity of the model on given tokens.\n", 217 | "\n", 218 | " Args:\n", 219 | " model (CountLanguageModel): Trained language model\n", 220 | " tokens (list): List of tokens to evaluate on\n", 221 | " context_size (int): Maximum context size to consider\n", 222 | "\n", 223 | " Returns:\n", 224 | " float: Perplexity score (lower is better)\n", 225 | " \"\"\"\n", 226 | " # Handle empty token list\n", 227 | " if not tokens:\n", 228 | " return float('inf')\n", 229 | "\n", 230 | " # Initialize log likelihood accumulator\n", 231 | " total_log_likelihood = 0\n", 232 | " num_tokens = len(tokens)\n", 233 | "\n", 234 | " # Calculate probability for each token given its context\n", 235 | " for i in range(num_tokens):\n", 236 | " # Get appropriate context window, handling start of sequence\n", 237 | " context_start = max(0, i - context_size)\n", 238 | " context = tuple(tokens[context_start:i])\n", 239 | " token = tokens[i]\n", 240 | "\n", 241 | " # Get probability of this token given its context\n", 242 | " probability = model.get_probability(token, context)\n", 243 | "\n", 244 | " # Add log probability to total (using log for numerical stability)\n", 245 | " total_log_likelihood += math.log(probability)\n", 246 | "\n", 247 | " # Calculate average log likelihood\n", 248 | " average_log_likelihood = total_log_likelihood / num_tokens\n", 249 | "\n", 250 | " # Convert to perplexity: exp(-average_log_likelihood)\n", 251 | " # Lower perplexity indicates better model performance\n", 252 | " perplexity = math.exp(-average_log_likelihood)\n", 253 | " return perplexity\n", 254 | "\n", 255 | "def tokenize(text):\n", 256 | " \"\"\"\n", 257 | " Tokenizes text into words and periods.\n", 258 | "\n", 259 | " Args:\n", 260 | " text (str): Input text to tokenize\n", 261 | "\n", 262 | " Returns:\n", 263 | " list: List of lowercase tokens matching words or periods\n", 264 | " \"\"\"\n", 265 | " return re.findall(r\"\\b[a-zA-Z0-9]+\\b|[.]\", text.lower())\n", 266 | "\n", 267 | "def download_and_prepare_data(data_url):\n", 268 | " \"\"\"\n", 269 | " Downloads and prepares training and test data.\n", 270 | "\n", 271 | " Args:\n", 272 | " data_url (str): URL of the corpus to download\n", 273 | "\n", 274 | " Returns:\n", 275 | " tuple: (training_tokens, test_tokens) split 90/10\n", 276 | " \"\"\"\n", 277 | " # Download and extract the corpus\n", 278 | " corpus = download_corpus(data_url)\n", 279 | "\n", 280 | " # Convert text to tokens\n", 281 | " tokens = tokenize(corpus)\n", 282 | "\n", 283 | " # Split into training (90%) and test (10%) sets\n", 284 | " split_index = int(len(tokens) * 0.9)\n", 285 | " train_corpus = tokens[:split_index]\n", 286 | " test_corpus = tokens[split_index:]\n", 287 | "\n", 288 | " return train_corpus, test_corpus\n", 289 | "\n", 290 | "def save_model(model, model_name):\n", 291 | " \"\"\"\n", 292 | " Saves the trained language model to disk.\n", 293 | "\n", 294 | " Args:\n", 295 | " model (CountLanguageModel): Trained model to save\n", 296 | " model_name (str): Name to use for the saved model file\n", 297 | "\n", 298 | " Returns:\n", 299 | " str: Path to the saved model file\n", 300 | "\n", 301 | " Raises:\n", 302 | " IOError: If there's an error writing to disk\n", 303 | " \"\"\"\n", 304 | " # Create models directory if it doesn't exist\n", 305 | " os.makedirs('models', exist_ok=True)\n", 306 | "\n", 307 | " # Construct file path\n", 308 | " model_path = os.path.join('models', f'{model_name}.pkl')\n", 309 | "\n", 310 | " try:\n", 311 | " print(f\"Saving model to {model_path}...\")\n", 312 | " with open(model_path, 'wb') as f:\n", 313 | " pickle.dump({\n", 314 | " 'n': model.n,\n", 315 | " 'ngram_counts': model.ngram_counts,\n", 316 | " 'total_unigrams': model.total_unigrams\n", 317 | " }, f)\n", 318 | " print(\"Model saved successfully.\")\n", 319 | " return model_path\n", 320 | " except IOError as e:\n", 321 | " print(f\"Error saving model: {e}\")\n", 322 | " raise\n", 323 | "\n", 324 | "def load_model(model_name):\n", 325 | " \"\"\"\n", 326 | " Loads a trained language model from disk.\n", 327 | "\n", 328 | " Args:\n", 329 | " model_name (str): Name of the model to load\n", 330 | "\n", 331 | " Returns:\n", 332 | " CountLanguageModel: Loaded model instance\n", 333 | "\n", 334 | " Raises:\n", 335 | " FileNotFoundError: If the model file doesn't exist\n", 336 | " IOError: If there's an error reading the file\n", 337 | " \"\"\"\n", 338 | " model_path = os.path.join('models', f'{model_name}.pkl')\n", 339 | "\n", 340 | " try:\n", 341 | " print(f\"Loading model from {model_path}...\")\n", 342 | " with open(model_path, 'rb') as f:\n", 343 | " model_data = pickle.load(f)\n", 344 | "\n", 345 | " # Create new model instance\n", 346 | " model = CountLanguageModel(model_data['n'])\n", 347 | "\n", 348 | " # Restore model state\n", 349 | " model.ngram_counts = model_data['ngram_counts']\n", 350 | " model.total_unigrams = model_data['total_unigrams']\n", 351 | "\n", 352 | " print(\"Model loaded successfully.\")\n", 353 | " return model\n", 354 | " except FileNotFoundError:\n", 355 | " print(f\"Model file not found: {model_path}\")\n", 356 | " raise\n", 357 | " except IOError as e:\n", 358 | " print(f\"Error loading model: {e}\")\n", 359 | " raise\n", 360 | "\n", 361 | "def get_hyperparameters():\n", 362 | " \"\"\"\n", 363 | " Returns model hyperparameters.\n", 364 | "\n", 365 | " Returns:\n", 366 | " int: Size of n-grams to use in the model\n", 367 | " \"\"\"\n", 368 | " n = 5\n", 369 | " return n" 370 | ] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "source": [ 375 | "## Training the model\n", 376 | "\n", 377 | "In the cell below, we load the data, train, and save the model:" 378 | ], 379 | "metadata": { 380 | "id": "ug7YhOF6Dczx" 381 | } 382 | }, 383 | { 384 | "cell_type": "code", 385 | "source": [ 386 | "# Main model training block\n", 387 | "if __name__ == \"__main__\":\n", 388 | " # Initialize random seeds for reproducibility\n", 389 | " set_seed(42)\n", 390 | " n = get_hyperparameters()\n", 391 | " model_name = \"count_model\"\n", 392 | "\n", 393 | " # Download and prepare the Brown corpus\n", 394 | " data_url = \"https://www.thelmbook.com/data/brown\"\n", 395 | " train_corpus, test_corpus = download_and_prepare_data(data_url)\n", 396 | "\n", 397 | " # Train the model and evaluate its performance\n", 398 | " print(\"\\nTraining the model...\")\n", 399 | " model = CountLanguageModel(n)\n", 400 | " train(model, train_corpus)\n", 401 | " print(\"\\nModel training complete.\")\n", 402 | "\n", 403 | " perplexity = compute_perplexity(model, test_corpus, n)\n", 404 | " print(f\"\\nPerplexity on test corpus: {perplexity:.2f}\")\n", 405 | "\n", 406 | " save_model(model, model_name)" 407 | ], 408 | "metadata": { 409 | "id": "l1z9hxbJDmm0", 410 | "outputId": "f3a798a5-abcc-4a2c-c54d-54093c48bcf4", 411 | "colab": { 412 | "base_uri": "https://localhost:8080/" 413 | } 414 | }, 415 | "execution_count": null, 416 | "outputs": [ 417 | { 418 | "output_type": "stream", 419 | "name": "stdout", 420 | "text": [ 421 | "Downloading corpus from https://www.thelmbook.com/data/brown...\n", 422 | "Decompressing and reading the corpus...\n", 423 | "Corpus size: 6185606 characters\n", 424 | "\n", 425 | "Training the model...\n", 426 | "\n", 427 | "Model training complete.\n", 428 | "\n", 429 | "Perplexity on test corpus: 299.06\n", 430 | "Saving model to models/count_model.pkl...\n", 431 | "Model saved successfully.\n" 432 | ] 433 | } 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "source": [ 439 | "## Testing the model\n", 440 | "\n", 441 | "Below, we load the trained model and use it to generate text:" 442 | ], 443 | "metadata": { 444 | "id": "Glud7JWgDyMb" 445 | } 446 | }, 447 | { 448 | "cell_type": "code", 449 | "source": [ 450 | "# Main model testing block\n", 451 | "if __name__ == \"__main__\":\n", 452 | "\n", 453 | " model = load_model(model_name)\n", 454 | "\n", 455 | " # Test the model with some example contexts\n", 456 | " contexts = [\n", 457 | " \"i will build a\",\n", 458 | " \"the best place to\",\n", 459 | " \"she was riding a\"\n", 460 | " ]\n", 461 | "\n", 462 | " # Generate completions for each context\n", 463 | " for context in contexts:\n", 464 | " tokens = tokenize(context)\n", 465 | " next_token = model.predict_next_token(tokens)\n", 466 | " print(f\"\\nContext: {context}\")\n", 467 | " print(f\"Next token: {next_token}\")\n", 468 | " print(f\"Generated text: {generate_text(model, tokens, 10)}\")" 469 | ], 470 | "metadata": { 471 | "id": "tesX-DH9D23a", 472 | "outputId": "886a9061-e2eb-478d-add1-62dee4052540", 473 | "colab": { 474 | "base_uri": "https://localhost:8080/" 475 | } 476 | }, 477 | "execution_count": null, 478 | "outputs": [ 479 | { 480 | "output_type": "stream", 481 | "name": "stdout", 482 | "text": [ 483 | "Loading model from models/count_model.pkl...\n", 484 | "Model loaded successfully.\n", 485 | "\n", 486 | "Context: i will build a\n", 487 | "Next token: wall\n", 488 | "Generated text: i will build a wall to keep the people in and added so long\n", 489 | "\n", 490 | "Context: the best place to\n", 491 | "Next token: live\n", 492 | "Generated text: the best place to live in 30 per cent to get happiness for yourself\n", 493 | "\n", 494 | "Context: she was riding a\n", 495 | "Next token: horse\n", 496 | "Generated text: she was riding a horse and showing a dog are very similar your aids\n" 497 | ] 498 | } 499 | ] 500 | } 501 | ], 502 | "metadata": { 503 | "colab": { 504 | "provenance": [], 505 | "authorship_tag": "ABX9TyM1ddESY73Kx2xDPOENzz6k", 506 | "include_colab_link": true 507 | }, 508 | "kernelspec": { 509 | "display_name": "Python 3", 510 | "name": "python3" 511 | }, 512 | "language_info": { 513 | "name": "python" 514 | } 515 | }, 516 | "nbformat": 4, 517 | "nbformat_minor": 0 518 | } -------------------------------------------------------------------------------- /embedding_vs_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | vocab_size = 10 # Smaller vocabulary for better visualization 5 | emb_dim = 4 # Smaller embedding dimension 6 | token_idx = 3 # Index of the token we want to embed 7 | 8 | # Approach 1: Using nn.Embedding which directly maps indices to dense vectors 9 | embedding = nn.Embedding(vocab_size, emb_dim) 10 | 11 | # Approach 2: Using nn.Linear to achieve the same effect 12 | # The Linear layer performs: output = input @ weight.t() + bias 13 | # In our case, bias=False so: output = input @ weight.t() 14 | linear = nn.Linear(vocab_size, emb_dim, bias=False) 15 | 16 | # Copy and transpose embedding weights for the linear layer 17 | # Embedding weights shape: (vocab_size, emb_dim) 18 | # Linear weights shape: (emb_dim, vocab_size) <- notice the transpose 19 | linear.weight.data = embedding.weight.data.t() 20 | 21 | # Create one-hot input for linear layer - zeros everywhere except position token_idx 22 | one_hot = torch.zeros(vocab_size) 23 | one_hot[token_idx] = 1 24 | 25 | # Get embedding by directly indexing into the embedding matrix 26 | # embedding.weight[token_idx] is what happens under the hood 27 | emb_output = embedding(torch.tensor([token_idx])) 28 | 29 | # For linear layer: add batch dimension since linear expects shape (batch_size, input_dim) 30 | # Result will be: one_hot @ weight.t(), which selects the token_idx row of weight.t() 31 | linear_output = linear(one_hot.unsqueeze(0)) 32 | 33 | # Print outputs and comparison to see the equivalence 34 | print(f"Embedding output:\n{emb_output}\n") 35 | print(f"Linear output:\n{linear_output}\n") 36 | print(f"Are tensors equal? {torch.equal(emb_output, linear_output)}") 37 | print(f"Are tensors close? {torch.allclose(emb_output, linear_output)}") 38 | 39 | # Verify outputs are the same within numerical precision 40 | # Uses default tolerances: rtol=1e-5, atol=1e-8 41 | # Formula: |x - y| ≤ atol + rtol * |y| 42 | print(f"\nMaximum difference between tensors: {(emb_output - linear_output).abs().max().item()}") -------------------------------------------------------------------------------- /emotion_GPT2_as_classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "source": [ 16 | "
\n", 17 | "
\n", 18 | " \n", 19 | " \n", 20 | " \n", 26 | " \n", 31 | " \n", 32 | "
\n", 21 | " \n", 22 | " A notebook for The Hundred-Page Language Models Book by Andriy Burkov

\n", 23 | " Code repository: https://github.com/aburkov/theLMbook\n", 24 | "
\n", 25 | "
\n", 27 | " \n", 28 | " \"The\n", 29 | " \n", 30 | "
\n", 33 | "
\n", 34 | "
" 35 | ], 36 | "metadata": { 37 | "id": "LpCleINoAMtW" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "id": "yy0zjL_2ouOU", 48 | "outputId": "96ee626b-f141-4438-afa1-ae39e5305184" 49 | }, 50 | "outputs": [ 51 | { 52 | "output_type": "stream", 53 | "name": "stderr", 54 | "text": [ 55 | "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at openai-community/gpt2 and are newly initialized: ['score.weight']\n", 56 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", 57 | "Epoch 1/8: 100%|██████████| 1125/1125 [00:51<00:00, 21.68it/s, Loss=0.488]\n" 58 | ] 59 | }, 60 | { 61 | "output_type": "stream", 62 | "name": "stdout", 63 | "text": [ 64 | "Average loss: 0.4882, Test accuracy: 0.9230\n" 65 | ] 66 | }, 67 | { 68 | "output_type": "stream", 69 | "name": "stderr", 70 | "text": [ 71 | "Epoch 2/8: 100%|██████████| 1125/1125 [00:51<00:00, 21.73it/s, Loss=0.144]\n" 72 | ] 73 | }, 74 | { 75 | "output_type": "stream", 76 | "name": "stdout", 77 | "text": [ 78 | "Average loss: 0.1437, Test accuracy: 0.9310\n" 79 | ] 80 | }, 81 | { 82 | "output_type": "stream", 83 | "name": "stderr", 84 | "text": [ 85 | "Epoch 3/8: 100%|██████████| 1125/1125 [00:51<00:00, 21.77it/s, Loss=0.115]\n" 86 | ] 87 | }, 88 | { 89 | "output_type": "stream", 90 | "name": "stdout", 91 | "text": [ 92 | "Average loss: 0.1149, Test accuracy: 0.9410\n" 93 | ] 94 | }, 95 | { 96 | "output_type": "stream", 97 | "name": "stderr", 98 | "text": [ 99 | "Epoch 4/8: 100%|██████████| 1125/1125 [00:51<00:00, 21.81it/s, Loss=0.104]\n" 100 | ] 101 | }, 102 | { 103 | "output_type": "stream", 104 | "name": "stdout", 105 | "text": [ 106 | "Average loss: 0.1038, Test accuracy: 0.9395\n" 107 | ] 108 | }, 109 | { 110 | "output_type": "stream", 111 | "name": "stderr", 112 | "text": [ 113 | "Epoch 5/8: 100%|██████████| 1125/1125 [00:51<00:00, 21.78it/s, Loss=0.0963]\n" 114 | ] 115 | }, 116 | { 117 | "output_type": "stream", 118 | "name": "stdout", 119 | "text": [ 120 | "Average loss: 0.0963, Test accuracy: 0.9340\n" 121 | ] 122 | }, 123 | { 124 | "output_type": "stream", 125 | "name": "stderr", 126 | "text": [ 127 | "Epoch 6/8: 100%|██████████| 1125/1125 [00:51<00:00, 21.71it/s, Loss=0.0851]\n" 128 | ] 129 | }, 130 | { 131 | "output_type": "stream", 132 | "name": "stdout", 133 | "text": [ 134 | "Average loss: 0.0851, Test accuracy: 0.9395\n" 135 | ] 136 | }, 137 | { 138 | "output_type": "stream", 139 | "name": "stderr", 140 | "text": [ 141 | "Epoch 7/8: 100%|██████████| 1125/1125 [00:51<00:00, 21.74it/s, Loss=0.0806]\n" 142 | ] 143 | }, 144 | { 145 | "output_type": "stream", 146 | "name": "stdout", 147 | "text": [ 148 | "Average loss: 0.0806, Test accuracy: 0.9400\n" 149 | ] 150 | }, 151 | { 152 | "output_type": "stream", 153 | "name": "stderr", 154 | "text": [ 155 | "Epoch 8/8: 100%|██████████| 1125/1125 [00:51<00:00, 21.74it/s, Loss=0.0767]\n" 156 | ] 157 | }, 158 | { 159 | "output_type": "stream", 160 | "name": "stdout", 161 | "text": [ 162 | "Average loss: 0.0767, Test accuracy: 0.9460\n", 163 | "Input: I'm so happy to be able to finetune an LLM!\n", 164 | "Predicted emotion: joy\n" 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "# Import required libraries\n", 170 | "import torch # Main PyTorch library\n", 171 | "from torch.utils.data import DataLoader # For dataset handling\n", 172 | "from torch.optim import AdamW # Optimizer for training\n", 173 | "from transformers import AutoTokenizer, AutoModelForSequenceClassification # Hugging Face components\n", 174 | "from tqdm import tqdm # Progress bar utilities\n", 175 | "import json # For parsing JSON data\n", 176 | "import requests # For downloading dataset from URL\n", 177 | "import gzip # For decompressing dataset\n", 178 | "import random # For setting seeds and shuffling data\n", 179 | "\n", 180 | "def set_seed(seed):\n", 181 | " \"\"\"\n", 182 | " Sets random seeds for reproducibility across different libraries.\n", 183 | "\n", 184 | " Args:\n", 185 | " seed (int): Seed value for random number generation\n", 186 | " \"\"\"\n", 187 | " # Set Python's built-in random seed\n", 188 | " random.seed(seed)\n", 189 | " # Set PyTorch's CPU random seed\n", 190 | " torch.manual_seed(seed)\n", 191 | " # Set seed for all available GPUs\n", 192 | " torch.cuda.manual_seed_all(seed)\n", 193 | " # Request cuDNN to use deterministic algorithms\n", 194 | " torch.backends.cudnn.deterministic = True\n", 195 | " # Disable cuDNN's auto-tuner for consistent behavior\n", 196 | " torch.backends.cudnn.benchmark = False\n", 197 | "\n", 198 | "def load_and_split_dataset(url, test_ratio=0.1):\n", 199 | " \"\"\"\n", 200 | " Downloads and splits dataset into train and test sets.\n", 201 | "\n", 202 | " Args:\n", 203 | " url (str): URL of the dataset\n", 204 | " test_ratio (float): Proportion of data for testing\n", 205 | "\n", 206 | " Returns:\n", 207 | " tuple: (train_dataset, test_dataset)\n", 208 | " \"\"\"\n", 209 | " # Download and decompress dataset\n", 210 | " response = requests.get(url)\n", 211 | " content = gzip.decompress(response.content).decode()\n", 212 | "\n", 213 | " # Parse JSON lines into list of examples\n", 214 | " dataset = [json.loads(line) for line in content.splitlines()]\n", 215 | "\n", 216 | " # Shuffle and split dataset\n", 217 | " random.shuffle(dataset)\n", 218 | " split_index = int(len(dataset) * (1 - test_ratio))\n", 219 | "\n", 220 | " return dataset[:split_index], dataset[split_index:]\n", 221 | "\n", 222 | "def load_model_and_tokenizer(model_name, device, label_to_id, id_to_label, unique_labels):\n", 223 | " \"\"\"\n", 224 | " Loads and configures the model and tokenizer for sequence classification.\n", 225 | "\n", 226 | " Args:\n", 227 | " model_name (str): Name of pre-trained model\n", 228 | " device: Device to load model on\n", 229 | " label_to_id (dict): Mapping from label strings to IDs\n", 230 | " id_to_label (dict): Mapping from IDs to label strings\n", 231 | " unique_labels (list): List of all possible labels\n", 232 | "\n", 233 | " Returns:\n", 234 | " tuple: (model, tokenizer)\n", 235 | " \"\"\"\n", 236 | " # Initialize model with correct number of output classes\n", 237 | " model = AutoModelForSequenceClassification.from_pretrained(\n", 238 | " model_name,\n", 239 | " num_labels=len(unique_labels)\n", 240 | " )\n", 241 | "\n", 242 | " # Configure padding and label mappings\n", 243 | " model.config.pad_token_id = model.config.eos_token_id\n", 244 | " model.config.id2label = id_to_label\n", 245 | " model.config.label2id = label_to_id\n", 246 | "\n", 247 | " # Initialize and configure tokenizer\n", 248 | " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 249 | " tokenizer.pad_token = tokenizer.eos_token\n", 250 | "\n", 251 | " return (model.to(device), tokenizer)\n", 252 | "\n", 253 | "def encode_text(tokenizer, text, return_tensor=False):\n", 254 | " \"\"\"\n", 255 | " Encodes text using the provided tokenizer.\n", 256 | "\n", 257 | " Args:\n", 258 | " tokenizer: Hugging Face tokenizer\n", 259 | " text (str): Text to encode\n", 260 | " return_tensor (bool): Whether to return PyTorch tensor\n", 261 | "\n", 262 | " Returns:\n", 263 | " List or tensor of token IDs\n", 264 | " \"\"\"\n", 265 | " # If tensor output is requested, encode with PyTorch tensors\n", 266 | " if return_tensor:\n", 267 | " return tokenizer.encode(\n", 268 | " text, add_special_tokens=False, return_tensors=\"pt\"\n", 269 | " )\n", 270 | " # Otherwise return list of token IDs\n", 271 | " else:\n", 272 | " return tokenizer.encode(text, add_special_tokens=False)\n", 273 | "\n", 274 | "class TextClassificationDataset(torch.utils.data.Dataset):\n", 275 | " \"\"\"\n", 276 | " PyTorch Dataset for text classification.\n", 277 | " Converts text and labels into model-ready format.\n", 278 | "\n", 279 | " Args:\n", 280 | " data (list): List of dictionaries containing text and labels\n", 281 | " tokenizer: Hugging Face tokenizer\n", 282 | " label_to_id (dict): Mapping from label strings to IDs\n", 283 | " \"\"\"\n", 284 | " def __init__(self, data, tokenizer, label_to_id):\n", 285 | " self.data = data\n", 286 | " self.tokenizer = tokenizer\n", 287 | " self.label_to_id = label_to_id\n", 288 | "\n", 289 | " def __len__(self):\n", 290 | " # Return total number of examples\n", 291 | " return len(self.data)\n", 292 | "\n", 293 | " def __getitem__(self, idx):\n", 294 | " \"\"\"\n", 295 | " Returns a single training example.\n", 296 | "\n", 297 | " Args:\n", 298 | " idx (int): Index of the example to fetch\n", 299 | "\n", 300 | " Returns:\n", 301 | " dict: Contains input_ids and labels\n", 302 | " \"\"\"\n", 303 | " # Get example from dataset\n", 304 | " item = self.data[idx]\n", 305 | " # Convert text to token IDs\n", 306 | " input_ids = encode_text(self.tokenizer, item[\"text\"])\n", 307 | " # Convert label string to ID\n", 308 | " labels = self.label_to_id[item[\"label\"]]\n", 309 | "\n", 310 | " return {\n", 311 | " \"input_ids\": input_ids,\n", 312 | " \"labels\": labels\n", 313 | " }\n", 314 | "\n", 315 | "def collate_fn(batch):\n", 316 | " \"\"\"\n", 317 | " Collates batch of examples into training-ready format.\n", 318 | " Handles padding and conversion to tensors.\n", 319 | "\n", 320 | " Args:\n", 321 | " batch: List of examples from Dataset\n", 322 | "\n", 323 | " Returns:\n", 324 | " dict: Contains input_ids, labels, and attention_mask tensors\n", 325 | " \"\"\"\n", 326 | " # Find longest sequence for padding\n", 327 | " max_length = max(len(item[\"input_ids\"]) for item in batch)\n", 328 | "\n", 329 | " # Pad input sequences with zeros\n", 330 | " input_ids = [\n", 331 | " item[\"input_ids\"] +\n", 332 | " [0] * (max_length - len(item[\"input_ids\"]))\n", 333 | " for item in batch\n", 334 | " ]\n", 335 | "\n", 336 | " # Create attention masks (1 for tokens, 0 for padding)\n", 337 | " attention_mask = [\n", 338 | " [1] * len(item[\"input_ids\"]) +\n", 339 | " [0] * (max_length - len(item[\"input_ids\"]))\n", 340 | " for item in batch\n", 341 | " ]\n", 342 | "\n", 343 | " # Collect labels\n", 344 | " labels = [item[\"labels\"] for item in batch]\n", 345 | "\n", 346 | " # Convert everything to tensors\n", 347 | " return {\n", 348 | " \"input_ids\": torch.tensor(input_ids),\n", 349 | " \"labels\": torch.tensor(labels),\n", 350 | " \"attention_mask\": torch.tensor(attention_mask)\n", 351 | " }\n", 352 | "\n", 353 | "def generate_label(model, tokenizer, text):\n", 354 | " \"\"\"\n", 355 | " Generates label prediction for input text.\n", 356 | "\n", 357 | " Args:\n", 358 | " model: Fine-tuned model\n", 359 | " tokenizer: Associated tokenizer\n", 360 | " text (str): Input text to classify\n", 361 | "\n", 362 | " Returns:\n", 363 | " str: Predicted label\n", 364 | " \"\"\"\n", 365 | " # Encode text and move to model's device\n", 366 | " input_ids = encode_text(\n", 367 | " tokenizer,\n", 368 | " text,\n", 369 | " return_tensor=True\n", 370 | " ).to(model.device)\n", 371 | "\n", 372 | " # Get model predictions\n", 373 | " outputs = model(input_ids)\n", 374 | " logits = outputs.logits[0]\n", 375 | " # Get class with highest probability\n", 376 | " predicted_class = logits.argmax().item()\n", 377 | " # Convert class ID to label string\n", 378 | " return model.config.id2label[predicted_class]\n", 379 | "\n", 380 | "def calculate_accuracy(model, dataloader):\n", 381 | " \"\"\"\n", 382 | " Calculates prediction accuracy on a dataset.\n", 383 | "\n", 384 | " Args:\n", 385 | " model: Fine-tuned model\n", 386 | " dataloader: DataLoader containing evaluation examples\n", 387 | "\n", 388 | " Returns:\n", 389 | " float: Accuracy score\n", 390 | " \"\"\"\n", 391 | " # Set model to evaluation mode\n", 392 | " model.eval()\n", 393 | " correct = 0\n", 394 | " total = 0\n", 395 | "\n", 396 | " # Disable gradient computation for efficiency\n", 397 | " with torch.no_grad():\n", 398 | " for batch in dataloader:\n", 399 | " # Move batch to device\n", 400 | " input_ids = batch[\"input_ids\"].to(model.device)\n", 401 | " attention_mask = batch[\"attention_mask\"].to(model.device)\n", 402 | " labels = batch[\"labels\"].to(model.device)\n", 403 | "\n", 404 | " # Get model predictions\n", 405 | " outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n", 406 | " predictions = outputs.logits.argmax(dim=-1)\n", 407 | "\n", 408 | " # Update accuracy counters\n", 409 | " correct += (predictions == labels).sum().item()\n", 410 | " total += labels.size(0)\n", 411 | "\n", 412 | " # Calculate accuracy\n", 413 | " accuracy = correct / total\n", 414 | " # Reset model to training mode\n", 415 | " model.train()\n", 416 | " return accuracy\n", 417 | "\n", 418 | "def create_label_mappings(train_dataset):\n", 419 | " \"\"\"\n", 420 | " Creates mappings between label strings and IDs.\n", 421 | "\n", 422 | " Args:\n", 423 | " train_dataset: List of training examples\n", 424 | "\n", 425 | " Returns:\n", 426 | " tuple: (label_to_id, id_to_label, unique_labels)\n", 427 | " \"\"\"\n", 428 | " # Get sorted list of unique labels\n", 429 | " unique_labels = sorted(set(item[\"label\"] for item in train_dataset))\n", 430 | " # Create mappings between labels and IDs\n", 431 | " label_to_id = {label: i for i, label in enumerate(unique_labels)}\n", 432 | " id_to_label = {i: label for label, i in label_to_id.items()}\n", 433 | " return label_to_id, id_to_label, unique_labels\n", 434 | "\n", 435 | "def test_model(model_path, test_input):\n", 436 | " \"\"\"\n", 437 | " Tests a saved model on a single input.\n", 438 | "\n", 439 | " Args:\n", 440 | " model_path (str): Path to saved model\n", 441 | " test_input (str): Text to classify\n", 442 | " \"\"\"\n", 443 | " # Setup device and load model\n", 444 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 445 | " model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)\n", 446 | " tokenizer = AutoTokenizer.from_pretrained(model_path)\n", 447 | "\n", 448 | " # Generate and display prediction\n", 449 | " emotion = generate_label(model, tokenizer, test_input)\n", 450 | " print(f\"Input: {test_input}\")\n", 451 | " print(f\"Predicted emotion: {emotion}\")\n", 452 | "\n", 453 | "def download_and_prepare_data(data_url, tokenizer, batch_size):\n", 454 | " \"\"\"\n", 455 | " Downloads and prepares dataset for training.\n", 456 | "\n", 457 | " Args:\n", 458 | " data_url (str): URL of the dataset\n", 459 | " tokenizer: Tokenizer for text processing\n", 460 | " batch_size (int): Batch size for DataLoader\n", 461 | "\n", 462 | " Returns:\n", 463 | " tuple: (train_dataloader, test_dataloader, label_to_id, id_to_label, unique_labels)\n", 464 | " \"\"\"\n", 465 | " # Load and split dataset\n", 466 | " train_dataset, test_dataset = load_and_split_dataset(data_url)\n", 467 | "\n", 468 | " # Create label mappings\n", 469 | " label_to_id, id_to_label, unique_labels = create_label_mappings(train_dataset)\n", 470 | "\n", 471 | " # Create datasets\n", 472 | " train_data = TextClassificationDataset(\n", 473 | " train_dataset,\n", 474 | " tokenizer,\n", 475 | " label_to_id\n", 476 | " )\n", 477 | " test_data = TextClassificationDataset(\n", 478 | " test_dataset,\n", 479 | " tokenizer,\n", 480 | " label_to_id\n", 481 | " )\n", 482 | "\n", 483 | " # Create dataloaders\n", 484 | " train_dataloader = DataLoader(\n", 485 | " train_data,\n", 486 | " batch_size=batch_size,\n", 487 | " shuffle=True,\n", 488 | " collate_fn=collate_fn\n", 489 | " )\n", 490 | " test_dataloader = DataLoader(\n", 491 | " test_data,\n", 492 | " batch_size=batch_size,\n", 493 | " shuffle=True,\n", 494 | " collate_fn=collate_fn\n", 495 | " )\n", 496 | " return train_dataloader, test_dataloader, label_to_id, id_to_label, unique_labels\n", 497 | "\n", 498 | "def get_hyperparameters():\n", 499 | " \"\"\"\n", 500 | " Returns training hyperparameters.\n", 501 | "\n", 502 | " Returns:\n", 503 | " tuple: (num_epochs, batch_size, learning_rate)\n", 504 | " \"\"\"\n", 505 | " # Train for fewer epochs as sequence classification converges faster\n", 506 | " num_epochs=8\n", 507 | " # Standard batch size that works well with most GPU memory\n", 508 | " batch_size=16\n", 509 | " # Standard learning rate for fine-tuning transformers\n", 510 | " learning_rate=5e-5\n", 511 | " return num_epochs, batch_size, learning_rate\n", 512 | "\n", 513 | "# Main training script\n", 514 | "if __name__ == \"__main__\":\n", 515 | " # Set random seed for reproducibility\n", 516 | " seed = 42\n", 517 | " set_seed(seed)\n", 518 | "\n", 519 | " # Configure training parameters\n", 520 | " data_url = \"https://www.thelmbook.com/data/emotions\"\n", 521 | " model_name = \"openai-community/gpt2\"\n", 522 | "\n", 523 | " # Get hyperparameters\n", 524 | " num_epochs, batch_size, learning_rate = get_hyperparameters()\n", 525 | "\n", 526 | " # Setup device\n", 527 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 528 | "\n", 529 | " # Initialize tokenizer\n", 530 | " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 531 | " tokenizer.pad_token = tokenizer.eos_token\n", 532 | "\n", 533 | " # Prepare data and get label mappings\n", 534 | " train_loader, test_loader, label_to_id, id_to_label, unique_labels = download_and_prepare_data(\n", 535 | " data_url,\n", 536 | " tokenizer,\n", 537 | " batch_size\n", 538 | " )\n", 539 | "\n", 540 | " # Initialize model for sequence classification\n", 541 | " model = AutoModelForSequenceClassification.from_pretrained(\n", 542 | " model_name,\n", 543 | " num_labels=len(unique_labels)\n", 544 | " ).to(device)\n", 545 | "\n", 546 | " # Configure model's label handling\n", 547 | " model.config.pad_token_id = model.config.eos_token_id\n", 548 | " model.config.id2label = id_to_label\n", 549 | " model.config.label2id = label_to_id\n", 550 | "\n", 551 | " # Initialize optimizer\n", 552 | " optimizer = AdamW(model.parameters(), lr=learning_rate)\n", 553 | "\n", 554 | " # Training loop\n", 555 | " for epoch in range(num_epochs):\n", 556 | " model.train()\n", 557 | " total_loss = 0\n", 558 | " num_batches = 0\n", 559 | " progress_bar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs}\")\n", 560 | "\n", 561 | " for batch in progress_bar:\n", 562 | " # Move batch to device\n", 563 | " input_ids = batch[\"input_ids\"].to(device)\n", 564 | " attention_mask = batch[\"attention_mask\"].to(device)\n", 565 | " labels = batch[\"labels\"].to(device)\n", 566 | "\n", 567 | " # Forward pass\n", 568 | " outputs = model(\n", 569 | " input_ids=input_ids,\n", 570 | " attention_mask=attention_mask,\n", 571 | " labels=labels\n", 572 | " )\n", 573 | "\n", 574 | " # Backward pass and optimization\n", 575 | " loss = outputs.loss\n", 576 | " loss.backward()\n", 577 | " optimizer.step()\n", 578 | " optimizer.zero_grad()\n", 579 | "\n", 580 | " # Update metrics\n", 581 | " total_loss += loss.item()\n", 582 | " num_batches += 1\n", 583 | "\n", 584 | " progress_bar.set_postfix({\"Loss\": total_loss / num_batches})\n", 585 | "\n", 586 | " # Display epoch metrics\n", 587 | " avg_loss = total_loss / num_batches\n", 588 | " test_acc = calculate_accuracy(model, test_loader)\n", 589 | " print(f\"Average loss: {avg_loss:.4f}, Test accuracy: {test_acc:.4f}\")\n", 590 | "\n", 591 | " # Save the fine-tuned model\n", 592 | " model.save_pretrained(\"./finetuned_model\")\n", 593 | " tokenizer.save_pretrained(\"./finetuned_model\")\n", 594 | "\n", 595 | " # Test the model\n", 596 | " test_input = \"I'm so happy to be able to finetune an LLM!\"\n", 597 | " test_model(\"./finetuned_model\", test_input)" 598 | ] 599 | } 600 | ], 601 | "metadata": { 602 | "colab": { 603 | "provenance": [], 604 | "gpuType": "A100", 605 | "authorship_tag": "ABX9TyOfPNbJZ+ZzhLYc0JRHSozh", 606 | "include_colab_link": true 607 | }, 608 | "kernelspec": { 609 | "display_name": "Python 3", 610 | "name": "python3" 611 | }, 612 | "language_info": { 613 | "name": "python" 614 | }, 615 | "accelerator": "GPU" 616 | }, 617 | "nbformat": 4, 618 | "nbformat_minor": 0 619 | } -------------------------------------------------------------------------------- /emotion_GPT2_as_text_generator_LoRA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "source": [ 16 | "
\n", 17 | "
\n", 18 | " \n", 19 | " \n", 20 | " \n", 26 | " \n", 31 | " \n", 32 | "
\n", 21 | " \n", 22 | " A notebook for The Hundred-Page Language Models Book by Andriy Burkov

\n", 23 | " Code repository: https://github.com/aburkov/theLMbook\n", 24 | "
\n", 25 | "
\n", 27 | " \n", 28 | " \"The\n", 29 | " \n", 30 | "
\n", 33 | "
\n", 34 | "
" 35 | ], 36 | "metadata": { 37 | "id": "fl7Fu-B4uARb" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "id": "yy0zjL_2ouOU", 48 | "outputId": "e12e16c1-7815-411f-db02-e40a3bd05659" 49 | }, 50 | "outputs": [ 51 | { 52 | "output_type": "stream", 53 | "name": "stdout", 54 | "text": [ 55 | "Using device: cuda\n" 56 | ] 57 | }, 58 | { 59 | "output_type": "stream", 60 | "name": "stderr", 61 | "text": [ 62 | "Epoch 1/18: 100%|██████████| 1125/1125 [00:53<00:00, 20.98it/s, Loss=0.613]\n" 63 | ] 64 | }, 65 | { 66 | "output_type": "stream", 67 | "name": "stdout", 68 | "text": [ 69 | "Epoch 1 - Average loss: 0.6127, Test accuracy: 0.7605\n" 70 | ] 71 | }, 72 | { 73 | "output_type": "stream", 74 | "name": "stderr", 75 | "text": [ 76 | "Epoch 2/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.62it/s, Loss=0.353]\n" 77 | ] 78 | }, 79 | { 80 | "output_type": "stream", 81 | "name": "stdout", 82 | "text": [ 83 | "Epoch 2 - Average loss: 0.3532, Test accuracy: 0.7970\n" 84 | ] 85 | }, 86 | { 87 | "output_type": "stream", 88 | "name": "stderr", 89 | "text": [ 90 | "Epoch 3/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.61it/s, Loss=0.237]\n" 91 | ] 92 | }, 93 | { 94 | "output_type": "stream", 95 | "name": "stdout", 96 | "text": [ 97 | "Epoch 3 - Average loss: 0.2375, Test accuracy: 0.8530\n" 98 | ] 99 | }, 100 | { 101 | "output_type": "stream", 102 | "name": "stderr", 103 | "text": [ 104 | "Epoch 4/18: 100%|██████████| 1125/1125 [00:55<00:00, 20.45it/s, Loss=0.184]\n" 105 | ] 106 | }, 107 | { 108 | "output_type": "stream", 109 | "name": "stdout", 110 | "text": [ 111 | "Epoch 4 - Average loss: 0.1843, Test accuracy: 0.8985\n" 112 | ] 113 | }, 114 | { 115 | "output_type": "stream", 116 | "name": "stderr", 117 | "text": [ 118 | "Epoch 5/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.68it/s, Loss=0.146]\n" 119 | ] 120 | }, 121 | { 122 | "output_type": "stream", 123 | "name": "stdout", 124 | "text": [ 125 | "Epoch 5 - Average loss: 0.1457, Test accuracy: 0.9175\n" 126 | ] 127 | }, 128 | { 129 | "output_type": "stream", 130 | "name": "stderr", 131 | "text": [ 132 | "Epoch 6/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.63it/s, Loss=0.121]\n" 133 | ] 134 | }, 135 | { 136 | "output_type": "stream", 137 | "name": "stdout", 138 | "text": [ 139 | "Epoch 6 - Average loss: 0.1208, Test accuracy: 0.9215\n" 140 | ] 141 | }, 142 | { 143 | "output_type": "stream", 144 | "name": "stderr", 145 | "text": [ 146 | "Epoch 7/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.50it/s, Loss=0.103]\n" 147 | ] 148 | }, 149 | { 150 | "output_type": "stream", 151 | "name": "stdout", 152 | "text": [ 153 | "Epoch 7 - Average loss: 0.1028, Test accuracy: 0.9260\n" 154 | ] 155 | }, 156 | { 157 | "output_type": "stream", 158 | "name": "stderr", 159 | "text": [ 160 | "Epoch 8/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.62it/s, Loss=0.0927]\n" 161 | ] 162 | }, 163 | { 164 | "output_type": "stream", 165 | "name": "stdout", 166 | "text": [ 167 | "Epoch 8 - Average loss: 0.0927, Test accuracy: 0.9260\n" 168 | ] 169 | }, 170 | { 171 | "output_type": "stream", 172 | "name": "stderr", 173 | "text": [ 174 | "Epoch 9/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.59it/s, Loss=0.0887]\n" 175 | ] 176 | }, 177 | { 178 | "output_type": "stream", 179 | "name": "stdout", 180 | "text": [ 181 | "Epoch 9 - Average loss: 0.0887, Test accuracy: 0.9330\n" 182 | ] 183 | }, 184 | { 185 | "output_type": "stream", 186 | "name": "stderr", 187 | "text": [ 188 | "Epoch 10/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.58it/s, Loss=0.079]\n" 189 | ] 190 | }, 191 | { 192 | "output_type": "stream", 193 | "name": "stdout", 194 | "text": [ 195 | "Epoch 10 - Average loss: 0.0790, Test accuracy: 0.9315\n" 196 | ] 197 | }, 198 | { 199 | "output_type": "stream", 200 | "name": "stderr", 201 | "text": [ 202 | "Epoch 11/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.69it/s, Loss=0.0771]\n" 203 | ] 204 | }, 205 | { 206 | "output_type": "stream", 207 | "name": "stdout", 208 | "text": [ 209 | "Epoch 11 - Average loss: 0.0771, Test accuracy: 0.9325\n" 210 | ] 211 | }, 212 | { 213 | "output_type": "stream", 214 | "name": "stderr", 215 | "text": [ 216 | "Epoch 12/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.60it/s, Loss=0.0699]\n" 217 | ] 218 | }, 219 | { 220 | "output_type": "stream", 221 | "name": "stdout", 222 | "text": [ 223 | "Epoch 12 - Average loss: 0.0699, Test accuracy: 0.9345\n" 224 | ] 225 | }, 226 | { 227 | "output_type": "stream", 228 | "name": "stderr", 229 | "text": [ 230 | "Epoch 13/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.60it/s, Loss=0.0663]\n" 231 | ] 232 | }, 233 | { 234 | "output_type": "stream", 235 | "name": "stdout", 236 | "text": [ 237 | "Epoch 13 - Average loss: 0.0663, Test accuracy: 0.9265\n" 238 | ] 239 | }, 240 | { 241 | "output_type": "stream", 242 | "name": "stderr", 243 | "text": [ 244 | "Epoch 14/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.61it/s, Loss=0.064]\n" 245 | ] 246 | }, 247 | { 248 | "output_type": "stream", 249 | "name": "stdout", 250 | "text": [ 251 | "Epoch 14 - Average loss: 0.0640, Test accuracy: 0.9375\n" 252 | ] 253 | }, 254 | { 255 | "output_type": "stream", 256 | "name": "stderr", 257 | "text": [ 258 | "Epoch 15/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.60it/s, Loss=0.0633]\n" 259 | ] 260 | }, 261 | { 262 | "output_type": "stream", 263 | "name": "stdout", 264 | "text": [ 265 | "Epoch 15 - Average loss: 0.0633, Test accuracy: 0.9380\n" 266 | ] 267 | }, 268 | { 269 | "output_type": "stream", 270 | "name": "stderr", 271 | "text": [ 272 | "Epoch 16/18: 100%|██████████| 1125/1125 [00:54<00:00, 20.75it/s, Loss=0.0605]\n" 273 | ] 274 | }, 275 | { 276 | "output_type": "stream", 277 | "name": "stdout", 278 | "text": [ 279 | "Epoch 16 - Average loss: 0.0605, Test accuracy: 0.9390\n" 280 | ] 281 | }, 282 | { 283 | "output_type": "stream", 284 | "name": "stderr", 285 | "text": [ 286 | "Epoch 17/18: 100%|██████████| 1125/1125 [00:53<00:00, 20.85it/s, Loss=0.0571]\n" 287 | ] 288 | }, 289 | { 290 | "output_type": "stream", 291 | "name": "stdout", 292 | "text": [ 293 | "Epoch 17 - Average loss: 0.0571, Test accuracy: 0.9370\n" 294 | ] 295 | }, 296 | { 297 | "output_type": "stream", 298 | "name": "stderr", 299 | "text": [ 300 | "Epoch 18/18: 100%|██████████| 1125/1125 [00:55<00:00, 20.26it/s, Loss=0.0574]\n" 301 | ] 302 | }, 303 | { 304 | "output_type": "stream", 305 | "name": "stdout", 306 | "text": [ 307 | "Epoch 18 - Average loss: 0.0574, Test accuracy: 0.9420\n", 308 | "Training accuracy: 0.9424\n", 309 | "Test accuracy: 0.9420\n", 310 | "Using device: cuda\n", 311 | "Input: I'm so happy to be able to finetune an LLM!\n", 312 | "Generated emotion: joy\n" 313 | ] 314 | } 315 | ], 316 | "source": [ 317 | "# Import required libraries\n", 318 | "import json # For parsing JSON data\n", 319 | "import random # For setting seeds and shuffling data\n", 320 | "import gzip # For decompressing dataset\n", 321 | "import requests # For downloading dataset from URL\n", 322 | "import torch # Main PyTorch library\n", 323 | "from peft import get_peft_model, LoraConfig, TaskType # For efficient finetuning using LoRA\n", 324 | "from torch.utils.data import Dataset, DataLoader # For dataset handling\n", 325 | "from transformers import AutoTokenizer, AutoModelForCausalLM # Hugging Face model components\n", 326 | "from torch.optim import AdamW # Optimizer for training\n", 327 | "from tqdm import tqdm # Progress bar utilities\n", 328 | "import re # For text normalization\n", 329 | "\n", 330 | "def set_seed(seed):\n", 331 | " \"\"\"\n", 332 | " Sets random seeds for reproducibility across different libraries.\n", 333 | "\n", 334 | " Args:\n", 335 | " seed (int): Seed value for random number generation\n", 336 | " \"\"\"\n", 337 | " # Set Python's built-in random seed\n", 338 | " random.seed(seed)\n", 339 | " # Set PyTorch's CPU random seed\n", 340 | " torch.manual_seed(seed)\n", 341 | " # Set seed for all available GPUs\n", 342 | " torch.cuda.manual_seed_all(seed)\n", 343 | " # Request cuDNN to use deterministic algorithms\n", 344 | " torch.backends.cudnn.deterministic = True\n", 345 | " # Disable cuDNN's auto-tuner for consistent behavior\n", 346 | " torch.backends.cudnn.benchmark = False\n", 347 | "\n", 348 | "def build_prompt(text):\n", 349 | " \"\"\"\n", 350 | " Creates a standardized prompt for emotion classification.\n", 351 | "\n", 352 | " Args:\n", 353 | " text (str): Input text to classify\n", 354 | "\n", 355 | " Returns:\n", 356 | " str: Formatted prompt for the model\n", 357 | " \"\"\"\n", 358 | " # Format the input text into a consistent prompt structure\n", 359 | " return f\"Predict the emotion for the following text: {text}\\nEmotion:\"\n", 360 | "\n", 361 | "def encode_text(tokenizer, text, return_tensor=False):\n", 362 | " \"\"\"\n", 363 | " Encodes text using the provided tokenizer.\n", 364 | "\n", 365 | " Args:\n", 366 | " tokenizer: Hugging Face tokenizer\n", 367 | " text (str): Text to encode\n", 368 | " return_tensor (bool): Whether to return PyTorch tensor\n", 369 | "\n", 370 | " Returns:\n", 371 | " List or tensor of token IDs\n", 372 | " \"\"\"\n", 373 | " # If tensor output is requested, encode with PyTorch tensors\n", 374 | " if return_tensor:\n", 375 | " return tokenizer.encode(\n", 376 | " text, add_special_tokens=False, return_tensors=\"pt\"\n", 377 | " )\n", 378 | " # Otherwise return list of token IDs\n", 379 | " else:\n", 380 | " return tokenizer.encode(text, add_special_tokens=False)\n", 381 | "\n", 382 | "def decode_text(tokenizer, token_ids):\n", 383 | " \"\"\"\n", 384 | " Decodes token IDs back to text.\n", 385 | "\n", 386 | " Args:\n", 387 | " tokenizer: Hugging Face tokenizer\n", 388 | " token_ids: List or tensor of token IDs\n", 389 | "\n", 390 | " Returns:\n", 391 | " str: Decoded text\n", 392 | " \"\"\"\n", 393 | " # Convert token IDs back to text, skipping special tokens\n", 394 | " return tokenizer.decode(token_ids, skip_special_tokens=True)\n", 395 | "\n", 396 | "class PromptCompletionDataset(Dataset):\n", 397 | " \"\"\"\n", 398 | " PyTorch Dataset for prompt-completion pairs.\n", 399 | " Handles the conversion of text data into model-ready format.\n", 400 | "\n", 401 | " Args:\n", 402 | " data (list): List of dictionaries containing prompts and completions\n", 403 | " tokenizer: Hugging Face tokenizer\n", 404 | " \"\"\"\n", 405 | " def __init__(self, data, tokenizer):\n", 406 | " # Store the raw data and tokenizer for later use\n", 407 | " self.data = data\n", 408 | " self.tokenizer = tokenizer\n", 409 | "\n", 410 | " def __len__(self):\n", 411 | " # Return the total number of examples in the dataset\n", 412 | " return len(self.data)\n", 413 | "\n", 414 | " def __getitem__(self, idx):\n", 415 | " \"\"\"\n", 416 | " Returns a single training example.\n", 417 | "\n", 418 | " Args:\n", 419 | " idx (int): Index of the example to fetch\n", 420 | "\n", 421 | " Returns:\n", 422 | " dict: Contains input_ids, labels, prompt, and expected completion\n", 423 | " \"\"\"\n", 424 | " # Get the specific example from our dataset\n", 425 | " item = self.data[idx]\n", 426 | " prompt = item[\"prompt\"]\n", 427 | " completion = item[\"completion\"]\n", 428 | "\n", 429 | " # Convert text to token IDs for both prompt and completion\n", 430 | " encoded_prompt = encode_text(self.tokenizer, prompt)\n", 431 | " encoded_completion = encode_text(self.tokenizer, completion)\n", 432 | " # Get the end-of-sequence token ID\n", 433 | " eos_token = self.tokenizer.eos_token_id\n", 434 | "\n", 435 | " # Combine prompt and completion tokens with EOS token\n", 436 | " input_ids = encoded_prompt + encoded_completion + [eos_token]\n", 437 | " # Create labels: -100 for prompt (ignored in loss), completion tokens for learning\n", 438 | " labels = [-100] * len(encoded_prompt) + encoded_completion + [eos_token]\n", 439 | "\n", 440 | " return {\n", 441 | " \"input_ids\": input_ids,\n", 442 | " \"labels\": labels,\n", 443 | " \"prompt\": prompt,\n", 444 | " \"expected_completion\": completion\n", 445 | " }\n", 446 | "\n", 447 | "def collate_fn(batch):\n", 448 | " \"\"\"\n", 449 | " Collates batch of examples into training-ready format.\n", 450 | " Handles padding and conversion to tensors.\n", 451 | "\n", 452 | " Args:\n", 453 | " batch: List of examples from Dataset\n", 454 | "\n", 455 | " Returns:\n", 456 | " tuple: (input_ids, attention_mask, labels, prompts, expected_completions)\n", 457 | " \"\"\"\n", 458 | " # Find the longest sequence in the batch for padding\n", 459 | " max_length = max(len(item[\"input_ids\"]) for item in batch)\n", 460 | "\n", 461 | " # Pad input sequences to max_length with pad token\n", 462 | " input_ids = [\n", 463 | " item[\"input_ids\"] +\n", 464 | " [tokenizer.pad_token_id] * (max_length - len(item[\"input_ids\"]))\n", 465 | " for item in batch\n", 466 | " ]\n", 467 | "\n", 468 | " # Pad label sequences with -100 (ignored in loss calculation)\n", 469 | " labels = [\n", 470 | " item[\"labels\"] +\n", 471 | " [-100] * (max_length - len(item[\"labels\"]))\n", 472 | " for item in batch\n", 473 | " ]\n", 474 | "\n", 475 | " # Create attention masks: 1 for real tokens, 0 for padding\n", 476 | " attention_mask = [\n", 477 | " [1] * len(item[\"input_ids\"]) +\n", 478 | " [0] * (max_length - len(item[\"input_ids\"]))\n", 479 | " for item in batch\n", 480 | " ]\n", 481 | "\n", 482 | " # Keep original prompts and completions for evaluation\n", 483 | " prompts = [item[\"prompt\"] for item in batch]\n", 484 | " expected_completions = [item[\"expected_completion\"] for item in batch]\n", 485 | "\n", 486 | " # Convert everything to PyTorch tensors except text\n", 487 | " return (\n", 488 | " torch.tensor(input_ids),\n", 489 | " torch.tensor(attention_mask),\n", 490 | " torch.tensor(labels),\n", 491 | " prompts,\n", 492 | " expected_completions\n", 493 | " )\n", 494 | "\n", 495 | "def normalize_text(text):\n", 496 | " \"\"\"\n", 497 | " Normalizes text for consistent comparison.\n", 498 | "\n", 499 | " Args:\n", 500 | " text (str): Input text\n", 501 | "\n", 502 | " Returns:\n", 503 | " str: Normalized text\n", 504 | " \"\"\"\n", 505 | " # Remove leading/trailing whitespace and convert to lowercase\n", 506 | " text = text.strip().lower()\n", 507 | " # Replace multiple whitespace characters with single space\n", 508 | " text = re.sub(r\"\\s+\", ' ', text)\n", 509 | " return text\n", 510 | "\n", 511 | "def calculate_accuracy(model, tokenizer, loader):\n", 512 | " \"\"\"\n", 513 | " Calculates prediction accuracy on a dataset.\n", 514 | "\n", 515 | " Args:\n", 516 | " model: Finetuned model\n", 517 | " tokenizer: Associated tokenizer\n", 518 | " loader: DataLoader containing evaluation examples\n", 519 | "\n", 520 | " Returns:\n", 521 | " float: Accuracy score\n", 522 | " \"\"\"\n", 523 | " # Set model to evaluation mode\n", 524 | " model.eval()\n", 525 | " # Initialize counters for accuracy calculation\n", 526 | " correct = 0\n", 527 | " total = 0\n", 528 | "\n", 529 | " # Disable gradient computation for efficiency\n", 530 | " with torch.no_grad():\n", 531 | " for input_ids, attention_mask, labels, prompts, expected_completions in loader:\n", 532 | " for prompt, expected_completion in zip(prompts, expected_completions):\n", 533 | " # Generate model's prediction\n", 534 | " generated_text = generate_text(model, tokenizer, prompt)\n", 535 | " # Compare normalized versions of prediction and target\n", 536 | " if normalize_text(generated_text) == normalize_text(expected_completion):\n", 537 | " correct += 1\n", 538 | " total += 1\n", 539 | "\n", 540 | " # Calculate accuracy, handling empty dataset case\n", 541 | " accuracy = correct / total if total > 0 else 0\n", 542 | " # Reset model to training mode\n", 543 | " model.train()\n", 544 | " return accuracy\n", 545 | "\n", 546 | "def generate_text(model, tokenizer, prompt, max_new_tokens=50):\n", 547 | " \"\"\"\n", 548 | " Generates text completion for a given prompt.\n", 549 | "\n", 550 | " Args:\n", 551 | " model: Finetuned model\n", 552 | " tokenizer: Associated tokenizer\n", 553 | " prompt (str): Input prompt\n", 554 | " max_new_tokens (int): Maximum number of tokens to generate\n", 555 | "\n", 556 | " Returns:\n", 557 | " str: Generated completion\n", 558 | " \"\"\"\n", 559 | " # Encode prompt and move to model's device\n", 560 | " input_ids = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n", 561 | "\n", 562 | " # Generate completion using model's generate method\n", 563 | " output_ids = model.generate(\n", 564 | " input_ids=input_ids[\"input_ids\"],\n", 565 | " attention_mask=input_ids[\"attention_mask\"],\n", 566 | " max_new_tokens=max_new_tokens,\n", 567 | " pad_token_id=tokenizer.pad_token_id,\n", 568 | " eos_token_id=tokenizer.eos_token_id\n", 569 | " )[0]\n", 570 | "\n", 571 | " # Extract and decode only the generated part (excluding prompt)\n", 572 | " generated_text = decode_text(tokenizer, output_ids[input_ids[\"input_ids\"].shape[1]:])\n", 573 | " return generated_text.strip()\n", 574 | "\n", 575 | "def test_model(model_path, test_input):\n", 576 | " \"\"\"\n", 577 | " Tests a saved model on a single input.\n", 578 | "\n", 579 | " Args:\n", 580 | " model_path (str): Path to saved model\n", 581 | " test_input (str): Text to classify\n", 582 | " \"\"\"\n", 583 | " # Determine device (GPU if available, else CPU)\n", 584 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 585 | " print(f\"Using device: {device}\")\n", 586 | "\n", 587 | " # Load saved model and tokenizer\n", 588 | " model = AutoModelForCausalLM.from_pretrained(model_path).to(device)\n", 589 | " tokenizer = AutoTokenizer.from_pretrained(model_path)\n", 590 | "\n", 591 | " # Configure padding token\n", 592 | " if tokenizer.pad_token is None:\n", 593 | " tokenizer.pad_token = tokenizer.eos_token\n", 594 | " model.config.pad_token_id = tokenizer.pad_token_id\n", 595 | "\n", 596 | " # Generate and display prediction\n", 597 | " prompt = build_prompt(test_input)\n", 598 | " generated_text = generate_text(model, tokenizer, prompt)\n", 599 | "\n", 600 | " print(f\"Input: {test_input}\")\n", 601 | " print(f\"Generated emotion: {generated_text}\")\n", 602 | "\n", 603 | "def download_and_prepare_data(data_url, tokenizer, batch_size, test_ratio=0.1):\n", 604 | " \"\"\"\n", 605 | " Downloads and prepares dataset for training.\n", 606 | "\n", 607 | " Args:\n", 608 | " data_url (str): URL of the dataset\n", 609 | " tokenizer: Tokenizer for text processing\n", 610 | " batch_size (int): Batch size for DataLoader\n", 611 | " test_ratio (float): Proportion of data for testing\n", 612 | "\n", 613 | " Returns:\n", 614 | " tuple: (train_loader, test_loader)\n", 615 | " \"\"\"\n", 616 | " # Download and decompress dataset\n", 617 | " response = requests.get(data_url)\n", 618 | " content = gzip.decompress(response.content).decode()\n", 619 | "\n", 620 | " # Process each example into prompt-completion pairs\n", 621 | " dataset = []\n", 622 | " for entry in map(json.loads, content.splitlines()):\n", 623 | " dataset.append({\n", 624 | " \"prompt\": build_prompt(entry['text']),\n", 625 | " \"completion\": entry[\"label\"].strip()\n", 626 | " })\n", 627 | "\n", 628 | " # Split into train and test sets\n", 629 | " random.shuffle(dataset)\n", 630 | " split_index = int(len(dataset) * (1 - test_ratio))\n", 631 | " train_data = dataset[:split_index]\n", 632 | " test_data = dataset[split_index:]\n", 633 | "\n", 634 | " # Create datasets\n", 635 | " train_dataset = PromptCompletionDataset(train_data, tokenizer)\n", 636 | " test_dataset = PromptCompletionDataset(test_data, tokenizer)\n", 637 | "\n", 638 | " # Create data loaders\n", 639 | " train_loader = DataLoader(\n", 640 | " train_dataset,\n", 641 | " batch_size=batch_size,\n", 642 | " shuffle=True,\n", 643 | " collate_fn=collate_fn\n", 644 | " )\n", 645 | " test_loader = DataLoader(\n", 646 | " test_dataset,\n", 647 | " batch_size=batch_size,\n", 648 | " shuffle=False,\n", 649 | " collate_fn=collate_fn\n", 650 | " )\n", 651 | "\n", 652 | " return train_loader, test_loader\n", 653 | "\n", 654 | "def get_hyperparameters():\n", 655 | " \"\"\"\n", 656 | " Returns training hyperparameters.\n", 657 | "\n", 658 | " Returns:\n", 659 | " tuple: (num_epochs, batch_size, learning_rate)\n", 660 | " \"\"\"\n", 661 | " # Train for more epochs with LoRA as it's more efficient\n", 662 | " num_epochs = 18\n", 663 | " # Batch size\n", 664 | " batch_size = 16\n", 665 | " # Standard learning rate for finetuning transformers\n", 666 | " learning_rate = 5e-5\n", 667 | "\n", 668 | " return num_epochs, batch_size, learning_rate\n", 669 | "\n", 670 | "# Main training script\n", 671 | "if __name__ == \"__main__\":\n", 672 | " # Set random seeds for reproducibility\n", 673 | " set_seed(42)\n", 674 | "\n", 675 | " # Configure basic training parameters\n", 676 | " data_url = \"https://www.thelmbook.com/data/emotions\"\n", 677 | " model_name = \"openai-community/gpt2\"\n", 678 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 679 | " print(f\"Using device: {device}\")\n", 680 | "\n", 681 | " # Initialize tokenizer\n", 682 | " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 683 | " tokenizer.pad_token = tokenizer.eos_token\n", 684 | "\n", 685 | " # Configure LoRA parameters\n", 686 | " peft_config = LoraConfig(\n", 687 | " task_type = TaskType.CAUSAL_LM, # Set task type for causal language modeling\n", 688 | " inference_mode = False, # Enable training mode\n", 689 | " r = 16, # Rank of LoRA update matrices\n", 690 | " lora_alpha = 32 # LoRA scaling factor\n", 691 | " )\n", 692 | "\n", 693 | " # Load model and apply LoRA configuration\n", 694 | " model = AutoModelForCausalLM.from_pretrained(model_name).to(device)\n", 695 | " model = get_peft_model(model, peft_config)\n", 696 | "\n", 697 | " # Get hyperparameters and prepare data\n", 698 | " num_epochs, batch_size, learning_rate = get_hyperparameters()\n", 699 | " train_loader, test_loader = download_and_prepare_data(data_url, tokenizer, batch_size)\n", 700 | "\n", 701 | " # Initialize optimizer\n", 702 | " optimizer = AdamW(model.parameters(), lr=learning_rate)\n", 703 | "\n", 704 | " # Training loop\n", 705 | " for epoch in range(num_epochs):\n", 706 | " total_loss = 0\n", 707 | " num_batches = 0\n", 708 | " progress_bar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs}\")\n", 709 | "\n", 710 | " for input_ids, attention_mask, labels, _, _ in progress_bar:\n", 711 | " # Move batch to device\n", 712 | " input_ids = input_ids.to(device)\n", 713 | " attention_mask = attention_mask.to(device)\n", 714 | " labels = labels.to(device)\n", 715 | "\n", 716 | " # Forward pass\n", 717 | " outputs = model(\n", 718 | " input_ids=input_ids,\n", 719 | " attention_mask=attention_mask,\n", 720 | " labels=labels\n", 721 | " )\n", 722 | " loss = outputs.loss\n", 723 | "\n", 724 | " # Backward pass and optimization\n", 725 | " loss.backward()\n", 726 | " optimizer.step()\n", 727 | " optimizer.zero_grad()\n", 728 | "\n", 729 | " # Update metrics\n", 730 | " total_loss += loss.item()\n", 731 | " num_batches += 1\n", 732 | " progress_bar.set_postfix({\"Loss\": total_loss / num_batches})\n", 733 | "\n", 734 | " # Calculate and display epoch metrics\n", 735 | " avg_loss = total_loss / num_batches\n", 736 | " test_acc = calculate_accuracy(model, tokenizer, test_loader)\n", 737 | " print(f\"Epoch {epoch+1} - Average loss: {avg_loss:.4f}, Test accuracy: {test_acc:.4f}\")\n", 738 | "\n", 739 | " # Calculate final model performance\n", 740 | " train_acc = calculate_accuracy(model, tokenizer, train_loader)\n", 741 | " print(f\"Training accuracy: {train_acc:.4f}\")\n", 742 | " print(f\"Test accuracy: {test_acc:.4f}\")\n", 743 | "\n", 744 | " # Save the LoRA-tuned model and tokenizer\n", 745 | " model.save_pretrained(\"./finetuned_model\")\n", 746 | " tokenizer.save_pretrained(\"./finetuned_model\")\n", 747 | "\n", 748 | " # Test the finetuned model with a sample input\n", 749 | " test_input = \"I'm so happy to be able to finetune an LLM!\"\n", 750 | " test_model(\"./finetuned_model\", test_input)" 751 | ] 752 | } 753 | ], 754 | "metadata": { 755 | "colab": { 756 | "provenance": [], 757 | "gpuType": "A100", 758 | "authorship_tag": "ABX9TyOF3tQeAXCKCxUGZKAJ7kw5", 759 | "include_colab_link": true 760 | }, 761 | "kernelspec": { 762 | "display_name": "Python 3", 763 | "name": "python3" 764 | }, 765 | "language_info": { 766 | "name": "python" 767 | }, 768 | "accelerator": "GPU" 769 | }, 770 | "nbformat": 4, 771 | "nbformat_minor": 0 772 | } -------------------------------------------------------------------------------- /emotion_classifier_LR.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "source": [ 16 | "
\n", 17 | "
\n", 18 | " \n", 19 | " \n", 20 | " \n", 26 | " \n", 31 | " \n", 32 | "
\n", 21 | " \n", 22 | " A notebook for The Hundred-Page Language Models Book by Andriy Burkov

\n", 23 | " Code repository: https://github.com/aburkov/theLMbook\n", 24 | "
\n", 25 | "
\n", 27 | " \n", 28 | " \"The\n", 29 | " \n", 30 | "
\n", 33 | "
\n", 34 | "
" 35 | ], 36 | "metadata": { 37 | "id": "P9WVD-1ZAmYf" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "id": "yy0zjL_2ouOU", 48 | "outputId": "63da21d9-c17e-42fc-d9fd-aa1041b38dba" 49 | }, 50 | "outputs": [ 51 | { 52 | "output_type": "stream", 53 | "name": "stdout", 54 | "text": [ 55 | "Number of training examples: 18000\n", 56 | "Number of test examples: 2000\n", 57 | "\n", 58 | "Train accuracy: 0.9854\n", 59 | "Test accuracy: 0.8855\n", 60 | "\n", 61 | "--- Better hyperparameters ---\n", 62 | "Train accuracy: 0.9962\n", 63 | "Test accuracy: 0.8910\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "# Import required libraries\n", 69 | "import gzip # For decompressing gzipped data files\n", 70 | "import json # For parsing JSON-formatted data\n", 71 | "import random # For shuffling dataset and setting seeds\n", 72 | "import requests # For downloading dataset from URL\n", 73 | "from sklearn.feature_extraction.text import CountVectorizer # Text vectorization utility\n", 74 | "from sklearn.linear_model import LogisticRegression # Logistic regression model\n", 75 | "from sklearn.metrics import accuracy_score # For model evaluation\n", 76 | "\n", 77 | "# ----------------------------\n", 78 | "# Utility Functions\n", 79 | "# ----------------------------\n", 80 | "\n", 81 | "def set_seed(seed):\n", 82 | " \"\"\"\n", 83 | " Sets random seed for reproducibility.\n", 84 | "\n", 85 | " Args:\n", 86 | " seed (int): Seed value for random number generation\n", 87 | " \"\"\"\n", 88 | " random.seed(seed)\n", 89 | "\n", 90 | "def download_and_split_data(data_url, test_ratio=0.1):\n", 91 | " \"\"\"\n", 92 | " Downloads emotion classification dataset from URL and splits into train/test sets.\n", 93 | " Handles decompression and JSON parsing of the raw data.\n", 94 | "\n", 95 | " Args:\n", 96 | " data_url (str): URL of the gzipped JSON dataset\n", 97 | " test_ratio (float): Proportion of data to use for testing (default: 0.1)\n", 98 | "\n", 99 | " Returns:\n", 100 | " tuple: (X_train, y_train, X_test, y_test) containing:\n", 101 | " - X_train, X_test: Lists of text examples for training and testing\n", 102 | " - y_train, y_test: Lists of corresponding emotion labels\n", 103 | " \"\"\"\n", 104 | " # Download and decompress the dataset\n", 105 | " response = requests.get(data_url)\n", 106 | " content = gzip.decompress(response.content).decode()\n", 107 | "\n", 108 | " # Parse JSON lines into list of dictionaries\n", 109 | " dataset = [json.loads(line) for line in content.splitlines()]\n", 110 | "\n", 111 | " # Shuffle dataset for random split\n", 112 | " random.shuffle(dataset)\n", 113 | "\n", 114 | " # Split into train and test sets\n", 115 | " split_index = int(len(dataset) * (1 - test_ratio))\n", 116 | " train, test = dataset[:split_index], dataset[split_index:]\n", 117 | "\n", 118 | " # Separate text and labels\n", 119 | " X_train = [item[\"text\"] for item in train]\n", 120 | " y_train = [item[\"label\"] for item in train]\n", 121 | " X_test = [item[\"text\"] for item in test]\n", 122 | " y_test = [item[\"label\"] for item in test]\n", 123 | "\n", 124 | " return X_train, y_train, X_test, y_test\n", 125 | "\n", 126 | "# ----------------------------\n", 127 | "# Main Execution\n", 128 | "# ----------------------------\n", 129 | "\n", 130 | "# Set random seed for reproducibility\n", 131 | "set_seed(42)\n", 132 | "\n", 133 | "# Download and prepare dataset\n", 134 | "data_url = \"https://www.thelmbook.com/data/emotions\"\n", 135 | "X_train_text, y_train, X_test_text, y_test = download_and_split_data(\n", 136 | " data_url, test_ratio=0.1\n", 137 | ")\n", 138 | "\n", 139 | "print(\"Number of training examples:\", len(X_train_text))\n", 140 | "print(\"Number of test examples:\", len(X_test_text))\n", 141 | "\n", 142 | "# ----------------------------\n", 143 | "# Baseline Model\n", 144 | "# ----------------------------\n", 145 | "\n", 146 | "# Initialize text vectorizer with basic parameters\n", 147 | "# max_features=10_000: Limit vocabulary to top 10k most frequent words\n", 148 | "# binary=True: Convert counts to binary indicators (0/1)\n", 149 | "vectorizer = CountVectorizer(max_features=10_000, binary=True)\n", 150 | "\n", 151 | "# Transform text data to numerical features\n", 152 | "X_train = vectorizer.fit_transform(X_train_text)\n", 153 | "X_test = vectorizer.transform(X_test_text)\n", 154 | "\n", 155 | "# Initialize and train logistic regression model\n", 156 | "model = LogisticRegression(random_state=42, max_iter=1000)\n", 157 | "model.fit(X_train, y_train)\n", 158 | "\n", 159 | "# Make predictions on train and test sets\n", 160 | "y_train_pred = model.predict(X_train)\n", 161 | "y_test_pred = model.predict(X_test)\n", 162 | "\n", 163 | "# Calculate and display accuracy metrics\n", 164 | "train_accuracy = accuracy_score(y_train, y_train_pred)\n", 165 | "test_accuracy = accuracy_score(y_test, y_test_pred)\n", 166 | "\n", 167 | "print(f\"\\nTrain accuracy: {train_accuracy:.4f}\")\n", 168 | "print(f\"Test accuracy: {test_accuracy:.4f}\")\n", 169 | "\n", 170 | "# ----------------------------\n", 171 | "# Improved Model\n", 172 | "# ----------------------------\n", 173 | "\n", 174 | "print(\"\\n--- Better hyperparameters ---\")\n", 175 | "\n", 176 | "# Initialize vectorizer with improved parameters\n", 177 | "# max_features=20000: Increased vocabulary size\n", 178 | "# ngram_range=(1, 2): Include both unigrams and bigrams\n", 179 | "vectorizer = CountVectorizer(max_features=20000, ngram_range=(1, 2))\n", 180 | "\n", 181 | "# Transform text data with new vectorizer\n", 182 | "X_train = vectorizer.fit_transform(X_train_text)\n", 183 | "X_test = vectorizer.transform(X_test_text)\n", 184 | "\n", 185 | "# Train and evaluate model with same parameters\n", 186 | "model = LogisticRegression(random_state=42, max_iter=1000)\n", 187 | "model.fit(X_train, y_train)\n", 188 | "\n", 189 | "y_train_pred = model.predict(X_train)\n", 190 | "y_test_pred = model.predict(X_test)\n", 191 | "\n", 192 | "train_accuracy = accuracy_score(y_train, y_train_pred)\n", 193 | "test_accuracy = accuracy_score(y_test, y_test_pred)\n", 194 | "\n", 195 | "print(f\"Train accuracy: {train_accuracy:.4f}\")\n", 196 | "print(f\"Test accuracy: {test_accuracy:.4f}\")" 197 | ] 198 | } 199 | ], 200 | "metadata": { 201 | "colab": { 202 | "provenance": [], 203 | "authorship_tag": "ABX9TyNKH13VydD5aNyFkNnbyP3F", 204 | "include_colab_link": true 205 | }, 206 | "kernelspec": { 207 | "display_name": "Python 3", 208 | "name": "python3" 209 | }, 210 | "language_info": { 211 | "name": "python" 212 | } 213 | }, 214 | "nbformat": 4, 215 | "nbformat_minor": 0 216 | } -------------------------------------------------------------------------------- /instruct_GPT2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "source": [ 16 | "
\n", 17 | "
\n", 18 | " \n", 19 | " \n", 20 | " \n", 26 | " \n", 31 | " \n", 32 | "
\n", 21 | " \n", 22 | " A notebook for The Hundred-Page Language Models Book by Andriy Burkov

\n", 23 | " Code repository: https://github.com/aburkov/theLMbook\n", 24 | "
\n", 25 | "
\n", 27 | " \n", 28 | " \"The\n", 29 | " \n", 30 | "
\n", 33 | "
\n", 34 | "
" 35 | ], 36 | "metadata": { 37 | "id": "Wu6Fr-_WuSMZ" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "id": "yy0zjL_2ouOU", 48 | "outputId": "fa6f8e9f-283b-4362-d1e1-c6a47fd7d74b" 49 | }, 50 | "outputs": [ 51 | { 52 | "output_type": "stream", 53 | "name": "stdout", 54 | "text": [ 55 | "Using device: cuda\n", 56 | "\n", 57 | "Dataset size: 510\n", 58 | "Training samples: 459\n", 59 | "Test samples: 51\n" 60 | ] 61 | }, 62 | { 63 | "output_type": "stream", 64 | "name": "stderr", 65 | "text": [ 66 | "Epoch 1/4: 100%|██████████| 29/29 [00:03<00:00, 7.61it/s, Loss=1.62]\n" 67 | ] 68 | }, 69 | { 70 | "output_type": "stream", 71 | "name": "stdout", 72 | "text": [ 73 | "Epoch 1 - Average loss: 1.6233\n" 74 | ] 75 | }, 76 | { 77 | "output_type": "stream", 78 | "name": "stderr", 79 | "text": [ 80 | "Epoch 2/4: 100%|██████████| 29/29 [00:03<00:00, 7.80it/s, Loss=1.02]\n" 81 | ] 82 | }, 83 | { 84 | "output_type": "stream", 85 | "name": "stdout", 86 | "text": [ 87 | "Epoch 2 - Average loss: 1.0221\n" 88 | ] 89 | }, 90 | { 91 | "output_type": "stream", 92 | "name": "stderr", 93 | "text": [ 94 | "Epoch 3/4: 100%|██████████| 29/29 [00:03<00:00, 7.83it/s, Loss=0.688]\n" 95 | ] 96 | }, 97 | { 98 | "output_type": "stream", 99 | "name": "stdout", 100 | "text": [ 101 | "Epoch 3 - Average loss: 0.6885\n" 102 | ] 103 | }, 104 | { 105 | "output_type": "stream", 106 | "name": "stderr", 107 | "text": [ 108 | "Epoch 4/4: 100%|██████████| 29/29 [00:03<00:00, 7.74it/s, Loss=0.422]\n" 109 | ] 110 | }, 111 | { 112 | "output_type": "stream", 113 | "name": "stdout", 114 | "text": [ 115 | "Epoch 4 - Average loss: 0.4222\n", 116 | "\n", 117 | "Testing finetuned model:\n", 118 | "Using device: cuda\n", 119 | "\n", 120 | "Input: Who is the President of the United States?\n", 121 | "Full generated text: George W. Bush\n", 122 | "<|im_end|>\n", 123 | "Cleaned response: George W. Bush\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "# Import required libraries\n", 129 | "import json # For parsing JSON data\n", 130 | "import random # For setting seeds and shuffling data\n", 131 | "import requests # For downloading dataset from URL\n", 132 | "import torch # Main PyTorch library\n", 133 | "from torch.utils.data import Dataset, DataLoader # For dataset handling\n", 134 | "from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria # HuggingFace components\n", 135 | "from tqdm import tqdm # Progress bar utilities\n", 136 | "import re # For text normalization\n", 137 | "\n", 138 | "def set_seed(seed):\n", 139 | " \"\"\"\n", 140 | " Sets random seeds for reproducibility across different libraries.\n", 141 | "\n", 142 | " Args:\n", 143 | " seed (int): Seed value for random number generation\n", 144 | " \"\"\"\n", 145 | " # Set Python's built-in random seed\n", 146 | " random.seed(seed)\n", 147 | " # Set PyTorch's CPU random seed\n", 148 | " torch.manual_seed(seed)\n", 149 | " # Set seed for all available GPUs\n", 150 | " torch.cuda.manual_seed_all(seed)\n", 151 | " # Request cuDNN to use deterministic algorithms\n", 152 | " torch.backends.cudnn.deterministic = True\n", 153 | " # Disable cuDNN's auto-tuner for consistent behavior\n", 154 | " torch.backends.cudnn.benchmark = False\n", 155 | "\n", 156 | "def build_prompt(instruction, solution=None):\n", 157 | " \"\"\"\n", 158 | " Creates a chat-formatted prompt with system, user, and assistant messages.\n", 159 | "\n", 160 | " Args:\n", 161 | " instruction (str): User's instruction/question\n", 162 | " solution (str, optional): Expected response for training\n", 163 | "\n", 164 | " Returns:\n", 165 | " str: Formatted prompt string\n", 166 | " \"\"\"\n", 167 | " # Add solution with end token if provided\n", 168 | " wrapped_solution = \"\"\n", 169 | " if solution:\n", 170 | " wrapped_solution = f\"\\n{solution}\\n<|im_end|>\"\n", 171 | "\n", 172 | " # Build chat format with system, user, and assistant messages\n", 173 | " return f\"\"\"<|im_start|>system\n", 174 | "You are a helpful assistant.\n", 175 | "<|im_end|>\n", 176 | "<|im_start|>user\n", 177 | "{instruction}\n", 178 | "<|im_end|>\n", 179 | "<|im_start|>assistant\"\"\" + wrapped_solution\n", 180 | "\n", 181 | "def encode_text(tokenizer, text, return_tensor=False):\n", 182 | " \"\"\"\n", 183 | " Encodes text using the provided tokenizer.\n", 184 | "\n", 185 | " Args:\n", 186 | " tokenizer: Hugging Face tokenizer\n", 187 | " text (str): Text to encode\n", 188 | " return_tensor (bool): Whether to return PyTorch tensor\n", 189 | "\n", 190 | " Returns:\n", 191 | " List or tensor of token IDs\n", 192 | " \"\"\"\n", 193 | " # If tensor output is requested, encode with PyTorch tensors\n", 194 | " if return_tensor:\n", 195 | " return tokenizer.encode(\n", 196 | " text, add_special_tokens=False, return_tensors=\"pt\"\n", 197 | " )\n", 198 | " # Otherwise return list of token IDs\n", 199 | " else:\n", 200 | " return tokenizer.encode(text, add_special_tokens=False)\n", 201 | "\n", 202 | "class EndTokenStoppingCriteria(StoppingCriteria):\n", 203 | " \"\"\"\n", 204 | " Custom stopping criteria for text generation.\n", 205 | " Stops when a specific end token sequence is generated.\n", 206 | "\n", 207 | " Args:\n", 208 | " end_tokens (list): Token IDs that signal generation should stop\n", 209 | " device: Device where the model is running\n", 210 | " \"\"\"\n", 211 | " def __init__(self, end_tokens, device):\n", 212 | " self.end_tokens = torch.tensor(end_tokens).to(device)\n", 213 | "\n", 214 | " def __call__(self, input_ids, scores):\n", 215 | " \"\"\"\n", 216 | " Checks if generation should stop for each sequence.\n", 217 | "\n", 218 | " Args:\n", 219 | " input_ids: Current generated token IDs\n", 220 | " scores: Token probabilities\n", 221 | "\n", 222 | " Returns:\n", 223 | " tensor: Boolean tensor indicating which sequences should stop\n", 224 | " \"\"\"\n", 225 | " should_stop = []\n", 226 | "\n", 227 | " # Check each sequence for end tokens\n", 228 | " for sequence in input_ids:\n", 229 | " if len(sequence) >= len(self.end_tokens):\n", 230 | " # Compare last tokens with end tokens\n", 231 | " last_tokens = sequence[-len(self.end_tokens):]\n", 232 | " should_stop.append(torch.all(last_tokens == self.end_tokens))\n", 233 | " else:\n", 234 | " should_stop.append(False)\n", 235 | "\n", 236 | " return torch.tensor(should_stop, device=input_ids.device)\n", 237 | "\n", 238 | "class PromptCompletionDataset(Dataset):\n", 239 | " \"\"\"\n", 240 | " PyTorch Dataset for instruction-completion pairs.\n", 241 | " Handles the conversion of text data into model-ready format.\n", 242 | "\n", 243 | " Args:\n", 244 | " data (list): List of dictionaries containing instructions and solutions\n", 245 | " tokenizer: Hugging Face tokenizer\n", 246 | " \"\"\"\n", 247 | " def __init__(self, data, tokenizer):\n", 248 | " self.data = data\n", 249 | " self.tokenizer = tokenizer\n", 250 | "\n", 251 | " def __len__(self):\n", 252 | " # Return total number of examples\n", 253 | " return len(self.data)\n", 254 | "\n", 255 | " def __getitem__(self, idx):\n", 256 | " \"\"\"\n", 257 | " Returns a single training example.\n", 258 | "\n", 259 | " Args:\n", 260 | " idx (int): Index of the example to fetch\n", 261 | "\n", 262 | " Returns:\n", 263 | " dict: Contains input_ids, labels, prompt, and expected completion\n", 264 | " \"\"\"\n", 265 | " # Get example from dataset\n", 266 | " item = self.data[idx]\n", 267 | " # Build full prompt with instruction\n", 268 | " prompt = build_prompt(item[\"instruction\"])\n", 269 | " # Format completion with end token\n", 270 | " completion = f\"\"\"{item[\"solution\"]}\\n<|im_end|>\"\"\"\n", 271 | "\n", 272 | " # Convert text to token IDs\n", 273 | " encoded_prompt = encode_text(self.tokenizer, prompt)\n", 274 | " encoded_completion = encode_text(self.tokenizer, completion)\n", 275 | " eos_token = [self.tokenizer.eos_token_id]\n", 276 | "\n", 277 | " # Combine for full input sequence\n", 278 | " input_ids = encoded_prompt + encoded_completion + eos_token\n", 279 | " # Create labels: -100 for prompt (ignored in loss)\n", 280 | " labels = [-100] * len(encoded_prompt) + encoded_completion + eos_token\n", 281 | "\n", 282 | " return {\n", 283 | " \"input_ids\": input_ids,\n", 284 | " \"labels\": labels,\n", 285 | " \"prompt\": prompt,\n", 286 | " \"expected_completion\": completion\n", 287 | " }\n", 288 | "\n", 289 | "def collate_fn(batch):\n", 290 | " \"\"\"\n", 291 | " Collates batch of examples into training-ready format.\n", 292 | " Handles padding and conversion to tensors.\n", 293 | "\n", 294 | " Args:\n", 295 | " batch: List of examples from Dataset\n", 296 | "\n", 297 | " Returns:\n", 298 | " tuple: (input_ids, attention_mask, labels, prompts, expected_completions)\n", 299 | " \"\"\"\n", 300 | " # Find longest sequence for padding\n", 301 | " max_length = max(len(item[\"input_ids\"]) for item in batch)\n", 302 | "\n", 303 | " # Pad input sequences\n", 304 | " input_ids = [\n", 305 | " item[\"input_ids\"] +\n", 306 | " [tokenizer.pad_token_id] * (max_length - len(item[\"input_ids\"]))\n", 307 | " for item in batch\n", 308 | " ]\n", 309 | " # Pad label sequences\n", 310 | " labels = [\n", 311 | " item[\"labels\"] +\n", 312 | " [-100] * (max_length - len(item[\"labels\"]))\n", 313 | " for item in batch\n", 314 | " ]\n", 315 | " # Create attention masks\n", 316 | " attention_mask = [\n", 317 | " [1] * len(item[\"input_ids\"]) +\n", 318 | " [0] * (max_length - len(item[\"input_ids\"]))\n", 319 | " for item in batch\n", 320 | " ]\n", 321 | " prompts = [item[\"prompt\"] for item in batch]\n", 322 | " expected_completions = [item[\"expected_completion\"] for item in batch]\n", 323 | "\n", 324 | " return (\n", 325 | " torch.tensor(input_ids),\n", 326 | " torch.tensor(attention_mask),\n", 327 | " torch.tensor(labels),\n", 328 | " prompts,\n", 329 | " expected_completions\n", 330 | " )\n", 331 | "\n", 332 | "def normalize_text(text):\n", 333 | " \"\"\"\n", 334 | " Normalizes text for consistent comparison.\n", 335 | "\n", 336 | " Args:\n", 337 | " text (str): Input text\n", 338 | "\n", 339 | " Returns:\n", 340 | " str: Normalized text\n", 341 | " \"\"\"\n", 342 | " # Remove leading/trailing whitespace and convert to lowercase\n", 343 | " text = text.strip().lower()\n", 344 | " # Replace multiple whitespace characters with single space\n", 345 | " text = re.sub(r'\\s+', ' ', text)\n", 346 | " return text\n", 347 | "\n", 348 | "def generate_text(model, tokenizer, prompt, max_new_tokens=100):\n", 349 | " \"\"\"\n", 350 | " Generates text completion for a given prompt.\n", 351 | "\n", 352 | " Args:\n", 353 | " model: Fine-tuned model\n", 354 | " tokenizer: Associated tokenizer\n", 355 | " prompt (str): Input prompt\n", 356 | " max_new_tokens (int): Maximum number of tokens to generate\n", 357 | "\n", 358 | " Returns:\n", 359 | " str: Generated completion\n", 360 | " \"\"\"\n", 361 | " # Encode prompt and move to model's device\n", 362 | " input_ids = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n", 363 | "\n", 364 | " # Setup end token detection\n", 365 | " end_tokens = tokenizer.encode(\"<|im_end|>\", add_special_tokens=False)\n", 366 | " stopping_criteria = [EndTokenStoppingCriteria(end_tokens, model.device)]\n", 367 | "\n", 368 | " # Generate completion\n", 369 | " output_ids = model.generate(\n", 370 | " input_ids=input_ids[\"input_ids\"],\n", 371 | " attention_mask=input_ids[\"attention_mask\"],\n", 372 | " max_new_tokens=max_new_tokens,\n", 373 | " pad_token_id=tokenizer.pad_token_id,\n", 374 | " stopping_criteria=stopping_criteria\n", 375 | " )[0]\n", 376 | "\n", 377 | " # Extract and decode only the generated part\n", 378 | " generated_ids = output_ids[input_ids[\"input_ids\"].shape[1]:]\n", 379 | " generated_text = tokenizer.decode(generated_ids).strip()\n", 380 | " return generated_text\n", 381 | "\n", 382 | "def test_model(model_path, test_input):\n", 383 | " \"\"\"\n", 384 | " Tests a saved model on a single input.\n", 385 | "\n", 386 | " Args:\n", 387 | " model_path (str): Path to saved model\n", 388 | " test_input (str): Instruction to test\n", 389 | " \"\"\"\n", 390 | " # Setup device and load model\n", 391 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 392 | " print(f\"Using device: {device}\")\n", 393 | "\n", 394 | " # Load model and tokenizer\n", 395 | " model = AutoModelForCausalLM.from_pretrained(model_path).to(device)\n", 396 | " tokenizer = AutoTokenizer.from_pretrained(model_path)\n", 397 | " tokenizer.pad_token = tokenizer.eos_token\n", 398 | "\n", 399 | " # Generate and display prediction\n", 400 | " prompt = build_prompt(test_input)\n", 401 | " generated_text = generate_text(model, tokenizer, prompt)\n", 402 | "\n", 403 | " print(f\"\\nInput: {test_input}\")\n", 404 | " print(f\"Full generated text: {generated_text}\")\n", 405 | " print(f\"\"\"Cleaned response: {generated_text.replace(\"<|im_end|>\", \"\").strip()}\"\"\")\n", 406 | "\n", 407 | "def download_and_prepare_data(data_url, tokenizer, batch_size, test_ratio=0.1):\n", 408 | " \"\"\"\n", 409 | " Downloads and prepares dataset for training.\n", 410 | "\n", 411 | " Args:\n", 412 | " data_url (str): URL of the dataset\n", 413 | " tokenizer: Tokenizer for text processing\n", 414 | " batch_size (int): Batch size for DataLoader\n", 415 | " test_ratio (float): Proportion of data for testing\n", 416 | "\n", 417 | " Returns:\n", 418 | " tuple: (train_loader, test_loader)\n", 419 | " \"\"\"\n", 420 | " # Download dataset\n", 421 | " response = requests.get(data_url)\n", 422 | " dataset = []\n", 423 | " # Parse each line as an instruction-solution pair\n", 424 | " for line in response.text.splitlines():\n", 425 | " if line.strip(): # Skip empty lines\n", 426 | " entry = json.loads(line)\n", 427 | " dataset.append({\n", 428 | " \"instruction\": entry[\"instruction\"],\n", 429 | " \"solution\": entry[\"solution\"]\n", 430 | " })\n", 431 | "\n", 432 | " # Split into train and test sets\n", 433 | " random.shuffle(dataset)\n", 434 | " split_index = int(len(dataset) * (1 - test_ratio))\n", 435 | " train_data = dataset[:split_index]\n", 436 | " test_data = dataset[split_index:]\n", 437 | "\n", 438 | " # Print dataset statistics\n", 439 | " print(f\"\\nDataset size: {len(dataset)}\")\n", 440 | " print(f\"Training samples: {len(train_data)}\")\n", 441 | " print(f\"Test samples: {len(test_data)}\")\n", 442 | "\n", 443 | " # Create datasets\n", 444 | " train_dataset = PromptCompletionDataset(train_data, tokenizer)\n", 445 | " test_dataset = PromptCompletionDataset(test_data, tokenizer)\n", 446 | "\n", 447 | " # Create dataloaders\n", 448 | " train_loader = DataLoader(\n", 449 | " train_dataset,\n", 450 | " batch_size=batch_size,\n", 451 | " shuffle=True,\n", 452 | " collate_fn=collate_fn\n", 453 | " )\n", 454 | " test_loader = DataLoader(\n", 455 | " test_dataset,\n", 456 | " batch_size=batch_size,\n", 457 | " shuffle=False,\n", 458 | " collate_fn=collate_fn\n", 459 | " )\n", 460 | "\n", 461 | " return train_loader, test_loader\n", 462 | "\n", 463 | "def get_hyperparameters():\n", 464 | " \"\"\"\n", 465 | " Returns training hyperparameters.\n", 466 | "\n", 467 | " Returns:\n", 468 | " tuple: (num_epochs, batch_size, learning_rate)\n", 469 | " \"\"\"\n", 470 | " # Fewer epochs for instruction tuning as it's more data-efficient\n", 471 | " num_epochs = 4\n", 472 | " # Standard batch size that works well with most GPU memory\n", 473 | " batch_size = 16\n", 474 | " # Standard learning rate for fine-tuning transformers\n", 475 | " learning_rate = 5e-5\n", 476 | "\n", 477 | " return num_epochs, batch_size, learning_rate\n", 478 | "\n", 479 | "# Main training script\n", 480 | "if __name__ == \"__main__\":\n", 481 | " # Set random seed for reproducibility\n", 482 | " set_seed(42)\n", 483 | "\n", 484 | " # Configure training parameters\n", 485 | " data_url = \"https://www.thelmbook.com/data/instruct\"\n", 486 | " model_name = \"openai-community/gpt2\"\n", 487 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 488 | " print(f\"Using device: {device}\")\n", 489 | "\n", 490 | " # Initialize tokenizer and model\n", 491 | " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 492 | " tokenizer.pad_token = tokenizer.eos_token\n", 493 | "\n", 494 | " model = AutoModelForCausalLM.from_pretrained(model_name).to(device)\n", 495 | "\n", 496 | " # Get hyperparameters and prepare data\n", 497 | " num_epochs, batch_size, learning_rate = get_hyperparameters()\n", 498 | " train_loader, test_loader = download_and_prepare_data(data_url, tokenizer, batch_size)\n", 499 | "\n", 500 | " # Initialize optimizer\n", 501 | " optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", 502 | "\n", 503 | " # Training loop\n", 504 | " for epoch in range(num_epochs):\n", 505 | " total_loss = 0\n", 506 | " num_batches = 0\n", 507 | " progress_bar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{num_epochs}\")\n", 508 | "\n", 509 | " for input_ids, attention_mask, labels, _, _ in progress_bar:\n", 510 | " # Move batch to device\n", 511 | " input_ids = input_ids.to(device)\n", 512 | " attention_mask = attention_mask.to(device)\n", 513 | " labels = labels.to(device)\n", 514 | "\n", 515 | " # Forward pass\n", 516 | " outputs = model(\n", 517 | " input_ids=input_ids,\n", 518 | " attention_mask=attention_mask,\n", 519 | " labels=labels\n", 520 | " )\n", 521 | " loss = outputs.loss\n", 522 | "\n", 523 | " # Backward pass and optimization\n", 524 | " loss.backward()\n", 525 | " optimizer.step()\n", 526 | " optimizer.zero_grad()\n", 527 | "\n", 528 | " # Update metrics\n", 529 | " total_loss += loss.item()\n", 530 | " num_batches += 1\n", 531 | "\n", 532 | " progress_bar.set_postfix({\"Loss\": total_loss / num_batches})\n", 533 | "\n", 534 | " # Display epoch metrics\n", 535 | " avg_loss = total_loss / num_batches\n", 536 | " print(f\"Epoch {epoch+1} - Average loss: {avg_loss:.4f}\")\n", 537 | "\n", 538 | " # Save the fine-tuned model\n", 539 | " model.save_pretrained(\"./finetuned_model\")\n", 540 | " tokenizer.save_pretrained(\"./finetuned_model\")\n", 541 | "\n", 542 | " # Test the model\n", 543 | " print(\"\\nTesting finetuned model:\")\n", 544 | " test_input = \"Who is the President of the United States?\"\n", 545 | " test_model(\"./finetuned_model\", test_input)" 546 | ] 547 | } 548 | ], 549 | "metadata": { 550 | "colab": { 551 | "provenance": [], 552 | "gpuType": "A100", 553 | "authorship_tag": "ABX9TyMNhBkBP6E26XdVthfpawAc", 554 | "include_colab_link": true 555 | }, 556 | "kernelspec": { 557 | "display_name": "Python 3", 558 | "name": "python3" 559 | }, 560 | "language_info": { 561 | "name": "python" 562 | }, 563 | "accelerator": "GPU" 564 | }, 565 | "nbformat": 4, 566 | "nbformat_minor": 0 567 | } -------------------------------------------------------------------------------- /quadratic_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import matplotlib.colors as colors 5 | 6 | # Define the quadratic loss function J(w, b) 7 | # This function calculates the mean squared error for three data points: 8 | # (150, 200), (200, 600), and (260, 500) 9 | def calculate_loss(w, b): 10 | return ( 11 | ((150*w + b - 200)**2 + 12 | (200*w + b - 600)**2 + 13 | (260*w + b - 500)**2) / 3 14 | ) 15 | 16 | # Set up the plot parameters 17 | plt.rcParams['font.size'] = 16 18 | 19 | # Generate parameter space for w and b 20 | w_values = np.linspace(-10, 10, 400) 21 | b_values = np.linspace(-1000, 1000, 400) 22 | W, B = np.meshgrid(w_values, b_values) 23 | Z = calculate_loss(W, B) 24 | 25 | # Create custom colormap 26 | colors_palette = [ 27 | '#4a90e2', # Blue 28 | '#f8e71c', # Yellow 29 | '#ff6b6b' # Coral/Red 30 | ] 31 | custom_cmap = colors.LinearSegmentedColormap.from_list('custom', colors_palette) 32 | 33 | # Create and setup the 3D plot 34 | fig = plt.figure(figsize=(10, 8)) 35 | ax = fig.add_subplot(111, projection='3d') 36 | 37 | # Plot the surface 38 | surface = ax.plot_surface(W, B, Z, cmap=custom_cmap) 39 | 40 | # Set labels and adjust plot appearance 41 | ax.set_xlabel('$w$', fontsize=16) 42 | ax.set_ylabel('$b$', fontsize=16) 43 | ax.set_zlabel('$J(w,b)$', fontsize=16) 44 | ax.set_box_aspect(aspect=None, zoom=0.95) 45 | 46 | # Adjust layout and display plot 47 | plt.tight_layout() 48 | plt.show() -------------------------------------------------------------------------------- /sampling_method.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "source": [ 16 | "
\n", 17 | "
\n", 18 | " \n", 19 | " \n", 20 | " \n", 26 | " \n", 31 | " \n", 32 | "
\n", 21 | " \n", 22 | " A notebook for The Hundred-Page Language Models Book by Andriy Burkov

\n", 23 | " Code repository: https://github.com/aburkov/theLMbook\n", 24 | "
\n", 25 | "
\n", 27 | " \n", 28 | " \"The\n", 29 | " \n", 30 | "
\n", 33 | "
\n", 34 | "
" 35 | ], 36 | "metadata": { 37 | "id": "YMLqrwiuulzT" 38 | } 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "source": [ 43 | "# Token sampling method\n", 44 | "\n", 45 | "## Method implementation\n", 46 | "\n", 47 | "In the cell below, we implement the token sampling method that combines temperature, top-k, and top-p:" 48 | ], 49 | "metadata": { 50 | "id": "kb9Akwe7xttX" 51 | } 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "id": "yy0zjL_2ouOU", 58 | "colab": { 59 | "base_uri": "https://localhost:8080/" 60 | }, 61 | "outputId": "46ce8acf-0762-4e64-fc08-0711e330c5b1" 62 | }, 63 | "outputs": [ 64 | { 65 | "output_type": "stream", 66 | "name": "stdout", 67 | "text": [ 68 | "Test vocabulary: ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'lazy', 'dog']\n", 69 | "Initial logits: [ 2. 1.5 1. 0.5 0. -0.5 -1. -1.5]\n", 70 | "\n", 71 | "Sampling with different parameters:\n", 72 | "\n", 73 | "Test 1: Default parameters (temperature=0.7, no top-k/p filtering)\n", 74 | "Samples: ['the', 'the', 'fox', 'brown', 'brown']\n", 75 | "\n", 76 | "Test 2: High temperature (temperature=2.0)\n", 77 | "Samples: ['jumps', 'quick', 'dog', 'the', 'the']\n", 78 | "\n", 79 | "Test 3: Low temperature (temperature=0.2)\n", 80 | "Samples: ['quick', 'the', 'the', 'the', 'the']\n", 81 | "\n", 82 | "Test 4: Top-k filtering (top_k=3)\n", 83 | "Samples: ['the', 'the', 'quick', 'the', 'the']\n", 84 | "\n", 85 | "Test 5: Top-p filtering (top_p=0.9)\n", 86 | "Samples: ['the', 'the', 'brown', 'the', 'quick']\n", 87 | "\n", 88 | "Test 6: Combined filtering (temperature=0.5, top_k=3, top_p=0.9)\n", 89 | "Samples: ['the', 'the', 'the', 'the', 'the']\n", 90 | "\n", 91 | "Error handling examples:\n", 92 | "Expected error: Mismatch between logits and vocabulary sizes.\n", 93 | "Expected error: Temperature must be positive.\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "import numpy as np\n", 99 | "\n", 100 | "def validate_inputs(logits, vocabulary, temperature, top_k, top_p):\n", 101 | " \"\"\"\n", 102 | " Validate all input parameters for the token sampling process.\n", 103 | "\n", 104 | " Args:\n", 105 | " logits (list): Raw model output scores for each token\n", 106 | " vocabulary (list): List of all possible tokens\n", 107 | " temperature (float): Temperature parameter for logits scaling\n", 108 | " top_k (int): Number of highest probability tokens to keep\n", 109 | " top_p (float): Cumulative probability threshold for nucleus sampling\n", 110 | "\n", 111 | " Raises:\n", 112 | " ValueError: If any parameters are invalid or out of expected ranges\n", 113 | " \"\"\"\n", 114 | " if len(logits) != len(vocabulary):\n", 115 | " raise ValueError(\"Mismatch between logits and vocabulary sizes.\")\n", 116 | " if temperature <= 0:\n", 117 | " raise ValueError(\"Temperature must be positive.\")\n", 118 | " if top_k < 0 or top_k > len(logits):\n", 119 | " raise ValueError(\"top_k must be between 0 and len(logits).\")\n", 120 | " if not 0 < top_p <= 1:\n", 121 | " raise ValueError(\"top_p must be in the range (0, 1].\")\n", 122 | "\n", 123 | "def get_token_counts(prev_tokens, vocabulary):\n", 124 | " \"\"\"\n", 125 | " Count the frequency of each token in the previous generation history.\n", 126 | "\n", 127 | " Args:\n", 128 | " prev_tokens (list): Previously generated tokens\n", 129 | " vocabulary (list): List of all possible tokens\n", 130 | "\n", 131 | " Returns:\n", 132 | " dict: Mapping of token indices to their frequencies\n", 133 | " \"\"\"\n", 134 | " token_counts = {}\n", 135 | " if prev_tokens is not None:\n", 136 | " for token in prev_tokens:\n", 137 | " if token in vocabulary:\n", 138 | " idx = vocabulary.index(token)\n", 139 | " token_counts[idx] = token_counts.get(idx, 0) + 1\n", 140 | " return token_counts\n", 141 | "\n", 142 | "def apply_presence_penalty(logits, token_counts, presence_penalty):\n", 143 | " \"\"\"\n", 144 | " Apply presence penalty to tokens that have appeared before.\n", 145 | "\n", 146 | " Args:\n", 147 | " logits (numpy.ndarray): Token logits\n", 148 | " token_counts (dict): Mapping of token indices to their frequencies\n", 149 | " presence_penalty (float): Fixed penalty to subtract from logits of present tokens\n", 150 | "\n", 151 | " Returns:\n", 152 | " numpy.ndarray: Modified logits with presence penalty applied\n", 153 | "\n", 154 | " Note:\n", 155 | " Unlike frequency penalty, this applies the same penalty regardless of frequency\n", 156 | " \"\"\"\n", 157 | " if presence_penalty != 0.0:\n", 158 | " for idx in token_counts:\n", 159 | " logits[idx] -= presence_penalty\n", 160 | " return logits\n", 161 | "\n", 162 | "def apply_frequency_penalty(logits, token_counts, frequency_penalty):\n", 163 | " \"\"\"\n", 164 | " Apply frequency penalty proportional to token occurrence count.\n", 165 | "\n", 166 | " Args:\n", 167 | " logits (numpy.ndarray): Token logits\n", 168 | " token_counts (dict): Mapping of token indices to their frequencies\n", 169 | " frequency_penalty (float): Penalty factor to multiply with token frequency\n", 170 | "\n", 171 | " Returns:\n", 172 | " numpy.ndarray: Modified logits with frequency penalty applied\n", 173 | "\n", 174 | " Note:\n", 175 | " Penalty increases linearly with token frequency\n", 176 | " \"\"\"\n", 177 | " if frequency_penalty != 0.0:\n", 178 | " for idx, count in token_counts.items():\n", 179 | " logits[idx] -= frequency_penalty * count\n", 180 | " return logits\n", 181 | "\n", 182 | "def apply_temperature(logits, temperature):\n", 183 | " \"\"\"\n", 184 | " Apply temperature scaling to logits to control randomness.\n", 185 | "\n", 186 | " Args:\n", 187 | " logits (numpy.ndarray): Token logits\n", 188 | " temperature (float): Temperature parameter (>1 increases randomness, <1 decreases it)\n", 189 | "\n", 190 | " Returns:\n", 191 | " numpy.ndarray: Temperature-scaled and normalized logits\n", 192 | "\n", 193 | " Note:\n", 194 | " - Higher temperature makes distribution more uniform\n", 195 | " - Lower temperature makes distribution more peaked\n", 196 | " - Normalizes by subtracting max for numerical stability\n", 197 | " \"\"\"\n", 198 | " if temperature != 1.0:\n", 199 | " logits = logits / temperature\n", 200 | " return logits - np.max(logits)\n", 201 | "\n", 202 | "def apply_top_k_filtering(logits, top_k, min_tokens_to_keep=1):\n", 203 | " \"\"\"\n", 204 | " Apply top-k filtering to keep only the k highest probability tokens.\n", 205 | "\n", 206 | " Args:\n", 207 | " logits (numpy.ndarray): Token logits\n", 208 | " top_k (int): Number of top tokens to keep\n", 209 | " min_tokens_to_keep (int): Minimum number of tokens to keep regardless of top-k\n", 210 | "\n", 211 | " Returns:\n", 212 | " numpy.ndarray: Modified logits with all but top-k tokens set to -inf\n", 213 | "\n", 214 | " Note:\n", 215 | " Ensures at least min_tokens_to_keep tokens remain available for sampling\n", 216 | " \"\"\"\n", 217 | " if top_k > 0:\n", 218 | " indices_to_remove = np.argsort(logits)[:-min_tokens_to_keep]\n", 219 | " indices_to_keep = np.argsort(logits)[-top_k:]\n", 220 | " for idx in indices_to_remove:\n", 221 | " if idx not in indices_to_keep:\n", 222 | " logits[idx] = float('-inf')\n", 223 | " return logits\n", 224 | "\n", 225 | "def apply_top_p_filtering(logits, top_p, min_tokens_to_keep=1):\n", 226 | " \"\"\"\n", 227 | " Apply nucleus (top-p) filtering to keep tokens comprising top p probability mass.\n", 228 | "\n", 229 | " Args:\n", 230 | " logits (numpy.ndarray): Token logits\n", 231 | " top_p (float): Cumulative probability threshold (0 to 1)\n", 232 | " min_tokens_to_keep (int): Minimum number of tokens to keep regardless of top-p\n", 233 | "\n", 234 | " Returns:\n", 235 | " numpy.ndarray: Modified logits with unlikely tokens set to -inf\n", 236 | "\n", 237 | " Note:\n", 238 | " 1. Converts logits to probabilities\n", 239 | " 2. Sorts tokens by probability\n", 240 | " 3. Keeps minimal set of tokens whose cumulative probability >= top_p\n", 241 | " 4. Ensures at least min_tokens_to_keep tokens remain\n", 242 | " \"\"\"\n", 243 | " if top_p < 1.0:\n", 244 | " probs = np.exp(logits)\n", 245 | " probs = probs / probs.sum()\n", 246 | "\n", 247 | " sorted_indices = np.argsort(probs)[::-1]\n", 248 | " sorted_probs = probs[sorted_indices]\n", 249 | " cumulative_probs = np.cumsum(sorted_probs)\n", 250 | "\n", 251 | " sorted_indices_to_remove = sorted_indices[cumulative_probs > top_p]\n", 252 | "\n", 253 | " if len(sorted_indices_to_remove) > len(sorted_indices) - min_tokens_to_keep:\n", 254 | " sorted_indices_to_remove = sorted_indices_to_remove[\n", 255 | " :len(sorted_indices) - min_tokens_to_keep\n", 256 | " ]\n", 257 | "\n", 258 | " logits[sorted_indices_to_remove] = float('-inf')\n", 259 | " return logits\n", 260 | "\n", 261 | "def convert_to_probabilities(logits):\n", 262 | " \"\"\"\n", 263 | " Convert logits to a valid probability distribution using softmax.\n", 264 | "\n", 265 | " Args:\n", 266 | " logits (numpy.ndarray): Token logits\n", 267 | "\n", 268 | " Returns:\n", 269 | " numpy.ndarray: Probability distribution summing to 1\n", 270 | " \"\"\"\n", 271 | " probs = np.exp(logits)\n", 272 | " return probs / probs.sum()\n", 273 | "\n", 274 | "def sample_token(logits, vocabulary, temperature=0.7, top_k=0, top_p=1.0,\n", 275 | " repetition_penalty=1.0, presence_penalty=0.0, frequency_penalty=0.0,\n", 276 | " prev_tokens=None):\n", 277 | " \"\"\"\n", 278 | " Main function for sampling the next token using various sampling strategies.\n", 279 | " Applies sampling methods in the same order as the transformers library.\n", 280 | "\n", 281 | " Args:\n", 282 | " logits (list): Raw model output scores for each token\n", 283 | " vocabulary (list): List of all possible tokens\n", 284 | " temperature (float): Temperature for logits scaling (default: 0.7)\n", 285 | " top_k (int): Number of highest probability tokens to keep (default: 0, disabled)\n", 286 | " top_p (float): Cumulative probability threshold for nucleus sampling (default: 1.0)\n", 287 | " repetition_penalty (float): Penalty for repeated tokens (default: 1.0, no penalty)\n", 288 | " presence_penalty (float): Fixed penalty for token presence (default: 0.0)\n", 289 | " frequency_penalty (float): Penalty scaled by token frequency (default: 0.0)\n", 290 | " prev_tokens (list): Previously generated tokens (default: None)\n", 291 | "\n", 292 | " Returns:\n", 293 | " str: Sampled token from vocabulary\n", 294 | "\n", 295 | " Process:\n", 296 | " 1. Validate all input parameters\n", 297 | " 2. Apply repetition, presence, and frequency penalties\n", 298 | " 3. Apply temperature scaling\n", 299 | " 4. Apply top-k and top-p filtering\n", 300 | " 5. Convert to probability distribution and sample\n", 301 | " \"\"\"\n", 302 | " validate_inputs(logits, vocabulary, temperature, top_k, top_p)\n", 303 | "\n", 304 | " logits = np.array(logits, dtype=np.float64)\n", 305 | "\n", 306 | " # 1. Apply penalties\n", 307 | " token_counts = get_token_counts(prev_tokens, vocabulary)\n", 308 | " logits = apply_presence_penalty(logits, token_counts, presence_penalty)\n", 309 | " logits = apply_frequency_penalty(logits, token_counts, frequency_penalty)\n", 310 | "\n", 311 | " # 2. Apply temperature scaling\n", 312 | " logits = apply_temperature(logits, temperature)\n", 313 | "\n", 314 | " # 3. Apply filtering\n", 315 | " logits = apply_top_k_filtering(logits, top_k)\n", 316 | " logits = apply_top_p_filtering(logits, top_p)\n", 317 | "\n", 318 | " # 4. Convert to probabilities and sample\n", 319 | " probabilities = convert_to_probabilities(logits)\n", 320 | " return np.random.choice(vocabulary, p=probabilities)\n", 321 | "\n", 322 | "if __name__ == \"__main__\":\n", 323 | " # Create a test vocabulary and corresponding logits\n", 324 | " vocabulary = [\"the\", \"quick\", \"brown\", \"fox\", \"jumps\", \"over\", \"lazy\", \"dog\"]\n", 325 | " logits = np.array([2.0, 1.5, 1.0, 0.5, 0.0, -0.5, -1.0, -1.5])\n", 326 | "\n", 327 | " print(\"Test vocabulary:\", vocabulary)\n", 328 | " print(\"Initial logits:\", logits)\n", 329 | " print(\"\\nSampling with different parameters:\")\n", 330 | "\n", 331 | " # Test 1: Default parameters\n", 332 | " print(\"\\nTest 1: Default parameters (temperature=0.7, no top-k/p filtering)\")\n", 333 | " samples = [sample_token(logits.copy(), vocabulary) for _ in range(5)]\n", 334 | " print(\"Samples:\", samples)\n", 335 | "\n", 336 | " # Test 2: High temperature (more random)\n", 337 | " print(\"\\nTest 2: High temperature (temperature=2.0)\")\n", 338 | " samples = [sample_token(logits.copy(), vocabulary, temperature=2.0) for _ in range(5)]\n", 339 | " print(\"Samples:\", samples)\n", 340 | "\n", 341 | " # Test 3: Low temperature (more deterministic)\n", 342 | " print(\"\\nTest 3: Low temperature (temperature=0.2)\")\n", 343 | " samples = [sample_token(logits.copy(), vocabulary, temperature=0.2) for _ in range(5)]\n", 344 | " print(\"Samples:\", samples)\n", 345 | "\n", 346 | " # Test 4: Top-k filtering\n", 347 | " print(\"\\nTest 4: Top-k filtering (top_k=3)\")\n", 348 | " samples = [sample_token(logits.copy(), vocabulary, top_k=3) for _ in range(5)]\n", 349 | " print(\"Samples:\", samples)\n", 350 | "\n", 351 | " # Test 5: Top-p filtering\n", 352 | " print(\"\\nTest 5: Top-p filtering (top_p=0.9)\")\n", 353 | " samples = [sample_token(logits.copy(), vocabulary, top_p=0.9) for _ in range(5)]\n", 354 | " print(\"Samples:\", samples)\n", 355 | "\n", 356 | " # Test 6: Combined filtering\n", 357 | " print(\"\\nTest 6: Combined filtering (temperature=0.5, top_k=3, top_p=0.9)\")\n", 358 | " samples = [sample_token(logits.copy(), vocabulary, temperature=0.5, top_k=3, top_p=0.9)\n", 359 | " for _ in range(5)]\n", 360 | " print(\"Samples:\", samples)\n", 361 | "\n", 362 | " # Demonstrate error handling\n", 363 | " print(\"\\nError handling examples:\")\n", 364 | " try:\n", 365 | " # Test with mismatched sizes\n", 366 | " sample_token(logits[:5], vocabulary)\n", 367 | " except ValueError as e:\n", 368 | " print(\"Expected error:\", e)\n", 369 | "\n", 370 | " try:\n", 371 | " # Test with invalid temperature\n", 372 | " sample_token(logits, vocabulary, temperature=0)\n", 373 | " except ValueError as e:\n", 374 | " print(\"Expected error:\", e)" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": { 381 | "id": "s61ovCwawq3f" 382 | }, 383 | "outputs": [], 384 | "source": [] 385 | } 386 | ], 387 | "metadata": { 388 | "colab": { 389 | "provenance": [], 390 | "authorship_tag": "ABX9TyNHhXoTui8A3Sg3rEKzvmNi", 391 | "include_colab_link": true 392 | }, 393 | "kernelspec": { 394 | "display_name": "Python 3", 395 | "name": "python3" 396 | }, 397 | "language_info": { 398 | "name": "python" 399 | } 400 | }, 401 | "nbformat": 4, 402 | "nbformat_minor": 0 403 | } -------------------------------------------------------------------------------- /spotify_gemini_playlist.py: -------------------------------------------------------------------------------- 1 | # Read instructions here: https://x.com/burkov/status/1921303279562064098 2 | 3 | import os 4 | import json 5 | import random 6 | import time 7 | import spotipy 8 | from spotipy.oauth2 import SpotifyOAuth 9 | import requests 10 | from dotenv import load_dotenv 11 | 12 | # --- Configuration --- 13 | load_dotenv() 14 | 15 | SPOTIFY_CLIENT_ID = os.getenv("SPOTIPY_CLIENT_ID") 16 | SPOTIFY_CLIENT_SECRET = os.getenv("SPOTIPY_CLIENT_SECRET") 17 | SPOTIFY_REDIRECT_URI = os.getenv("SPOTIPY_REDIRECT_URI") 18 | OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") 19 | 20 | SCOPES = "user-library-read playlist-modify-public playlist-read-private playlist-read-collaborative" 21 | NEW_PLAYLIST_NAME = "New Gemini Recommendations" 22 | ALL_RECS_PLAYLIST_NAME = "All Gemini Recommendations" 23 | 24 | TARGET_NEW_SONGS_COUNT = 20 25 | MAX_GEMINI_ATTEMPTS = 10 # Increased attempts as we ask for exactly 20 each time 26 | MAX_SONGS_TO_GEMINI_PROMPT = 200 # Max liked songs for the initial Gemini prompt 27 | 28 | GEMINI_MODEL = "google/gemini-2.5-flash-preview" 29 | # Ensure this model ID is valid. If not, try "google/gemini-flash-1.5" 30 | 31 | # --- Helper Functions --- 32 | 33 | def get_spotify_client(): 34 | auth_manager = SpotifyOAuth( 35 | client_id=SPOTIFY_CLIENT_ID, 36 | client_secret=SPOTIFY_CLIENT_SECRET, 37 | redirect_uri=SPOTIFY_REDIRECT_URI, 38 | scope=SCOPES, 39 | cache_path=".spotify_cache" 40 | ) 41 | sp = spotipy.Spotify(auth_manager=auth_manager) 42 | print("Successfully authenticated with Spotify.") 43 | return sp 44 | 45 | def get_all_liked_songs_details(sp): 46 | print("Fetching all liked songs details...") 47 | liked_songs_details = [] 48 | offset = 0 49 | limit = 50 50 | while True: 51 | try: 52 | results = sp.current_user_saved_tracks(limit=limit, offset=offset) 53 | if not results or not results['items']: 54 | break 55 | for item in results['items']: 56 | track = item.get('track') 57 | if track and track.get('name') and track.get('artists'): 58 | if track['artists']: # Ensure artist list is not empty 59 | artist_name = track['artists'][0]['name'] 60 | liked_songs_details.append({"track": track['name'], "artist": artist_name}) 61 | offset += limit 62 | print(f"Fetched {len(liked_songs_details)} liked songs so far...") 63 | if not results.get('next'): 64 | break 65 | time.sleep(0.05) 66 | except Exception as e: 67 | print(f"Error fetching liked songs page: {e}") 68 | break 69 | print(f"Total liked songs details fetched: {len(liked_songs_details)}") 70 | return liked_songs_details 71 | 72 | def get_playlist_by_name(sp, playlist_name, user_id): 73 | playlists = sp.current_user_playlists(limit=50) 74 | while playlists: 75 | for playlist in playlists['items']: 76 | if playlist['name'] == playlist_name and playlist['owner']['id'] == user_id: 77 | return playlist 78 | if playlists['next']: 79 | playlists = sp.next(playlists) 80 | time.sleep(0.05) 81 | else: 82 | playlists = None 83 | return None 84 | 85 | def get_or_create_playlist_id(sp, user_id, playlist_name, public=True): 86 | playlist_object = get_playlist_by_name(sp, playlist_name, user_id) 87 | if playlist_object: 88 | print(f"Found existing playlist: '{playlist_name}' (ID: {playlist_object['id']})") 89 | return playlist_object['id'] 90 | else: 91 | print(f"Playlist '{playlist_name}' not found. Creating it...") 92 | try: 93 | new_playlist = sp.user_playlist_create(user=user_id, name=playlist_name, public=public) 94 | print(f"Successfully created playlist: '{playlist_name}' (ID: {new_playlist['id']})") 95 | return new_playlist['id'] 96 | except Exception as e: 97 | print(f"Error creating playlist '{playlist_name}': {e}") 98 | return None 99 | 100 | def get_playlist_tracks_simplified(sp, playlist_id): 101 | if not playlist_id: return [] 102 | print(f"Fetching tracks from playlist ID: {playlist_id}...") 103 | playlist_tracks = [] 104 | offset = 0 105 | limit = 100 106 | while True: 107 | try: 108 | results = sp.playlist_items(playlist_id, limit=limit, offset=offset, fields="items(track(name,artists(name))),next") 109 | if not results or not results['items']: break 110 | for item in results['items']: 111 | track_info = item.get('track') 112 | if track_info and track_info.get('name') and track_info.get('artists'): 113 | if track_info['artists']: 114 | artist_name = track_info['artists'][0]['name'] 115 | playlist_tracks.append({"track": track_info['name'], "artist": artist_name}) 116 | offset += limit 117 | print(f"Fetched {len(playlist_tracks)} tracks from playlist ID {playlist_id} so far...") 118 | if not results.get('next'): break 119 | time.sleep(0.05) 120 | except Exception as e: 121 | print(f"Error fetching playlist items for {playlist_id}: {e}") 122 | break 123 | print(f"Total tracks fetched from playlist ID {playlist_id}: {len(playlist_tracks)}") 124 | return playlist_tracks 125 | 126 | def get_gemini_recommendations(api_key, conversation_history): 127 | """ 128 | Sends the conversation history to Gemini and requests recommendations. 129 | Returns a tuple: (parsed_recommendations_list, raw_assistant_response_content_string) 130 | The last message in conversation_history is assumed to be the current user prompt. 131 | """ 132 | print(f"\nSending request to Gemini with {len(conversation_history)} messages in history...") 133 | if not conversation_history or conversation_history[-1]["role"] != "user": 134 | print("Error: Conversation history is empty or does not end with a user message.") 135 | return [], None 136 | 137 | try: 138 | response = requests.post( 139 | url="https://openrouter.ai/api/v1/chat/completions", 140 | headers={ 141 | "Authorization": f"Bearer {api_key}", 142 | "Content-Type": "application/json" 143 | }, 144 | json={ 145 | "model": GEMINI_MODEL, 146 | "messages": conversation_history, 147 | "response_format": {"type": "json_object"} 148 | }, 149 | timeout=60 # Increased timeout for potentially longer LLM responses 150 | ) 151 | response.raise_for_status() 152 | 153 | response_data = response.json() 154 | raw_assistant_response_content = response_data['choices'][0]['message']['content'] 155 | 156 | recommendations = [] 157 | try: 158 | parsed_content = json.loads(raw_assistant_response_content) 159 | if isinstance(parsed_content, list): 160 | recommendations = parsed_content 161 | elif isinstance(parsed_content, dict) and len(parsed_content.keys()) == 1: 162 | key = list(parsed_content.keys())[0] 163 | if isinstance(parsed_content[key], list): 164 | recommendations = parsed_content[key] 165 | except json.JSONDecodeError: 166 | print("Gemini response was not directly parsable JSON. Attempting to clean...") 167 | content_to_parse = raw_assistant_response_content 168 | if content_to_parse.startswith("```json"): content_to_parse = content_to_parse[7:] 169 | if content_to_parse.endswith("```"): content_to_parse = content_to_parse[:-3] 170 | content_to_parse = content_to_parse.strip() 171 | try: 172 | parsed_content = json.loads(content_to_parse) 173 | if isinstance(parsed_content, list): recommendations = parsed_content 174 | elif isinstance(parsed_content, dict) and len(parsed_content.keys()) == 1: 175 | key = list(parsed_content.keys())[0] 176 | if isinstance(parsed_content[key], list): recommendations = parsed_content[key] 177 | except json.JSONDecodeError as e_clean: 178 | print(f"Error: Gemini response could not be parsed as JSON even after cleaning: {e_clean}") 179 | print(f"Gemini Raw Response Content:\n{raw_assistant_response_content}") 180 | return [], raw_assistant_response_content # Return raw content for history even on parse error 181 | 182 | valid_recommendations = [] 183 | for rec in recommendations: 184 | if isinstance(rec, dict) and "track" in rec and "artist" in rec: 185 | valid_recommendations.append({"track": str(rec["track"]), "artist": str(rec["artist"])}) 186 | else: 187 | print(f"Warning: Skipping invalid recommendation format from Gemini: {rec}") 188 | 189 | print(f"Received {len(valid_recommendations)} validly structured recommendations from Gemini.") 190 | return valid_recommendations, raw_assistant_response_content 191 | 192 | except requests.exceptions.RequestException as e: 193 | print(f"Error calling OpenRouter API: {e}") 194 | if hasattr(e, 'response') and e.response is not None: 195 | print(f"Response status: {e.response.status_code}") 196 | try: print(f"Response content: {e.response.json()}") 197 | except json.JSONDecodeError: print(f"Response content: {e.response.text}") 198 | return [], None 199 | except (KeyError, IndexError) as e: 200 | raw_resp_text = response.text if 'response' in locals() else 'No response object' 201 | print(f"Error parsing Gemini response structure: {e}") 202 | print(f"Gemini Raw Response (full): {raw_resp_text}") 203 | return [], None 204 | 205 | 206 | def verify_songs_on_spotify_v2(sp, recommended_songs_details): 207 | print("\nVerifying recommended songs on Spotify...") 208 | available_songs_info = [] 209 | for song_detail in recommended_songs_details: 210 | track_name = song_detail.get('track') 211 | artist_name = song_detail.get('artist') 212 | if not track_name or not artist_name: continue 213 | query = f"track:{track_name} artist:{artist_name}" 214 | try: 215 | results = sp.search(q=query, type="track", limit=1) 216 | time.sleep(0.05) 217 | if results and results['tracks']['items']: 218 | found_track = results['tracks']['items'][0] 219 | available_songs_info.append({ 220 | "uri": found_track['uri'], 221 | "track": found_track['name'], 222 | "artist": found_track['artists'][0]['name'] 223 | }) 224 | print(f" Found on Spotify: '{found_track['name']}' by {found_track['artists'][0]['name']}") 225 | else: 226 | print(f" Not found on Spotify: '{track_name}' by {artist_name}") 227 | except Exception as e: 228 | print(f" Error searching for '{track_name}' by {artist_name}: {e}") 229 | print(f"\nVerified {len(available_songs_info)} songs as available on Spotify.") 230 | return available_songs_info 231 | 232 | def update_playlist_items(sp, playlist_id, track_uris, replace=False): 233 | if not playlist_id: return False 234 | if not track_uris and not replace: return True 235 | if not track_uris and replace: 236 | try: 237 | sp.playlist_replace_items(playlist_id, []) 238 | print(f"Cleared all items from playlist ID {playlist_id}.") 239 | return True 240 | except Exception as e: print(f"Error clearing playlist {playlist_id}: {e}"); return False 241 | 242 | action = "Replacing" if replace else "Adding" 243 | print(f"{action} {len(track_uris)} songs for playlist ID {playlist_id}...") 244 | try: 245 | if replace: 246 | # Spotipy's playlist_replace_items handles batching internally up to 100. 247 | # For >100, it might still be one call to Spotify API that errors, 248 | # or spotipy might make multiple calls. 249 | # Let's stick to safer manual batching if >100 for replace. 250 | if len(track_uris) <= 100: 251 | sp.playlist_replace_items(playlist_id, track_uris) 252 | else: 253 | sp.playlist_replace_items(playlist_id, []) # Clear 254 | for i in range(0, len(track_uris), 100): 255 | sp.playlist_add_items(playlist_id, track_uris[i:i + 100]) 256 | time.sleep(0.1) 257 | else: # Appending 258 | for i in range(0, len(track_uris), 100): 259 | sp.playlist_add_items(playlist_id, track_uris[i:i + 100]) 260 | time.sleep(0.1) 261 | print(f"Successfully {action.lower()}ed songs in playlist ID {playlist_id}.") 262 | return True 263 | except Exception as e: 264 | print(f"Error {action.lower()}ing songs in playlist {playlist_id}: {e}") 265 | return False 266 | 267 | # --- Main Execution --- 268 | if __name__ == "__main__": 269 | if not all([SPOTIFY_CLIENT_ID, SPOTIFY_CLIENT_SECRET, SPOTIFY_REDIRECT_URI, OPENROUTER_API_KEY]): 270 | print("Error: Missing environment variables. Please check .env file."); exit(1) 271 | 272 | sp_client = get_spotify_client() 273 | if not sp_client: exit(1) 274 | 275 | user_info = sp_client.me() 276 | user_id = user_info['id'] 277 | print(f"Logged in as: {user_info.get('display_name', user_id)}") 278 | 279 | # 1. Get ALL liked songs and create a set for filtering 280 | all_my_liked_songs_details = get_all_liked_songs_details(sp_client) 281 | if not all_my_liked_songs_details: 282 | print("No liked songs found. Exiting."); exit() 283 | 284 | all_my_liked_songs_set = set() 285 | for song_detail in all_my_liked_songs_details: 286 | track = song_detail.get('track', "").strip().lower() 287 | artist = song_detail.get('artist', "").strip().lower() 288 | if track and artist: 289 | all_my_liked_songs_set.add((track, artist)) 290 | print(f"Created set of {len(all_my_liked_songs_set)} unique liked songs for de-duplication.") 291 | 292 | # 2. Shuffle liked songs and take a sample for the initial Gemini prompt 293 | random.shuffle(all_my_liked_songs_details) 294 | sample_liked_songs_for_gemini_prompt = all_my_liked_songs_details[:MAX_SONGS_TO_GEMINI_PROMPT] 295 | 296 | # Get "All Gemini Recommendations" playlist history 297 | all_recs_playlist_id = get_or_create_playlist_id(sp_client, user_id, ALL_RECS_PLAYLIST_NAME) 298 | existing_all_recs_songs_details = [] 299 | if all_recs_playlist_id: 300 | existing_all_recs_songs_details = get_playlist_tracks_simplified(sp_client, all_recs_playlist_id) 301 | 302 | all_recs_history_set = set() # Stores (track.lower(), artist.lower()) from "All Gemini Recs" playlist 303 | for song_detail in existing_all_recs_songs_details: 304 | track = song_detail.get('track', "").strip().lower() 305 | artist = song_detail.get('artist', "").strip().lower() 306 | if track and artist: 307 | all_recs_history_set.add((track, artist)) 308 | print(f"Found {len(all_recs_history_set)} unique songs in '{ALL_RECS_PLAYLIST_NAME}' history.") 309 | 310 | # 3-5. Iteratively get new recommendations 311 | collected_new_songs_for_playlist_uris = [] 312 | collected_new_songs_for_playlist_details = [] # Stores Spotify-verified dicts for final playlist 313 | 314 | conversation_history = [] 315 | # Stores dicts {track, artist} of ALL songs Gemini suggests in this session (raw names from Gemini) 316 | # Used to tell Gemini what to avoid in follow-up prompts. 317 | all_gemini_suggestions_this_session_raw_details = [] 318 | 319 | # Initial user prompt for the very first message to Gemini 320 | liked_songs_prompt_str = "\n".join([f"- \"{s['track']}\" by {s['artist']}" for s in sample_liked_songs_for_gemini_prompt]) 321 | initial_user_prompt_content = f"""You are a music recommendation assistant. I will provide you with a list of songs I like. 322 | Based on this list, please recommend {TARGET_NEW_SONGS_COUNT} additional songs that I might enjoy. 323 | It's important that your response is ONLY a valid JSON array of objects, where each object has a "track" key (song title) and an "artist" key (artist name). 324 | Do not include any other text, explanations, or markdown formatting outside of the JSON array. 325 | 326 | Here are some songs I like: 327 | {liked_songs_prompt_str} 328 | 329 | Please provide {TARGET_NEW_SONGS_COUNT} new song recommendations in the specified JSON format.""" 330 | conversation_history.append({"role": "user", "content": initial_user_prompt_content}) 331 | 332 | 333 | for attempt in range(MAX_GEMINI_ATTEMPTS): 334 | if len(collected_new_songs_for_playlist_uris) >= TARGET_NEW_SONGS_COUNT: 335 | print("\nTarget number of new songs reached.") 336 | break 337 | 338 | print(f"\n--- Gemini Request Attempt {attempt + 1}/{MAX_GEMINI_ATTEMPTS} ---") 339 | 340 | # If this is not the first attempt, construct and add follow-up user message 341 | if attempt > 0: 342 | songs_suggested_by_gemini_this_session_str = "\n".join( 343 | [f"- \"{s['track']}\" by {s['artist']}" for s in all_gemini_suggestions_this_session_raw_details] 344 | ) 345 | if not songs_suggested_by_gemini_this_session_str: 346 | songs_suggested_by_gemini_this_session_str = "(None previously suggested in this session)" 347 | 348 | follow_up_user_prompt_content = f"""Okay, thank you. Now, please provide {TARGET_NEW_SONGS_COUNT} MORE unique song recommendations based on the initial list of songs I like (provided at the start of our conversation). 349 | It is very important that these new recommendations are different from any songs you've already suggested to me in this conversation. For reference, here are the songs you've suggested so far (please avoid these): 350 | {songs_suggested_by_gemini_this_session_str} 351 | 352 | Also, ensure these new recommendations are different from the initial list of liked songs I provided. 353 | Your response must be ONLY a valid JSON array of objects, with "track" and "artist" keys, as before.""" 354 | conversation_history.append({"role": "user", "content": follow_up_user_prompt_content}) 355 | # Prune conversation history if it gets too long (optional, depends on model limits) 356 | # For now, let it grow for a few turns. Gemini Flash has a decent context window. 357 | # if len(conversation_history) > 10: # Example: keep last 10 messages + initial prompt 358 | # conversation_history = [conversation_history[0]] + conversation_history[-9:] 359 | 360 | 361 | gemini_batch_recs_parsed, raw_assistant_response_str = get_gemini_recommendations( 362 | OPENROUTER_API_KEY, 363 | conversation_history 364 | ) 365 | 366 | if raw_assistant_response_str: # If Gemini responded, add its response to history 367 | conversation_history.append({"role": "assistant", "content": raw_assistant_response_str}) 368 | 369 | if not gemini_batch_recs_parsed: 370 | print("Gemini returned no valid recommendations in this batch or there was an API error.") 371 | if attempt < MAX_GEMINI_ATTEMPTS - 1: time.sleep(3) 372 | continue 373 | 374 | # Add raw suggestions from this Gemini batch to `all_gemini_suggestions_this_session_raw_details` 375 | # This list helps construct the "avoid these" part of the next follow-up prompt. 376 | for rec in gemini_batch_recs_parsed: # rec is dict {track, artist} 377 | all_gemini_suggestions_this_session_raw_details.append(rec) 378 | 379 | print(f"Gemini suggested {len(gemini_batch_recs_parsed)} songs. Verifying on Spotify and filtering...") 380 | verified_spotify_songs_this_batch = verify_songs_on_spotify_v2(sp_client, gemini_batch_recs_parsed) 381 | 382 | newly_added_this_turn_count = 0 383 | for verified_song_info in verified_spotify_songs_this_batch: # dict {'uri', 'track', 'artist'} 384 | if len(collected_new_songs_for_playlist_uris) >= TARGET_NEW_SONGS_COUNT: 385 | break 386 | 387 | # Use Spotify's canonical track/artist names for consistent checking 388 | spotify_track_name_lower = verified_song_info['track'].lower() 389 | spotify_artist_name_lower = verified_song_info['artist'].lower() 390 | spotify_song_key = (spotify_track_name_lower, spotify_artist_name_lower) 391 | 392 | is_liked = spotify_song_key in all_my_liked_songs_set 393 | is_in_all_recs_playlist_history = spotify_song_key in all_recs_history_set 394 | 395 | # Check if URI is already in the list we are building this session 396 | is_already_collected_for_new_playlist_this_session = any( 397 | vs['uri'] == verified_song_info['uri'] for vs in collected_new_songs_for_playlist_details 398 | ) 399 | 400 | if not is_liked and not is_in_all_recs_playlist_history and not is_already_collected_for_new_playlist_this_session: 401 | collected_new_songs_for_playlist_uris.append(verified_song_info['uri']) 402 | collected_new_songs_for_playlist_details.append(verified_song_info) 403 | newly_added_this_turn_count +=1 404 | print(f" ++ Collected for new playlist: '{verified_song_info['track']}' by '{verified_song_info['artist']}'") 405 | else: 406 | reason = [] 407 | if is_liked: reason.append("is liked") 408 | if is_in_all_recs_playlist_history: reason.append("in all_recs history") 409 | if is_already_collected_for_new_playlist_this_session: reason.append("already collected this session") 410 | print(f" -- Skipped '{verified_song_info['track']}' by '{verified_song_info['artist']}' (Reason: {', '.join(reason)})") 411 | 412 | print(f"Added {newly_added_this_turn_count} new songs this turn.") 413 | print(f"Total collected for new playlist so far: {len(collected_new_songs_for_playlist_uris)}/{TARGET_NEW_SONGS_COUNT}") 414 | 415 | if len(collected_new_songs_for_playlist_uris) >= TARGET_NEW_SONGS_COUNT: 416 | break 417 | elif attempt < MAX_GEMINI_ATTEMPTS -1 : 418 | time.sleep(2) # Pause before next Gemini attempt 419 | 420 | # --- End of iterative collection --- 421 | 422 | final_uris_for_new_playlist = collected_new_songs_for_playlist_uris[:TARGET_NEW_SONGS_COUNT] 423 | final_details_for_all_recs_update = collected_new_songs_for_playlist_details[:TARGET_NEW_SONGS_COUNT] 424 | 425 | if not final_uris_for_new_playlist: 426 | print("\nNo new, verifiable songs were collected from Gemini after all attempts. Exiting.") 427 | exit() 428 | 429 | print(f"\nCollected {len(final_uris_for_new_playlist)} final new songs for '{NEW_PLAYLIST_NAME}'.") 430 | 431 | # 5. Save to "New Gemini Recommendations" (replacing) 432 | new_playlist_id = get_or_create_playlist_id(sp_client, user_id, NEW_PLAYLIST_NAME) 433 | if new_playlist_id: 434 | print(f"\nUpdating playlist '{NEW_PLAYLIST_NAME}' by replacing items...") 435 | if update_playlist_items(sp_client, new_playlist_id, final_uris_for_new_playlist, replace=True): 436 | playlist_url_new = sp_client.playlist(new_playlist_id)['external_urls']['spotify'] 437 | print(f"Successfully updated '{NEW_PLAYLIST_NAME}'. URL: {playlist_url_new}") 438 | else: 439 | print(f"Could not create or find playlist '{NEW_PLAYLIST_NAME}'.") 440 | 441 | # 6. Add these songs to "All Gemini Recommendations" (appending) 442 | if all_recs_playlist_id and final_details_for_all_recs_update: # Use details to get URIs 443 | uris_to_add_to_all_recs = [song['uri'] for song in final_details_for_all_recs_update] 444 | print(f"\nAppending {len(uris_to_add_to_all_recs)} songs to '{ALL_RECS_PLAYLIST_NAME}'...") 445 | if update_playlist_items(sp_client, all_recs_playlist_id, uris_to_add_to_all_recs, replace=False): 446 | playlist_url_all = sp_client.playlist(all_recs_playlist_id)['external_urls']['spotify'] 447 | print(f"Successfully appended songs to '{ALL_RECS_PLAYLIST_NAME}'. URL: {playlist_url_all}") 448 | elif not all_recs_playlist_id: 449 | print(f"Could not find or create playlist '{ALL_RECS_PLAYLIST_NAME}' to append songs.") 450 | 451 | print("\nScript finished.") 452 | -------------------------------------------------------------------------------- /wiki/GPU-rental.md: -------------------------------------------------------------------------------- 1 | # GPU Rental Services 2 | 3 | - [Lambda](https://lambdalabs.com/) - If you purchase the book, you can claim $150 in free GPU credits on Lambda. Contact the author at author@thelmbook.com 4 | - [Vast.ai](https://vast.ai/) 5 | - [RunPod](https://www.runpod.io/) 6 | - [Hyperstack](https://www.hyperstack.cloud/) 7 | - [Vultr](https://www.vultr.com/) 8 | -------------------------------------------------------------------------------- /wiki/MoE.md: -------------------------------------------------------------------------------- 1 | # Mixture of Experts 2 | 3 | - [Mixture of Experts Explained](https://huggingface.co/blog/moe) 4 | - [Create Mixtures of Experts with MergeKit](https://huggingface.co/blog/mlabonne/frankenmoe) 5 | - [Applying Mixture of Experts in LLM Architectures](https://developer.nvidia.com/blog/applying-mixture-of-experts-in-llm-architectures/) 6 | - [Mixture-of-Experts in the Era of LLMs: A New Odyssey](https://moe-tutorial.github.io/) -------------------------------------------------------------------------------- /wiki/PyTorch.md: -------------------------------------------------------------------------------- 1 | # Learning PyTorch 2 | 3 | [gimmick:theme](spacelab) 4 | 5 | - [Introduction to PyTorch](https://pytorch.org/tutorials/beginner/basics/intro.html) 6 | - [Learn PyTorch for deep learning in a day](https://www.youtube.com/watch?v=Z_ikDlimN6A) 7 | - [Learn PyTorch for deep learning: Zero to mastery book](https://www.learnpytorch.io/) 8 | - [PyTorch recipes](https://pytorch.org/tutorials/recipes/recipes_index.html) -------------------------------------------------------------------------------- /wiki/VLM.md: -------------------------------------------------------------------------------- 1 | # Vision Language Models 2 | 3 | - [Vision Language Models Explained](https://huggingface.co/blog/vlms) 4 | - [A Dive into Vision-Language Models](https://huggingface.co/blog/vision_language_pretraining) 5 | - [SmolVLM - small yet mighty Vision Language Model](https://huggingface.co/blog/smolvlm) 6 | - [Introduction to Vision Language Models](https://huggingface.co/learn/computer-vision-course/en/unit4/multimodal-models/vlm-intro) 7 | - [An Introduction to Vision-Language Modeling](https://arxiv.org/pdf/2405.17247v1) 8 | - [Vison Language Models from Scratch](https://sachinruk.github.io/blog/2024-08-11-vision-language-models.html) -------------------------------------------------------------------------------- /wiki/alignment.md: -------------------------------------------------------------------------------- 1 | # Alignment 2 | 3 | ## RLHF 4 | 5 | - [Reinforcement Learning from Human Feedback](https://icml.cc/media/icml-2023/Slides/21554.pdf) 6 | - [A curated list of reinforcement learning with human feedback resources](https://github.com/opendilab/awesome-RLHF) 7 | 8 | ## DPO 9 | 10 | - [Preference Tuning LLMs with Direct Preference Optimization Methods](https://huggingface.co/blog/pref-tuning) 11 | - [Fine-tune Llama 2 with DPO](https://huggingface.co/blog/dpo-trl) 12 | - [RLHF in 2024 with DPO & Hugging Face](https://www.philschmid.de/dpo-align-llms-in-2024-with-trl) 13 | - [Unveiling the Hidden Reward System in Language Models: A Dive into DPO](https://allam.vercel.app/post/dpo/) -------------------------------------------------------------------------------- /wiki/colabs.md: -------------------------------------------------------------------------------- 1 | # Colab Notebooks 2 | 3 | - [2.3. Byte-Pair Encoding](https://github.com/aburkov/theLMbook/blob/main/byte_pair_encoding.ipynb) 4 | - [2.5. Count-Based Language Model](https://github.com/aburkov/theLMbook/blob/main/count_language_model.ipynb) 5 | - [3.6. Training an RNN Language Model](https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb) 6 | - [4.9. Transformer in Python](https://github.com/aburkov/theLMbook/blob/main/news_decoder_language_model.ipynb) 7 | - [5.3.1. Baseline Emotion Classifier](https://github.com/aburkov/theLMbook/blob/main/emotion_classifier_LR.ipynb) 8 | - [5.3.2. Emotion Generation](https://github.com/aburkov/theLMbook/blob/main/emotion_GPT2_as_text_generator.ipynb) 9 | - [5.3.3. Finetuning to Follow Instructions](https://github.com/aburkov/theLMbook/blob/main/instruct_GPT2.ipynb) 10 | - [5.4.4. Penalties](https://github.com/aburkov/theLMbook/blob/main/sampling_method.ipynb) 11 | - [5.5.2. Parameter-Efficient Finetuning (PEFT)](https://github.com/aburkov/theLMbook/blob/main/emotion_GPT2_as_text_generator_LoRA.ipynb) 12 | - [5.6. LLM as a Classifier](https://github.com/aburkov/theLMbook/blob/main/emotion_GPT2_as_classifier.ipynb) -------------------------------------------------------------------------------- /wiki/compression.md: -------------------------------------------------------------------------------- 1 | # Model Compression 2 | 3 | ## Pruning 4 | 5 | - [Pruning Tutorial](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) 6 | 7 | ## Distillation 8 | 9 | - [Knowledge Distillation Tutorial](https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html) 10 | - [Distilling Llama3.1 8B into Llama3.2 1B using Knowledge Distillation](https://pytorch.org/torchtune/0.3/tutorials/llama_kd_tutorial.html) 11 | - [Distilling Llama3.1 8B into 1B in torchtune](https://pytorch.org/blog/llama-into-torchtune/) 12 | 13 | ## Quantization 14 | 15 | - [Quantization](https://pytorch.org/docs/stable/quantization.html) 16 | - [Practical Quantization in PyTorch](https://pytorch.org/blog/quantization-in-practice/) 17 | - [Quantization](https://huggingface.co/docs/optimum/en/concept_guides/quantization) 18 | - [Quantization](https://huggingface.co/docs/transformers/en/quantization/overview) 19 | - [Quantize 🤗 Transformers models](https://huggingface.co/docs/transformers/v4.27.0/en/main_classes/quantization) 20 | - [Quantization-Aware Training for Large Language Models with PyTorch](https://pytorch.org/blog/quantization-aware-training/) 21 | - [Introduction to Quantization cooked in 🤗 with 💗🧑‍🍳](https://huggingface.co/blog/merve/quantization) -------------------------------------------------------------------------------- /wiki/corrections.md: -------------------------------------------------------------------------------- 1 | ## 4.6. Residual Connection 2 | 3 | **The text reads:** "The illustration depicts an **encoding** block with residual connections" 4 | **The text should read:** "The illustration depicts a **decoder** block with residual connections" 5 | -------------------------------------------------------------------------------- /wiki/deployment.md: -------------------------------------------------------------------------------- 1 | # Deployment 2 | 3 | - [Executive guide on secure LLM deployment](https://www.run.ai/blog/executive-guide-on-secure-llm-deployment) 4 | - [LLMOps: MLOps for large language models](https://www.giskard.ai/knowledge/llmops-mlops-for-large-language-models) 5 | - [A complete guide to LLMOps for machine learning](https://www.truefoundry.com/blog/llmops-mastering-the-art-of-managing-large-language-models-challenges-best-practices-and-future-trends) 6 | - [The ultimate guide to deploying large language models safely and securely](https://www.lakera.ai/blog/how-to-deploy-an-llm) 7 | - [Best practices for large language model deployment](https://arize.com/blog-course/large-language-model-llm-deployment/) 8 | -------------------------------------------------------------------------------- /wiki/distributed.md: -------------------------------------------------------------------------------- 1 | # Distributed Training 2 | 3 | - [Distributed and Parallel Training Tutorials](https://pytorch.org/tutorials/distributed/home.html) 4 | - [Efficient Training on Multiple GPUs](https://huggingface.co/docs/transformers/en/perf_train_gpu_many) 5 | - [Distributed training with 🤗 Accelerate](https://huggingface.co/docs/transformers/en/accelerate) -------------------------------------------------------------------------------- /wiki/embeddings.md: -------------------------------------------------------------------------------- 1 | # Embeddings 2 | 3 | ## Word Embeddings 4 | 5 | - [A Deep Dive into NLP Tokenization and Encoding with Word and Sentence Embeddings](https://datajenius.com/2022/03/13/a-deep-dive-into-nlp-tokenization-encoding-word-embeddings-sentence-embeddings-word2vec-bert/) 6 | - [What Are Word Embeddings for Text?](https://machinelearningmastery.com/what-are-word-embeddings/) 7 | - [BERT Word Embeddings Tutorial](https://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/) 8 | 9 | ## Document Embeddings 10 | 11 | - [Introduction to Sentence Embeddings](https://osanseviero.github.io/hackerllama/blog/posts/sentence_embeddings/) 12 | - [Sentence Transformer](https://www.sbert.net/docs/quickstart.html) 13 | - [Text and Code Embeddings by Contrastive Pre-Training](https://arxiv.org/pdf/2201.10005) 14 | - [How to Train a Custom LLM Embedding Model](https://dagshub.com/blog/how-to-train-a-custom-llm-embedding-model/) 15 | - [Demystifying Embeddings 101: The Foundation of Large Language Models](https://datasciencedojo.com/blog/embeddings-and-llm/) -------------------------------------------------------------------------------- /wiki/encoder-decoder.md: -------------------------------------------------------------------------------- 1 | # Encoder-Decoder Architecture 2 | 3 | - [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 4 | - [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/pdf/1910.13461) 5 | - [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer) 6 | - [Fast Neural Machine Translation in C++](https://marian-nmt.github.io/) 7 | - [Scaling Instruction-Finetuned Language Models](https://arxiv.org/abs/2210.11416) 8 | - [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) 9 | - [PaLM: Scaling Language Modeling with Pathways](https://arxiv.org/abs/2204.02311) 10 | - [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 11 | - [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) 12 | - [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) 13 | -------------------------------------------------------------------------------- /wiki/encoder.md: -------------------------------------------------------------------------------- 1 | # Encoder Architecture 2 | 3 | ## BERT 4 | 5 | - [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) 6 | - [BERT 101](https://huggingface.co/blog/bert-101) 7 | - [How to Code BERT Using PyTorch – Tutorial With Examples](https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial) 8 | - [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) 9 | 10 | ## RoBERTa 11 | 12 | - [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) 13 | - [Training RoBERTa from scratch - the missing guide](https://zablo.net/blog/post/training-roberta-from-scratch-the-missing-guide-polish-language-model/) 14 | -------------------------------------------------------------------------------- /wiki/evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluation Methods 2 | 3 | - [Evaluating Large Language Models: A Comprehensive Survey](https://arxiv.org/pdf/2310.19736) 4 | - [Challenges in Language Model Evaluations](https://lm-evaluation-challenges.github.io/) 5 | - [Holistic Evaluation of Language Models (HELM)](https://crfm.stanford.edu/helm/) 6 | - [Elo Uncovered: Robustness and Best Practices in Language Model Evaluation](https://arxiv.org/pdf/2311.17295) 7 | - [A Survey on Evaluation of Large Language Models](https://dl.acm.org/doi/pdf/10.1145/3641289) 8 | - [ASystematic Evaluation of Large Language Models of Code](https://dl.acm.org/doi/pdf/10.1145/3520312.3534862) 9 | - [Evaluating Large Language Models Trained on Code](https://arxiv.org/pdf/2107.03374) 10 | - [Is Your Code Generated by ChatGPT Really Correct? Rigorous Evaluation of Large Language Models for Code Generation](https://proceedings.neurips.cc/paper_files/paper/2023/file/43e9d647ccd3e4b7b5baab53f0368686-Paper-Conference.pdf) 11 | - [Elo vs Bradley-Terry model](https://www.keiruaprod.fr/blog/2021/06/02/elo-vs-bradley-terry-model.html) 12 | - [Chatbot Arena - New models & Elo system update](https://blog.lmarena.ai/blog/2023/leaderboard-elo-update/) 13 | - [Does Style Matter?](https://blog.lmarena.ai/blog/2024/style-control/) 14 | - [Statistical Extensions of the Bradley-Terry and Elo Models](https://blog.lmarena.ai/blog/2024/extended-arena/) 15 | - [Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena](https://arxiv.org/abs/2306.05685) 16 | - [Chatbot Arena Leaderboard Calculation (Bradley-Terry model)](https://colab.research.google.com/drive/1_X0OmMCMxZzGyCXwTl9ruZo3cTB1HaNi) 17 | - [Lecture 24 — The Bradley-Terry model](https://web.stanford.edu/class/archive/stats/stats200/stats200.1172/Lecture24.pdf) 18 | - [Bradley–Terry Model](https://real-statistics.com/reliability/bradley-terry-model/) 19 | - [Code: How Bradley Terry Model Works](https://www.kaggle.com/code/shaz13/code-how-bradley-terry-model-works) 20 | - [A Gentle Introduction to the Bootstrap Method](https://machinelearningmastery.com/a-gentle-introduction-to-the-bootstrap-method/) -------------------------------------------------------------------------------- /wiki/function-calling.md: -------------------------------------------------------------------------------- 1 | # Function Calling 2 | 3 | - [Introduction to function calling](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling) 4 | - [Function Calling](https://huggingface.co/docs/hugs/en/guides/function-calling) 5 | - [Function Calling with Open-Source LLMs](https://www.bentoml.com/blog/function-calling-with-open-source-llms) 6 | - [Beyond the Leaderboard: Unpacking Function Calling Evaluation](https://www.databricks.com/blog/unpacking-function-calling-eval) 7 | - [What is LLM Function Calling and How Does it Work?](https://quiq.com/blog/llm-function-calling/) 8 | - [Function Calling with Granite Tutorial](https://www.ibm.com/think/tutorials/granite-function-calling) -------------------------------------------------------------------------------- /wiki/index.md: -------------------------------------------------------------------------------- 1 | # The Hundred-Page Language Models Book's Wiki 2 | 3 | [gimmick:theme](spacelab) 4 | 5 | ## Getting Started 6 | 7 | - [PyTorch tutorials](PyTorch.md) 8 | - [Math fundamentals](math.md) 9 | - [GPU-enabled notebook services](notebook-services.md) 10 | - [GPU rental services](GPU-rental.md) 11 | 12 | ## Extended Chapters 13 | 14 | * [Derivation for 1.7 Gradient Descent](https://www.dropbox.com/scl/fi/zpnwrmhfatnoyy2sepucd/chapter_1_extra_1.pdf?rlkey=bj28oknku9ofs81nl59nv2wwu&dl=0) 15 | * [Derivation for 2.6. Evaluating Language Models](https://www.dropbox.com/scl/fi/9i1r13h06jevahdhez2dg/chapter_6_extra_1.pdf?rlkey=r9723j6dz2g3zo36um1fjytp4&dl=0) 16 | * [Extra Chapter A: Convolutional Neural Network](https://www.dropbox.com/scl/fi/ytm50ol12mv5sq0fvu741/chapter_A.pdf?rlkey=greb4k4j335qbtm5o7qju8p9g&dl=0) 17 | 18 | ## Code 19 | 20 | * [Python scripts](scripts.md) 21 | * [Colab notebooks](colabs.md) 22 | 23 | ## Engineering 24 | 25 | * [Online finetuning services](online-finetuning.md) 26 | * [Deployment](deployment.md) 27 | * [Inference cost and speed optimization](inference.md) 28 | * [Distributed training](distributed.md) 29 | * [Preventing overfitting](overfitting.md) 30 | 31 | ## Language Model 32 | 33 | * [Evaluation methods](evaluation.md) 34 | * [Prompt engineering](prompting.md) 35 | * [Function calling](function-calling.md) 36 | 37 | ## Advanced Topics 38 | 39 | * [Scaling laws](scaling.md) 40 | * [Mixture of experts](MoE.md) 41 | * [Model merging](merging.md) 42 | * [Model compression](compression.md) 43 | * [Preference-based alignment](alignment.md) 44 | * [Security](security.md) 45 | * [Vision language models](VLM.md) 46 | 47 | ## Additional Reading 48 | 49 | * [Embeddings](embeddings.md) 50 | * [Tokenization methods](tokenization.md) 51 | * [Encoder architecture](encoder.md) 52 | * [Encoder-decoder architecture](encoder-decoder.md) 53 | * [Non-Transformer architectures](non-transformer.md) 54 | 55 | ## Corrections 56 | 57 | * [Corrections](corrections.md) 58 | -------------------------------------------------------------------------------- /wiki/inference.md: -------------------------------------------------------------------------------- 1 | # Inference Cost and Speed Optimization 2 | 3 | ## Distillation 4 | 5 | - [LLM distillation demystified: a complete guide](https://snorkel.ai/blog/llm-distillation-demystified-a-complete-guide/) 6 | - [Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes](https://arxiv.org/pdf/2305.02301) 7 | - [MiniLLM: Knowledge Distillation of Large Language Models](https://arxiv.org/pdf/2306.08543) 8 | - [A Survey on Knowledge Distillation of Large Language Models](https://arxiv.org/pdf/2402.13116) 9 | - [Knowledge distillation: Teaching LLM's with synthetic data](https://wandb.ai/byyoung3/ML_NEWS3/reports/Knowledge-distillation-Teaching-LLM-s-with-synthetic-data--Vmlldzo5MTMyMzA2) 10 | 11 | ## Pruning 12 | 13 | - [The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks](https://arxiv.org/abs/1803.03635) 14 | - [LLM-Pruner: On the Structural Pruning of Large Language Models](https://arxiv.org/abs/2305.11627) 15 | - [A Simple and Effective Pruning Approach for Large Language Models](https://arxiv.org/abs/2306.11695) 16 | - [The Truth is in There: Improving Reasoning in Language Models with Layer-Selective Rank Reduction](https://arxiv.org/pdf/2312.13558) 17 | - [Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning](https://arxiv.org/abs/2310.06694) 18 | 19 | ## Speculative Decoding 20 | 21 | - [A Hitchhiker's Guide to Speculative Decoding](https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/) 22 | - [Fast Inference from Transformers via Speculative Decoding](https://arxiv.org/abs/2211.17192) 23 | - [Accelerating Large Language Model Decoding with Speculative Sampling](https://arxiv.org/abs/2302.01318) 24 | - [Inference with Reference: Lossless Acceleration of Large Language Models](https://arxiv.org/pdf/2304.04487) 25 | - [Predictive Pipelined Decoding: A Compute-Latency Trade-off for Exact LLM Decoding](https://arxiv.org/pdf/2307.05908) 26 | - [Accelerating LLM Inference with Staged Speculative Decoding](https://arxiv.org/pdf/2308.04623) 27 | - [Looking back at speculative decoding](https://research.google/blog/looking-back-at-speculative-decoding/) 28 | - [SpecTr: Fast Speculative Decoding via Optimal Transport](https://openreview.net/pdf?id=SdYHLTCC5J) 29 | 30 | ## Post-Training Quantization 31 | 32 | - [Quantization](https://huggingface.co/docs/optimum/en/concept_guides/quantization) 33 | - [Top LLM quantization methods and their impact on model quality](https://www.deepchecks.com/top-llm-quantization-methods-impact-on-model-quality/) 34 | - [Doing more with less: LLM quantization](https://www.redhat.com/en/blog/doing-more-less-llm-quantization-part-2) 35 | - [SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models](https://proceedings.mlr.press/v202/xiao23c/xiao23c.pdf) 36 | - [A Comprehensive Study on Post-Training Quantization for Large Language Models](https://cli99.com/pdf/ds-w4a16-23.pdf) 37 | - [LoftQ: LoRA-Fine-Tuning-Aware Quantization for Large Language Models](https://arxiv.org/abs/2310.08659) 38 | - [QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2309.14717) 39 | - [RPTQ: Reorder-based Post-training Quantization for Large Language Models](https://arxiv.org/pdf/2304.01089) -------------------------------------------------------------------------------- /wiki/math.md: -------------------------------------------------------------------------------- 1 | # Math Fundamentals 2 | 3 | ## Linear Algebra 4 | 5 | - [3Blue1Brown's linear algebra collection](https://www.3blue1brown.com/topics/linear-algebra) 6 | - [Linear Algebra MIT open course](https://ocw.mit.edu/courses/18-06-linear-algebra-spring-2010/) 7 | - [Linear Algebra - As an Introduction to Abstract Mathematics textbook](https://www.math.ucdavis.edu/%7Eanne/linear_algebra/) 8 | 9 | ## Calculus 10 | 11 | - [3Blue1Brown's essence of calculus collection](https://www.youtube.com/playlist?list=PLZHQObOWTQDMsr9K-rj53DwVRMYO3t5Yr) 12 | - [Calculus Online Textbook](https://ocw.mit.edu/courses/res-18-001-calculus-fall-2023/pages/textbook/) 13 | - [Calculus 1 by Professor Leonard](https://www.youtube.com/playlist?list=PLF797E961509B4EB5) 14 | - [CALCULUS by Michel van Biezen](https://www.youtube.com/playlist?list=PLX2gX-ftPVXWMrapS-ROUEKCxv5lpA5zh) 15 | - [Paul's online notes](https://tutorial.math.lamar.edu/) 16 | 17 | ## Probability and Statistics 18 | 19 | - [3Blue1Brown's probability collection](https://www.3blue1brown.com/topics/probability) 20 | - [Introduction to Probability and Statistics MIT open course](https://ocw.mit.edu/courses/18-05-introduction-to-probability-and-statistics-spring-2022/) 21 | - [A Modern Introduction to Probability and Statistics](https://cis.temple.edu/~latecki/Courses/CIS2033-Spring13/Modern_intro_probability_statistics_Dekking05.pdf) -------------------------------------------------------------------------------- /wiki/merging.md: -------------------------------------------------------------------------------- 1 | # Model Merging 2 | 3 | - [Model merging](https://huggingface.co/docs/peft/en/developer_guides/model_merging) 4 | - [Merge Large Language Models with mergekit](https://huggingface.co/blog/mlabonne/merge-models) 5 | - [An Introduction to Model Merging for LLMs](https://developer.nvidia.com/blog/an-introduction-to-model-merging-for-llms/) -------------------------------------------------------------------------------- /wiki/non-transformer.md: -------------------------------------------------------------------------------- 1 | # Non-Transformer Architectures 2 | 3 | - [Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling](https://arxiv.org/abs/2406.07522) 4 | - [xLSTM: Extended Long Short-Term Memory](https://arxiv.org/abs/2405.04517) 5 | - [Were RNNs All We Needed?](https://arxiv.org/html/2410.01201v1) 6 | - [RWKV: Reinventing RNNs for the Transformer Era](https://arxiv.org/abs/2305.13048) 7 | - [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) 8 | - [Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention](https://arxiv.org/abs/2006.16236) 9 | - [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) 10 | - [Hyena Hierarchy: Towards Larger Convolutional Language Models](https://arxiv.org/abs/2302.10866) 11 | - [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/abs/2307.08621) 12 | - [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) 13 | - [Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models](https://arxiv.org/abs/2402.19427) 14 | - [Recurrent Memory Transformer](https://arxiv.org/abs/2207.06881) 15 | -------------------------------------------------------------------------------- /wiki/notebook-services.md: -------------------------------------------------------------------------------- 1 | # GPU-Enabled Notebook Services 2 | 3 | ## Free 4 | 5 | - [Google Colab](https://colab.research.google.com/) 6 | - [Lightning.ai](https://lightning.ai/) 7 | 8 | ## Paid 9 | 10 | - [Google Colab Pro](https://colab.research.google.com/) 11 | - [Lightning.ai Pro](https://lightning.ai/) 12 | - [Amazon SageMaker](https://aws.amazon.com/sagemaker/) 13 | - [Deepnote](https://deepnote.com/) -------------------------------------------------------------------------------- /wiki/online-finetuning.md: -------------------------------------------------------------------------------- 1 | # Online finetuning services 2 | 3 | - [Together.ai](https://www.together.ai/) 4 | - [Fireworks.ai](https://fireworks.ai/) 5 | - [OpenAI](https://platform.openai.com) -------------------------------------------------------------------------------- /wiki/overfitting.md: -------------------------------------------------------------------------------- 1 | # Preventing Overfitting 2 | 3 | - [The Hundred-Page Machine Learning Book](https://themlbook.com/wiki/doku.php) 4 | - [How to Avoid Overfitting in Deep Learning Neural Networks](https://machinelearningmastery.com/introduction-to-regularization-to-reduce-overfitting-and-improve-generalization-error/) 5 | - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf) 6 | - [The Theory Behind Overfitting, Cross Validation, Regularization, Bagging, and Boosting: Tutorial](https://arxiv.org/pdf/1905.12787) -------------------------------------------------------------------------------- /wiki/prompting.md: -------------------------------------------------------------------------------- 1 | # Prompt Engineering 2 | 3 | - [LLM Prompt Engineering Simplified Book](https://llmnanban.akmmusai.pro/Book/LLM-Prompt-Engineering-Simplified-Book/) 4 | - [Best practices for prompt engineering with the OpenAI API](https://help.openai.com/en/articles/6654000-best-practices-for-prompt-engineering-with-the-openai-api) 5 | - [Prompt Engineering Best Practices: Tips, Tricks, and Tools](https://www.digitalocean.com/resources/articles/prompt-engineering-best-practices) 6 | - [Tips to enhance your prompt-engineering abilities](https://cloud.google.com/blog/products/application-development/five-best-practices-for-prompt-engineering) 7 | - [Prompt Engineering Guide](https://www.promptingguide.ai/) 8 | - [Prompt Engineering Guide](https://learnprompting.org/docs/introduction) 9 | - [Unleashing the potential of prompt engineering in Large Language Models: a comprehensive review](https://arxiv.org/pdf/2310.14735) -------------------------------------------------------------------------------- /wiki/scaling.md: -------------------------------------------------------------------------------- 1 | # Neural Scaling Laws 2 | 3 | - [Scaling Laws for LLMs: From GPT-3 to o3](https://cameronrwolfe.substack.com/p/llm-scaling-laws) 4 | - [Deep Learning Scaling is Predictable, Empirically](https://arxiv.org/abs/1712.00409) 5 | - [Scaling Laws for Neural Language Models](https://arxiv.org/abs/2001.08361) 6 | - [Training Compute-Optimal Large Language Models](https://arxiv.org/abs/2203.15556) -------------------------------------------------------------------------------- /wiki/scripts.md: -------------------------------------------------------------------------------- 1 | # Python Scripts 2 | 3 | - [1.2. Model - Quadratic loss](https://github.com/aburkov/theLMbook/blob/main/quadratic_loss.py) -------------------------------------------------------------------------------- /wiki/security.md: -------------------------------------------------------------------------------- 1 | # Security 2 | 3 | - [Common prompt injection attacks](https://docs.aws.amazon.com/prescriptive-guidance/latest/llm-prompt-engineering-best-practices/common-attacks.html) 4 | - [Systematically Analyzing Prompt Injection Vulnerabilities in Diverse LLM Architectures](https://arxiv.org/abs/2410.23308) 5 | - [Prompt injection](https://learn.snyk.io/lesson/prompt-injection/) 6 | - [Deceptive Delight: Jailbreak LLMs Through Camouflage and Distraction](https://unit42.paloaltonetworks.com/jailbreak-llms-through-camouflage-distraction/) 7 | - [Awesome-Jailbreak-on-LLMs](https://github.com/yueliu1999/Awesome-Jailbreak-on-LLMs) -------------------------------------------------------------------------------- /wiki/test.md: -------------------------------------------------------------------------------- 1 | # Test 2 | 3 | [gimmick: math]() 4 | 5 | Problem with & in a bmatrix: 6 | 7 | $$ 8 | \mathbf{A} \stackrel{\text{def}}{=} \begin{bmatrix} 9 | a_{1,1} & a_{1,2} & \cdots & a_{1,n} \\ 10 | a_{2,1} & a_{2,2} & \cdots & a_{2,n} \\ 11 | \vdots & \vdots & \ddots & \vdots \\ 12 | a_{m,1} & a_{m,2} & \cdots & a_{m,n} 13 | \end{bmatrix} 14 | $$ 15 | 16 | Problem with & in an array: 17 | 18 | $$ 19 | \begin{array}{ccc} \frac{\partial \text{l}_i}{\partial \tilde{y}_i}, & \frac{\partial \tilde{y}_i}{\partial z_i}, &\text{ and } & \frac{\partial z_i}{\partial w^{(j)}}\end{array}. 20 | $$ 21 | 22 | Problem with & in aligned: 23 | 24 | $$ 25 | \begin{aligned} 26 | J(2.58,-91.76) &= \frac{(2.58\cdot150 - 91.76 - 200)^2}{3} + \frac{(2.58\cdot200-91.76-600)^2}{3} \\ 27 | &+ \frac{(2.58\cdot260-91.76 - 500)^2}{3} = 15403.19 28 | \end{aligned} 29 | $$ 30 | -------------------------------------------------------------------------------- /wiki/tokenization.md: -------------------------------------------------------------------------------- 1 | # Tokenization 2 | 3 | - [WordPiece tokenization](https://huggingface.co/learn/nlp-course/en/chapter6/6) 4 | - [Summary of the tokenizers](https://huggingface.co/docs/transformers/en/tokenizer_summary) 5 | - [Building a tokenizer, block by block](https://huggingface.co/learn/nlp-course/en/chapter6/8) 6 | - [Fast WordPiece Tokenization](https://arxiv.org/pdf/2012.15524) --------------------------------------------------------------------------------