├── .gitignore ├── README.md ├── RESOURCES.md ├── Slides.odp ├── Slides.pdf ├── chat.py ├── colab └── HowToTrainOnColab.ipynb ├── data ├── DummyData.txt ├── WhatsAppTemplate.md └── supervised_fine_tuning.json ├── images ├── back_prop_exploding_gradients.svg ├── back_prop_vanishing_gradients.svg ├── big_bird_inside_one_head.svg ├── bo_dmagh_logo.svg ├── chunked_local_attention_masked_attention_scores_matrix_window_size_3.svg ├── conversation_1.png ├── conversation_2.png ├── conversation_template.svg ├── conversational_format.svg ├── course_thumbnail .png ├── dilated_local_attention_masked_attention_scores_matrix_window_size_3.svg ├── forward_pass.svg ├── get_dataset_from_hugging_face.png ├── global_attention.svg ├── global_plus_local_attention_masked_attention_scores_matrix_window_size_3.svg ├── gqa_design.svg ├── gqa_inside_attention_layer.svg ├── linear_attention_inside_attention_layer.svg ├── local_attention_inside_one_head.svg ├── local_attention_masked_attention_scores_matrix_window_size_3.svg ├── lora_vs_full_fine_tuning.svg ├── masking.svg ├── mha_attention_scores_matrix.svg ├── mha_design.svg ├── mha_inside_one_head.svg ├── mha_masked_attention_scores_matrix.svg ├── mha_mqa_comparison.svg ├── mla_inside_one_head.svg ├── mqa_design.svg ├── mqa_inside_one_head.svg ├── networks_being_fine_tuned.svg ├── neural_network.svg ├── parts_of_the_transformer_architecture_phase_0.svg ├── parts_of_the_transformer_architecture_phase_1.svg ├── parts_of_the_transformer_architecture_phase_2.svg ├── parts_of_the_transformer_architecture_phase_3.svg ├── parts_of_the_transformer_architecture_phase_4.svg ├── parts_of_the_transformer_architecture_phase_5.svg ├── peft_fine_tuning.svg ├── post_normalization.svg ├── pre_normalization.svg ├── qa_format.svg ├── random_attention.svg ├── relative_positional_encoding_visualized.svg ├── rope_preserved_relative_positions.svg ├── rope_rotation_explanation.svg ├── rotary_positional_encoding_visualized.svg ├── sinusoidal_positional_encoding_smooth.svg ├── sparse_attention_masked.svg ├── transformer_step_1.svg ├── transformer_step_2.svg ├── transformer_step_3.svg └── underfitting_overfitting_loss_comparison.svg ├── minbpe ├── __init__.py ├── base.py ├── basic.py ├── gpt4.py └── regex.py ├── notebooks ├── 1_DataCleaning.ipynb ├── 2_BytePairEncoding.ipynb ├── 3_TransformerModel.ipynb ├── 4_1_ModelTrainingAllBatches.ipynb ├── 4_2_ModelTrainingRandomBatches.ipynb ├── 5_FineTuningDataset.ipynb ├── 6_1_FineTuningNoContext.ipynb ├── 6_2_FineTuningWithContext.ipynb ├── 7_ParameterEfficientFineTuning.ipynb ├── 8_1_LetsScaleHuggingFaceDataset.ipynb ├── 8_2_LetsScaleTokenizer.ipynb ├── 8_3_LetsScaleEncoding.ipynb ├── 8_4_LetsScalePreTraining.ipynb ├── 8_5_1_LetsScaleFineTuningQA.ipynb ├── 8_5_2_LetsScaleFineTuningConversationMasking.ipynb ├── 8_5_2_LetsScaleFineTuningConversationNoMasking.ipynb ├── 9_1_1_ImprovingTransformerAbsolutePositionalEncoding.ipynb ├── 9_1_2_ImprovingTransformerRotaryPositionalEncoding.ipynb ├── 9_1_3_ImprovingTransformerSinusoidalPositionalEncoding.ipynb ├── 9_1_4_ImprovingTransformerRelativePositionalEncoding.ipynb ├── 9_1_5_ImprovingTransformerNoPositionalEncoding.ipynb ├── 9_2_1_ImprovingTransformerLocalAttention.ipynb ├── 9_2_2_ImprovingTransformerMultiQueryAttention.ipynb ├── 9_2_3_ImprovingTransformerGroupedQueryAttention.ipynb ├── 9_2_4_ImprovingTransformerLinearAttention.ipynb ├── 9_2_5_ImprovingTransformerBigBirdAttention.ipynb ├── 9_2_6_ImprovingTransformerMultiHeadLatentAttention.ipynb ├── 9_3_1_ImprovingTransformerGeLUActivationFuncion.ipynb ├── 9_3_2_ImprovingTransformerSwiGLUActivationFuncion.ipynb ├── 9_4_1_ImprovingTransformerRMSNorm.ipynb ├── 9_4_2_ImprovingTransformerPostNormalization.ipynb ├── 9_4_3_ImprovingTransformerPostNormalizationNoDropout.ipynb ├── 9_5_1_ImprovingTransformerBestModelPhase1.ipynb ├── 9_5_2_ImprovingTransformerBestModelPhase2.ipynb ├── 9_5_3_ImprovingTransformerBestModelPhase3.ipynb └── 9_5_4_ImprovingTransformerBestModelPhase4.ipynb ├── requirements.txt ├── scripts ├── before_and_after_normalization.py ├── generate_sinusoidal_positional_encoding_image.py └── plot_activation_functions.py └── transformer ├── __init__.py ├── best_model_phase_1.py ├── best_model_phase_2.py ├── best_model_phase_3.py ├── best_model_phase_4.py ├── lora.py ├── model.py ├── model_big_bird.py ├── model_grouped_query_attention.py ├── model_linear_attention.py ├── model_local_attention.py ├── model_multi_head_latent_attention.py ├── model_multi_head_latent_attention_gelu.py ├── model_multi_head_latent_attention_swiglu.py ├── model_multi_head_latent_attention_swiglu_post_normalization.py ├── model_multi_head_latent_attention_swiglu_post_normalization_no_dropout.py ├── model_multi_head_latent_attention_swiglu_rms_norm.py ├── model_multi_query_attention.py ├── model_no_positional_encoding.py ├── model_relative_positional_encoding.py ├── model_rope.py └── model_sinusoidal_positional_encoding.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | data/private/ 177 | output/ 178 | checkpoints/ 179 | AtlaSetCombined.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Train your language model course 2 | 3 | We’ve all used Large Language Models (LLMs) and been amazed by what they can do. I wanted to understand how these models are built, so I created this course. 4 | 5 | I’m from Morocco and speak Moroccan Darija. Most LLMs today understand it a little, but they can't hold proper conversations in Darija. So, as a challenge, I decided to train a language model from scratch using my own WhatsApp conversations in Darija. 6 | 7 | I've made a YouTube playlist documenting every step. You can watch it at your own pace. If anything is unclear, feel free to open an issue in this repository. I’ll be happy to help! 8 | 9 | [![course_thumbnail](./images/course_thumbnail%20.png)](https://www.youtube.com/playlist?list=PLMSb3cZXtIfptKdr56uEdiM5pR6HDMoUX) 10 | 11 | ## What is in this repository? 12 | 13 | - `notebooks/`: Jupyter notebooks for each step in the pipeline. 14 | - `Slides.odp`: Presentation slides used in the YouTube series. 15 | - `data/`: Sample data and templates. 16 | - `transformer/`: Scripts for the Transformer and LoRA implementations. 17 | - `minbpe/`: A tokenizer from [Andrej' Karpathy's repo](https://github.com/karpathy/minbpe), since it's not available as a package. 18 | 19 | ## Setup 20 | 21 | To get started, install [Python](https://www.python.org/downloads/) and the required dependencies by running: 22 | 23 | ```bash 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ## What you will learn? 28 | 29 | This course covers: 30 | 31 | 1. Extracting data from WhatsApp. 32 | 2. Tokenizing text using the BPE algorithm. 33 | 3. Understanding Transformer models. 34 | 4. Pre-training the model. 35 | 5. Creating a fine-tuning dataset. 36 | 6. Fine-tuning the model (Instruction tuning and LoRA fine-tuning). 37 | 38 | Each topic has a video in the [YouTube playlist](https://www.youtube.com/playlist?list=PLMSb3cZXtIfptKdr56uEdiM5pR6HDMoUX) and a Jupyter notebook in the [`notebooks/`](./notebooks/) folder. 39 | 40 | ## My experience 41 | 42 | Like you, I had never trained a language model before. After following the steps in this course, I built my own 42M parameter model called **BoDmagh**. In Moroccan Darija, **BoDmagh** can mean **someone with a brain**. The word **Bo + [noun]** means something is deep inside you, so **BoDmagh** can also mean a smart person. 43 | 44 | Here are two example conversations I had with the model: 45 | 46 | ![conversation_1](./images/conversation_1.png) 47 | ![conversation_2](./images/conversation_2.png) 48 | 49 | The [Supervised Fine-Tuning (SFT) dataset](https://github.com/ImadSaddik/BoDmaghDataset) I created really helped improve the model’s ability to hold a conversation. 50 | 51 | ### Limitations 52 | 53 | The model doesn’t always give correct answers. If I try to discuss many different topics, it struggles. This is likely because both the model and the SFT dataset are small. Training on more data and using a larger model could improve the results. I might explore this in the future. 54 | 55 | ## Contributions 56 | 57 | We welcome contributions! If you find any issues or have suggestions for improvements, please open an issue or submit a pull request. 58 | 59 | ## Need help? 60 | 61 | You can reach me through: 62 | 63 | - **YouTube** – Leave a comment on the videos. 64 | - **LinkedIn** – [Connect with me](https://www.linkedin.com/in/imadsaddik/). 65 | - **Email** – [simad3647@gmail.com](mailto:simad3647@gmail.com). 66 | -------------------------------------------------------------------------------- /RESOURCES.md: -------------------------------------------------------------------------------- 1 | # Useful resources 2 | 3 | Here are some helpful resources I used while making this course. I’m sharing them with you in case they help. Enjoy! 4 | 5 | ## Videos 6 | 7 | - [Deep Dive into LLMs like ChatGPT (Theory)](https://www.youtube.com/watch?v=7xTGNNLPyMI&t=7652s) 8 | - [Building GPT from Scratch (Practice 1)](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=1s) 9 | - [Building a GPT Tokenizer (Practice 2)](https://www.youtube.com/watch?v=zduSFxRajkE&t=440s) 10 | - [Reproducing GPT-2 (Practice 3)](https://www.youtube.com/watch?v=l8pRSuU81PU) 11 | 12 | ## GitHub repositories 13 | 14 | - [Train a language model using WhatsApp group chats](https://github.com/bernhard-pfann/lad-gpt) 15 | - [Create an AI version of yourself using WhatsApp chats](https://github.com/kinggongzilla/ai-clone-whatsapp) 16 | - [Extract WhatsApp key/database without root access](https://github.com/YuvrajRaghuvanshiS/WhatsApp-Key-Database-Extractor) 17 | 18 | ## Articles 19 | 20 | - [Train a language model on your WhatsApp chats](https://towardsdatascience.com/build-a-language-model-on-your-whatsapp-chats-31264a9ced90/) 21 | - [Fine-tune an LLM to create your digital twin](https://medium.com/better-programming/unleash-your-digital-twin-how-fine-tuning-llm-can-create-your-perfect-doppelganger-b5913e7dda2e) 22 | - [Fine-tuning LLMs using QLoRA](https://dassum.medium.com/fine-tune-large-language-model-llm-on-a-custom-dataset-with-qlora-fb60abdeba07) 23 | - [Guide to fine-tuning large language models](https://www.datacamp.com/tutorial/fine-tuning-large-language-models) 24 | - [Understanding LoRA and DoRA for model tuning](https://magazine.sebastianraschka.com/p/lora-and-dora-from-scratch) 25 | - [Step-by-step guide to fine-tune LLMs with LoRA](https://medium.com/@manindersingh120996/practical-guide-to-fine-tune-llms-with-lora-c835a99d7593) 26 | - [Fine-tuning LLMs for multi-turn conversations](https://www.together.ai/blog/fine-tuning-llms-for-multi-turn-conversations-a-technical-deep-dive#:~:text=Fine-tuning%20LLMs%20for%20multi-turn%20conversations%20requires%20careful%20attention,while%20managing%20computational%20resources%20efficiently.) 27 | - [Fine-tuning mixtral 7bx8 with LoRA](https://medium.com/@prakharsaxena11111/finetuning-mixtral-7bx8-6071b0ebf114) 28 | - [Positional Encoding Explained: A Deep Dive into Transformer PE](https://medium.com/thedeephub/positional-encoding-explained-a-deep-dive-into-transformer-pe-65cfe8cfe10b) 29 | - [You could have designed state of the art positional encoding](https://huggingface.co/blog/designing-positional-encoding) 30 | - [Relative Positional Encoding](https://jaketae.github.io/study/relative-positional-encoding/) 31 | - [What is grouped query attention (GQA)?](https://www.ibm.com/think/topics/grouped-query-attention) 32 | - [Linear Attention Is All You Need](https://medium.com/data-science/linear-attention-is-all-you-need-5fa9c845c1b5) 33 | - [Attention Variations — MQA vs GQA vs MHA vs MLA](https://verticalserve.medium.com/group-query-attention-58283b337c65) 34 | - [Understanding Multi-Head Latent Attention](https://planetbanatt.net/articles/mla.html) 35 | - [DeepSeek's Multi-Head Latent Attention](https://liorsinai.github.io/machine-learning/2025/02/22/mla.html#multi-head-latent-attention) 36 | - [Sliding Window Attention](https://medium.com/@manojkumal/sliding-window-attention-565f963a1ffd) 37 | - [Can LLMs learn from a single example?](https://www.fast.ai/posts/2023-09-04-learning-jumps/) 38 | - [Normalization Layer Placement (Pre-LN vs Post-LN)](https://apxml.com/courses/how-to-build-a-large-language-model/chapter-11-scaling-transformers-architectural-choices/normalization-layer-placement) 39 | - [Batch Normalization, Layer Normalization and Root Mean Square Layer Normalization: A Comprehensive Guide with Python Implementations](https://afterhoursresearch.hashnode.dev/batch-normalization-layer-normalization-and-root-mean-square-layer-normalization-a-comprehensive-guide-with-python-implementations) 40 | - [Deep Dive into Deep Learning: Layers, RMSNorm, and Batch Normalization](https://2020machinelearning.medium.com/deep-dive-into-deep-learning-layers-rmsnorm-and-batch-normalization-b2423552be9f) 41 | - [Exploring SwiGLU : The Activation Function Powering Modern LLMs](https://medium.com/@s_boudefel/exploring-swiglu-the-activation-function-powering-modern-llms-9697f88221e7) 42 | - [All the Activation Functions (and a history of deep learning)](https://dublog.net/blog/all-the-activations/) 43 | 44 | ## Reddit discussions 45 | 46 | - [Finetuned Llama 2-7B using WhatsApp chats](https://www.reddit.com/r/LocalLLaMA/comments/18ny05c/finetuned_llama_27b_on_my_whatsapp_chats/) 47 | - [How to train your model](https://www.reddit.com/r/Oobabooga/comments/19480dr/how_to_train_your_dra_model/?share_id=FandRNmK84MItOJYIynap&utm_medium=android_app&utm_name=androidcss&utm_source=share&utm_term=1) 48 | - [Exporting full WhatsApp chat history](https://www.reddit.com/r/DataHoarder/comments/a7c0yq/full_whatsapp_chat_export_40000_messages/) 49 | - [Normalization in transformers](https://www.reddit.com/r/MachineLearning/comments/1ecict8/d_normalization_in_transformers/) 50 | 51 | ## Notebooks 52 | 53 | - [Unsloth AI Notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks) 54 | - [LLMs-from-Scratch: Chapter 6 Notebook](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb?utm_source=substack&utm_medium=email) 55 | - [LoRA Implementation from Scratch](https://www.kaggle.com/code/aisuko/lora-from-scratch) 56 | 57 | ## Scripts 58 | 59 | - [Multi-Head Latent Attention (MLA) Implementation](https://github.com/ambisinister/mla-experiments/blob/main/modeling/attention/mla.py) 60 | - [RMSNorm Implementation](https://github.com/meta-llama/llama/blob/main/llama/model.py#L34-L77) 61 | 62 | ## Research papers 63 | 64 | - [DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model](https://arxiv.org/pdf/2405.04434) 65 | - [Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention](https://arxiv.org/pdf/2006.16236) 66 | - [LINEAR ATTENTION IS (MAYBE) ALL YOU NEED (TO UNDERSTAND TRANSFORMER OPTIMIZATION)](https://arxiv.org/pdf/2310.01082) 67 | - [Attention Is All You Need](https://arxiv.org/pdf/1706.03762) 68 | - [Longformer: The Long-Document Transformer](https://arxiv.org/pdf/2004.05150) 69 | - [Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/pdf/1508.04025) 70 | - [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/pdf/2305.13245) 71 | - [Big Bird: Transformers for Longer Sequences](https://arxiv.org/pdf/2007.14062) 72 | - [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467) 73 | - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) 74 | - [On Layer Normalization in the Transformer Architecture](https://arxiv.org/pdf/2002.04745) 75 | - [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467) 76 | - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167) 77 | - [Layer Normalization](https://arxiv.org/pdf/1607.06450) 78 | - [Dropout Reduces Underfitting](https://arxiv.org/pdf/2303.01500) 79 | - [GLU Variants Improve Transformer](https://arxiv.org/pdf/2002.05202v1) 80 | -------------------------------------------------------------------------------- /Slides.odp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImadSaddik/Train_Your_Language_Model_Course/e9e8e01b46e1376406bd1c3a0e1692b64ba660ea/Slides.odp -------------------------------------------------------------------------------- /Slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImadSaddik/Train_Your_Language_Model_Course/e9e8e01b46e1376406bd1c3a0e1692b64ba660ea/Slides.pdf -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from minbpe import RegexTokenizer 4 | from transformer.model import GPTLanguageModel 5 | 6 | TOKENS = { 7 | "start": "<|start_turn|>", 8 | "end": "<|end_turn|>", 9 | "separator": "<|separator|>", 10 | "eos": "<|endoftext|>" 11 | } 12 | 13 | 14 | def get_vocab_size(tokenizer: RegexTokenizer) -> int: 15 | vocab = tokenizer.vocab 16 | special_tokens = tokenizer.special_tokens 17 | 18 | return len(vocab) + len(special_tokens) 19 | 20 | 21 | def get_input_tokens(turns: list[dict], tokenizer: RegexTokenizer, device: str) -> torch.Tensor: 22 | formatted_input = "".join( 23 | f"{TOKENS['start']}{turn['role']}{TOKENS['separator']}{turn['content']}{TOKENS['end']}" 24 | for turn in turns 25 | ) 26 | formatted_input += f"{TOKENS['start']}assistant{TOKENS['separator']}" 27 | input_tokens = tokenizer.encode(formatted_input, allowed_special="all") 28 | return torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0).to(device) 29 | 30 | 31 | def get_generated_message( 32 | input_tokens: torch.Tensor, 33 | model: GPTLanguageModel, 34 | tokenizer: RegexTokenizer, 35 | block_size: int 36 | ) -> str: 37 | model.eval() 38 | model_answer = "" 39 | while True: 40 | try: 41 | output_tokens = model.advanced_generation( 42 | input_tokens=input_tokens, max_new_tokens=1, temperature=0.9, top_k=50, top_p=None 43 | ) 44 | last_generated_token = output_tokens[0, -1].item() 45 | 46 | if last_generated_token in {tokenizer.special_tokens["<|endoftext|>"], tokenizer.special_tokens["<|end_turn|>"]}: 47 | break 48 | 49 | input_tokens = torch.cat( 50 | (input_tokens, output_tokens[:, -1:]), dim=1) 51 | model_answer += tokenizer.decode([last_generated_token]) 52 | 53 | if input_tokens.size(1) > block_size: 54 | break 55 | except Exception: 56 | continue 57 | return model_answer.strip() 58 | 59 | 60 | def get_system_message() -> str: 61 | return "سميتك بودماغ صاوبك عماد الصاديق باش تعاون الناس بالإجابة على الأسئلة ديالهوم. حاول تكون ضريف معاهم، جاوبهم بلطف، او الى شي حد بانلك معصب اولا كيخسر فالهضرة حاول أنك تهدنو او متعصبش عليه." 62 | 63 | 64 | def get_model( 65 | block_size: int, 66 | device: str, 67 | vocab_size: int, 68 | n_embd: int, 69 | n_head: int, 70 | n_layer: int, 71 | dropout: float, 72 | ignore_index: int, 73 | ) -> GPTLanguageModel: 74 | return GPTLanguageModel( 75 | vocab_size=vocab_size, 76 | block_size=block_size, 77 | n_embd=n_embd, 78 | n_head=n_head, 79 | n_layer=n_layer, 80 | dropout=dropout, 81 | device=device, 82 | ignore_index=ignore_index, 83 | ).to(device) 84 | 85 | 86 | def load_checkpoint(model: GPTLanguageModel, checkpoint_path: str) -> GPTLanguageModel: 87 | checkpoint = torch.load(checkpoint_path, weights_only=True) 88 | model_state_dict = checkpoint["model_state_dict"] 89 | model.load_state_dict(model_state_dict) 90 | return model 91 | 92 | 93 | def get_tokenizer(tokenizer_path: str) -> RegexTokenizer: 94 | tokenizer = RegexTokenizer() 95 | tokenizer.load(model_file=tokenizer_path) 96 | return tokenizer 97 | 98 | 99 | if __name__ == "__main__": 100 | tokenizer = get_tokenizer("./output/tokenizer/darija_tokenizer.model") 101 | vocab_size = get_vocab_size(tokenizer) 102 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 103 | model = get_model( 104 | block_size=1024, 105 | device=device, 106 | vocab_size=vocab_size, 107 | n_embd=512, 108 | n_head=12, 109 | n_layer=8, 110 | dropout=0.2, 111 | ignore_index=tokenizer.special_tokens["<|padding|>"] 112 | ) 113 | 114 | checkpoint_path = "./output/fine_tuning/qa/base/run_2/checkpoint_50.pth" 115 | model = load_checkpoint(model, checkpoint_path=checkpoint_path) 116 | 117 | turns = [{"role": "system", "content": get_system_message()}] 118 | while True: 119 | user_message = input("You: ") 120 | if user_message.lower() == "quit": 121 | print("Goodbye!") 122 | break 123 | 124 | turns.append({"role": "user", "content": user_message}) 125 | input_tokens = get_input_tokens(turns, tokenizer, device) 126 | model_answer = get_generated_message( 127 | input_tokens, model, tokenizer, 1024) 128 | turns.append({"role": "assistant", "content": model_answer}) 129 | 130 | print(f"Assistant: {model_answer}\n") 131 | -------------------------------------------------------------------------------- /data/DummyData.txt: -------------------------------------------------------------------------------- 1 | 26/02/2025, 09:15 - Person 1: Hey 2 | 26/02/2025, 09:16 - Person 1: How are you? 3 | [26/02/2025, 9:18 AM] ~ Person 2: Hey! I’m good, what about you? 4 | [26/02/2025, 9:20 AM] ~ Person 1: I’m good too, just a bit tired 5 | 26/02/2025, 09:21 - Person 2: Long night? 6 | [26/02/2025, 9:22 AM] ~ Person 1: Yeah, had to finish some work 7 | 26/02/2025, 09:23 - Person 1: It took longer than expected 8 | 26/02/2025, 09:24 - Person 2: I get that, I slept late too 9 | [26/02/2025, 9:25 AM] ~ Person 1: What were you doing? 10 | [26/02/2025, 9:26 AM] ~ Person 2: Watching a show, then got stuck on YouTube 😂 11 | 26/02/2025, 09:27 - Person 1: Hahaha classic, which show? 12 | 26/02/2025, 09:29 - Person 2: Breaking Bad, ever watched it? 13 | [26/02/2025, 9:30 AM] ~ Person 1: Oh yeah, one of my favorites! 14 | 26/02/2025, 09:31 - Person 1: It just keeps getting better 15 | [26/02/2025, 9:32 AM] ~ Person 2: Same! I’m at season 4 now 16 | [26/02/2025, 9:33 AM] ~ Person 1: Oh man, things are about to get intense 17 | [26/02/2025, 9:34 AM] ~ Person 1: You’re in for a ride 18 | [26/02/2025, 9:35 AM] ~ Person 2: I know, I can’t wait 😬 19 | 26/02/2025, 09:36 - Person 1: Btw, are we still meeting up later? 20 | 26/02/2025, 09:37 - Person 2: Yeah, what time? 21 | [26/02/2025, 9:38 AM] ~ Person 1: How about 5 PM? 22 | [26/02/2025, 9:39 AM] ~ Person 1: Or do you prefer later? 23 | 26/02/2025, 09:40 - Person 2: 5 PM works for me 24 | 26/02/2025, 09:41 - Person 1: Cool! The usual café? 25 | 26/02/2025, 09:42 - Person 2: Sounds good, see you then! 26 | 26/02/2025, 09:46 - Person 2: Maybe, I’ll show you later 27 | 26/02/2025, 09:47 - Person 1: Sure, I’ll take a look 28 | 26/02/2025, 09:48 - Person 2: Appreciate it! 29 | 26/02/2025, 09:49 - Person 1: No problem 😎 30 | 26/02/2025, 09:50 - Person 2: By the way, did you see the new trailer? 31 | 26/02/2025, 09:51 - Person 1: Which one? 32 | 26/02/2025, 09:52 - Person 2: The new sci-fi movie, it looks amazing! 33 | 26/02/2025, 09:53 - Person 1: Oh yeah, I saw it! I can’t wait for that 34 | 26/02/2025, 09:54 - Person 1: We have to watch it together when it’s out 35 | 26/02/2025, 09:55 - Person 2: Definitely! 36 | 26/02/2025, 09:56 - Person 2: Also, did you hear about the new game release? 37 | 26/02/2025, 09:57 - Person 1: No, which one? 38 | 26/02/2025, 09:58 - Person 2: The new open-world RPG, the one we were waiting for 39 | 26/02/2025, 09:59 - Person 2: It’s finally dropping next month! 40 | 26/02/2025, 10:00 - Person 1: No way! I need to pre-order it 41 | 26/02/2025, 10:01 - Person 1: We should do a co-op session when it comes out 42 | 26/02/2025, 10:02 - Person 2: 100%, I’m in! 43 | 26/02/2025, 17:00 - Person 1: I’m here, where are you? 44 | 26/02/2025, 17:02 - Person 2: Just parking, give me a sec 45 | 26/02/2025, 17:03 - Person 1: No rush, I got us a table 46 | 26/02/2025, 17:05 - Person 2: Nice, I’m coming in 47 | 26/02/2025, 17:06 - Person 1: Cool, want a coffee? 48 | 26/02/2025, 17:07 - Person 2: Yes please, black as always 49 | 26/02/2025, 17:08 - Person 1: Got it! 50 | 26/02/2025, 17:10 - Person 2: So, what’s up? 51 | 26/02/2025, 17:12 - Person 1: Not much, just been busy with work 52 | 26/02/2025, 17:13 - Person 2: Same here, it never ends 😅 53 | 26/02/2025, 17:14 - Person 1: I know right? We need a vacation 54 | 26/02/2025, 17:15 - Person 2: 100%! Where should we go? 55 | 26/02/2025, 17:16 - Person 1: Somewhere with a beach maybe? 56 | 26/02/2025, 17:17 - Person 2: That sounds perfect 57 | 26/02/2025, 17:18 - Person 1: Let’s plan something soon 58 | 26/02/2025, 17:19 - Person 2: Definitely, we need a break 59 | 26/02/2025, 17:20 - Person 1: For real 😆 60 | 26/02/2025, 17:21 - Person 2: Anyway, about that project, here take a look 61 | 26/02/2025, 17:23 - Person 1: Hmm, yeah I see what’s missing 62 | 26/02/2025, 17:24 - Person 1: You need to adjust this part 63 | 26/02/2025, 17:25 - Person 2: Oh, I didn’t think of that! 64 | 26/02/2025, 17:26 - Person 1: Give it a try, might work 65 | 26/02/2025, 17:27 - Person 2: I will, thanks man! 66 | 26/02/2025, 17:28 - Person 1: No worries 67 | 26/02/2025, 17:30 - Person 2: Alright, ready to head out? 68 | 26/02/2025, 17:31 - Person 1: Yeah, let’s go! 69 | 26/02/2025, 17:32 - Person 2: I’ll drive 70 | 26/02/2025, 17:33 - Person 1: You sure? I can drive if you’re tired 71 | 26/02/2025, 17:34 - Person 2: Nah, I got this 72 | 26/02/2025, 17:35 - Person 1: Alright, let’s roll! 73 | -------------------------------------------------------------------------------- /data/WhatsAppTemplate.md: -------------------------------------------------------------------------------- 1 | # WhatsApp chat format 2 | 3 | When you export a WhatsApp chat, it follows this format: 4 | 5 | ```text 6 | DD/MM/YYYY, HH:MM - Person X: Message 7 | ``` 8 | 9 | If you export text only, media messages will appear like this: 10 | 11 | ```text 12 | DD/MM/YYYY, HH:MM - Person X: 13 | ``` 14 | 15 | Before running the [1_DataCleaning.ipynb](/notebooks/1_DataCleaning.ipynb) notebook, make sure your exported files match this format. 16 | -------------------------------------------------------------------------------- /images/back_prop_exploding_gradients.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | InputOutputHidden layers 163 | -------------------------------------------------------------------------------- /images/back_prop_vanishing_gradients.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | InputOutputHidden layers 163 | -------------------------------------------------------------------------------- /images/bo_dmagh_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 35 | 37 | 42 | 49 | 52 | 55 | 60 | 61 | 62 | BoDmagh 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /images/conversation_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImadSaddik/Train_Your_Language_Model_Course/e9e8e01b46e1376406bd1c3a0e1692b64ba660ea/images/conversation_1.png -------------------------------------------------------------------------------- /images/conversation_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImadSaddik/Train_Your_Language_Model_Course/e9e8e01b46e1376406bd1c3a0e1692b64ba660ea/images/conversation_2.png -------------------------------------------------------------------------------- /images/course_thumbnail .png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImadSaddik/Train_Your_Language_Model_Course/e9e8e01b46e1376406bd1c3a0e1692b64ba660ea/images/course_thumbnail .png -------------------------------------------------------------------------------- /images/forward_pass.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | InputOutputHidden layers 195 | -------------------------------------------------------------------------------- /images/get_dataset_from_hugging_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImadSaddik/Train_Your_Language_Model_Course/e9e8e01b46e1376406bd1c3a0e1692b64ba660ea/images/get_dataset_from_hugging_face.png -------------------------------------------------------------------------------- /images/mha_mqa_comparison.svg: -------------------------------------------------------------------------------- 1 | 2 | 16 | 34 | 41 | 43 | 46 | 49 | 55 | 56 | 59 | 65 | 66 | 69 | 75 | 76 | 77 | 80 | 83 | 84 | 89 | 94 | 99 | 104 | 109 | 114 | 119 | 124 | 129 | 134 | 138 | 142 | 146 | 150 | 154 | 158 | 162 | 166 | 170 | 174 | 178 | 182 | 187 | 192 | 100 198 | 200 204 | 500 210 | 1000 216 | 2000 222 | 0 228 | 5 234 | 10 240 | 15 246 | 20 252 | 254 | 257 | 258 | 267 | MHA 274 | 279 | 287 | MQA 294 | 299 | 307 | 317 | Comparison of MHA and MQA inference times 325 | Number of tokens to generate 332 | Inference time (s) 340 | x7 346 | x6.8 352 | x6 358 | x1.7 364 | x0.8 370 | 371 | -------------------------------------------------------------------------------- /images/mqa_design.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 36 | 38 | 50 | 62 | 74 | 86 | 98 | 110 | 111 | 115 | 122 | 130 | 138 | Q 149 | K 160 | V 171 | 179 | 187 | 195 | 203 | 211 | 219 | 221 | 229 | 237 | 245 | 253 | 261 | 269 | 270 | 271 | 272 | -------------------------------------------------------------------------------- /images/neural_network.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 199 | -------------------------------------------------------------------------------- /images/post_normalization.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | XAttention layerNormalization 185 | -------------------------------------------------------------------------------- /images/pre_normalization.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | XAttention layerNormalization 187 | -------------------------------------------------------------------------------- /images/rope_rotation_explanation.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ImadImadHiImadMy friend is MrθXY 214 | -------------------------------------------------------------------------------- /minbpe/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Tokenizer 2 | from .basic import BasicTokenizer 3 | from .regex import RegexTokenizer 4 | from .gpt4 import GPT4Tokenizer 5 | -------------------------------------------------------------------------------- /minbpe/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains the base Tokenizer class and a few common helper functions. 3 | The base class also contains the (common) save/load functionality. 4 | It would be possible to be a lot more strict about the interface and 5 | e.g. isolating all regex/pattern parts to the RegexTokenizer, but 6 | some concessions are made for simplicity. 7 | """ 8 | import unicodedata 9 | 10 | # ----------------------------------------------------------------------------- 11 | # a few helper functions useful for both BasicTokenizer and RegexTokenizer 12 | 13 | 14 | def get_stats(ids, counts=None): 15 | """ 16 | Given a list of integers, return a dictionary of counts of consecutive pairs 17 | Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1} 18 | Optionally allows to update an existing dictionary of counts 19 | """ 20 | counts = {} if counts is None else counts 21 | for pair in zip(ids, ids[1:]): # iterate consecutive elements 22 | counts[pair] = counts.get(pair, 0) + 1 23 | return counts 24 | 25 | 26 | def merge(ids, pair, idx): 27 | """ 28 | In the list of integers (ids), replace all consecutive occurrences 29 | of pair with the new integer token idx 30 | Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] 31 | """ 32 | newids = [] 33 | i = 0 34 | while i < len(ids): 35 | # if not at the very last position AND the pair matches, replace it 36 | if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]: 37 | newids.append(idx) 38 | i += 2 39 | else: 40 | newids.append(ids[i]) 41 | i += 1 42 | return newids 43 | 44 | # first two helper functions... 45 | 46 | 47 | def replace_control_characters(s: str) -> str: 48 | # we don't want to print control characters 49 | # which distort the output (e.g. \n or much worse) 50 | # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117 51 | # http://www.unicode.org/reports/tr44/#GC_Values_Table 52 | chars = [] 53 | for ch in s: 54 | if unicodedata.category(ch)[0] != "C": 55 | chars.append(ch) # this character is ok 56 | else: 57 | chars.append(f"\\u{ord(ch):04x}") # escape 58 | return "".join(chars) 59 | 60 | 61 | def render_token(t: bytes) -> str: 62 | # pretty print a token, escaping control characters 63 | s = t.decode('utf-8', errors='replace') 64 | s = replace_control_characters(s) 65 | return s 66 | 67 | # ----------------------------------------------------------------------------- 68 | # the base Tokenizer class 69 | 70 | 71 | class Tokenizer: 72 | """Base class for Tokenizers""" 73 | 74 | def __init__(self): 75 | # default: vocab size of 256 (all bytes), no merges, no patterns 76 | self.merges = {} # (int, int) -> int 77 | self.pattern = "" # str 78 | self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257} 79 | self.vocab = self._build_vocab() # int -> bytes 80 | 81 | def train(self, text, vocab_size, verbose=False): 82 | # Tokenizer can train a vocabulary of size vocab_size from text 83 | raise NotImplementedError 84 | 85 | def encode(self, text): 86 | # Tokenizer can encode a string into a list of integers 87 | raise NotImplementedError 88 | 89 | def decode(self, ids): 90 | # Tokenizer can decode a list of integers into a string 91 | raise NotImplementedError 92 | 93 | def _build_vocab(self): 94 | # vocab is simply and deterministically derived from merges 95 | vocab = {idx: bytes([idx]) for idx in range(256)} 96 | for (p0, p1), idx in self.merges.items(): 97 | vocab[idx] = vocab[p0] + vocab[p1] 98 | for special, idx in self.special_tokens.items(): 99 | vocab[idx] = special.encode("utf-8") 100 | return vocab 101 | 102 | def save(self, file_prefix): 103 | """ 104 | Saves two files: file_prefix.vocab and file_prefix.model 105 | This is inspired (but not equivalent to!) sentencepiece's model saving: 106 | - model file is the critical one, intended for load() 107 | - vocab file is just a pretty printed version for human inspection only 108 | """ 109 | # write the model: to be used in load() later 110 | model_file = file_prefix + ".model" 111 | with open(model_file, 'w') as f: 112 | # write the version, pattern and merges, that's all that's needed 113 | f.write("minbpe v1\n") 114 | f.write(f"{self.pattern}\n") 115 | # write the special tokens, first the number of them, then each one 116 | f.write(f"{len(self.special_tokens)}\n") 117 | for special, idx in self.special_tokens.items(): 118 | f.write(f"{special} {idx}\n") 119 | # the merges dict 120 | for idx1, idx2 in self.merges: 121 | f.write(f"{idx1} {idx2}\n") 122 | # write the vocab: for the human to look at 123 | vocab_file = file_prefix + ".vocab" 124 | inverted_merges = {idx: pair for pair, idx in self.merges.items()} 125 | with open(vocab_file, "w", encoding="utf-8") as f: 126 | for idx, token in self.vocab.items(): 127 | # note: many tokens may be partial utf-8 sequences 128 | # and cannot be decoded into valid strings. Here we're using 129 | # errors='replace' to replace them with the replacement char �. 130 | # this also means that we couldn't possibly use .vocab in load() 131 | # because decoding in this way is a lossy operation! 132 | s = render_token(token) 133 | # find the children of this token, if any 134 | if idx in inverted_merges: 135 | # if this token has children, render it nicely as a merge 136 | idx0, idx1 = inverted_merges[idx] 137 | s0 = render_token(self.vocab[idx0]) 138 | s1 = render_token(self.vocab[idx1]) 139 | f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") 140 | else: 141 | # otherwise this is leaf token, just print it 142 | # (this should just be the first 256 tokens, the bytes) 143 | f.write(f"[{s}] {idx}\n") 144 | 145 | def load(self, model_file): 146 | """Inverse of save() but only for the model file""" 147 | assert model_file.endswith(".model") 148 | # read the model file 149 | merges = {} 150 | special_tokens = {} 151 | idx = 256 152 | with open(model_file, 'r', encoding="utf-8") as f: 153 | # read the version 154 | version = f.readline().strip() 155 | assert version == "minbpe v1" 156 | # read the pattern 157 | self.pattern = f.readline().strip() 158 | # read the special tokens 159 | num_special = int(f.readline().strip()) 160 | for _ in range(num_special): 161 | special, special_idx = f.readline().strip().split() 162 | special_tokens[special] = int(special_idx) 163 | # read the merges 164 | for line in f: 165 | idx1, idx2 = map(int, line.split()) 166 | merges[(idx1, idx2)] = idx 167 | idx += 1 168 | self.merges = merges 169 | self.special_tokens = special_tokens 170 | self.vocab = self._build_vocab() 171 | -------------------------------------------------------------------------------- /minbpe/basic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal (byte-level) Byte Pair Encoding tokenizer. 3 | 4 | Algorithmically follows along the GPT tokenizer: 5 | https://github.com/openai/gpt-2/blob/master/src/encoder.py 6 | 7 | But: 8 | - Does not handle the regular expression splitting pattern. 9 | - Does not handle any special tokens. 10 | """ 11 | 12 | from tqdm import tqdm 13 | from .base import Tokenizer, get_stats, merge 14 | 15 | 16 | class BasicTokenizer(Tokenizer): 17 | 18 | def __init__(self): 19 | super().__init__() 20 | 21 | def train(self, text, vocab_size, verbose=False): 22 | assert vocab_size >= 256 23 | num_merges = vocab_size - 256 24 | 25 | # input text preprocessing 26 | text_bytes = text.encode("utf-8") # raw bytes 27 | ids = list(text_bytes) # list of integers in range 0..255 28 | 29 | # iteratively merge the most common pairs to create new tokens 30 | merges = {} # (int, int) -> int 31 | vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes 32 | for i in tqdm(range(num_merges), total=num_merges): 33 | # count up the number of times every consecutive pair appears 34 | stats = get_stats(ids) 35 | # find the pair with the highest count 36 | pair = max(stats, key=stats.get) 37 | # mint a new token: assign it the next available id 38 | idx = 256 + i 39 | # replace all occurrences of pair in ids with idx 40 | ids = merge(ids, pair, idx) 41 | # save the merge 42 | merges[pair] = idx 43 | vocab[idx] = vocab[pair[0]] + vocab[pair[1]] 44 | # prints 45 | if verbose: 46 | print( 47 | f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") 48 | 49 | # save class variables 50 | self.merges = merges # used in encode() 51 | self.vocab = vocab # used in decode() 52 | 53 | def decode(self, ids): 54 | # given ids (list of integers), return Python string 55 | text_bytes = b"".join(self.vocab[idx] for idx in ids) 56 | text = text_bytes.decode("utf-8", errors="replace") 57 | return text 58 | 59 | def encode(self, text): 60 | # given a string text, return the token ids 61 | text_bytes = text.encode("utf-8") # raw bytes 62 | ids = list(text_bytes) # list of integers in range 0..255 63 | while len(ids) >= 2: 64 | # find the pair with the lowest merge index 65 | stats = get_stats(ids) 66 | pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) 67 | # subtle: if there are no more merges available, the key will 68 | # result in an inf for every single pair, and the min will be 69 | # just the first pair in the list, arbitrarily 70 | # we can detect this terminating case by a membership check 71 | if pair not in self.merges: 72 | break # nothing else can be merged anymore 73 | # otherwise let's merge the best pair (lowest merge index) 74 | idx = self.merges[pair] 75 | ids = merge(ids, pair, idx) 76 | return ids 77 | -------------------------------------------------------------------------------- /minbpe/gpt4.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the GPT-4 Tokenizer as a light wrapper around the RegexTokenizer. 3 | Note that this is a pretrained tokenizer. By default and inside init(), it 4 | loads the pretrained tokenizer from the `cl100k_base` tokenizer of tiktoken. 5 | """ 6 | 7 | import tiktoken 8 | from .regex import RegexTokenizer 9 | 10 | 11 | def bpe(mergeable_ranks, token, max_rank): 12 | # helper function used in get_gpt4_merges() to reconstruct the merge forest 13 | parts = [bytes([b]) for b in token] 14 | while True: 15 | min_idx = None 16 | min_rank = None 17 | for i, pair in enumerate(zip(parts[:-1], parts[1:])): 18 | rank = mergeable_ranks.get(pair[0] + pair[1]) 19 | if rank is not None and (min_rank is None or rank < min_rank): 20 | min_idx = i 21 | min_rank = rank 22 | if min_rank is None or (max_rank is not None and min_rank >= max_rank): 23 | break 24 | assert min_idx is not None 25 | parts = parts[:min_idx] + [parts[min_idx] + 26 | parts[min_idx + 1]] + parts[min_idx + 2:] 27 | return parts 28 | 29 | 30 | def recover_merges(mergeable_ranks): 31 | # the `merges` are already the byte sequences in their merged state. 32 | # so we have to recover the original pairings. We can do this by doing 33 | # a small BPE training run on all the tokens, in their order. 34 | # also see https://github.com/openai/tiktoken/issues/60 35 | # also see https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306 36 | merges = {} 37 | for token, rank in mergeable_ranks.items(): 38 | if len(token) == 1: 39 | continue # skip raw bytes 40 | pair = tuple(bpe(mergeable_ranks, token, max_rank=rank)) 41 | assert len(pair) == 2 42 | # recover the integer ranks of the pair 43 | ix0 = mergeable_ranks[pair[0]] 44 | ix1 = mergeable_ranks[pair[1]] 45 | merges[(ix0, ix1)] = rank 46 | 47 | return merges 48 | 49 | 50 | GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" 51 | GPT4_SPECIAL_TOKENS = { 52 | '<|endoftext|>': 100257, 53 | '<|fim_prefix|>': 100258, 54 | '<|fim_middle|>': 100259, 55 | '<|fim_suffix|>': 100260, 56 | '<|endofprompt|>': 100276 57 | } 58 | 59 | 60 | class GPT4Tokenizer(RegexTokenizer): 61 | """Lightweight wrapper on RegexTokenizer that matches GPT-4's tokenizer.""" 62 | 63 | def __init__(self): 64 | super().__init__(pattern=GPT4_SPLIT_PATTERN) 65 | # get the official tokenizer and its merges 66 | enc = tiktoken.get_encoding("cl100k_base") 67 | mergeable_ranks = enc._mergeable_ranks 68 | # the merges are those of gpt4, but we have to recover them 69 | self.merges = recover_merges(mergeable_ranks) 70 | # reconstruct the vocab from the merges 71 | vocab = {idx: bytes([idx]) for idx in range(256)} 72 | for (p0, p1), idx in self.merges.items(): 73 | vocab[idx] = vocab[p0] + vocab[p1] 74 | self.vocab = vocab 75 | # now here is another tricky part. 76 | # for some reason, the tokens corresponding to individual bytes 77 | # are permuted in a different order. This is completely non-sensical 78 | # and probably historical, but therefore we have to deal with it here. 79 | self.byte_shuffle = { 80 | i: mergeable_ranks[bytes([i])] for i in range(256)} 81 | self.inverse_byte_shuffle = { 82 | v: k for k, v in self.byte_shuffle.items()} 83 | # finally register the special tokens 84 | self.register_special_tokens(GPT4_SPECIAL_TOKENS) 85 | 86 | def _encode_chunk(self, text_bytes): 87 | # before we start processing bytes, we have to permute them 88 | text_bytes = bytes(self.byte_shuffle[b] for b in text_bytes) 89 | ids = super()._encode_chunk(text_bytes) 90 | return ids 91 | 92 | def decode(self, ids): 93 | # we have to un-permute the bytes before we decode 94 | text_bytes = b"".join(self.vocab[idx] for idx in ids) 95 | text_bytes = bytes(self.inverse_byte_shuffle[b] for b in text_bytes) 96 | text = text_bytes.decode("utf-8", errors="replace") 97 | return text 98 | 99 | # this is a pretrained tokenizer, it is not intended to be trained 100 | def train(self, text, vocab_size, verbose=False): 101 | raise NotImplementedError 102 | 103 | # save/load would require some thought. 104 | # we'd have to change save/load of base to add support for byte_shuffle... 105 | # alternatively, we could move byte_shuffle to base class, but that would 106 | # mean that we're making ugly our beautiful Tokenizer just to support 107 | # the GPT-4 tokenizer and its weird historical quirks around byte_shuffle. 108 | def save(self, file_prefix): 109 | raise NotImplementedError("GPT4Tokenizer cannot be saved.") 110 | 111 | def load(self, model_file): 112 | raise NotImplementedError("GPT4Tokenizer cannot be loaded.") 113 | 114 | def save_vocab(self, vocab_file): 115 | # just for visualization purposes let's output the GPT-4 tokens 116 | # in the exact same format as the base class would. 117 | # simple run as: 118 | # python -c "from minbpe import GPT4Tokenizer; GPT4Tokenizer().save_vocab('gpt4.vocab')" 119 | from .base import render_token 120 | # build vocab being mindful of the byte shuffle 121 | vocab = {idx: bytes([self.inverse_byte_shuffle[idx]]) 122 | for idx in range(256)} 123 | for (p0, p1), idx in self.merges.items(): 124 | vocab[idx] = vocab[p0] + vocab[p1] 125 | # now merge the shuffled bytes and write to file 126 | inverted_merges = {idx: pair for pair, idx in self.merges.items()} 127 | with open(vocab_file, "w", encoding="utf-8") as f: 128 | for idx, token in vocab.items(): 129 | s = render_token(token) 130 | if idx in inverted_merges: 131 | idx0, idx1 = inverted_merges[idx] 132 | s0 = render_token(vocab[idx0]) 133 | s1 = render_token(vocab[idx1]) 134 | f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") 135 | else: 136 | f.write(f"[{s}] {idx}\n") 137 | -------------------------------------------------------------------------------- /minbpe/regex.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal (byte-level) Byte Pair Encoding tokenizer. 3 | 4 | Algorithmically follows along the GPT tokenizer: 5 | https://github.com/openai/gpt-2/blob/master/src/encoder.py 6 | 7 | Unlike BasicTokenizer: 8 | - RegexTokenizer handles an optional regex splitting pattern. 9 | - RegexTokenizer handles optional special tokens. 10 | """ 11 | 12 | import regex as re 13 | from tqdm import tqdm 14 | from .base import Tokenizer, get_stats, merge 15 | 16 | 17 | # the main GPT text split patterns, see 18 | # https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py 19 | GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" 20 | GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" 21 | 22 | 23 | class RegexTokenizer(Tokenizer): 24 | 25 | def __init__(self, pattern=None): 26 | """ 27 | - pattern: optional string to override the default (GPT-4 split pattern) 28 | - special_tokens: str -> int dictionary of special tokens 29 | example: {'<|endoftext|>': 100257} 30 | """ 31 | super().__init__() 32 | self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern 33 | self.compiled_pattern = re.compile(self.pattern) 34 | self.special_tokens = {} 35 | self.inverse_special_tokens = {} 36 | 37 | def train(self, text, vocab_size, verbose=False): 38 | assert vocab_size >= 256 39 | num_merges = vocab_size - 256 40 | 41 | # split the text up into text chunks 42 | text_chunks = re.findall(self.compiled_pattern, text) 43 | 44 | # input text preprocessing 45 | ids = [list(ch.encode("utf-8")) for ch in text_chunks] 46 | 47 | # iteratively merge the most common pairs to create new tokens 48 | merges = {} # (int, int) -> int 49 | vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes 50 | for i in tqdm(range(num_merges), total=num_merges): 51 | # count the number of times every consecutive pair appears 52 | stats = {} 53 | for chunk_ids in ids: 54 | # passing in stats will update it in place, adding up counts 55 | get_stats(chunk_ids, stats) 56 | # find the pair with the highest count 57 | pair = max(stats, key=stats.get) 58 | # mint a new token: assign it the next available id 59 | idx = 256 + i 60 | # replace all occurrences of pair in ids with idx 61 | ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] 62 | # save the merge 63 | merges[pair] = idx 64 | vocab[idx] = vocab[pair[0]] + vocab[pair[1]] 65 | # prints 66 | if verbose: 67 | print( 68 | f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") 69 | 70 | # save class variables 71 | self.merges = merges # used in encode() 72 | self.vocab = vocab # used in decode() 73 | 74 | def register_special_tokens(self, special_tokens): 75 | # special_tokens is a dictionary of str -> int 76 | # example: {"<|endoftext|>": 100257} 77 | self.special_tokens = special_tokens 78 | self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} 79 | 80 | def decode(self, ids): 81 | # given ids (list of integers), return Python string 82 | part_bytes = [] 83 | for idx in ids: 84 | if idx in self.vocab: 85 | part_bytes.append(self.vocab[idx]) 86 | elif idx in self.inverse_special_tokens: 87 | part_bytes.append( 88 | self.inverse_special_tokens[idx].encode("utf-8")) 89 | else: 90 | raise ValueError(f"invalid token id: {idx}") 91 | text_bytes = b"".join(part_bytes) 92 | text = text_bytes.decode("utf-8", errors="replace") 93 | return text 94 | 95 | def _encode_chunk(self, text_bytes): 96 | # return the token ids 97 | # let's begin. first, convert all bytes to integers in range 0..255 98 | ids = list(text_bytes) 99 | while len(ids) >= 2: 100 | # find the pair with the lowest merge index 101 | stats = get_stats(ids) 102 | pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) 103 | # subtle: if there are no more merges available, the key will 104 | # result in an inf for every single pair, and the min will be 105 | # just the first pair in the list, arbitrarily 106 | # we can detect this terminating case by a membership check 107 | if pair not in self.merges: 108 | break # nothing else can be merged anymore 109 | # otherwise let's merge the best pair (lowest merge index) 110 | idx = self.merges[pair] 111 | ids = merge(ids, pair, idx) 112 | return ids 113 | 114 | def encode_ordinary(self, text): 115 | """Encoding that ignores any special tokens.""" 116 | # split text into chunks of text by categories defined in regex pattern 117 | text_chunks = re.findall(self.compiled_pattern, text) 118 | # all chunks of text are encoded separately, then results are joined 119 | ids = [] 120 | for chunk in text_chunks: 121 | chunk_bytes = chunk.encode("utf-8") # raw bytes 122 | chunk_ids = self._encode_chunk(chunk_bytes) 123 | ids.extend(chunk_ids) 124 | return ids 125 | 126 | def encode(self, text, allowed_special="none_raise"): 127 | """ 128 | Unlike encode_ordinary, this function handles special tokens. 129 | allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens 130 | if none_raise, then an error is raised if any special token is encountered in text 131 | this is the default tiktoken behavior right now as well 132 | any other behavior is either annoying, or a major footgun 133 | """ 134 | # decode the user desire w.r.t. handling of special tokens 135 | special = None 136 | if allowed_special == "all": 137 | special = self.special_tokens 138 | elif allowed_special == "none": 139 | special = {} 140 | elif allowed_special == "none_raise": 141 | special = {} 142 | assert all(token not in text for token in self.special_tokens) 143 | elif isinstance(allowed_special, set): 144 | special = {k: v for k, v in self.special_tokens.items() 145 | if k in allowed_special} 146 | else: 147 | raise ValueError( 148 | f"allowed_special={allowed_special} not understood") 149 | if not special: 150 | # shortcut: if no special tokens, just use the ordinary encoding 151 | return self.encode_ordinary(text) 152 | # otherwise, we have to be careful with potential special tokens in text 153 | # we handle special tokens by splitting the text 154 | # based on the occurrence of any exact match with any of the special tokens 155 | # we can use re.split for this. note that surrounding the pattern with () 156 | # makes it into a capturing group, so the special tokens will be included 157 | special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")" 158 | special_chunks = re.split(special_pattern, text) 159 | # now all the special characters are separated from the rest of the text 160 | # all chunks of text are encoded separately, then results are joined 161 | ids = [] 162 | for part in special_chunks: 163 | if part in special: 164 | # this is a special token, encode it separately as a special case 165 | ids.append(special[part]) 166 | else: 167 | # this is an ordinary sequence, encode it normally 168 | ids.extend(self.encode_ordinary(part)) 169 | return ids 170 | -------------------------------------------------------------------------------- /notebooks/1_DataCleaning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Reading files" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "We read the `.txt` files line by line and apply the following filters:\n", 15 | "\n", 16 | "1. **Remove lines containing a WhatsApp encryption notice** \n", 17 | " - ❌ **Before:** `dd/mm/yyyy, hh:mm - Messages and calls are end-to-end encrypted. No one outside of this chat, not even WhatsApp, can read or listen to them. Tap to learn more.` \n", 18 | " - ✅ **After:** *(Removed)* \n", 19 | "\n", 20 | "2. **Remove lines with ``** \n", 21 | " - ❌ **Before:** `dd/mm/yyyy, hh:mm - Person: ` \n", 22 | " - ✅ **After:** *(Removed)* \n", 23 | "\n", 24 | "3. **Remove lines containing email addresses** \n", 25 | " - ❌ **Before:** `dd/mm/yyyy, hh:mm - Person: example@gmail.com` \n", 26 | " - ✅ **After:** *(Removed)* \n", 27 | "\n", 28 | "4. **Remove lines containing links** \n", 29 | " - ❌ **Before:** `dd/mm/yyyy, hh:mm - Person: https://www.example.com/` \n", 30 | " - ✅ **After:** *(Removed)* \n", 31 | "\n", 32 | "5. **Replace `` with an empty string** \n", 33 | " - ❌ **Before:** `dd/mm/yyyy, hh:mm - Person: hey, how are you? ` \n", 34 | " - ✅ **After:** `dd/mm/yyyy, hh:mm - Person: hey, how are you?`\n", 35 | "\n", 36 | "6. **Remove lines with the text `You deleted this message`** \n", 37 | " - ❌ **Before:** `dd/mm/yyyy, hh:mm - Person: You deleted this message` \n", 38 | " - ✅ **After:** *(Removed)* \n", 39 | "\n", 40 | "7. **Remove lines with the text `null`** \n", 41 | " - ❌ **Before:** `dd/mm/yyyy, hh:mm - Person: null` \n", 42 | " - ✅ **After:** *(Removed)* \n", 43 | "\n", 44 | "8. **Remove lines with the text `created group`** \n", 45 | " - ❌ **Before:** `dd/mm/yyyy, hh:mm - Person created group \"group name\"` \n", 46 | " - ✅ **After:** *(Removed)* \n", 47 | "\n", 48 | "9. **Remove lines with the text `added you`** \n", 49 | " - ❌ **Before:** `dd/mm/yyyy, hh:mm - Person added you` \n", 50 | " - ✅ **After:** *(Removed)* \n", 51 | "\n", 52 | "10. **Replace tagging (`@person`) with an empty string** \n", 53 | " - ❌ **Before:** `dd/mm/yyyy, hh:mm - Person: @person are you coming?` \n", 54 | " - ✅ **After:** `dd/mm/yyyy, hh:mm - Person: are you coming?` \n", 55 | "\n", 56 | "After filtering, we normalize the content:\n", 57 | "\n", 58 | "- **Replace narrow no-break spaces** (`\\u202F`) with a regular space (`\" \"`) — often found in iOS exports. \n", 59 | "- **Remove square brackets around timestamps** (iOS format): \n", 60 | " - ❌ `[dd/mm/yyyy, hh:mm AM/PM]` → ✅ `dd/mm/yyyy, hh:mm AM/PM` \n", 61 | "- **Strip invisible Unicode characters** like `\\u200E` (Left-to-Right Mark) and `\\u200F` (Right-to-Left Mark).\n", 62 | "\n", 63 | "These steps ensure reliable timestamp parsing and consistent regex behavior across both Android and iOS WhatsApp exports." 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "import re\n", 73 | "import pandas as pd\n", 74 | "\n", 75 | "\n", 76 | "def read_whatsapp_chat(file_path: str) -> pd.DataFrame:\n", 77 | " # Define filtering patterns\n", 78 | " encryption_message = \"Messages and calls are end-to-end encrypted. No one outside of this chat, not even WhatsApp, can read or listen to them. Tap to learn more.\"\n", 79 | " media_pattern = \"\"\n", 80 | " email_pattern = r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}'\n", 81 | " url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'\n", 82 | " edited_message = \"\"\n", 83 | " deleted_message = \"You deleted this message\"\n", 84 | " null_message = \"null\"\n", 85 | " created_group_message = \"created group\"\n", 86 | " added_you_to_group_message = \"added you\"\n", 87 | " tagging_pattern = r'@[\\w]+'\n", 88 | "\n", 89 | " with open(file_path, 'r', encoding='utf-8') as f:\n", 90 | " lines = f.readlines()\n", 91 | "\n", 92 | " # Apply filters to remove unwanted lines\n", 93 | " filtered_lines = []\n", 94 | " for line in lines:\n", 95 | " if (\n", 96 | " encryption_message not in line and\n", 97 | " deleted_message not in line and\n", 98 | " null_message != line.split(\" \")[-1] and\n", 99 | " media_pattern not in line and\n", 100 | " created_group_message not in line and\n", 101 | " added_you_to_group_message not in line and\n", 102 | " not re.search(email_pattern, line) and\n", 103 | " not re.search(url_pattern, line)\n", 104 | " ):\n", 105 | " line = line.replace(edited_message, \"\").strip()\n", 106 | " line = re.sub(tagging_pattern, \"\", line).strip()\n", 107 | " filtered_lines.append(line)\n", 108 | "\n", 109 | " # Normalize content:\n", 110 | " content = '\\n'.join(filtered_lines)\n", 111 | " # Replace narrow no-break space (iOS specific)\n", 112 | " content = content.replace('\\u202f', ' ')\n", 113 | " # Remove square brackets if they surround the timestamp (only for iOS)\n", 114 | " content = re.sub(\n", 115 | " r'\\[(\\d{1,2}/\\d{1,2}/\\d{2,4}, \\d{1,2}:\\d{2}(?::\\d{2})?\\s?[APap][Mm])\\]',\n", 116 | " r'\\1',\n", 117 | " content\n", 118 | " )\n", 119 | " # Remove LRM and RLM characters (Left-to-Right Mark and Right-to-Left Mark)\n", 120 | " content = content.replace('\\u200E', '').replace('\\u200F', '')\n", 121 | "\n", 122 | " # Updated regex pattern to match both iOS and Android WhatsApp exports.\n", 123 | " pattern = r'(\\d{1,2}/\\d{1,2}/\\d{2,4}, \\d{1,2}:\\d{2}(?::\\d{2})?(?:\\s?[APap][Mm])?)\\s?(?:-|\\~)?\\s?(.*?): (.*?)(?=\\n\\d{1,2}/\\d{1,2}/\\d{2,4}, \\d{1,2}:\\d{2}|$)'\n", 124 | " messages = re.findall(pattern, content, re.DOTALL)\n", 125 | " df = pd.DataFrame(messages, columns=['timestamp', 'sender', 'message'])\n", 126 | "\n", 127 | " timestamps = []\n", 128 | " for timestamp in df['timestamp']:\n", 129 | " try:\n", 130 | " timestamp = pd.to_datetime(\n", 131 | " timestamp, format='mixed', errors='coerce')\n", 132 | " except Exception as e:\n", 133 | " print(f\"Error parsing timestamp '{timestamp}': {e}\")\n", 134 | " timestamp = pd.NaT\n", 135 | " timestamps.append(timestamp)\n", 136 | "\n", 137 | " df['timestamp'] = timestamps\n", 138 | " return df" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "The `all_chats` dictionary holds the content of each file as a dataframe with three columns: `timestamp`, `sender`, and `message`. " 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "from pathlib import Path\n", 155 | "\n", 156 | "all_chats = {}\n", 157 | "data_directory = Path(\"../data/private\")\n", 158 | "for file in data_directory.glob('*.txt'):\n", 159 | " file_name = file.stem\n", 160 | " all_chats[file_name] = read_whatsapp_chat(file)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## Text sequence" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "The text should be merged into a single sequence to prepare it for the next step, where the BPE algorithm will be applied and the text will be encoded." 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "text_sequence = \"\"\n", 184 | "for file_name in all_chats.keys():\n", 185 | " text_sequence += \" \".join(all_chats[file_name]['message'].values)\n", 186 | "\n", 187 | "len(text_sequence)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "with open(\"../output/combined_text.txt\", \"w\", encoding=\"utf-8\") as f:\n", 197 | " f.write(text_sequence)" 198 | ] 199 | } 200 | ], 201 | "metadata": { 202 | "kernelspec": { 203 | "display_name": "vincent", 204 | "language": "python", 205 | "name": "python3" 206 | }, 207 | "language_info": { 208 | "codemirror_mode": { 209 | "name": "ipython", 210 | "version": 3 211 | }, 212 | "file_extension": ".py", 213 | "mimetype": "text/x-python", 214 | "name": "python", 215 | "nbconvert_exporter": "python", 216 | "pygments_lexer": "ipython3", 217 | "version": "3.11.10" 218 | } 219 | }, 220 | "nbformat": 4, 221 | "nbformat_minor": 2 222 | } 223 | -------------------------------------------------------------------------------- /notebooks/2_BytePairEncoding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Load the sequence" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "with open(\"../output/combined_text.txt\", \"r\", encoding=\"utf-8\") as f:\n", 17 | " text_sequence = f.read()\n", 18 | "\n", 19 | "len(text_sequence)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "## BPE algorithm" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "I am using the [minBPE](https://github.com/karpathy/minbpe) repository to tokenize the sequence of text." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 5, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import sys\n", 43 | "sys.path.append('..')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "Start by training the tokenizer on the text sequence that you saved in the previous notebook." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 6, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "from minbpe import BasicTokenizer\n", 60 | "\n", 61 | "tokenizer = BasicTokenizer()\n", 62 | "tokenizer.train(text_sequence, vocab_size=1024)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "Visualize the vocabulary." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "vocab = tokenizer.vocab\n", 79 | "vocab" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "Test the tokenizer." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "tokenizer.encode(\"Salam labas\")" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 9, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "'Salam labas'" 107 | ] 108 | }, 109 | "execution_count": 9, 110 | "metadata": {}, 111 | "output_type": "execute_result" 112 | } 113 | ], 114 | "source": [ 115 | "tokenizer.decode([702, 310, 346, 115])" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "Add special tokens to the vocabulary. These tokens are going to be used a lot in the fine-tuning step." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "max_vocab_id = list(tokenizer.vocab.keys())[-1]\n", 132 | "tokenizer.special_tokens = {\n", 133 | " \"<|startoftext|>\": max_vocab_id + 1,\n", 134 | " \"<|separator|>\": max_vocab_id + 2,\n", 135 | " \"<|endoftext|>\": max_vocab_id + 3,\n", 136 | " \"<|unk|>\": max_vocab_id + 4,\n", 137 | " \"<|padding|>\": max_vocab_id + 5,\n", 138 | "}" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "I have more than 618K tokens for training and validation. This is pretty good, but if you can add more, that would be even better." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "len(tokenizer.encode(text_sequence))" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "Save the tokenizer" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 12, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "tokenizer.save(file_prefix=\"../output/tokenizer/my_tokenizer\")" 171 | ] 172 | } 173 | ], 174 | "metadata": { 175 | "kernelspec": { 176 | "display_name": "vincent", 177 | "language": "python", 178 | "name": "python3" 179 | }, 180 | "language_info": { 181 | "codemirror_mode": { 182 | "name": "ipython", 183 | "version": 3 184 | }, 185 | "file_extension": ".py", 186 | "mimetype": "text/x-python", 187 | "name": "python", 188 | "nbconvert_exporter": "python", 189 | "pygments_lexer": "ipython3", 190 | "version": "3.11.10" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 2 195 | } 196 | -------------------------------------------------------------------------------- /notebooks/5_FineTuningDataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Read the file" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "file_path = \"../data/private/fine_tuning.txt\"\n", 17 | "with open(file_path, 'r', encoding='utf-8') as f:\n", 18 | " lines = f.readlines()\n", 19 | "\n", 20 | "len(lines)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "## Clean the conversation" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import re\n", 37 | "\n", 38 | "encryption_message = \"Messages and calls are end-to-end encrypted. No one outside of this chat, not even WhatsApp, can read or listen to them. Tap to learn more.\"\n", 39 | "media_pattern = \"\"\n", 40 | "email_pattern = r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}'\n", 41 | "url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'\n", 42 | "edited_message = \"\"\n", 43 | "deleted_message = \"You deleted this message\"\n", 44 | "null_message = \"null\"\n", 45 | "created_group_message = \"created group\"\n", 46 | "added_you_to_group_message = \"added you\"\n", 47 | "tagging_pattern = r'@[\\w]+'\n", 48 | "\n", 49 | "\n", 50 | "filtered_lines = []\n", 51 | "for line in lines:\n", 52 | " if (\n", 53 | " encryption_message not in line and\n", 54 | " deleted_message not in line and\n", 55 | " null_message != line.split(\" \")[-1] and\n", 56 | " media_pattern not in line and\n", 57 | " created_group_message not in line and\n", 58 | " added_you_to_group_message not in line and\n", 59 | " not re.search(email_pattern, line) and\n", 60 | " not re.search(url_pattern, line)\n", 61 | " ):\n", 62 | " line = line.replace(edited_message, \"\").strip()\n", 63 | " line = re.sub(tagging_pattern, \"\", line).strip()\n", 64 | " filtered_lines.append(line)\n", 65 | "\n", 66 | "pattern = r'(\\d{2}/\\d{2}/\\d{4}, \\d{2}:\\d{2}) - (.*?): (.*?)(?=\\n\\d{2}/\\d{2}/\\d{4}, \\d{2}:\\d{2} -|$)'\n", 67 | "content = '\\n'.join(filtered_lines)\n", 68 | "messages = re.findall(pattern, content, re.DOTALL)\n", 69 | "\n", 70 | "lines_removed = len(lines) - len(filtered_lines)\n", 71 | "print(f\"Lines removed: {lines_removed}\")" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "## Create the dataset" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "### 1. Group messages by sender" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "If a conversation is structured as follows: \n", 93 | "\n", 94 | "```\n", 95 | "User 1: Hey! \n", 96 | "User 1: How are you? \n", 97 | "User 2: I am fine \n", 98 | "User 2: And you? \n", 99 | "User 1: Good. \n", 100 | "```\n", 101 | "\n", 102 | "We want to transform it into: \n", 103 | "\n", 104 | "```\n", 105 | "User 1: Hey!\\nHow are you? \n", 106 | "User 2: I am fine\\nAnd you? \n", 107 | "User 1: Good \n", 108 | "```" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "grouped_messages = []\n", 118 | "\n", 119 | "for _, sender, message in messages:\n", 120 | " if grouped_messages and grouped_messages[-1][\"sender\"] == sender:\n", 121 | " grouped_messages[-1][\"message\"] += \"\\n\" + message\n", 122 | " else:\n", 123 | " grouped_messages.append({\n", 124 | " \"sender\": sender,\n", 125 | " \"message\": message\n", 126 | " })\n", 127 | "\n", 128 | "len(grouped_messages)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "### 2. Include special tokens" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "Each message follows this format: \n", 143 | "```\n", 144 | "<|startoftext|>Sender<|separator|>Message<|endoftext|>\n", 145 | "```" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# Define special tokens\n", 155 | "start_of_text_token = \"<|startoftext|>\"\n", 156 | "end_of_text_token = \"<|endoftext|>\"\n", 157 | "separator_token = \"<|separator|>\"\n", 158 | "\n", 159 | "fine_tuning_data = []\n", 160 | "\n", 161 | "for message in grouped_messages:\n", 162 | " sender = message[\"sender\"]\n", 163 | " message_text = message[\"message\"]\n", 164 | " input_sequence = f\"{start_of_text_token}{sender}{separator_token}{message_text}{end_of_text_token}\"\n", 165 | " fine_tuning_data.append(input_sequence)\n", 166 | "\n", 167 | "len(fine_tuning_data)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "### 3. Save the data" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 6, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "import json\n", 184 | "\n", 185 | "save_path = \"../output/fine_tuning/data/fine_tuning.json\"\n", 186 | "with open(save_path, 'w', encoding='utf-8') as f:\n", 187 | " json.dump(fine_tuning_data, f, ensure_ascii=False, indent=4)" 188 | ] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "vincent", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.11.10" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 2 212 | } 213 | -------------------------------------------------------------------------------- /notebooks/8_1_LetsScaleHuggingFaceDataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Downaload a dataset" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In some cases, you need to login to be able to download a dataset. Run this cell, if that is the case." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from huggingface_hub import login\n", 24 | "login()" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "On Huggingface, click on the copy button to get the name of the dataset." 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "![dataset_from_hugging_face](../images/get_dataset_from_hugging_face.png)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "And provide that name to `load_dataset`" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "from datasets import load_dataset\n", 55 | "\n", 56 | "dataset = load_dataset(\"atlasia/Atlaset\")" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Create the sequence of text" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "Merge the splits into a single text sequence. In [AtlaSet](https://huggingface.co/datasets/atlasia/Atlaset), there were only two splits (train and test), but if a validation set is included, add an extra loop to append its text to the data list as well." 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "from tqdm import tqdm\n", 80 | "\n", 81 | "\n", 82 | "data = []\n", 83 | "rows = dataset[\"train\"][\"text\"]\n", 84 | "for row in tqdm(rows):\n", 85 | " data.append(row)\n", 86 | "\n", 87 | "rows = dataset[\"test\"][\"text\"]\n", 88 | "for row in tqdm(rows):\n", 89 | " data.append(row)\n", 90 | "\n", 91 | "print(len(\" \".join(data)))" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "Finally store the sequence of text on disk." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "with open(\"../data/AtlaSetCombined.txt\", \"w\") as f:\n", 108 | " f.write(\" \".join(data))" 109 | ] 110 | } 111 | ], 112 | "metadata": { 113 | "kernelspec": { 114 | "display_name": "vincent", 115 | "language": "python", 116 | "name": "python3" 117 | }, 118 | "language_info": { 119 | "codemirror_mode": { 120 | "name": "ipython", 121 | "version": 3 122 | }, 123 | "file_extension": ".py", 124 | "mimetype": "text/x-python", 125 | "name": "python", 126 | "nbconvert_exporter": "python", 127 | "pygments_lexer": "ipython3", 128 | "version": "3.11.10" 129 | } 130 | }, 131 | "nbformat": 4, 132 | "nbformat_minor": 2 133 | } 134 | -------------------------------------------------------------------------------- /notebooks/8_2_LetsScaleTokenizer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Load the sequence" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "This time, we'll load a sample from the text sequence instead of the entire dataset to prevent excessive RAM usage. If the RAM is full, the BPE algorithm won't function properly due to a lack of available memory. \n", 15 | "\n", 16 | "Adjust the `number_of_characters_to_read` value to find the optimal setting for your system." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "with open(\"../data/AtlaSetCombined.txt\", \"r\") as f:\n", 26 | " number_of_characters_to_read = 10_000_000\n", 27 | " text_sequence = f.read(number_of_characters_to_read)\n", 28 | "\n", 29 | "len(text_sequence)" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## BPE algorithm" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "I am using the [minBPE](https://github.com/karpathy/minbpe) repository to tokenize the sequence of text." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "import sys\n", 53 | "sys.path.append('..')" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "Start by training the tokenizer on the text sequence that you saved in the previous notebook." 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "from minbpe import RegexTokenizer\n", 70 | "\n", 71 | "tokenizer = RegexTokenizer()\n", 72 | "tokenizer.train(text_sequence, vocab_size=16_384)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "Visualize the vocabulary." 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "vocab = tokenizer.vocab\n", 89 | "vocab" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "Test the tokenizer." 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "tokenizer.encode(\"Salam labas\")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "tokenizer.decode([83, 1813, 3363, 32, 7312, 3770, 115])" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "Add special tokens to the vocabulary. These tokens are going to be used a lot in the fine-tuning step." 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "max_vocab_id = list(tokenizer.vocab.keys())[-1]\n", 131 | "tokenizer.special_tokens = {\n", 132 | " \"<|startoftext|>\": max_vocab_id + 1,\n", 133 | " \"<|separator|>\": max_vocab_id + 2,\n", 134 | " \"<|endoftext|>\": max_vocab_id + 3,\n", 135 | " \"<|unk|>\": max_vocab_id + 4,\n", 136 | " \"<|padding|>\": max_vocab_id + 5\n", 137 | "}" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "Save the tokenizer" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "tokenizer.save(file_prefix=\"../output/tokenizer/darija_tokenizer\")" 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "vincent", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.11.10" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 2 178 | } 179 | -------------------------------------------------------------------------------- /notebooks/8_3_LetsScaleEncoding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Encoding the sequence of text" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "First train the tokenizer with this notebook [2_BytePairEncoding](./2_BytePairEncoding.ipynb) and save it to disk." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import sys\n", 24 | "sys.path.append('..')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "Load the new tokenizer." 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "from minbpe import RegexTokenizer\n", 41 | "\n", 42 | "tokenizer = RegexTokenizer()\n", 43 | "tokenizer.load(model_file=\"../output/tokenizer/darija_tokenizer.model\")" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "Encode the data in batches otherwise you will get OOM (Out Of Memory) error. Experiment with the `batch_size` value until you find the want that uses the most of your RAM without crashing VSCode." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "encoded_text_sequence = []\n", 60 | "batch_size = 100_000_000\n", 61 | "file_path = \"../data/AtlaSetCombined.txt\"\n", 62 | "\n", 63 | "with open(file_path, \"r\") as f:\n", 64 | " while True:\n", 65 | " chunk = f.read(batch_size)\n", 66 | " if not chunk:\n", 67 | " break\n", 68 | " batch_tokens = tokenizer.encode(chunk)\n", 69 | " encoded_text_sequence.extend(batch_tokens)\n", 70 | " print(f\"Processed {len(encoded_text_sequence)} tokens so far.\")\n", 71 | "\n", 72 | "print(f\"Total tokens: {len(encoded_text_sequence)}\")" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "Save the encoded data so that we can load it later." 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "import numpy as np\n", 89 | "\n", 90 | "output_path = \"../output/encoded_data/encoded_atlaset.npy\"\n", 91 | "np.save(output_path, np.array(encoded_text_sequence, dtype=np.int64))\n", 92 | "\n", 93 | "# Free up memory\n", 94 | "del encoded_text_sequence" 95 | ] 96 | } 97 | ], 98 | "metadata": { 99 | "kernelspec": { 100 | "display_name": "vincent", 101 | "language": "python", 102 | "name": "python3" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": { 106 | "name": "ipython", 107 | "version": 3 108 | }, 109 | "file_extension": ".py", 110 | "mimetype": "text/x-python", 111 | "name": "python", 112 | "nbconvert_exporter": "python", 113 | "pygments_lexer": "ipython3", 114 | "version": "3.11.10" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 2 119 | } 120 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.10.1 2 | numpy==2.2.4 3 | tiktoken==0.9.0 4 | torch==2.5.0 5 | tqdm==4.66.5 6 | -------------------------------------------------------------------------------- /scripts/before_and_after_normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from sklearn.preprocessing import MinMaxScaler 5 | 6 | np.random.seed(42) 7 | feature1_original = np.random.normal(loc=50, scale=10, size=100) 8 | feature2_original = 0.05 * feature1_original + np.random.normal(loc=3, scale=0.8, size=100) 9 | original_data = np.column_stack((feature1_original, feature2_original)) 10 | 11 | scaler = MinMaxScaler() 12 | normalized_data = scaler.fit_transform(original_data) 13 | 14 | marker_size = 50 15 | alpha_value = 0.7 16 | plt.figure(figsize=(12, 6)) 17 | 18 | # Plot 1: Before Normalization 19 | plt.subplot(1, 2, 1) 20 | plt.scatter(original_data[:, 0], original_data[:, 1], color='red', s=marker_size) 21 | plt.title('Before normalization') 22 | plt.xlabel('Feature 1') 23 | plt.ylabel('Feature 2') 24 | plt.grid(True, linestyle='--', alpha=alpha_value) 25 | 26 | # Plot 2: After Normalization 27 | plt.subplot(1, 2, 2) 28 | plt.scatter(normalized_data[:, 0], normalized_data[:, 1], color='blue', s=marker_size) 29 | plt.title('After normalization') 30 | plt.xlabel('Feature 1 (Normalized)') 31 | plt.ylabel('Feature 2 (Normalized)') 32 | plt.grid(True, linestyle='--', alpha=alpha_value) 33 | 34 | plt.tight_layout() 35 | plt.savefig('before_and_after_normalization.png', dpi=300, bbox_inches='tight') 36 | plt.show() 37 | -------------------------------------------------------------------------------- /scripts/generate_sinusoidal_positional_encoding_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | block_size = 128 5 | embedding_size = 64 6 | dimensions_to_plot = 16 7 | plot_density = 10 8 | 9 | if embedding_size % 2 != 0: 10 | raise ValueError("embedding_size must be an even number.") 11 | 12 | if dimensions_to_plot > embedding_size: 13 | dimensions_to_plot = embedding_size 14 | elif dimensions_to_plot <= 0: 15 | raise ValueError("dimensions_to_plot must be a positive number.") 16 | 17 | # Corresponds to 1 / (10000^(2i / embedding_size)) or exp(-(2i / embedding_size) * log(10000)). 18 | even_numbers = np.arange(0, embedding_size, 2, dtype=np.float32) 19 | denominator = np.exp(even_numbers * -(np.log(10000.0) / embedding_size)) 20 | 21 | # Generate a denser range of positions for smoother plotting 22 | positions_smooth = np.linspace( 23 | 0, 24 | block_size - 1, 25 | int(block_size * plot_density) 26 | ) 27 | 28 | try: 29 | plt.style.use('seaborn-v0_8-whitegrid') 30 | except IOError: 31 | print("Style 'seaborn-v0_8-whitegrid' not found, using default.") 32 | 33 | figure, axis = plt.subplots(figsize=(15, 10)) 34 | figure.patch.set_facecolor('white') 35 | axis.set_facecolor('white') 36 | 37 | colors = ['#0077be', '#d95319'] 38 | vertical_offset_factor = 2.5 39 | 40 | for i in range(0, dimensions_to_plot): 41 | color_index = i % 2 42 | denominator_value = denominator[i // 2] 43 | 44 | if i % 2 == 0: # Even dimension index (0, 2, 4...) -> Sine 45 | pe_smooth = np.sin(positions_smooth * denominator_value) 46 | else: # Odd dimension index (1, 3, 5...) -> Cosine 47 | pe_smooth = np.cos(positions_smooth * denominator_value) 48 | 49 | # Calculate the vertical offset for stacking lines visually 50 | offset = (dimensions_to_plot - 1 - i) * vertical_offset_factor 51 | axis.plot( 52 | positions_smooth, 53 | pe_smooth + offset, 54 | color=colors[color_index], 55 | linewidth=1.5, 56 | ) 57 | 58 | axis.set_title( 59 | f'Sinusoidal positional encoding (First {dimensions_to_plot} dimensions)', 60 | fontsize=20 61 | ) 62 | axis.set_xlabel('Position in sequence', fontsize=18) 63 | axis.set_ylabel('Dimension index', fontsize=18) 64 | 65 | axis.set_yticks([]) 66 | axis.set_yticklabels([]) 67 | 68 | axis.set_xlim(0, block_size - 1) 69 | axis.set_ylim( 70 | -vertical_offset_factor, 71 | (dimensions_to_plot + 0.5) * vertical_offset_factor 72 | ) 73 | 74 | plt.tight_layout() 75 | plt.savefig( 76 | "../images/sinusoidal_positional_encoding_smooth.svg", 77 | format="svg", 78 | dpi=300, 79 | bbox_inches="tight" 80 | ) 81 | plt.show() 82 | -------------------------------------------------------------------------------- /scripts/plot_activation_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from numpy import ndarray 5 | 6 | 7 | def sigmoid(x: ndarray) -> ndarray: 8 | # Sigmoid: f(x) = 1 / (1 + exp(-x)) 9 | return 1 / (1 + np.exp(-x)) 10 | 11 | 12 | def relu(x: ndarray) -> ndarray: 13 | # ReLU: f(x) = max(0, x) 14 | return np.maximum(0, x) 15 | 16 | 17 | def gelu(x: ndarray) -> ndarray: 18 | # GeLU (Gaussian Error Linear Unit approximation): 19 | # f(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) 20 | return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3))) 21 | 22 | 23 | def leaky_relu(x: ndarray, alpha: float = 0.1) -> ndarray: 24 | # LeakyReLU: f(x) = alpha * x if x < 0 else x 25 | return np.where(x > 0, x, x * alpha) 26 | 27 | 28 | def silu(x: ndarray) -> ndarray: 29 | # SiLU (Sigmoid Linear Unit, also known as Swish): 30 | # f(x) = x * sigmoid(x) 31 | return x * sigmoid(x) 32 | 33 | 34 | max_absolute_value = 3.5 35 | x_values = np.linspace(-max_absolute_value, max_absolute_value, 400) 36 | 37 | y_sigmoid = sigmoid(x_values) 38 | y_relu = relu(x_values) 39 | y_gelu = gelu(x_values) 40 | y_leaky_relu = leaky_relu(x_values) 41 | y_silu = silu(x_values) 42 | 43 | plt.figure(figsize=(12, 8)) 44 | 45 | line_width = 3 46 | plt.plot(x_values, y_sigmoid, label='Sigmoid', linewidth=line_width, color='blue') 47 | plt.plot(x_values, y_relu, label='ReLU', linewidth=line_width, color='red') 48 | plt.plot(x_values, y_gelu, label='GeLU', linewidth=line_width, color='green') 49 | plt.plot(x_values, y_leaky_relu, label=r'LeakyReLU ($\alpha=0.1$)', linewidth=line_width, color='orange') 50 | plt.plot(x_values, y_silu, label=r'SiLU / Swish ($x\sigma(x)$)', linewidth=line_width, color='purple') 51 | 52 | plt.title('Activation functions') 53 | plt.xlabel('x') 54 | plt.ylabel('f(x)') 55 | 56 | plt.legend(loc='upper left') 57 | plt.grid(True, linestyle='--', alpha=0.7) 58 | 59 | plt.axhline(0, color='black', linewidth=0.5) 60 | plt.axvline(0, color='black', linewidth=0.5) 61 | 62 | plt.savefig('activation_functions.png', dpi=300, bbox_inches='tight') 63 | plt.show() 64 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImadSaddik/Train_Your_Language_Model_Course/e9e8e01b46e1376406bd1c3a0e1692b64ba660ea/transformer/__init__.py -------------------------------------------------------------------------------- /transformer/lora.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | 5 | from transformer.model import GPTLanguageModel 6 | 7 | 8 | class LoRALayer(nn.Module): 9 | def __init__(self, in_dim: int, out_dim: int, rank: int, alpha: float) -> None: 10 | super().__init__() 11 | std_dev = 1/torch.sqrt(torch.tensor(rank).float()) 12 | self.A = nn.Parameter(torch.randn(in_dim, rank)*std_dev) 13 | self.B = nn.Parameter(torch.zeros(rank, out_dim)) 14 | self.alpha = alpha 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | x = self.alpha*(x@self.A@self.B) 18 | return x 19 | 20 | 21 | class LinearWithLoRA(nn.Module): 22 | def __init__(self, linear: nn.Linear, rank: int, alpha: float) -> None: 23 | super().__init__() 24 | self.linear = linear 25 | self.lora = LoRALayer( 26 | linear.in_features, linear.out_features, rank, alpha 27 | ) 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | return self.linear(x) + self.lora(x) 31 | 32 | 33 | def print_trainable_parameters(model: GPTLanguageModel) -> None: 34 | trainable_parameters = 0 35 | all_parameters = 0 36 | for _, param in model.named_parameters(): 37 | all_parameters += param.numel() 38 | if param.requires_grad: 39 | trainable_parameters += param.numel() 40 | 41 | print( 42 | f"All parameters: {all_parameters/1e6:.2f}M | " 43 | f"Trainable parameters: {trainable_parameters/1e6:.2f}M | " 44 | f"Trainable %: {100 * trainable_parameters / all_parameters:.2f}%" 45 | ) 46 | 47 | 48 | def get_lora_model(model: GPTLanguageModel, lora_config: dict, device: str) -> GPTLanguageModel: 49 | lora_model = copy.deepcopy(model) 50 | _replace_linear_layers_with_lora_layers(lora_model, lora_config) 51 | _freeze_non_lora_layers(lora_model) 52 | lora_model = lora_model.to(device) 53 | return lora_model 54 | 55 | 56 | def _replace_linear_layers_with_lora_layers(module: nn.Module, lora_config: dict) -> None: 57 | rank = lora_config.get('rank', 4) 58 | alpha = lora_config.get('alpha', 8) 59 | 60 | for name, child in list(module.named_children()): 61 | if isinstance(child, nn.Linear): 62 | setattr(module, name, LinearWithLoRA( 63 | child, rank=rank, alpha=alpha)) 64 | else: 65 | _replace_linear_layers_with_lora_layers( 66 | child, lora_config) 67 | 68 | 69 | def _freeze_non_lora_layers(model: GPTLanguageModel) -> None: 70 | for name, param in model.named_parameters(): 71 | if 'lora' in name: 72 | param.requires_grad = True 73 | else: 74 | param.requires_grad = False 75 | -------------------------------------------------------------------------------- /transformer/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import Optional, Tuple 5 | from torch.nn import functional as F 6 | 7 | 8 | class Head(nn.Module): 9 | """ one head of self-attention """ 10 | 11 | def __init__(self, n_embd: int, head_size: int, block_size: int, dropout: float) -> None: 12 | super().__init__() 13 | self.key = nn.Linear(n_embd, head_size, bias=False) 14 | self.query = nn.Linear(n_embd, head_size, bias=False) 15 | self.value = nn.Linear(n_embd, head_size, bias=False) 16 | self.register_buffer('tril', torch.tril( 17 | torch.ones(block_size, block_size))) 18 | self.dropout = nn.Dropout(dropout) 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | _, T, _ = x.shape 22 | k = self.key(x) # (B,T,hs) 23 | q = self.query(x) # (B,T,hs) 24 | weights = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 25 | weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) 26 | weights = F.softmax(weights, dim=-1) 27 | weights = self.dropout(weights) 28 | v = self.value(x) 29 | out = weights @ v 30 | return out 31 | 32 | 33 | class MultiHeadAttention(nn.Module): 34 | """ multiple heads of self-attention in parallel """ 35 | 36 | def __init__(self, n_embd: int, num_heads: int, head_size: int, block_size: int, dropout: float) -> None: 37 | super().__init__() 38 | self.heads = nn.ModuleList([ 39 | Head(n_embd, head_size, block_size, dropout) 40 | for _ in range(num_heads) 41 | ]) 42 | self.projection = nn.Linear(head_size * num_heads, n_embd) 43 | self.dropout = nn.Dropout(dropout) 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | out = torch.cat([h(x) for h in self.heads], dim=-1) 47 | out = self.dropout(self.projection(out)) 48 | return out 49 | 50 | 51 | class FeedForward(nn.Module): 52 | """ a simple linear layer followed by a non-linearity """ 53 | 54 | def __init__(self, n_embd: int, dropout: float) -> None: 55 | super().__init__() 56 | self.net = nn.Sequential( 57 | nn.Linear(n_embd, 4 * n_embd), 58 | nn.ReLU(), 59 | nn.Linear(4 * n_embd, n_embd), 60 | nn.Dropout(dropout), 61 | ) 62 | 63 | def forward(self, x: torch.Tensor) -> torch.Tensor: 64 | return self.net(x) 65 | 66 | 67 | class Block(nn.Module): 68 | """ Transformer block: communication followed by computation """ 69 | 70 | def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float) -> None: 71 | super().__init__() 72 | head_size = n_embd // n_head 73 | error_message = f"n_embd {n_embd} must be divisible by n_head {n_head}" 74 | assert head_size * n_head == n_embd, error_message 75 | self.self_attention = MultiHeadAttention( 76 | n_embd=n_embd, 77 | num_heads=n_head, 78 | head_size=head_size, 79 | block_size=block_size, 80 | dropout=dropout 81 | ) 82 | self.feed_forward = FeedForward(n_embd, dropout) 83 | self.layer_norm_1 = nn.LayerNorm(n_embd) 84 | self.layer_norm_2 = nn.LayerNorm(n_embd) 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: 87 | x = x + self.self_attention(self.layer_norm_1(x)) 88 | x = x + self.feed_forward(self.layer_norm_2(x)) 89 | return x 90 | 91 | 92 | class GPTLanguageModel(nn.Module): 93 | def __init__( 94 | self, 95 | vocab_size: int, 96 | n_embd: int, 97 | n_head: int, 98 | block_size: int, 99 | n_layer: int, 100 | dropout: float, 101 | device: str, 102 | ignore_index: int = -100 103 | ) -> None: 104 | super().__init__() 105 | self.ignore_index = ignore_index 106 | self.block_size = block_size 107 | self.device = device 108 | 109 | self.token_embedding_table = nn.Embedding(vocab_size, n_embd) 110 | self.position_embedding_table = nn.Embedding(block_size, n_embd) 111 | self.blocks = nn.Sequential(*[ 112 | Block(n_embd, n_head, block_size, dropout) 113 | for _ in range(n_layer) 114 | ]) 115 | self.final_layer_norm = nn.LayerNorm(n_embd) 116 | self.final_linear_layer = nn.Linear(n_embd, vocab_size) 117 | 118 | self.apply(self._init_weights) 119 | self.to(device) 120 | 121 | def _init_weights(self, module: nn.Module) -> None: 122 | if isinstance(module, nn.Linear): 123 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 124 | if module.bias is not None: 125 | torch.nn.init.zeros_(module.bias) 126 | elif isinstance(module, nn.Embedding): 127 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 128 | 129 | def forward(self, input_tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 130 | B, T = input_tokens.shape 131 | 132 | token_embedding = self.token_embedding_table(input_tokens) 133 | positional_embedding = self.position_embedding_table( 134 | torch.arange(T, device=self.device)) 135 | x = token_embedding + positional_embedding 136 | x = self.blocks(x) 137 | x = self.final_layer_norm(x) 138 | logits = self.final_linear_layer(x) 139 | 140 | if targets is None: 141 | loss = None 142 | else: 143 | B, T, C = logits.shape 144 | logits = logits.view(B*T, C) 145 | targets = targets.view(B*T) 146 | loss = F.cross_entropy( 147 | logits, targets, ignore_index=self.ignore_index) 148 | 149 | return logits, loss 150 | 151 | def generate(self, input_tokens: torch.Tensor, max_new_tokens: int) -> torch.Tensor: 152 | """ 153 | Generates new tokens from the model. 154 | 155 | Args: 156 | input_tokens: The initial input tokens. 157 | max_new_tokens: The maximum number of tokens to generate. 158 | 159 | Returns: 160 | The generated tokens. 161 | """ 162 | for _ in range(max_new_tokens): 163 | cropped_input = input_tokens[:, -self.block_size:] 164 | logits, _ = self(cropped_input) 165 | logits = logits[:, -1, :] 166 | probs = F.softmax(logits, dim=-1) 167 | idx_next = torch.multinomial(probs, num_samples=1) 168 | input_tokens = torch.cat((input_tokens, idx_next), dim=1) 169 | return input_tokens 170 | 171 | def advanced_generation( 172 | self, 173 | input_tokens: torch.Tensor, 174 | max_new_tokens: int, 175 | temperature: float = 1.0, 176 | top_k: Optional[int] = None, 177 | top_p: Optional[float] = None 178 | ) -> torch.Tensor: 179 | """ 180 | Generates new tokens from the model. 181 | 182 | Args: 183 | input_tokens: The initial input tokens. 184 | max_new_tokens: The maximum number of tokens to generate. 185 | temperature: Controls randomness (higher = more random). 186 | top_k: Limits generation to the top-k most likely tokens. 187 | top_p: Limits generation to tokens with cumulative probability <= top_p. 188 | 189 | Returns: 190 | The generated tokens. 191 | """ 192 | for _ in range(max_new_tokens): 193 | cropped_input = input_tokens[:, -self.block_size:] 194 | logits, _ = self(cropped_input) 195 | logits = logits[:, -1, :] / temperature 196 | 197 | if top_k is not None: 198 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 199 | logits[logits < v[:, [-1]]] = -float('Inf') 200 | 201 | probs = F.softmax(logits, dim=-1) 202 | 203 | if top_p is not None: 204 | sorted_probs, sorted_indices = torch.sort( 205 | probs, descending=True) 206 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 207 | sorted_indices_to_remove = cumulative_probs > top_p 208 | sorted_indices_to_remove[..., 209 | 1:] = sorted_indices_to_remove[..., :-1].clone() 210 | sorted_indices_to_remove[..., 0] = 0 211 | indices_to_remove = torch.zeros_like(logits).scatter_( 212 | 1, sorted_indices, sorted_indices_to_remove) 213 | probs[indices_to_remove] = 0.0 214 | probs = probs / probs.sum(dim=-1, keepdim=True) 215 | 216 | idx_next = torch.multinomial(probs, num_samples=1) 217 | input_tokens = torch.cat((input_tokens, idx_next), dim=1) 218 | 219 | return input_tokens 220 | 221 | 222 | if __name__ == "__main__": 223 | # Example usage 224 | vocab_size = 16394 225 | embedding_size = 512 226 | number_of_heads = 8 227 | block_size = 1024 228 | number_of_blocks = 1 229 | dropout = 0.2 230 | head_size = embedding_size // number_of_heads 231 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 232 | 233 | model = GPTLanguageModel( 234 | vocab_size=vocab_size, 235 | n_embd=embedding_size, 236 | n_head=number_of_heads, 237 | block_size=block_size, 238 | n_layer=number_of_blocks, 239 | dropout=dropout, 240 | device=device 241 | ) 242 | 243 | model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) 244 | print(f"Model size: {model_size / 1e6:.2f}M parameters") 245 | 246 | print( 247 | f"Model created with {embedding_size=}, {number_of_heads=}, head_size={embedding_size//number_of_heads}") 248 | 249 | # Create dummy input 250 | input_tokens = torch.randint(0, vocab_size, (2, 50), device=device) 251 | 252 | # Test forward pass 253 | # Use input as target for testing shape 254 | logits, loss = model(input_tokens, targets=input_tokens) 255 | if loss is not None: 256 | print("Loss:", loss.item()) 257 | 258 | # Test generation 259 | print("Generating...") 260 | # Start generation from first 10 tokens 261 | generated_tokens = model.generate(input_tokens[:, :10], max_new_tokens=20) 262 | print("Generated tokens shape:", generated_tokens.shape) 263 | print("Generated sequence example (first batch):\n", 264 | generated_tokens[0].tolist()) 265 | 266 | # Test advanced generation 267 | print("\nAdvanced Generating (top_k=5, temp=0.8)...") 268 | generated_tokens_adv = model.advanced_generation( 269 | input_tokens[:, :10], 270 | max_new_tokens=20, 271 | temperature=0.8, 272 | top_k=10 273 | ) 274 | print("Generated tokens shape (adv):", generated_tokens_adv.shape) 275 | print("Generated sequence example (adv, first batch):\n", 276 | generated_tokens_adv[0].tolist()) 277 | -------------------------------------------------------------------------------- /transformer/model_no_positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import Optional, Tuple 5 | from torch.nn import functional as F 6 | 7 | 8 | class Head(nn.Module): 9 | """ one head of self-attention """ 10 | 11 | def __init__(self, n_embd: int, head_size: int, block_size: int, dropout: float) -> None: 12 | super().__init__() 13 | self.key = nn.Linear(n_embd, head_size, bias=False) 14 | self.query = nn.Linear(n_embd, head_size, bias=False) 15 | self.value = nn.Linear(n_embd, head_size, bias=False) 16 | self.register_buffer('tril', torch.tril( 17 | torch.ones(block_size, block_size))) 18 | self.dropout = nn.Dropout(dropout) 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | _, T, _ = x.shape 22 | k = self.key(x) # (B,T,hs) 23 | q = self.query(x) # (B,T,hs) 24 | weights = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 25 | weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) 26 | weights = F.softmax(weights, dim=-1) 27 | weights = self.dropout(weights) 28 | v = self.value(x) 29 | out = weights @ v 30 | return out 31 | 32 | 33 | class MultiHeadAttention(nn.Module): 34 | """ multiple heads of self-attention in parallel """ 35 | 36 | def __init__(self, n_embd: int, num_heads: int, head_size: int, block_size: int, dropout: float) -> None: 37 | super().__init__() 38 | self.heads = nn.ModuleList([ 39 | Head(n_embd, head_size, block_size, dropout) 40 | for _ in range(num_heads) 41 | ]) 42 | self.projection = nn.Linear(head_size * num_heads, n_embd) 43 | self.dropout = nn.Dropout(dropout) 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | out = torch.cat([h(x) for h in self.heads], dim=-1) 47 | out = self.dropout(self.projection(out)) 48 | return out 49 | 50 | 51 | class FeedForward(nn.Module): 52 | """ a simple linear layer followed by a non-linearity """ 53 | 54 | def __init__(self, n_embd: int, dropout: float) -> None: 55 | super().__init__() 56 | self.net = nn.Sequential( 57 | nn.Linear(n_embd, 4 * n_embd), 58 | nn.ReLU(), 59 | nn.Linear(4 * n_embd, n_embd), 60 | nn.Dropout(dropout), 61 | ) 62 | 63 | def forward(self, x: torch.Tensor) -> torch.Tensor: 64 | return self.net(x) 65 | 66 | 67 | class Block(nn.Module): 68 | """ Transformer block: communication followed by computation """ 69 | 70 | def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float) -> None: 71 | super().__init__() 72 | head_size = n_embd // n_head 73 | error_message = f"n_embd {n_embd} must be divisible by n_head {n_head}" 74 | assert head_size * n_head == n_embd, error_message 75 | self.self_attention = MultiHeadAttention( 76 | n_embd=n_embd, 77 | num_heads=n_head, 78 | head_size=head_size, 79 | block_size=block_size, 80 | dropout=dropout 81 | ) 82 | self.feed_forward = FeedForward(n_embd, dropout) 83 | self.layer_norm_1 = nn.LayerNorm(n_embd) 84 | self.layer_norm_2 = nn.LayerNorm(n_embd) 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: 87 | x = x + self.self_attention(self.layer_norm_1(x)) 88 | x = x + self.feed_forward(self.layer_norm_2(x)) 89 | return x 90 | 91 | 92 | class GPTLanguageModel(nn.Module): 93 | def __init__( 94 | self, 95 | vocab_size: int, 96 | n_embd: int, 97 | n_head: int, 98 | block_size: int, 99 | n_layer: int, 100 | dropout: float, 101 | device: str, 102 | ignore_index: int = -100 103 | ) -> None: 104 | super().__init__() 105 | self.ignore_index = ignore_index 106 | self.block_size = block_size 107 | self.device = device 108 | 109 | self.token_embedding_table = nn.Embedding(vocab_size, n_embd) 110 | self.blocks = nn.Sequential(*[ 111 | Block(n_embd, n_head, block_size, dropout) 112 | for _ in range(n_layer) 113 | ]) 114 | self.final_layer_norm = nn.LayerNorm(n_embd) 115 | self.final_linear_layer = nn.Linear(n_embd, vocab_size) 116 | 117 | self.apply(self._init_weights) 118 | self.to(device) 119 | 120 | def _init_weights(self, module: nn.Module) -> None: 121 | if isinstance(module, nn.Linear): 122 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 123 | if module.bias is not None: 124 | torch.nn.init.zeros_(module.bias) 125 | elif isinstance(module, nn.Embedding): 126 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 127 | 128 | def forward(self, input_tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 129 | B, T = input_tokens.shape 130 | 131 | x = self.token_embedding_table(input_tokens) 132 | x = self.blocks(x) 133 | x = self.final_layer_norm(x) 134 | logits = self.final_linear_layer(x) 135 | 136 | if targets is None: 137 | loss = None 138 | else: 139 | B, T, C = logits.shape 140 | logits = logits.view(B*T, C) 141 | targets = targets.view(B*T) 142 | loss = F.cross_entropy( 143 | logits, targets, ignore_index=self.ignore_index) 144 | 145 | return logits, loss 146 | 147 | def generate(self, input_tokens: torch.Tensor, max_new_tokens: int) -> torch.Tensor: 148 | """ 149 | Generates new tokens from the model. 150 | 151 | Args: 152 | input_tokens: The initial input tokens. 153 | max_new_tokens: The maximum number of tokens to generate. 154 | 155 | Returns: 156 | The generated tokens. 157 | """ 158 | for _ in range(max_new_tokens): 159 | cropped_input = input_tokens[:, -self.block_size:] 160 | logits, _ = self(cropped_input) 161 | logits = logits[:, -1, :] 162 | probs = F.softmax(logits, dim=-1) 163 | idx_next = torch.multinomial(probs, num_samples=1) 164 | input_tokens = torch.cat((input_tokens, idx_next), dim=1) 165 | return input_tokens 166 | 167 | def advanced_generation( 168 | self, 169 | input_tokens: torch.Tensor, 170 | max_new_tokens: int, 171 | temperature: float = 1.0, 172 | top_k: Optional[int] = None, 173 | top_p: Optional[float] = None 174 | ) -> torch.Tensor: 175 | """ 176 | Generates new tokens from the model. 177 | 178 | Args: 179 | input_tokens: The initial input tokens. 180 | max_new_tokens: The maximum number of tokens to generate. 181 | temperature: Controls randomness (higher = more random). 182 | top_k: Limits generation to the top-k most likely tokens. 183 | top_p: Limits generation to tokens with cumulative probability <= top_p. 184 | 185 | Returns: 186 | The generated tokens. 187 | """ 188 | for _ in range(max_new_tokens): 189 | cropped_input = input_tokens[:, -self.block_size:] 190 | logits, _ = self(cropped_input) 191 | logits = logits[:, -1, :] / temperature 192 | 193 | if top_k is not None: 194 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 195 | logits[logits < v[:, [-1]]] = -float('Inf') 196 | 197 | probs = F.softmax(logits, dim=-1) 198 | 199 | if top_p is not None: 200 | sorted_probs, sorted_indices = torch.sort( 201 | probs, descending=True) 202 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 203 | sorted_indices_to_remove = cumulative_probs > top_p 204 | sorted_indices_to_remove[..., 205 | 1:] = sorted_indices_to_remove[..., :-1].clone() 206 | sorted_indices_to_remove[..., 0] = 0 207 | indices_to_remove = torch.zeros_like(logits).scatter_( 208 | 1, sorted_indices, sorted_indices_to_remove) 209 | probs[indices_to_remove] = 0.0 210 | probs = probs / probs.sum(dim=-1, keepdim=True) 211 | 212 | idx_next = torch.multinomial(probs, num_samples=1) 213 | input_tokens = torch.cat((input_tokens, idx_next), dim=1) 214 | 215 | return input_tokens 216 | 217 | 218 | if __name__ == "__main__": 219 | # Example usage 220 | vocab_size = 16394 221 | embedding_size = 512 222 | number_of_heads = 8 223 | block_size = 1024 224 | number_of_blocks = 1 225 | dropout = 0.2 226 | head_size = embedding_size // number_of_heads 227 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 228 | 229 | model = GPTLanguageModel( 230 | vocab_size=vocab_size, 231 | n_embd=embedding_size, 232 | n_head=number_of_heads, 233 | block_size=block_size, 234 | n_layer=number_of_blocks, 235 | dropout=dropout, 236 | device=device 237 | ) 238 | 239 | model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) 240 | print(f"Model size: {model_size / 1e6:.2f}M parameters") 241 | 242 | print( 243 | f"Model created with {embedding_size=}, {number_of_heads=}, head_size={embedding_size//number_of_heads}") 244 | 245 | # Create dummy input 246 | input_tokens = torch.randint(0, vocab_size, (2, 50), device=device) 247 | 248 | # Test forward pass 249 | # Use input as target for testing shape 250 | logits, loss = model(input_tokens, targets=input_tokens) 251 | if loss is not None: 252 | print("Loss:", loss.item()) 253 | 254 | # Test generation 255 | print("Generating...") 256 | # Start generation from first 10 tokens 257 | generated_tokens = model.generate(input_tokens[:, :10], max_new_tokens=20) 258 | print("Generated tokens shape:", generated_tokens.shape) 259 | print("Generated sequence example (first batch):\n", 260 | generated_tokens[0].tolist()) 261 | 262 | # Test advanced generation 263 | print("\nAdvanced Generating (top_k=5, temp=0.8)...") 264 | generated_tokens_adv = model.advanced_generation( 265 | input_tokens[:, :10], 266 | max_new_tokens=20, 267 | temperature=0.8, 268 | top_k=10 269 | ) 270 | print("Generated tokens shape (adv):", generated_tokens_adv.shape) 271 | print("Generated sequence example (adv, first batch):\n", 272 | generated_tokens_adv[0].tolist()) 273 | -------------------------------------------------------------------------------- /transformer/model_relative_positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Optional, Tuple 4 | from torch.nn import functional as F 5 | 6 | 7 | class Head(nn.Module): 8 | """ one head of self-attention with RPE bias""" 9 | 10 | def __init__(self, n_embd: int, head_size: int, block_size: int, dropout: float) -> None: 11 | super().__init__() 12 | self.key = nn.Linear(n_embd, head_size, bias=False) 13 | self.query = nn.Linear(n_embd, head_size, bias=False) 14 | self.value = nn.Linear(n_embd, head_size, bias=False) 15 | tril = torch.tril(torch.ones(block_size, block_size)) 16 | self.register_buffer('tril', tril) 17 | self.dropout = nn.Dropout(dropout) 18 | 19 | def forward(self, x: torch.Tensor, head_bias: torch.Tensor) -> torch.Tensor: 20 | """ 21 | Args: 22 | x: Input tensor (B, T, C) 23 | head_bias: Relative position bias for this specific head (T, T) 24 | """ 25 | _, T, _ = x.shape 26 | k = self.key(x) # (B, T, head_size) 27 | q = self.query(x) # (B, T, head_size) 28 | 29 | # Compute attention scores ("affinities") 30 | # (B, T, head_size) @ (B, head_size, T) -> (B, T, T) 31 | weights = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 32 | 33 | # head_bias shape is (T, T). Unsqueeze to add batch dim for broadcasting -> (1, T, T). 34 | # (B, T, T) + (1, T, T) -> (B, T, T) 35 | weights = weights + head_bias.unsqueeze(0) 36 | weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) 37 | weights = F.softmax(weights, dim=-1) 38 | weights = self.dropout(weights) 39 | 40 | v = self.value(x) 41 | # (B, T, T) @ (B, T, head_size) -> (B, T, head_size) 42 | out = weights @ v 43 | return out 44 | 45 | 46 | class MultiHeadAttention(nn.Module): 47 | """ Multiple heads of self-attention in parallel, using Head """ 48 | 49 | def __init__( 50 | self, 51 | n_embd: int, 52 | num_heads: int, 53 | head_size: int, 54 | block_size: int, 55 | dropout: float, 56 | max_relative_distance: int = 16 57 | ) -> None: 58 | super().__init__() 59 | self.num_heads = num_heads 60 | self.max_relative_distance = max_relative_distance 61 | self.heads = nn.ModuleList([ 62 | Head( 63 | n_embd=n_embd, 64 | head_size=head_size, 65 | block_size=block_size, 66 | dropout=dropout 67 | ) 68 | for _ in range(num_heads) 69 | ]) 70 | self.projection = nn.Linear(head_size * num_heads, n_embd) 71 | self.dropout = nn.Dropout(dropout) 72 | self.num_buckets = 2 * self.max_relative_distance + 1 73 | self.relative_attention_bias = nn.Embedding( 74 | self.num_buckets, 75 | self.num_heads 76 | ) 77 | 78 | def _compute_relative_position_bias(self, sequence_length: int, device: torch.device) -> torch.Tensor: 79 | """ Computes the relative position bias tensor for ALL heads. """ 80 | query_positions = torch.arange(sequence_length, device=device) 81 | key_positions = torch.arange(sequence_length, device=device) 82 | 83 | # Shape (T, T) 84 | relative_position = key_positions[None, :] - query_positions[:, None] 85 | # Shift range to positive values [0, 2 * max_relative_distance] 86 | relative_indices = relative_position + self.max_relative_distance 87 | # Clamp to ensure indices are within the range [0, num_buckets - 1] 88 | relative_indices = torch.clamp( 89 | input=relative_indices, 90 | min=0, 91 | max=self.num_buckets - 1 92 | ) 93 | # Lookup biases for all heads: (T, T) -> (T, T, num_heads) 94 | bias = self.relative_attention_bias(relative_indices) 95 | # (T, T, num_heads) -> (num_heads, T, T) 96 | bias = bias.permute(2, 0, 1) 97 | return bias 98 | 99 | def forward(self, x: torch.Tensor) -> torch.Tensor: 100 | _, T, _ = x.shape 101 | relative_bias = self._compute_relative_position_bias(T, x.device) 102 | 103 | head_outputs = [] 104 | for i, head_module in enumerate(self.heads): 105 | # Slice shape: (T, T) 106 | head_bias_slice = relative_bias[i] 107 | head_output = head_module(x=x, head_bias=head_bias_slice) 108 | head_outputs.append(head_output) 109 | 110 | # Shape (B, T, num_heads * head_size) -> (B, T, n_embd) 111 | out = torch.cat(head_outputs, dim=-1) 112 | out = self.dropout(self.projection(out)) 113 | return out 114 | 115 | 116 | class FeedForward(nn.Module): 117 | def __init__(self, n_embd: int, dropout: float) -> None: 118 | super().__init__() 119 | self.net = nn.Sequential( 120 | nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), 121 | nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout), 122 | ) 123 | 124 | def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) 125 | 126 | 127 | class Block(nn.Module): 128 | """ Transformer block using the RPE Attention """ 129 | 130 | def __init__( 131 | self, 132 | n_embd: int, 133 | n_head: int, 134 | block_size: int, 135 | dropout: float, 136 | max_relative_distance: int 137 | ) -> None: 138 | super().__init__() 139 | head_size = n_embd // n_head 140 | error_message = f"n_embd {n_embd} must be divisible by n_head {n_head}" 141 | assert head_size * n_head == n_embd, error_message 142 | 143 | self.self_attention = MultiHeadAttention( 144 | n_embd=n_embd, 145 | num_heads=n_head, 146 | head_size=head_size, 147 | block_size=block_size, 148 | dropout=dropout, 149 | max_relative_distance=max_relative_distance 150 | ) 151 | self.feed_forward = FeedForward(n_embd, dropout) 152 | self.layer_norm_1 = nn.LayerNorm(n_embd) 153 | self.layer_norm_2 = nn.LayerNorm(n_embd) 154 | 155 | def forward(self, x: torch.Tensor) -> torch.Tensor: 156 | x = x + self.self_attention(self.layer_norm_1(x)) 157 | x = x + self.feed_forward(self.layer_norm_2(x)) 158 | return x 159 | 160 | 161 | class GPTLanguageModel(nn.Module): 162 | def __init__( 163 | self, 164 | vocab_size: int, 165 | n_embd: int, 166 | n_head: int, 167 | block_size: int, 168 | n_layer: int, 169 | dropout: float, 170 | device: str, 171 | ignore_index: int = -100, 172 | max_relative_distance: int = 16 173 | ) -> None: 174 | super().__init__() 175 | self.ignore_index = ignore_index 176 | self.block_size = block_size 177 | self.device = device 178 | 179 | self.token_embedding_table = nn.Embedding(vocab_size, n_embd) 180 | self.blocks = nn.Sequential(*[ 181 | Block( 182 | n_embd=n_embd, 183 | n_head=n_head, 184 | block_size=block_size, 185 | dropout=dropout, 186 | max_relative_distance=max_relative_distance 187 | ) for _ in range(n_layer) 188 | ]) 189 | self.final_layer_norm = nn.LayerNorm(n_embd) 190 | self.final_linear_layer = nn.Linear(n_embd, vocab_size) 191 | 192 | self.apply(self._init_weights) 193 | self.to(device) 194 | 195 | def _init_weights(self, module: nn.Module) -> None: 196 | if isinstance(module, nn.Linear): 197 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 198 | if module.bias is not None: 199 | torch.nn.init.zeros_(module.bias) 200 | elif isinstance(module, nn.Embedding): 201 | # Avoid re-initializing RPE bias if desired (e.g., keep default zero init) 202 | # This check is basic; better ways might involve setting an attribute 203 | is_rpe_bias = (hasattr(module, 'weight') and 204 | module.weight.shape == (2 * self.blocks[0].self_attention.max_relative_distance + 1, 205 | self.blocks[0].self_attention.num_heads)) 206 | if not is_rpe_bias: 207 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 208 | 209 | def forward(self, input_tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 210 | B, T = input_tokens.shape 211 | 212 | x = self.token_embedding_table(input_tokens) 213 | x = self.blocks(x) 214 | x = self.final_layer_norm(x) 215 | logits = self.final_linear_layer(x) 216 | 217 | if targets is None: 218 | loss = None 219 | else: 220 | B, T, C = logits.shape 221 | logits = logits.view(B*T, C) 222 | targets = targets.view(B*T) 223 | loss = F.cross_entropy( 224 | logits, targets, ignore_index=self.ignore_index) 225 | 226 | return logits, loss 227 | 228 | def generate(self, input_tokens: torch.Tensor, max_new_tokens: int) -> torch.Tensor: 229 | for _ in range(max_new_tokens): 230 | cropped_input = input_tokens[:, -self.block_size:] 231 | logits, _ = self(cropped_input) 232 | logits = logits[:, -1, :] 233 | probs = F.softmax(logits, dim=-1) 234 | idx_next = torch.multinomial(probs, num_samples=1) 235 | input_tokens = torch.cat((input_tokens, idx_next), dim=1) 236 | return input_tokens 237 | 238 | def advanced_generation( 239 | self, input_tokens: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, 240 | top_k: Optional[int] = None, top_p: Optional[float] = None 241 | ) -> torch.Tensor: 242 | for _ in range(max_new_tokens): 243 | cropped_input = input_tokens[:, -self.block_size:] 244 | logits, _ = self(cropped_input) 245 | logits = logits[:, -1, :] / temperature 246 | if top_k is not None: 247 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 248 | logits[logits < v[:, [-1]]] = -float('Inf') 249 | probs = F.softmax(logits, dim=-1) 250 | if top_p is not None: 251 | sorted_probs, sorted_indices = torch.sort( 252 | probs, descending=True) 253 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 254 | sorted_indices_to_remove = cumulative_probs > top_p 255 | sorted_indices_to_remove[..., 256 | 1:] = sorted_indices_to_remove[..., :-1].clone() 257 | sorted_indices_to_remove[..., 0] = 0 258 | indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_( 259 | 1, sorted_indices, sorted_indices_to_remove) 260 | probs[indices_to_remove] = 0.0 261 | probs = probs / probs.sum(dim=-1, keepdim=True) 262 | idx_next = torch.multinomial(probs, num_samples=1) 263 | input_tokens = torch.cat((input_tokens, idx_next), dim=1) 264 | return input_tokens 265 | 266 | 267 | if __name__ == "__main__": 268 | vocab_size = 10000 269 | embedding_size = 128 # n_embd 270 | number_of_heads = 4 # n_head 271 | block_size = 64 # Max context length 272 | number_of_blocks = 2 # n_layer 273 | dropout = 0.1 274 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 275 | max_relative_distance = 8 276 | 277 | print(f"Using device: {device}") 278 | print(f"Initializing GPTLanguageModel with {max_relative_distance=}") 279 | 280 | model = GPTLanguageModel( 281 | vocab_size=vocab_size, 282 | n_embd=embedding_size, 283 | n_head=number_of_heads, 284 | block_size=block_size, 285 | n_layer=number_of_blocks, 286 | dropout=dropout, 287 | device=device, 288 | max_relative_distance=max_relative_distance 289 | ) 290 | 291 | model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) 292 | print(f"Model size: {model_size / 1e6:.2f}M parameters") 293 | 294 | B = 2 # Batch size 295 | T = 30 # Sequence length (<= block_size) 296 | input_tokens = torch.randint(0, vocab_size, (B, T), device=device) 297 | 298 | print(f"\nTesting forward pass with input shape: {input_tokens.shape}") 299 | logits, loss = model(input_tokens, targets=input_tokens) 300 | if loss is not None: 301 | print(f"Loss: {loss.item():.4f}") 302 | else: 303 | print("Forward pass completed, no loss calculated.") 304 | print(f"Logits shape: {logits.shape}") 305 | 306 | print("\nTesting generation...") 307 | gen_input = input_tokens[:, :5] 308 | print(f"Generation input shape: {gen_input.shape}") 309 | generated_tokens = model.generate(gen_input, max_new_tokens=10) 310 | print(f"Generated tokens shape: {generated_tokens.shape}") 311 | print("Generated sequence example (first batch):\n", 312 | generated_tokens[0].tolist()) 313 | --------------------------------------------------------------------------------