├── .gitignore ├── CODE_LICENSE.txt ├── README.md ├── errata └── README.md ├── img └── cover.jpg └── supplementary ├── 00-1_python-setup-guide ├── README.md └── figures │ ├── download.png │ ├── miniforge-install.png │ └── new-env.png ├── README.md ├── q10-random-sources ├── data-sampling.ipynb ├── dropout.ipynb └── random-weights.ipynb ├── q11-conv-size └── q11-conv-size.ipynb ├── q12-fc-cnn-equivalence ├── img │ ├── fc-cnn-equivalent-1.png │ ├── fc-cnn-equivalent-2.png │ └── fc-cnn-equivalent-3.png └── q12-fc-cnn-equivalence.ipynb ├── q15-text-augment ├── backtranslation.ipynb ├── noise-injection.ipynb ├── sentence-order-shuffling.ipynb ├── synonym-replacement.ipynb ├── synthetic-data.ipynb ├── word-deletion.ipynb └── word-position-swapping.ipynb ├── q18-using-llms ├── 01_classifier-finetuning │ ├── 1_feature-extractor.ipynb │ ├── 2_finetune-last-layers.ipynb │ ├── 3_finetuning-all-layers.ipynb │ ├── figures │ │ ├── 1_feature-based.png │ │ ├── 2_finetune-last.png │ │ └── 3_finetune-all.png │ └── local_dataset_utilities.py ├── 02_prompting │ └── prompting.ipynb ├── 03_retrieval-augmented-generation │ ├── images │ │ └── rag-1.webp │ ├── retrieval-augmented-generation.ipynb │ └── sample-data │ │ └── Basic-Scientific-Food-Preparation-Lab-Manual.txt ├── 04_adapter │ ├── finetune-using-adapter-layers.ipynb │ └── local_dataset_utilities.py └── 05_lora │ ├── lora-llm.ipynb │ └── lora-mlp.ipynb ├── q19-evaluation-llms ├── BERTScore.ipynb ├── bleu.ipynb ├── perplexity.ipynb └── rouge.ipynb ├── q25_confidence-intervals ├── 1_four-methods.ipynb └── 2_four-methods-vs-true-value.ipynb ├── q26_conformal-prediction └── conformal_prediction.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /CODE_LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2023 Sebastian Raschka 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *Machine Learning Q and AI Beyond the Basics* Book 2 | 3 | 4 | 5 | The Supplementary Materials for the [Machine Learning Q and AI](https://nostarch.com/machine-learning-q-and-ai) book by [Sebastian Raschka](http://sebastianraschka.com). 6 | 7 | Please use the [Discussions](https://github.com/rasbt/ml-q-and-ai/discussions) for any questions about the book! 8 | 9 | 2023-ml-qai-cover 10 | 11 |
12 | 13 | #### About the Book 14 | 15 | If you’ve locked down the basics of machine learning and AI and want a fun way to address lingering knowledge gaps, this book is for you. This rapid-fire series of short chapters addresses 30 essential questions in the field, helping you stay current on the latest technologies you can implement in your own work. 16 | 17 | Each chapter of *Machine Learning Q and AI* asks and answers a central question, with diagrams to explain new concepts and ample references for further reading 18 | 19 | - Multi-GPU training paradigms 20 | - Finetuning transformers 21 | - Differences between encoder- and decoder-style LLMs 22 | - Concepts behind vision transformers 23 | - Confidence intervals for ML 24 | - And many more! 25 | 26 |

27 | This book is a fully edited and revised version of Machine Learning Q and AI, which was available on Leanpub. 28 |

29 | 30 |
31 | 32 | #### Reviews 33 | 34 | > “One could hardly ask for a better guide than Sebastian, who is, without exaggeration, the best machine learning educator currently in the field. On each page, Sebastian not only imparts his extensive knowledge but also shares the passion and curiosity that mark true expertise.”
35 | **-- Chris Albon, Director of Machine Learning, The Wikimedia Foundation** 36 | 37 |
38 | 39 | #### Links 40 | 41 | - [Preorder directly from No Starch press](https://nostarch.com/machine-learning-q-and-ai) 42 | - [Preorder directly from Amazon](https://www.amazon.com/Machine-Learning-AI-Essential-Questions/dp/1718503768) 43 | - [Supplementary Materias and Discussions](https://github.com/rasbt/MachineLearning-QandAI-book) 44 | 45 |
46 |
47 | 48 | ## Table of Contents 49 | 50 | | Title | URL Link | Supplementary Code | 51 | |---------|-------|----------| 52 | | 1 | Embeddings, Representations, and Latent Space | | 53 | | 2 | Self-Supervised Learning | | 54 | | 3 | Few-Shot Learning | | 55 | | 4 | The Lottery Ticket Hypothesis | | 56 | | 5 | Reducing Overfitting with Data | | 57 | | 6 | Reducing Overfitting with Model Modifications | | 58 | | 7 | Multi-GPU Training Paradigms | | 59 | | 8 | The Keys to the Success of Transformers | | 60 | | 9 | Generative AI Models | | 61 | | 10 | Sources of Randomness | [data-sampling.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q10-random-sources/data-sampling.ipynb)
[dropout.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q10-random-sources/dropout.ipynb)
[random-weights.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q10-random-sources/random-weights.ipynb)| 62 | || PART II: COMPUTER VISION | | 63 | | 11 | Calculating the Number of Parameters | [conv-size.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q11-conv-size/q11-conv-size.ipynb)| 64 | | 12 | The Equivalence of Fully Connected and Convolutional Layers | [fc-cnn-equivalence.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q12-fc-cnn-equivalence/q12-fc-cnn-equivalence.ipynb)| 65 | | 13 | Large Training Sets for Vision Transformers | | 66 | || PART III: NATURAL LANGUAGE PROCESSING | | 67 | | 14 | The Distributional Hypothesis | | 68 | | 15 | Data Augmentation for Text | [backtranslation.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q15-text-augment/backtranslation.ipynb)
[noise-injection.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q15-text-augment/noise-injection.ipynb)
[sentence-order-shuffling.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q15-text-augment/sentence-order-shuffling.ipynb)
[synonym-replacement.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q15-text-augment/synonym-replacement.ipynb)
[synthetic-data.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q15-text-augment/synthetic-data.ipynb)
[word-deletion.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q15-text-augment/word-deletion.ipynb)
[word-position-swapping.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q15-text-augment/word-position-swapping.ipynb)| 69 | | 16 | “Self”-Attention | | 70 | | 17 | Encoder- And Decoder-Style Transformers | | 71 | | 18 | Using and Finetuning Pretrained Transformers | | 72 | | 19 | Evaluating Generative Large Language Models | [BERTScore.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q19-evaluation-llms/BERTScore.ipynb)
[bleu.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q19-evaluation-llms/bleu.ipynb)
[perplexity.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q19-evaluation-llms/perplexity.ipynb)
[rouge.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q19-evaluation-llms/rouge.ipynb) | 73 | || PART IV: PRODUCTION AND DEPLOYMENT | | 74 | | 20 | Stateless And Stateful Training | | 75 | | 21 | Data-Centric AI | | 76 | | 22 | Speeding Up Inference | | 77 | | 23 | Data Distribution Shifts | | 78 | | | PART V: PREDICTIVE PERFORMANCE AND MODEL EVALUATION | | 79 | | 24 | Poisson and Ordinal Regression | | 80 | | 25 | Confidence Intervals | [four-methods.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q25_confidence-intervals/1_four-methods.ipynb)
[four-methods-vs-true-value.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q25_confidence-intervals/2_four-methods-vs-true-value.ipynb)| 81 | | 26 | Confidence Intervals Versus Conformal Predictions | [conformal_prediction.ipynb](https://github.com/rasbt/MachineLearning-QandAI-book/blob/main/supplementary/q26_conformal-prediction/conformal_prediction.ipynb) | 82 | | 27 | Proper Metrics | | 83 | | 28 | The K in K-Fold Cross-Validation | | 84 | | 29 | Training and Test Set Discordance | | 85 | | 30 | Limited Labeled Data | | 86 | 87 | -------------------------------------------------------------------------------- /errata/README.md: -------------------------------------------------------------------------------- 1 | # Errata 2 | 3 | 4 | #### Chapter 8 5 | 6 | The following sentence in Chapter 8 7 | 8 | > Transformers are easy to parallelize because they take a fixed-length sequence of word or image tokens as input. 9 | 10 | Is misleading because we only work with fixed-size sequences specifically during pretraining, finetuning, and batched inference. I.e., where we collect multiple sequences in a batch. A better explanation could be the following: 11 | 12 | > Like other deep learning architectures, transformers facilitate parallelization in batch training by handling sequences of word or image tokens. Although they can process variable-length sequences, in practice, sequences are often padded or truncated to fixed lengths for efficient parallel computation across multiple sequences. 13 | 14 | #### Chapter 12 15 | 16 | On the first page, just above the first figure (12-1), it says "two input and four output units" but should be "four inputs and two outputs" to match the figure caption of Figure 12-1. 17 | 18 | 19 | #### Chapter 17 20 | 21 | (p109) In figure 17-3, "next-sentence prediction" should be "next-word prediction". 22 | 23 | #### Chapter 18 24 | 25 | (p121, p122) In figure 18-6 and 18-7, "Fully connected layer" box that follows "Multihead self-attention" box should be removed. 26 | 27 | #### Chapter 19 28 | 29 | (p129) in the 2nd row of the second equation, that $\sum$ should be removed. 30 | 31 | (p129, 132) 2nd line from the bottom, "q15-text-augment subfolder" should be "q19-evaluation-llms subfolder" 32 | 33 | #### Chapter 25 34 | 35 | (p166) In note `stats.zscore` should be `stats.norm.ppf`. 36 | 37 | #### Chapter 27 38 | 39 | (p180) Below Figure 27-2, "AB", "BC", "AC" shoule be "A", "B", "C" respectively. 40 | 41 | (p181) In thrid paragraph, "How about the second criterion" should be "How about the second part of the first criterion". 42 | 43 | (p182) 4th and 9th line from the bottom, ">=" should be "<=". To correct the example, "p=0.5" should be changed, for example "p=0.1". 44 | 45 | #### Chapter 28 46 | 47 | (p188) In Exercise 28-1, "1(99 percent)" should be "1(100 percent)". 48 | 49 | #### Chapter 30 50 | 51 | (p194) 2nd line below Transfer Learning, "pretrained target dataset" should be "pretrained model on a target dataset". 52 | 53 | (p203) In Figure 30-11, "label more training" should be "label more training data". 54 | 55 | #### Appendix: Answer to the Exercises 56 | 57 | (p208) In 3-1, in the last line, "example per class" is better than "example per image". 58 | 59 | (p212) In last sentence of 10-2 answer, "deterministic dropout" should be "random dropout". 60 | 61 | (p217) In first and last sentences of 20-2 answer, "stateful retraining" should be "stateful training". 62 | -------------------------------------------------------------------------------- /img/cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/MachineLearning-QandAI-book/9c9994d3b0c320e428a441c2234d4043fa95110d/img/cover.jpg -------------------------------------------------------------------------------- /supplementary/00-1_python-setup-guide/README.md: -------------------------------------------------------------------------------- 1 | # Python Setup Tips 2 | 3 | 4 | 5 | There are several different ways you can install Python and set up your computing environment. Here, I am illustrating my personal preference. 6 | 7 | (I am using computers running macOS, but this workflow is similar for Linux machines and may work for other operating systems as well.) 8 | 9 | 10 | 11 | ## 1. Download and install Miniforge 12 | 13 | Download miniforge from the GitHub repository [here](https://github.com/conda-forge/miniforge). 14 | 15 | download 16 | 17 | Depending on your operating system, this should download either an `.sh` (macOS, Linux) or `.exe` file (Windows). 18 | 19 | For the `.sh` file, open your command line terminal and execute the following command 20 | 21 | ```bash 22 | sh ~/Desktop/Miniforge3-MacOSX-arm64.sh 23 | ``` 24 | 25 | where `Desktop/` is the folder where the Miniforge installer was downloaded to. On your computer, you may have to replace it with `Downloads/`. 26 | 27 | miniforge-install 28 | 29 | Next, step through the download instructions, confirming with "Enter". 30 | 31 | ## 2. Create a new virtual environment 32 | 33 | After the installation was successfully completed, I recommend creating a new virtual environment called `dl-fundamentals`, which you can do by executing 34 | 35 | ```bash 36 | conda create -n book python=3.9 37 | ``` 38 | 39 | new-env 40 | 41 | Next, activate your new virtual environment (you have to do it every time you open a new terminal window or tab): 42 | 43 | ```bash 44 | conda activate book 45 | ``` 46 | 47 | 48 | 49 | ## Optional: styling your terminal 50 | 51 | If you want to style your terminal similar to mine so that you can see which virtual environment is active, check out the [Oh My Zsh](https://github.com/ohmyzsh/ohmyzsh) project. 52 | 53 | 54 | 55 | # 3. Install new Python libraries 56 | 57 | 58 | 59 | To install new Python libraries, you can now use the `conda` package installer. For example, you can install [JupyterLab](https://jupyter.org/install) and [watermark](https://github.com/rasbt/watermark) as follows: 60 | 61 | ```bash 62 | conda install jupyterlab watermark 63 | ``` 64 | 65 | 66 | 67 | Alternatively you can also use `pip` to install libraries instead. By default, `pip` should be linked to your new `book` conda environment: 68 | 69 | ```bash 70 | pip install jupyterlab watermark 71 | ``` 72 | 73 | 74 | --- 75 | 76 | 77 | 78 | 79 | Any questions? Please feel free to reach out in the [Discussion Forum](https://github.com/rasbt/MachineLearning-QandAI-book/discussions). -------------------------------------------------------------------------------- /supplementary/00-1_python-setup-guide/figures/download.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/MachineLearning-QandAI-book/9c9994d3b0c320e428a441c2234d4043fa95110d/supplementary/00-1_python-setup-guide/figures/download.png -------------------------------------------------------------------------------- /supplementary/00-1_python-setup-guide/figures/miniforge-install.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/MachineLearning-QandAI-book/9c9994d3b0c320e428a441c2234d4043fa95110d/supplementary/00-1_python-setup-guide/figures/miniforge-install.png -------------------------------------------------------------------------------- /supplementary/00-1_python-setup-guide/figures/new-env.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/MachineLearning-QandAI-book/9c9994d3b0c320e428a441c2234d4043fa95110d/supplementary/00-1_python-setup-guide/figures/new-env.png -------------------------------------------------------------------------------- /supplementary/README.md: -------------------------------------------------------------------------------- 1 | # Readme 2 | 3 | To install the Python code requirements for the respective chapters, I recommend creating a new virtual environment first: 4 | 5 | ```bash 6 | conda create --new book python=3.9 7 | conda activate book 8 | ``` 9 | 10 | I prefer `conda`, and for more step-by-step instructions please also see the [00-1_python-setup-guide](00-1_python-setup-guide) folder for my personal Python setup preferences. 11 | 12 | However, you can also use any other virtual environment manager of your choice, for example, [venv](https://docs.python.org/3/library/venv.html). 13 | 14 | 15 | 16 | Then, you can ran the following pip command to install the code requirements: 17 | 18 | 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | -------------------------------------------------------------------------------- /supplementary/q10-random-sources/data-sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e1096748-2cbb-4bf1-aa28-fd18dda97022", 6 | "metadata": {}, 7 | "source": [ 8 | "# Data Sampling and Shuffling" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "ecdba48e-dbeb-4fd3-901f-88e71aec70d1", 14 | "metadata": {}, 15 | "source": [ 16 | "### Data Splitting without Random Seed" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 10, 22 | "id": "e569e619-5e76-492a-b110-0e3b36cb9149", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import numpy as np\n", 27 | "\n", 28 | "x_toydata = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9.])\n", 29 | "y_labels = np.array([ 0, 1, 0, 1, 1, 0, 0, 1, 0 ])" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 11, 35 | "id": "c85040c1-81e3-4c08-9f07-644cd74cca8b", 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "X_train [9. 2. 4. 6. 1. 5.]\n", 43 | "X_test [7. 3. 8.]\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "from sklearn.model_selection import train_test_split\n", 49 | "\n", 50 | "X_train, X_test, y_train, y_test = train_test_split(\n", 51 | " x_toydata, y_labels, test_size=0.3, shuffle=True, stratify=y_labels\n", 52 | ")\n", 53 | "\n", 54 | "print(\"X_train\", X_train)\n", 55 | "print(\"X_test\", X_test)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 12, 61 | "id": "8631e484-791a-4adf-857a-2a0529a6bc1f", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "X_train [5. 4. 6. 3. 1. 2.]\n", 69 | "X_test [7. 8. 9.]\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "X_train, X_test, y_train, y_test = train_test_split(\n", 75 | " x_toydata, y_labels, test_size=0.3, shuffle=True, stratify=y_labels\n", 76 | ")\n", 77 | "\n", 78 | "print(\"X_train\", X_train)\n", 79 | "print(\"X_test\", X_test)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "id": "3e8ecec5-b9d8-487c-872b-34537ab7c441", 85 | "metadata": {}, 86 | "source": [ 87 | "### Data Splitting with Random Seed" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 13, 93 | "id": "d49b4ee6-b4ff-4d5b-adb4-3fc541fb0b9a", 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "X_train [9. 2. 8. 3. 5. 7.]\n", 101 | "X_test [4. 1. 6.]\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "X_train, X_test, y_train, y_test = train_test_split(\n", 107 | " x_toydata, y_labels, test_size=0.3, shuffle=True, stratify=y_labels,\n", 108 | " random_state=123\n", 109 | ")\n", 110 | "\n", 111 | "print(\"X_train\", X_train)\n", 112 | "print(\"X_test\", X_test)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 14, 118 | "id": "e2a87370-de68-4c92-9d99-3faed33657f4", 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "X_train [9. 2. 8. 3. 5. 7.]\n", 126 | "X_test [4. 1. 6.]\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "X_train, X_test, y_train, y_test = train_test_split(\n", 132 | " x_toydata, y_labels, test_size=0.3, shuffle=True, stratify=y_labels,\n", 133 | " random_state=123\n", 134 | ")\n", 135 | "\n", 136 | "print(\"X_train\", X_train)\n", 137 | "print(\"X_test\", X_test)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "id": "2a5650fd-30b7-4d69-bbce-77f2c36ad272", 143 | "metadata": {}, 144 | "source": [ 145 | "### K-fold without Random Seed" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 15, 151 | "id": "9e234e41-2ac8-4aa9-b165-c30ee1a12448", 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "Feature values [2. 3. 4. 5. 6. 7.]\n", 159 | "Feature values [1. 2. 4. 7. 8. 9.]\n", 160 | "Feature values [1. 3. 5. 6. 8. 9.]\n" 161 | ] 162 | } 163 | ], 164 | "source": [ 165 | "from sklearn.model_selection import StratifiedKFold\n", 166 | "\n", 167 | "cv = StratifiedKFold(n_splits=3, shuffle=True)\n", 168 | "\n", 169 | "for train_idx, valid_idx in cv.split(x_toydata, y_labels):\n", 170 | " print(\"Feature values\", x_toydata[train_idx])" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 16, 176 | "id": "ba10c301-3346-4c21-913e-4c240d932084", 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "Feature values [2. 3. 4. 7. 8. 9.]\n", 184 | "Feature values [1. 2. 4. 5. 6. 9.]\n", 185 | "Feature values [1. 3. 5. 6. 7. 8.]\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "for train_idx, valid_idx in cv.split(x_toydata, y_labels):\n", 191 | " print(\"Feature values\", x_toydata[train_idx])" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "id": "1961da6d-50ea-4247-82d8-c0ee3f592edd", 197 | "metadata": {}, 198 | "source": [ 199 | "### K-fold with Random Seed" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 17, 205 | "id": "4f0390cb-ff36-4441-968e-cf1e2bde21ac", 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "name": "stdout", 210 | "output_type": "stream", 211 | "text": [ 212 | "Feature values [3. 4. 5. 6. 8. 9.]\n", 213 | "Feature values [1. 2. 4. 5. 6. 7.]\n", 214 | "Feature values [1. 2. 3. 7. 8. 9.]\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "cv = StratifiedKFold(n_splits=3, random_state=123, shuffle=True)\n", 220 | "\n", 221 | "for train_idx, valid_idx in cv.split(x_toydata, y_labels):\n", 222 | " print(\"Feature values\", x_toydata[train_idx])" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 18, 228 | "id": "dd766cf8-c818-4f43-9b4c-a0d51ef8ec21", 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "Feature values [3. 4. 5. 6. 8. 9.]\n", 236 | "Feature values [1. 2. 4. 5. 6. 7.]\n", 237 | "Feature values [1. 2. 3. 7. 8. 9.]\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "cv = StratifiedKFold(n_splits=3, random_state=123, shuffle=True)\n", 243 | "\n", 244 | "for train_idx, valid_idx in cv.split(x_toydata, y_labels):\n", 245 | " print(\"Feature values\", x_toydata[train_idx])" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "id": "6c9b9986", 251 | "metadata": {}, 252 | "source": [ 253 | "## Dataset Loading without Random Seed" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 1, 259 | "id": "82de5768", 260 | "metadata": {}, 261 | "outputs": [ 262 | { 263 | "name": "stdout", 264 | "output_type": "stream", 265 | "text": [ 266 | "tensor([7, 6, 2, 5, 3, 8, 8, 6])\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "from torchvision import datasets, transforms\n", 272 | "from torch.utils.data import DataLoader\n", 273 | "\n", 274 | "\n", 275 | "train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)\n", 276 | "train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)\n", 277 | "\n", 278 | "for inputs, labels in train_loader:\n", 279 | " pass\n", 280 | "print(labels)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 2, 286 | "id": "8def12a3", 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "name": "stdout", 291 | "output_type": "stream", 292 | "text": [ 293 | "tensor([4, 6, 5, 2, 5, 2, 9, 3])\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)\n", 299 | "\n", 300 | "for inputs, labels in train_loader:\n", 301 | " pass\n", 302 | "print(labels)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "id": "18d33c0c", 308 | "metadata": {}, 309 | "source": [ 310 | "## Dataset Loading with Random Seed" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 3, 316 | "id": "30627a3d", 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "tensor([1, 8, 8, 7, 2, 5, 4, 1])\n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "import torch\n", 329 | "\n", 330 | "torch.manual_seed(123)\n", 331 | "train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)\n", 332 | "\n", 333 | "for inputs, labels in train_loader:\n", 334 | " pass\n", 335 | "print(labels)" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 4, 341 | "id": "e346084e", 342 | "metadata": {}, 343 | "outputs": [ 344 | { 345 | "name": "stdout", 346 | "output_type": "stream", 347 | "text": [ 348 | "tensor([1, 8, 8, 7, 2, 5, 4, 1])\n" 349 | ] 350 | } 351 | ], 352 | "source": [ 353 | "torch.manual_seed(123)\n", 354 | "train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)\n", 355 | "\n", 356 | "for inputs, labels in train_loader:\n", 357 | " pass\n", 358 | "print(labels)" 359 | ] 360 | } 361 | ], 362 | "metadata": { 363 | "kernelspec": { 364 | "display_name": "Python 3 (ipykernel)", 365 | "language": "python", 366 | "name": "python3" 367 | }, 368 | "language_info": { 369 | "codemirror_mode": { 370 | "name": "ipython", 371 | "version": 3 372 | }, 373 | "file_extension": ".py", 374 | "mimetype": "text/x-python", 375 | "name": "python", 376 | "nbconvert_exporter": "python", 377 | "pygments_lexer": "ipython3", 378 | "version": "3.10.6" 379 | } 380 | }, 381 | "nbformat": 4, 382 | "nbformat_minor": 5 383 | } 384 | -------------------------------------------------------------------------------- /supplementary/q10-random-sources/dropout.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f08ed48d-0a05-4b3e-bb52-d63600ed6772", 6 | "metadata": {}, 7 | "source": [ 8 | "# Dropout" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "202c9deb-f13f-4ca1-92fc-1922282cf0e8", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "\n", 20 | "\n", 21 | "class MLP(torch.nn.Module):\n", 22 | " def __init__(self, num_features, num_classes):\n", 23 | " super().__init__()\n", 24 | "\n", 25 | " self.all_layers = torch.nn.Sequential(\n", 26 | " torch.nn.Linear(num_features, 10),\n", 27 | " torch.nn.ReLU(),\n", 28 | " torch.nn.Dropout(0.5),\n", 29 | " \n", 30 | " # output layer\n", 31 | " torch.nn.Linear(10, num_classes),\n", 32 | " )\n", 33 | "\n", 34 | " def forward(self, x):\n", 35 | " logits = self.all_layers(x)\n", 36 | " return logits" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "1b3c1d6c-e39c-4b56-9e69-7bd55b8d97ab", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "torch.manual_seed(123)\n", 47 | "\n", 48 | "model = MLP(num_features=5, num_classes=2)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "08f6a192-5212-4d21-9ed5-b83c81c1b609", 54 | "metadata": {}, 55 | "source": [ 56 | "## Dropout during training" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "id": "d4f54a58-dd64-4cba-9375-7d4698292002", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "x = torch.tensor([1., 0.3, 2.4, -1.1, -0.8])" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "id": "f104b59c-8f36-43d3-b3e0-3dcd4858c31a", 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "tensor([-0.1564, -0.2977], grad_fn=)" 79 | ] 80 | }, 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "model(x)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "id": "e8dea15d-e784-4492-97a9-329d96a4dd1d", 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "data": { 98 | "text/plain": [ 99 | "tensor([0.1359, 0.0523], grad_fn=)" 100 | ] 101 | }, 102 | "execution_count": 5, 103 | "metadata": {}, 104 | "output_type": "execute_result" 105 | } 106 | ], 107 | "source": [ 108 | "model(x)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "0fa5b0d3-b63c-4361-9176-3a3189caa540", 114 | "metadata": {}, 115 | "source": [ 116 | "## Disable dropout during inference" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 6, 122 | "id": "22705321-a2b0-4a38-adc9-a3c02a714ad9", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "model.eval();" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 7, 132 | "id": "a289774a-b3f6-457b-9ae6-341e6a8081ff", 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "data": { 137 | "text/plain": [ 138 | "tensor([-0.0458, -0.1777], grad_fn=)" 139 | ] 140 | }, 141 | "execution_count": 7, 142 | "metadata": {}, 143 | "output_type": "execute_result" 144 | } 145 | ], 146 | "source": [ 147 | "model(x)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 8, 153 | "id": "71adeb90-b0d5-491b-98d2-65801d3e650c", 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "tensor([-0.0458, -0.1777], grad_fn=)" 160 | ] 161 | }, 162 | "execution_count": 8, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "model(x)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "id": "84bb2fd0-37a2-4ebb-9984-eacdee0c166c", 174 | "metadata": {}, 175 | "source": [ 176 | "Note: during inference, it's also recommended to use either `torch.no_grad()` or `torch.inference_mode()` context so that gradient tracking is disabled. (Not used above to demonstrate that `.eval()` disables dropout during inference." 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 9, 182 | "id": "636ddddb-e198-4ec5-acee-c52bf4f6a6ea", 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "tensor([-0.0458, -0.1777])\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "with torch.inference_mode():\n", 195 | " print(model(x))" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "5dc8ffa6-6883-4b74-b0b9-32c98d3cf2fd", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 3 (ipykernel)", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.11.4" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 5 228 | } 229 | -------------------------------------------------------------------------------- /supplementary/q10-random-sources/random-weights.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7b759a05-9483-4df8-8787-c01a8ddf8f4b", 6 | "metadata": {}, 7 | "source": [ 8 | "# Random Weight Initialization" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "98b6b513-4095-44e1-9f7d-65a298d01c02", 14 | "metadata": {}, 15 | "source": [ 16 | "### Without Random Seed" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "6134dabb-ecae-47cc-9a51-0a08d3acb520", 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "Parameter containing:\n", 30 | "tensor([[ 0.1458, -0.5813],\n", 31 | " [-0.3033, 0.0979],\n", 32 | " [-0.0911, -0.2116]], requires_grad=True)\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "import torch \n", 38 | "\n", 39 | "\n", 40 | "layer = torch.nn.Linear(2, 3)\n", 41 | "print(layer.weight)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "id": "bd12f614-c46a-4c43-a22c-7e45088dcd6a", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "Parameter containing:\n", 55 | "tensor([[ 0.1760, -0.6285],\n", 56 | " [-0.1398, -0.0959],\n", 57 | " [ 0.3127, 0.4303]], requires_grad=True)\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "layer = torch.nn.Linear(2, 3)\n", 63 | "print(layer.weight)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "id": "a5a897b2-f3b7-48e9-a23a-adb7b016fd83", 69 | "metadata": {}, 70 | "source": [ 71 | "### With Random Seed" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "id": "5c4388d8-f30c-41d8-90ff-1e88547beb87", 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "Parameter containing:\n", 85 | "tensor([[-0.2883, 0.0234],\n", 86 | " [-0.3512, 0.2667],\n", 87 | " [-0.6025, 0.5183]], requires_grad=True)\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "torch.manual_seed(123)\n", 93 | "\n", 94 | "layer = torch.nn.Linear(2, 3)\n", 95 | "print(layer.weight)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "id": "26f89be0-1adf-4c3d-b3c3-a8316e7f50d9", 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "Parameter containing:\n", 109 | "tensor([[-0.2883, 0.0234],\n", 110 | " [-0.3512, 0.2667],\n", 111 | " [-0.6025, 0.5183]], requires_grad=True)\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "torch.manual_seed(123)\n", 117 | "\n", 118 | "layer = torch.nn.Linear(2, 3)\n", 119 | "print(layer.weight)" 120 | ] 121 | } 122 | ], 123 | "metadata": { 124 | "kernelspec": { 125 | "display_name": "Python 3 (ipykernel)", 126 | "language": "python", 127 | "name": "python3" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.11.4" 140 | } 141 | }, 142 | "nbformat": 4, 143 | "nbformat_minor": 5 144 | } 145 | -------------------------------------------------------------------------------- /supplementary/q11-conv-size/q11-conv-size.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c2b92e14-aa49-4b4c-b0c3-af07314eb8a7", 6 | "metadata": {}, 7 | "source": [ 8 | "# Supplementary material for Q11" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "a99e7ac1-9681-45c1-a4ef-715491d22828", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "PyTorch version: 2.0.1\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "import torch\n", 27 | "print(f\"PyTorch version: {torch.__version__}\")" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "94a0ce0c-2b97-4c1b-a1a8-8edaa547c5e1", 33 | "metadata": {}, 34 | "source": [ 35 | "## 1) Convolutional neural network architecture" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 68, 41 | "id": "c0cfee5b-2923-4f52-8e82-7ea0a306bcc9", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "class PyTorchCNN(torch.nn.Module):\n", 46 | " def __init__(self, num_classes):\n", 47 | " super().__init__()\n", 48 | "\n", 49 | " self.num_classes = num_classes\n", 50 | " self.features = torch.nn.Sequential(\n", 51 | " torch.nn.Conv2d(3, 5, kernel_size=5, stride=1), # 5 * (5*5 * 3) + 5\n", 52 | " torch.nn.ReLU(),\n", 53 | " torch.nn.MaxPool2d(kernel_size=5, stride=2),\n", 54 | " torch.nn.Conv2d(5, 12, kernel_size=3, stride=1), # 12 * (3*3 * 5) + 12\n", 55 | " torch.nn.ReLU(),\n", 56 | " torch.nn.AvgPool2d(kernel_size=3, stride=2),\n", 57 | " torch.nn.ReLU(),\n", 58 | " )\n", 59 | "\n", 60 | " self.classifier = torch.nn.Sequential(\n", 61 | " torch.nn.Flatten(),\n", 62 | " torch.nn.Linear(192, 128), # 192 * 128 + 128\n", 63 | " torch.nn.ReLU(),\n", 64 | " torch.nn.Linear(128, num_classes), # 128 * 10 + 10\n", 65 | " )\n", 66 | "\n", 67 | " def forward(self, x):\n", 68 | " x = self.features(x)\n", 69 | " x = self.classifier(x)\n", 70 | " return x" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "811d96c1-2269-4cc5-9b40-d3ddbc872693", 76 | "metadata": {}, 77 | "source": [ 78 | "## 2) Computing the number of parameters" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "id": "8473aea9-c497-46d6-8374-d4c8e044ed31", 84 | "metadata": {}, 85 | "source": [ 86 | "### 2 a) By hand" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 69, 92 | "id": "6e0a7164-846f-4054-9493-ccacc519ee51", 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "932\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "# convolutional part\n", 105 | "conv_part = (5 * (5*5 * 3) + 5 ) + ( 12 * (3*3 * 5) + 12 )\n", 106 | "print(conv_part)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 70, 112 | "id": "313597c8-3295-4373-b0dc-f6273733afc9", 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "25994\n" 120 | ] 121 | } 122 | ], 123 | "source": [ 124 | "# fully connected part\n", 125 | "fc_part = 192*128+128 + 128*10+10\n", 126 | "print(fc_part)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 71, 132 | "id": "a7d77806-33c1-4b0a-bf29-22643ec4938f", 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "26926\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "# total\n", 145 | "print(conv_part + fc_part)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "dd5e430f-1f76-477c-86e7-22372d6dbbdd", 151 | "metadata": {}, 152 | "source": [ 153 | "### 2 b) Adding .parameters() programmatically" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 72, 159 | "id": "d1f1d38c-d19e-41bb-b8b2-c9cfa19a8b42", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "model = PyTorchCNN(10)\n", 164 | "\n", 165 | "def count_parameters(model): \n", 166 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 73, 172 | "id": "9c6f9f5f-422e-4154-9db8-335bab3e5a7a", 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "data": { 177 | "text/plain": [ 178 | "932" 179 | ] 180 | }, 181 | "execution_count": 73, 182 | "metadata": {}, 183 | "output_type": "execute_result" 184 | } 185 | ], 186 | "source": [ 187 | "# convolutional part\n", 188 | "count_parameters(model.features)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 74, 194 | "id": "413ec8c3-e557-4304-9798-d227c0091adc", 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "data": { 199 | "text/plain": [ 200 | "25994" 201 | ] 202 | }, 203 | "execution_count": 74, 204 | "metadata": {}, 205 | "output_type": "execute_result" 206 | } 207 | ], 208 | "source": [ 209 | "# fully connected part\n", 210 | "count_parameters(model.classifier)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 75, 216 | "id": "5c6bbda1-16a5-4e04-b0f0-6cb65e550824", 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "data": { 221 | "text/plain": [ 222 | "26926" 223 | ] 224 | }, 225 | "execution_count": 75, 226 | "metadata": {}, 227 | "output_type": "execute_result" 228 | } 229 | ], 230 | "source": [ 231 | "# total\n", 232 | "count_parameters(model)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "id": "417c1976-4c5b-49bd-954f-b62819c9be64", 238 | "metadata": {}, 239 | "source": [ 240 | "### 2 c) Computer the memory size" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 97, 246 | "id": "ac9ca70a-e239-4ecb-bacf-247ed08e5ce4", 247 | "metadata": {}, 248 | "outputs": [ 249 | { 250 | "name": "stdout", 251 | "output_type": "stream", 252 | "text": [ 253 | " 0.11 Mb\n" 254 | ] 255 | } 256 | ], 257 | "source": [ 258 | "import sys\n", 259 | "\n", 260 | "def calculate_size(model): \n", 261 | " return sum(p.element_size()*p.numel() for p in model.parameters())\n", 262 | "\n", 263 | "size_in_bytes = calculate_size(model)\n", 264 | "size_in_megabytes = size_in_bytes * 1e-6\n", 265 | "\n", 266 | "print(f\"{size_in_megabytes: .2f} Mb\")" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "id": "f4c99a55-0b72-4247-8d1b-d06ea4726b32", 272 | "metadata": {}, 273 | "source": [ 274 | "### 2 d) Using the torchinfo library" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 99, 280 | "id": "2967e0b4-692e-4858-a8e3-3586b1744481", 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "Torchinfo version: 1.7.2\n" 288 | ] 289 | }, 290 | { 291 | "data": { 292 | "text/plain": [ 293 | "==========================================================================================\n", 294 | "Layer (type:depth-idx) Output Shape Param #\n", 295 | "==========================================================================================\n", 296 | "PyTorchCNN [16, 10] --\n", 297 | "├─Sequential: 1-1 [16, 12, 4, 4] --\n", 298 | "│ └─Conv2d: 2-1 [16, 5, 28, 28] 380\n", 299 | "│ └─ReLU: 2-2 [16, 5, 28, 28] --\n", 300 | "│ └─MaxPool2d: 2-3 [16, 5, 12, 12] --\n", 301 | "│ └─Conv2d: 2-4 [16, 12, 10, 10] 552\n", 302 | "│ └─ReLU: 2-5 [16, 12, 10, 10] --\n", 303 | "│ └─AvgPool2d: 2-6 [16, 12, 4, 4] --\n", 304 | "│ └─ReLU: 2-7 [16, 12, 4, 4] --\n", 305 | "├─Sequential: 1-2 [16, 10] --\n", 306 | "│ └─Flatten: 2-8 [16, 192] --\n", 307 | "│ └─Linear: 2-9 [16, 128] 24,704\n", 308 | "│ └─ReLU: 2-10 [16, 128] --\n", 309 | "│ └─Linear: 2-11 [16, 10] 1,290\n", 310 | "==========================================================================================\n", 311 | "Total params: 26,926\n", 312 | "Trainable params: 26,926\n", 313 | "Non-trainable params: 0\n", 314 | "Total mult-adds (M): 6.07\n", 315 | "==========================================================================================\n", 316 | "Input size (MB): 0.20\n", 317 | "Forward/backward pass size (MB): 0.67\n", 318 | "Params size (MB): 0.11\n", 319 | "Estimated Total Size (MB): 0.98\n", 320 | "==========================================================================================" 321 | ] 322 | }, 323 | "execution_count": 99, 324 | "metadata": {}, 325 | "output_type": "execute_result" 326 | } 327 | ], 328 | "source": [ 329 | "# using https://github.com/TylerYep/torchinfo\n", 330 | "# pip install torchinfo\n", 331 | "\n", 332 | "import torchinfo\n", 333 | "\n", 334 | "print(f\"Torchinfo version: {torchinfo.__version__}\")\n", 335 | "\n", 336 | "batch_size = 16\n", 337 | "torchinfo.summary(model, input_size=(batch_size, 3, 32, 32))" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "id": "4cedef9f-1d7a-4edf-854c-75cdb967f55e", 343 | "metadata": {}, 344 | "source": [ 345 | "## 3) ADAM optimizer" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 47, 351 | "id": "0a967f58-2808-4c66-b6dc-1493c50c263c", 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "optimizer = torch.optim.Adam(model.parameters())" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 56, 361 | "id": "0b1af818-0275-488d-a08f-c6d6a5a267ed", 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "data": { 366 | "text/plain": [ 367 | "" 368 | ] 369 | }, 370 | "execution_count": 56, 371 | "metadata": {}, 372 | "output_type": "execute_result" 373 | } 374 | ], 375 | "source": [ 376 | "optimizer.param_groups.count" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": 104, 382 | "id": "538c2a18-89f9-422c-80d5-7afe9abf95c9", 383 | "metadata": {}, 384 | "outputs": [ 385 | { 386 | "data": { 387 | "text/plain": [ 388 | "26926" 389 | ] 390 | }, 391 | "execution_count": 104, 392 | "metadata": {}, 393 | "output_type": "execute_result" 394 | } 395 | ], 396 | "source": [ 397 | "sum(p.numel() for p in optimizer.param_groups[0]['params'])" 398 | ] 399 | }, 400 | { 401 | "cell_type": "markdown", 402 | "id": "c910ddaf-ceb9-4e0c-830b-c69c12a87b61", 403 | "metadata": {}, 404 | "source": [ 405 | "## 4) BatchNorm" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 107, 411 | "id": "8ff1530f-df95-4201-9cb3-b1392752590c", 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "class PyTorchCNN(torch.nn.Module):\n", 416 | " def __init__(self, num_classes):\n", 417 | " super().__init__()\n", 418 | "\n", 419 | " self.num_classes = num_classes\n", 420 | " self.features = torch.nn.Sequential(\n", 421 | " torch.nn.Conv2d(3, 5, kernel_size=5, stride=1),\n", 422 | " torch.nn.BatchNorm2d(5), # NEW!\n", 423 | " torch.nn.ReLU(),\n", 424 | " torch.nn.MaxPool2d(kernel_size=5, stride=2),\n", 425 | " torch.nn.Conv2d(5, 12, kernel_size=3, stride=1),\n", 426 | " torch.nn.BatchNorm2d(12), # NEW!\n", 427 | " torch.nn.ReLU(),\n", 428 | " torch.nn.AvgPool2d(kernel_size=3, stride=2),\n", 429 | " torch.nn.ReLU(),\n", 430 | " )\n", 431 | "\n", 432 | " self.classifier = torch.nn.Sequential(\n", 433 | " torch.nn.Flatten(),\n", 434 | " torch.nn.Linear(192, 128),\n", 435 | " torch.nn.BatchNorm1d(128), # NEW!\n", 436 | " torch.nn.ReLU(),\n", 437 | " torch.nn.Linear(128, num_classes),\n", 438 | " )\n", 439 | "\n", 440 | " def forward(self, x):\n", 441 | " x = self.features(x)\n", 442 | " x = self.classifier(x)\n", 443 | " return x" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 108, 449 | "id": "5561e44b-0d3d-4a19-ac74-f5b033b6bdb1", 450 | "metadata": {}, 451 | "outputs": [ 452 | { 453 | "data": { 454 | "text/plain": [ 455 | "==========================================================================================\n", 456 | "Layer (type:depth-idx) Output Shape Param #\n", 457 | "==========================================================================================\n", 458 | "PyTorchCNN [16, 10] --\n", 459 | "├─Sequential: 1-1 [16, 12, 4, 4] --\n", 460 | "│ └─Conv2d: 2-1 [16, 5, 28, 28] 380\n", 461 | "│ └─BatchNorm2d: 2-2 [16, 5, 28, 28] 10\n", 462 | "│ └─ReLU: 2-3 [16, 5, 28, 28] --\n", 463 | "│ └─MaxPool2d: 2-4 [16, 5, 12, 12] --\n", 464 | "│ └─Conv2d: 2-5 [16, 12, 10, 10] 552\n", 465 | "│ └─BatchNorm2d: 2-6 [16, 12, 10, 10] 24\n", 466 | "│ └─ReLU: 2-7 [16, 12, 10, 10] --\n", 467 | "│ └─AvgPool2d: 2-8 [16, 12, 4, 4] --\n", 468 | "│ └─ReLU: 2-9 [16, 12, 4, 4] --\n", 469 | "├─Sequential: 1-2 [16, 10] --\n", 470 | "│ └─Flatten: 2-10 [16, 192] --\n", 471 | "│ └─Linear: 2-11 [16, 128] 24,704\n", 472 | "│ └─BatchNorm1d: 2-12 [16, 128] 256\n", 473 | "│ └─ReLU: 2-13 [16, 128] --\n", 474 | "│ └─Linear: 2-14 [16, 10] 1,290\n", 475 | "==========================================================================================\n", 476 | "Total params: 27,216\n", 477 | "Trainable params: 27,216\n", 478 | "Non-trainable params: 0\n", 479 | "Total mult-adds (M): 6.07\n", 480 | "==========================================================================================\n", 481 | "Input size (MB): 0.20\n", 482 | "Forward/backward pass size (MB): 1.34\n", 483 | "Params size (MB): 0.11\n", 484 | "Estimated Total Size (MB): 1.65\n", 485 | "==========================================================================================" 486 | ] 487 | }, 488 | "execution_count": 108, 489 | "metadata": {}, 490 | "output_type": "execute_result" 491 | } 492 | ], 493 | "source": [ 494 | "model = PyTorchCNN(10)\n", 495 | "\n", 496 | "torchinfo.summary(model, input_size=(batch_size, 3, 32, 32))" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": null, 502 | "id": "a2194290-c90b-4716-a515-e1ce3aa2dfb8", 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [] 506 | } 507 | ], 508 | "metadata": { 509 | "kernelspec": { 510 | "display_name": "Python 3 (ipykernel)", 511 | "language": "python", 512 | "name": "python3" 513 | }, 514 | "language_info": { 515 | "codemirror_mode": { 516 | "name": "ipython", 517 | "version": 3 518 | }, 519 | "file_extension": ".py", 520 | "mimetype": "text/x-python", 521 | "name": "python", 522 | "nbconvert_exporter": "python", 523 | "pygments_lexer": "ipython3", 524 | "version": "3.11.4" 525 | } 526 | }, 527 | "nbformat": 4, 528 | "nbformat_minor": 5 529 | } 530 | -------------------------------------------------------------------------------- /supplementary/q12-fc-cnn-equivalence/img/fc-cnn-equivalent-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/MachineLearning-QandAI-book/9c9994d3b0c320e428a441c2234d4043fa95110d/supplementary/q12-fc-cnn-equivalence/img/fc-cnn-equivalent-1.png -------------------------------------------------------------------------------- /supplementary/q12-fc-cnn-equivalence/img/fc-cnn-equivalent-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/MachineLearning-QandAI-book/9c9994d3b0c320e428a441c2234d4043fa95110d/supplementary/q12-fc-cnn-equivalence/img/fc-cnn-equivalent-2.png -------------------------------------------------------------------------------- /supplementary/q12-fc-cnn-equivalence/img/fc-cnn-equivalent-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/MachineLearning-QandAI-book/9c9994d3b0c320e428a441c2234d4043fa95110d/supplementary/q12-fc-cnn-equivalence/img/fc-cnn-equivalent-3.png -------------------------------------------------------------------------------- /supplementary/q12-fc-cnn-equivalence/q12-fc-cnn-equivalence.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c2b92e14-aa49-4b4c-b0c3-af07314eb8a7", 6 | "metadata": {}, 7 | "source": [ 8 | "# Supplementary material for Q12" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "b9622508-9b82-4d12-9918-aefc068af4a0", 14 | "metadata": {}, 15 | "source": [ 16 | "Under which circumstances are fully connected and convolutional layers equivalent?" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "a99e7ac1-9681-45c1-a4ef-715491d22828", 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "PyTorch version: 2.0.0\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "import torch\n", 35 | "print(f\"PyTorch version: {torch.__version__}\")" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "846a864a-76d8-4d64-a429-960bd0feb29a", 41 | "metadata": {}, 42 | "source": [ 43 | "## 1) Reference fully connected layer" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "f3eb52f1-0e48-4423-a8dc-a0b9c380cb6d", 49 | "metadata": {}, 50 | "source": [ 51 | "" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 2, 57 | "id": "036c0cdd-dfa4-49cf-9572-a6a727e4170f", 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "tensor([[-0.4775, -2.1469]])\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "torch.manual_seed(123)\n", 70 | "\n", 71 | "fc = torch.nn.Linear(4, 2)\n", 72 | "\n", 73 | "inputs = torch.tensor([[1., 2., 3., 4.]])\n", 74 | "\n", 75 | "with torch.no_grad():\n", 76 | " out1 = fc(inputs)\n", 77 | " \n", 78 | "print(out1)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "id": "59e6fb1a-86cc-467c-a401-52ac19ae5744", 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "Parameter containing:\n", 91 | "tensor([[-0.2039, 0.0166, -0.2483, 0.1886],\n", 92 | " [-0.4260, 0.3665, -0.3634, -0.3975]], requires_grad=True)" 93 | ] 94 | }, 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "fc.weight" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "id": "db888c2a-3f28-4e22-91ed-1f92d25a26d3", 107 | "metadata": {}, 108 | "source": [ 109 | "## 2) Scenario 1: The kernel size is equal to the input size" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "342b9ac2-56b2-4b8b-b7be-d6272b9abd52", 115 | "metadata": {}, 116 | "source": [ 117 | "" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "0f9bf351-c5d1-4b99-98ad-9d427e8205c0", 123 | "metadata": {}, 124 | "source": [ 125 | "Convolutional layers in PyTorch expect inputs on NCHW format by default, where\n", 126 | "\n", 127 | "- N = batch size\n", 128 | "- C = channels\n", 129 | "- H = height\n", 130 | "- W = width" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 4, 136 | "id": "ba3dbf8d-76dc-4a1a-8f5e-f3e09db14e62", 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "data": { 141 | "text/plain": [ 142 | "tensor([[[[1., 2.],\n", 143 | " [3., 4.]]]])" 144 | ] 145 | }, 146 | "execution_count": 4, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "reshaped = inputs.reshape(-1, 1, 2, 2)\n", 153 | "reshaped" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 5, 159 | "id": "12fa563f-c33e-4463-8bc3-012a0178dcaa", 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "torch.Size([2, 1, 2, 2])" 166 | ] 167 | }, 168 | "execution_count": 5, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | } 172 | ], 173 | "source": [ 174 | "conv = torch.nn.Conv2d(\n", 175 | " in_channels=1,\n", 176 | " out_channels=2,\n", 177 | " kernel_size=2\n", 178 | ")\n", 179 | "\n", 180 | "conv.weight.shape" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "id": "1da7fe7e-3fb3-4211-9c50-3197901cb559", 186 | "metadata": {}, 187 | "source": [ 188 | "Note that weights in Conv2d are also initialized randomly, so to get the exact same results, we overwrite the random weights in the convolutional layer with those in the fully connected layer." 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 6, 194 | "id": "fd107812-7a0f-49af-84c3-2703af1d9dd5", 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "tensor([[[[-0.4775]],\n", 202 | "\n", 203 | " [[-2.1469]]]])\n" 204 | ] 205 | } 206 | ], 207 | "source": [ 208 | "with torch.no_grad():\n", 209 | " conv.weight[0][0] = fc.weight[0].reshape(1, 2, 2)\n", 210 | " conv.weight[1][0] = fc.weight[1].reshape(1, 2, 2)\n", 211 | " conv.bias[0] = fc.bias[0]\n", 212 | " conv.bias[1] = fc.bias[1]\n", 213 | " \n", 214 | " out2 = conv(reshaped)\n", 215 | " \n", 216 | "print(out2)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 7, 222 | "id": "2d55e289-a616-4128-9ec2-bd3b40616dd3", 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "data": { 227 | "text/plain": [ 228 | "tensor([True, True])" 229 | ] 230 | }, 231 | "execution_count": 7, 232 | "metadata": {}, 233 | "output_type": "execute_result" 234 | } 235 | ], 236 | "source": [ 237 | "out1.flatten() == out2.flatten()" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "id": "6c830a5a-197e-4940-8de9-a6b1dae98b17", 243 | "metadata": {}, 244 | "source": [ 245 | "## 3) Scenario 2: The kernel has size one" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "id": "e8aaebca-4369-474c-9244-0d6a18b42f27", 251 | "metadata": {}, 252 | "source": [ 253 | "" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 8, 259 | "id": "de63bca7-17a2-489f-958f-41e96cfa21ed", 260 | "metadata": {}, 261 | "outputs": [ 262 | { 263 | "data": { 264 | "text/plain": [ 265 | "tensor([[[[1.]],\n", 266 | "\n", 267 | " [[2.]],\n", 268 | "\n", 269 | " [[3.]],\n", 270 | "\n", 271 | " [[4.]]]])" 272 | ] 273 | }, 274 | "execution_count": 8, 275 | "metadata": {}, 276 | "output_type": "execute_result" 277 | } 278 | ], 279 | "source": [ 280 | "reshaped2 = inputs.reshape(-1, 4, 1, 1)\n", 281 | "reshaped2" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 9, 287 | "id": "50f69097-94c3-4d85-835d-1ae5e369c67a", 288 | "metadata": {}, 289 | "outputs": [ 290 | { 291 | "data": { 292 | "text/plain": [ 293 | "torch.Size([2, 4, 1, 1])" 294 | ] 295 | }, 296 | "execution_count": 9, 297 | "metadata": {}, 298 | "output_type": "execute_result" 299 | } 300 | ], 301 | "source": [ 302 | "conv = torch.nn.Conv2d(\n", 303 | " in_channels=4,\n", 304 | " out_channels=2,\n", 305 | " kernel_size=1\n", 306 | ")\n", 307 | "\n", 308 | "conv.weight.shape" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 10, 314 | "id": "b7e805d1-73de-4ecf-97df-5e7c2ba410e5", 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "tensor([[[[-0.4775]],\n", 322 | "\n", 323 | " [[-2.1469]]]])\n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "with torch.no_grad():\n", 329 | " conv.weight[0] = fc.weight[0].reshape(4, 1, 1)\n", 330 | " conv.weight[1] = fc.weight[1].reshape(4, 1, 1)\n", 331 | " conv.bias[0] = fc.bias[0]\n", 332 | " conv.bias[1] = fc.bias[1]\n", 333 | " \n", 334 | " out3 = conv(reshaped2)\n", 335 | " \n", 336 | "print(out3)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 11, 342 | "id": "721f88fd-b174-4801-8ed3-be2b85a2c41f", 343 | "metadata": {}, 344 | "outputs": [ 345 | { 346 | "data": { 347 | "text/plain": [ 348 | "tensor([True, True])" 349 | ] 350 | }, 351 | "execution_count": 11, 352 | "metadata": {}, 353 | "output_type": "execute_result" 354 | } 355 | ], 356 | "source": [ 357 | "out1.flatten() == out3.flatten()" 358 | ] 359 | } 360 | ], 361 | "metadata": { 362 | "kernelspec": { 363 | "display_name": "Python 3 (ipykernel)", 364 | "language": "python", 365 | "name": "python3" 366 | }, 367 | "language_info": { 368 | "codemirror_mode": { 369 | "name": "ipython", 370 | "version": 3 371 | }, 372 | "file_extension": ".py", 373 | "mimetype": "text/x-python", 374 | "name": "python", 375 | "nbconvert_exporter": "python", 376 | "pygments_lexer": "ipython3", 377 | "version": "3.11.4" 378 | } 379 | }, 380 | "nbformat": 4, 381 | "nbformat_minor": 5 382 | } 383 | -------------------------------------------------------------------------------- /supplementary/q15-text-augment/backtranslation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d28595f6-33ab-4d4d-8373-b4db36681366", 6 | "metadata": {}, 7 | "source": [ 8 | "## Backtranslation for Data Augmentation" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "0bef5a6b-bf9e-4b41-aace-d0d98ac42ce4", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Author: Sebastian Raschka\n", 22 | "\n", 23 | "Python implementation: CPython\n", 24 | "Python version : 3.10.6\n", 25 | "IPython version : 8.12.0\n", 26 | "\n", 27 | "transformers: 4.27.2\n", 28 | "\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "%load_ext watermark\n", 34 | "%watermark -a 'Sebastian Raschka' -v -p transformers" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 7, 40 | "id": "01a56e71-a68d-4c99-ba42-d32b6f626621", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "from transformers import MarianMTModel, MarianTokenizer\n", 45 | "\n", 46 | "def back_translate(text):\n", 47 | " # English to German\n", 48 | " en_to_de_model_name = \"Helsinki-NLP/opus-mt-en-de\"\n", 49 | " en_to_de_tokenizer = MarianTokenizer.from_pretrained(en_to_de_model_name)\n", 50 | " en_to_de_model = MarianMTModel.from_pretrained(en_to_de_model_name)\n", 51 | " \n", 52 | " inputs = en_to_de_tokenizer([text], return_tensors=\"pt\")\n", 53 | " translated_german_tokens = en_to_de_model.generate(**inputs)\n", 54 | " translated_german_text = en_to_de_tokenizer.decode(translated_german_tokens[0], skip_special_tokens=True)\n", 55 | " \n", 56 | " # German to English\n", 57 | " de_to_en_model_name = 'Helsinki-NLP/opus-mt-de-en'\n", 58 | " de_to_en_tokenizer = MarianTokenizer.from_pretrained(de_to_en_model_name)\n", 59 | " de_to_en_model = MarianMTModel.from_pretrained(de_to_en_model_name)\n", 60 | "\n", 61 | " inputs_back = de_to_en_tokenizer([translated_german_text], return_tensors=\"pt\")\n", 62 | " translated_english_tokens = de_to_en_model.generate(**inputs_back)\n", 63 | " translated_back_english_text = de_to_en_tokenizer.decode(translated_english_tokens[0], skip_special_tokens=True)\n", 64 | "\n", 65 | " return translated_german_text, translated_back_english_text" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 8, 71 | "id": "d1807b7e-88da-41d4-9d75-27bbfc14278d", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "Original text:\n", 79 | "Despite the intermittent rain showers, Amelia decided to venture outside with her new umbrella, hoping to enjoy the fresh air and perhaps bump into some old friends at the local café down the street.\n", 80 | "--------------------------\n", 81 | "Translated text:\n", 82 | "Trotz der periodischen Regenschauer entschied sich Amelia, sich mit ihrem neuen Regenschirm nach draußen zu wagen, in der Hoffnung, die frische Luft zu genießen und vielleicht einige alte Freunde im örtlichen Café auf der Straße zu treffen.\n", 83 | "--------------------------\n", 84 | "Backtranslated text:\n", 85 | "Despite the periodic rain showers, Amelia decided to venture outside with her new umbrella, hoping to enjoy the fresh air and perhaps meet some old friends in the local café on the street.\n", 86 | "--------------------------\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "text = (\"Despite the intermittent rain showers, \"\n", 92 | " \"Amelia decided to venture outside with \"\n", 93 | " \"her new umbrella, hoping to enjoy the fresh \"\n", 94 | " \"air and perhaps bump into some old friends \"\n", 95 | " \"at the local café down the street.\"\n", 96 | " )\n", 97 | "\n", 98 | "translated_text, back_translated_text = back_translate(text)\n", 99 | "\n", 100 | "print(\"Original text:\")\n", 101 | "print(text)\n", 102 | "print(\"--------------------------\")\n", 103 | "\n", 104 | "print(\"Translated text:\")\n", 105 | "print(translated_text)\n", 106 | "print(\"--------------------------\")\n", 107 | " \n", 108 | "print(\"Backtranslated text:\")\n", 109 | "print(back_translated_text)\n", 110 | "print(\"--------------------------\")" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 11, 116 | "id": "370ccc24-102a-4df5-a1b3-04448e2e6669", 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | " Despite\n", 124 | " the\n", 125 | "- intermittent\n", 126 | "+ periodic\n", 127 | " rain\n", 128 | " showers,\n", 129 | " Amelia\n", 130 | " decided\n", 131 | " to\n", 132 | " venture\n", 133 | " outside\n", 134 | " with\n", 135 | " her\n", 136 | " new\n", 137 | " umbrella,\n", 138 | " hoping\n", 139 | " to\n", 140 | " enjoy\n", 141 | " the\n", 142 | " fresh\n", 143 | " air\n", 144 | " and\n", 145 | " perhaps\n", 146 | "+ meet\n", 147 | "- bump\n", 148 | "- into\n", 149 | " some\n", 150 | " old\n", 151 | " friends\n", 152 | "- at\n", 153 | "+ in\n", 154 | " the\n", 155 | " local\n", 156 | " café\n", 157 | "- down\n", 158 | "+ on\n", 159 | " the\n", 160 | " street.\n" 161 | ] 162 | } 163 | ], 164 | "source": [ 165 | "import difflib\n", 166 | "\n", 167 | "\n", 168 | "d = difflib.Differ()\n", 169 | "diff = d.compare(text.split(), \n", 170 | " back_translated_text.split())\n", 171 | "\n", 172 | "print('\\n'.join(diff))" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "59e92098-cd36-4878-b371-c3296b467bf0", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "Python 3 (ipykernel)", 187 | "language": "python", 188 | "name": "python3" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.10.6" 201 | } 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 5 205 | } 206 | -------------------------------------------------------------------------------- /supplementary/q15-text-augment/noise-injection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b050edae-48cf-47ba-b2ae-bf9fc9098bb4", 6 | "metadata": {}, 7 | "source": [ 8 | "## Noise Injection for Text Augmentation" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "0ab61d1f-466f-4469-99da-b172a959c2cb", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Author: Sebastian Raschka\n", 22 | "\n", 23 | "Python implementation: CPython\n", 24 | "Python version : 3.10.6\n", 25 | "IPython version : 8.12.0\n", 26 | "\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "%load_ext watermark\n", 32 | "%watermark -a 'Sebastian Raschka' -v" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "47edf286-4825-4274-94b4-d10423767646", 38 | "metadata": {}, 39 | "source": [ 40 | "### Random Character Insertion" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 5, 46 | "id": "cc137461-a113-4e62-8bf0-37ea2f5ebdb0", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "import random\n", 51 | "import string\n", 52 | "\n", 53 | "\n", 54 | "def random_character_insertion(text, insertion_rate=0.1):\n", 55 | " num_insertions = int(len(text) * insertion_rate)\n", 56 | " \n", 57 | " for _ in range(num_insertions):\n", 58 | " position = random.randint(0, len(text))\n", 59 | " character = random.choice(string.ascii_letters)\n", 60 | " text = text[:position] + character + text[position:]\n", 61 | "\n", 62 | " return text" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 9, 68 | "id": "19443702-7195-4535-937b-d21ccf301bdd", 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "Random Character Insertion: The Kcat jumped over the doZg.\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "random.seed(1)\n", 81 | "\n", 82 | "\n", 83 | "text = \"The cat jumped over the dog.\"\n", 84 | "augmented_text = random_character_insertion(text)\n", 85 | "print(\"Random Character Insertion:\", augmented_text)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 17, 91 | "id": "925d11d1-b3dc-4db5-b2dc-0a0228a13384", 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | " T\n", 99 | " h\n", 100 | " e\n", 101 | " \n", 102 | "+ K\n", 103 | " c\n", 104 | " a\n", 105 | " t\n", 106 | " \n", 107 | " j\n", 108 | " u\n", 109 | " m\n", 110 | " p\n", 111 | " e\n", 112 | " d\n", 113 | " \n", 114 | " o\n", 115 | " v\n", 116 | " e\n", 117 | " r\n", 118 | " \n", 119 | " t\n", 120 | " h\n", 121 | " e\n", 122 | " \n", 123 | " d\n", 124 | " o\n", 125 | "+ Z\n", 126 | " g\n", 127 | " .\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "import difflib\n", 133 | "\n", 134 | "\n", 135 | "d = difflib.Differ()\n", 136 | "diff = d.compare(text, \n", 137 | " augmented_text)\n", 138 | "\n", 139 | "print('\\n'.join(diff))" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "040eabc9-2583-4547-bafa-b5883382486b", 145 | "metadata": {}, 146 | "source": [ 147 | "### Random Character Deletion" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 18, 153 | "id": "4f927974-c05c-4a70-b813-6423503b3711", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "import random\n", 158 | "\n", 159 | "def random_character_deletion(text, deletion_rate=0.1):\n", 160 | "\n", 161 | " num_deletions = int(len(text) * deletion_rate)\n", 162 | " \n", 163 | " for _ in range(num_deletions):\n", 164 | " if len(text) == 0:\n", 165 | " break\n", 166 | " position = random.randint(0, len(text) - 1)\n", 167 | " text = text[:position] + text[position + 1:]\n", 168 | "\n", 169 | " return text" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 19, 175 | "id": "44409283-de8c-4a93-b27c-205c929823f6", 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "Random Character Insertion: The at jumped overthe dog.\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "random.seed(1)\n", 188 | "\n", 189 | "\n", 190 | "text = \"The cat jumped over the dog.\"\n", 191 | "augmented_text = random_character_deletion(text)\n", 192 | "print(\"Random Character Insertion:\", augmented_text)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 20, 198 | "id": "a635453c-c1c8-4494-9c1a-d40fbaebb9fd", 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | " T\n", 206 | " h\n", 207 | " e\n", 208 | " \n", 209 | "- c\n", 210 | " a\n", 211 | " t\n", 212 | " \n", 213 | " j\n", 214 | " u\n", 215 | " m\n", 216 | " p\n", 217 | " e\n", 218 | " d\n", 219 | " \n", 220 | " o\n", 221 | " v\n", 222 | " e\n", 223 | " r\n", 224 | "- \n", 225 | " t\n", 226 | " h\n", 227 | " e\n", 228 | " \n", 229 | " d\n", 230 | " o\n", 231 | " g\n", 232 | " .\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "import difflib\n", 238 | "\n", 239 | "\n", 240 | "d = difflib.Differ()\n", 241 | "diff = d.compare(text, \n", 242 | " augmented_text)\n", 243 | "\n", 244 | "print('\\n'.join(diff))" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "id": "8c82f867-7537-462a-a61a-f9063328d92d", 250 | "metadata": {}, 251 | "source": [ 252 | "### Typo Introduction" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 21, 258 | "id": "1be9b70a-e07a-4ce1-9e15-5a5457c25f3e", 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "import random\n", 263 | "\n", 264 | "def typo_introduction(text, introduction_rate=0.1):\n", 265 | " num_typos = int(len(text) * introduction_rate)\n", 266 | " \n", 267 | " for _ in range(num_typos):\n", 268 | " # Ensure there are at least two characters to swap\n", 269 | " if len(text) < 2:\n", 270 | " break\n", 271 | " position = random.randint(0, len(text) - 2)\n", 272 | " text = text[:position] + text[position + 1] + text[position] + text[position + 2:]\n", 273 | "\n", 274 | " return text\n" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 22, 280 | "id": "b3b7d05b-348c-42d7-9e34-3aafdbcef894", 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "name": "stdout", 285 | "output_type": "stream", 286 | "text": [ 287 | "Random Character Insertion: The act jumped ove rthe dog.\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "random.seed(1)\n", 293 | "\n", 294 | "\n", 295 | "text = \"The cat jumped over the dog.\"\n", 296 | "augmented_text = typo_introduction(text)\n", 297 | "print(\"Random Character Insertion:\", augmented_text)" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 23, 303 | "id": "aaf83ea4-89bc-456b-b834-88bc1a3845cb", 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "name": "stdout", 308 | "output_type": "stream", 309 | "text": [ 310 | " T\n", 311 | " h\n", 312 | " e\n", 313 | " \n", 314 | "+ a\n", 315 | " c\n", 316 | "- a\n", 317 | " t\n", 318 | " \n", 319 | " j\n", 320 | " u\n", 321 | " m\n", 322 | " p\n", 323 | " e\n", 324 | " d\n", 325 | " \n", 326 | " o\n", 327 | " v\n", 328 | " e\n", 329 | "+ \n", 330 | " r\n", 331 | "- \n", 332 | " t\n", 333 | " h\n", 334 | " e\n", 335 | " \n", 336 | " d\n", 337 | " o\n", 338 | " g\n", 339 | " .\n" 340 | ] 341 | } 342 | ], 343 | "source": [ 344 | "import difflib\n", 345 | "\n", 346 | "\n", 347 | "d = difflib.Differ()\n", 348 | "diff = d.compare(text, \n", 349 | " augmented_text)\n", 350 | "\n", 351 | "print('\\n'.join(diff))" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "id": "01410981-b320-40ff-bebb-b77b416045fd", 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [] 361 | } 362 | ], 363 | "metadata": { 364 | "kernelspec": { 365 | "display_name": "Python 3 (ipykernel)", 366 | "language": "python", 367 | "name": "python3" 368 | }, 369 | "language_info": { 370 | "codemirror_mode": { 371 | "name": "ipython", 372 | "version": 3 373 | }, 374 | "file_extension": ".py", 375 | "mimetype": "text/x-python", 376 | "name": "python", 377 | "nbconvert_exporter": "python", 378 | "pygments_lexer": "ipython3", 379 | "version": "3.10.6" 380 | } 381 | }, 382 | "nbformat": 4, 383 | "nbformat_minor": 5 384 | } 385 | -------------------------------------------------------------------------------- /supplementary/q15-text-augment/sentence-order-shuffling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b050edae-48cf-47ba-b2ae-bf9fc9098bb4", 6 | "metadata": {}, 7 | "source": [ 8 | "## Sentence Order Shuffling for Text Augmentation" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "0ab61d1f-466f-4469-99da-b172a959c2cb", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Author: Sebastian Raschka\n", 22 | "\n", 23 | "Python implementation: CPython\n", 24 | "Python version : 3.10.6\n", 25 | "IPython version : 8.12.0\n", 26 | "\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "%load_ext watermark\n", 32 | "%watermark -a 'Sebastian Raschka' -v" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 12, 38 | "id": "1f2926e6-aeac-4a9b-b762-f8a4a891d5b9", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import random\n", 43 | "\n", 44 | "\n", 45 | "def word_position_swapping(sentence, swapping_rate=0.1):\n", 46 | " words = sentence.split()\n", 47 | " num_swaps = int(len(words) * swapping_rate)\n", 48 | "\n", 49 | " for _ in range(num_swaps):\n", 50 | " # Select two random indices to swap\n", 51 | " index1, index2 = random.sample(range(len(words)), 2)\n", 52 | "\n", 53 | " # Swap the words at the selected indices\n", 54 | " words[index1], words[index2] = words[index2], words[index1]\n", 55 | "\n", 56 | " return \" \".join(words)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "id": "4f3bfef8-58ca-45c8-9762-9be6b23b06ef", 62 | "metadata": {}, 63 | "source": [ 64 | "**Random deletion with a 20% swapping rate**" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 14, 70 | "id": "1820de77-bd90-4247-a7d3-1ff7dc429ff4", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "The brown quick fox jumped over the lazy dog.\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "random.seed(1)\n", 83 | "sentence = \"The quick brown fox jumped over the lazy dog.\"\n", 84 | "augmented_sentence = word_position_swapping(sentence, swapping_rate=0.2)\n", 85 | "print(augmented_sentence)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "47edf286-4825-4274-94b4-d10423767646", 91 | "metadata": {}, 92 | "source": [ 93 | "**Show difference before and after**" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 15, 99 | "id": "823a4b6b-5c33-4fd5-ade7-3b718d902ad4", 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | " The\n", 107 | "+ brown\n", 108 | " quick\n", 109 | "- brown\n", 110 | " fox\n", 111 | " jumped\n", 112 | " over\n", 113 | " the\n", 114 | " lazy\n", 115 | " dog.\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "import difflib\n", 121 | "\n", 122 | "\n", 123 | "d = difflib.Differ()\n", 124 | "diff = d.compare(sentence.split(), augmented_sentence.split())\n", 125 | "\n", 126 | "print('\\n'.join(diff))" 127 | ] 128 | } 129 | ], 130 | "metadata": { 131 | "kernelspec": { 132 | "display_name": "Python 3 (ipykernel)", 133 | "language": "python", 134 | "name": "python3" 135 | }, 136 | "language_info": { 137 | "codemirror_mode": { 138 | "name": "ipython", 139 | "version": 3 140 | }, 141 | "file_extension": ".py", 142 | "mimetype": "text/x-python", 143 | "name": "python", 144 | "nbconvert_exporter": "python", 145 | "pygments_lexer": "ipython3", 146 | "version": "3.10.6" 147 | } 148 | }, 149 | "nbformat": 4, 150 | "nbformat_minor": 5 151 | } 152 | -------------------------------------------------------------------------------- /supplementary/q15-text-augment/synonym-replacement.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b050edae-48cf-47ba-b2ae-bf9fc9098bb4", 6 | "metadata": {}, 7 | "source": [ 8 | "## Synonym Replacement for Text Augmentation" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "0ab61d1f-466f-4469-99da-b172a959c2cb", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Author: Sebastian Raschka\n", 22 | "\n", 23 | "Python implementation: CPython\n", 24 | "Python version : 3.10.6\n", 25 | "IPython version : 8.12.0\n", 26 | "\n", 27 | "nltk: 3.8.1\n", 28 | "\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "%load_ext watermark\n", 34 | "%watermark -a 'Sebastian Raschka' -v -p nltk" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "id": "2e4539cb-2c81-47e2-8be9-03206f750771", 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stderr", 45 | "output_type": "stream", 46 | "text": [ 47 | "[nltk_data] Downloading package wordnet to\n", 48 | "[nltk_data] /Users/sebastian/nltk_data...\n", 49 | "[nltk_data] Package wordnet is already up-to-date!\n" 50 | ] 51 | }, 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "True" 56 | ] 57 | }, 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | } 62 | ], 63 | "source": [ 64 | "import nltk\n", 65 | "\n", 66 | "nltk.download('wordnet')" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "id": "81874f62-821b-40b3-94a0-f22e45d82352", 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stderr", 77 | "output_type": "stream", 78 | "text": [ 79 | "[nltk_data] Downloading package wordnet to\n", 80 | "[nltk_data] /Users/sebastian/nltk_data...\n", 81 | "[nltk_data] Package wordnet is already up-to-date!\n" 82 | ] 83 | }, 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "['quickly',\n", 88 | " 'rapidly',\n", 89 | " 'speedily',\n", 90 | " 'chop-chop',\n", 91 | " 'apace',\n", 92 | " 'promptly',\n", 93 | " 'quickly',\n", 94 | " 'quick',\n", 95 | " 'cursorily',\n", 96 | " 'quickly']" 97 | ] 98 | }, 99 | "execution_count": 3, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "from nltk.corpus import wordnet\n", 106 | "\n", 107 | "nltk.download('wordnet')\n", 108 | "\n", 109 | "def get_synonyms(word):\n", 110 | " synonyms = []\n", 111 | " for syn in wordnet.synsets(word):\n", 112 | " for lemma in syn.lemmas():\n", 113 | " synonyms.append(lemma.name())\n", 114 | " return synonyms\n", 115 | "\n", 116 | "\n", 117 | "get_synonyms(\"quickly\")" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "id": "6987216a-2485-4494-8d0f-45f5feffe5dd", 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "name": "stderr", 128 | "output_type": "stream", 129 | "text": [ 130 | "[nltk_data] Downloading package averaged_perceptron_tagger to\n", 131 | "[nltk_data] /Users/sebastian/nltk_data...\n", 132 | "[nltk_data] Package averaged_perceptron_tagger is already up-to-\n", 133 | "[nltk_data] date!\n" 134 | ] 135 | }, 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "True" 140 | ] 141 | }, 142 | "execution_count": 4, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | } 146 | ], 147 | "source": [ 148 | "# for part of speech tagging\n", 149 | "nltk.download('averaged_perceptron_tagger')" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 5, 155 | "id": "85d75f89-11dd-4248-b005-4c08981aa79b", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "data": { 160 | "text/plain": [ 161 | "[('The', 'DT'),\n", 162 | " ('cat', 'NN'),\n", 163 | " ('quickly', 'RB'),\n", 164 | " ('jumped', 'VBD'),\n", 165 | " ('over', 'IN'),\n", 166 | " ('the', 'DT'),\n", 167 | " ('lazy', 'JJ'),\n", 168 | " ('dog', 'NN'),\n", 169 | " ('.', '.')]" 170 | ] 171 | }, 172 | "execution_count": 5, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "def get_position_tags(text):\n", 179 | " words = nltk.word_tokenize(text)\n", 180 | " pos_tags = nltk.pos_tag(words)\n", 181 | " return pos_tags\n", 182 | " \n", 183 | "get_position_tags(\"The cat quickly jumped over the lazy dog.\")" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 6, 189 | "id": "c4fb3228-320a-4722-bbb5-d9c8098c10b9", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "import nltk\n", 194 | "from nltk.corpus import wordnet\n", 195 | "import random\n", 196 | "\n", 197 | "random.seed(123)\n", 198 | "\n", 199 | "\n", 200 | "def synonym_replacement(text, num_replacement=2):\n", 201 | "\n", 202 | " words = nltk.word_tokenize(text)\n", 203 | " \n", 204 | " # tag nounds, adjectives, etc.\n", 205 | " pos_tags = nltk.pos_tag(words)\n", 206 | " \n", 207 | " # Only replace adverbs (RB) and adjectives (JJ) for simplicity here\n", 208 | " candidates = [word for word, pos in pos_tags if pos in ['RB', 'JJ']]\n", 209 | "\n", 210 | " if len(candidates) < num_replacement:\n", 211 | " return words\n", 212 | " \n", 213 | " # Randomly choose the words to be replaced\n", 214 | " words_to_replace = random.sample(candidates, num_replacement)\n", 215 | " \n", 216 | " \n", 217 | " # For each word to replace, we get its synonyms and choose one randomly\n", 218 | " for word in words_to_replace:\n", 219 | " synonyms = get_synonyms(word)\n", 220 | " if synonyms:\n", 221 | " synonym = random.choice(synonyms)\n", 222 | " text = text.replace(word, synonym, 1)\n", 223 | " \n", 224 | " return text" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 7, 230 | "id": "19267b15-9267-443b-a81d-fc0241eb737e", 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "\n", 238 | "The cat rapidly jumped over the work-shy dog.\n" 239 | ] 240 | } 241 | ], 242 | "source": [ 243 | "text = \"\"\"\n", 244 | "The cat quickly jumped over the lazy dog.\n", 245 | "\"\"\"\n", 246 | "\n", 247 | "sentences = nltk.sent_tokenize(text)\n", 248 | "augmented_sentences = [synonym_replacement(sentence) for sentence in sentences]\n", 249 | "augmented_paragraph = ' '.join(augmented_sentences)\n", 250 | "\n", 251 | "print(augmented_paragraph)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "id": "3e107faa-8b4c-43c2-82eb-e2d9d82eab05", 257 | "metadata": {}, 258 | "source": [ 259 | "**Compare original with augmented text**" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 8, 265 | "id": "a8773c0b-2f11-42bc-b03e-ed850b1ae9bf", 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "name": "stdout", 270 | "output_type": "stream", 271 | "text": [ 272 | " The\n", 273 | " cat\n", 274 | "- quickly\n", 275 | "+ rapidly\n", 276 | " jumped\n", 277 | " over\n", 278 | " the\n", 279 | "- lazy\n", 280 | "+ work-shy\n", 281 | " dog.\n" 282 | ] 283 | } 284 | ], 285 | "source": [ 286 | "import difflib\n", 287 | "\n", 288 | "\n", 289 | "d = difflib.Differ()\n", 290 | "diff = d.compare(text.split(), augmented_paragraph.split())\n", 291 | "\n", 292 | "print('\\n'.join(diff))" 293 | ] 294 | } 295 | ], 296 | "metadata": { 297 | "kernelspec": { 298 | "display_name": "Python 3 (ipykernel)", 299 | "language": "python", 300 | "name": "python3" 301 | }, 302 | "language_info": { 303 | "codemirror_mode": { 304 | "name": "ipython", 305 | "version": 3 306 | }, 307 | "file_extension": ".py", 308 | "mimetype": "text/x-python", 309 | "name": "python", 310 | "nbconvert_exporter": "python", 311 | "pygments_lexer": "ipython3", 312 | "version": "3.10.6" 313 | } 314 | }, 315 | "nbformat": 4, 316 | "nbformat_minor": 5 317 | } 318 | -------------------------------------------------------------------------------- /supplementary/q15-text-augment/synthetic-data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "31ae3a94-7cae-40c5-b494-68ea023bf681", 6 | "metadata": {}, 7 | "source": [ 8 | "## Synthetic Data for Data Augmentation Using A Decoder-Style LLM" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "51461d61-1d5c-4b26-ad8f-053b46b2a545", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Author: Sebastian Raschka\n", 22 | "\n", 23 | "Python implementation: CPython\n", 24 | "Python version : 3.10.6\n", 25 | "IPython version : 8.12.0\n", 26 | "\n", 27 | "transformers: 4.27.2\n", 28 | "\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "%load_ext watermark\n", 34 | "%watermark -a 'Sebastian Raschka' -v -p transformers" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "id": "bd8f113d-56c5-492b-ac07-2105169c7f15", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n", 45 | "\n", 46 | "\n", 47 | "def generate_synthetic_text(prompt, num_samples=1):\n", 48 | " model_name = \"gpt2\"\n", 49 | " model = GPT2LMHeadModel.from_pretrained(model_name)\n", 50 | " tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n", 51 | " \n", 52 | " synthetic_texts = []\n", 53 | " for _ in range(num_samples):\n", 54 | " inputs = tokenizer(prompt, return_tensors=\"pt\")\n", 55 | " input_ids = inputs[\"input_ids\"]\n", 56 | " attention_mask = inputs[\"attention_mask\"]\n", 57 | "\n", 58 | " sample_output = model.generate(\n", 59 | " input_ids,\n", 60 | " max_length=100, # You can set this to control the length of generated text\n", 61 | " min_length=30, # Minimum length of the generated text\n", 62 | " num_return_sequences=1,\n", 63 | " attention_mask=attention_mask,\n", 64 | " no_repeat_ngram_size=2, # This will prevent repeating n-grams (here 2-grams) in the generated text\n", 65 | " early_stopping=True\n", 66 | " )\n", 67 | "\n", 68 | " text = tokenizer.decode(sample_output[0], skip_special_tokens=True)\n", 69 | " synthetic_texts.append(text)\n", 70 | " \n", 71 | " return synthetic_texts" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 5, 77 | "id": "2333612b-7fe3-4c38-9608-09945df78bcd", 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stderr", 82 | "output_type": "stream", 83 | "text": [ 84 | "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" 85 | ] 86 | }, 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "The weather was nice and I enjoyed the view. I was able to get a good view of the city and the surrounding area. The weather is good and it was a nice day.\n", 92 | "\n", 93 | "I was very impressed with the views. It was the first time I've been to the area and was really impressed. We had a great time and we were able get to see the entire city. There was also a lot of parking and there was plenty of traffic. Parking is very easy and you can\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "# Example prompt\n", 99 | "prompt = \"The weather was nice and I enjoyed\"\n", 100 | "\n", 101 | "# Generate synthetic data\n", 102 | "synthetic_data = generate_synthetic_text(prompt)\n", 103 | "for text in synthetic_data:\n", 104 | " print(text)" 105 | ] 106 | } 107 | ], 108 | "metadata": { 109 | "kernelspec": { 110 | "display_name": "Python 3 (ipykernel)", 111 | "language": "python", 112 | "name": "python3" 113 | }, 114 | "language_info": { 115 | "codemirror_mode": { 116 | "name": "ipython", 117 | "version": 3 118 | }, 119 | "file_extension": ".py", 120 | "mimetype": "text/x-python", 121 | "name": "python", 122 | "nbconvert_exporter": "python", 123 | "pygments_lexer": "ipython3", 124 | "version": "3.10.6" 125 | } 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 5 129 | } 130 | -------------------------------------------------------------------------------- /supplementary/q15-text-augment/word-deletion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b050edae-48cf-47ba-b2ae-bf9fc9098bb4", 6 | "metadata": {}, 7 | "source": [ 8 | "## Word Deletion for Text Augmentation" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "0ab61d1f-466f-4469-99da-b172a959c2cb", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Author: Sebastian Raschka\n", 22 | "\n", 23 | "Python implementation: CPython\n", 24 | "Python version : 3.10.6\n", 25 | "IPython version : 8.12.0\n", 26 | "\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "%load_ext watermark\n", 32 | "%watermark -a 'Sebastian Raschka' -v" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "1f2926e6-aeac-4a9b-b762-f8a4a891d5b9", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import random\n", 43 | "\n", 44 | "\n", 45 | "def word_deletion(sentence, deletion_rate=0.1):\n", 46 | " words = sentence.split()\n", 47 | " num_words_to_delete = int(len(words) * deletion_rate)\n", 48 | "\n", 49 | " for _ in range(num_words_to_delete):\n", 50 | " index_to_delete = random.randint(0, len(words) - 1)\n", 51 | " del words[index_to_delete]\n", 52 | "\n", 53 | " return \" \".join(words)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "4f3bfef8-58ca-45c8-9762-9be6b23b06ef", 59 | "metadata": {}, 60 | "source": [ 61 | "**Random deletion with a 20% deletion rate**" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "id": "1820de77-bd90-4247-a7d3-1ff7dc429ff4", 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "The quick fox jumped over the lazy dog.\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "random.seed(1)\n", 80 | "sentence = \"The quick brown fox jumped over the lazy dog.\"\n", 81 | "augmented_sentence = word_deletion(sentence, deletion_rate=0.2)\n", 82 | "print(augmented_sentence)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "id": "47edf286-4825-4274-94b4-d10423767646", 88 | "metadata": {}, 89 | "source": [ 90 | "**Show difference before and after**" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 4, 96 | "id": "823a4b6b-5c33-4fd5-ade7-3b718d902ad4", 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | " The\n", 104 | " quick\n", 105 | "- brown\n", 106 | " fox\n", 107 | " jumped\n", 108 | " over\n", 109 | " the\n", 110 | " lazy\n", 111 | " dog.\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "import difflib\n", 117 | "\n", 118 | "\n", 119 | "d = difflib.Differ()\n", 120 | "diff = d.compare(sentence.split(), augmented_sentence.split())\n", 121 | "\n", 122 | "print('\\n'.join(diff))" 123 | ] 124 | } 125 | ], 126 | "metadata": { 127 | "kernelspec": { 128 | "display_name": "Python 3 (ipykernel)", 129 | "language": "python", 130 | "name": "python3" 131 | }, 132 | "language_info": { 133 | "codemirror_mode": { 134 | "name": "ipython", 135 | "version": 3 136 | }, 137 | "file_extension": ".py", 138 | "mimetype": "text/x-python", 139 | "name": "python", 140 | "nbconvert_exporter": "python", 141 | "pygments_lexer": "ipython3", 142 | "version": "3.10.6" 143 | } 144 | }, 145 | "nbformat": 4, 146 | "nbformat_minor": 5 147 | } 148 | -------------------------------------------------------------------------------- /supplementary/q15-text-augment/word-position-swapping.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b050edae-48cf-47ba-b2ae-bf9fc9098bb4", 6 | "metadata": {}, 7 | "source": [ 8 | "## Word Position Swapping (Shuffling) for Text Augmentation" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "0ab61d1f-466f-4469-99da-b172a959c2cb", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Author: Sebastian Raschka\n", 22 | "\n", 23 | "Python implementation: CPython\n", 24 | "Python version : 3.10.6\n", 25 | "IPython version : 8.12.0\n", 26 | "\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "%load_ext watermark\n", 32 | "%watermark -a 'Sebastian Raschka' -v" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "1f2926e6-aeac-4a9b-b762-f8a4a891d5b9", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import random\n", 43 | "import re\n", 44 | "\n", 45 | "\n", 46 | "def sentence_order_shuffling(text):\n", 47 | "\n", 48 | " # split upon period or question mark:\n", 49 | " sentences = re.split(r'[.!?] ', text)\n", 50 | " random.shuffle(sentences)\n", 51 | " return '. '.join(sentences)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "4f3bfef8-58ca-45c8-9762-9be6b23b06ef", 57 | "metadata": {}, 58 | "source": [ 59 | "**Random swapping with a 20% swapping rate**" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "id": "1820de77-bd90-4247-a7d3-1ff7dc429ff4", 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "Original Paragraph:\n", 73 | " The cat quickly jumped over the lazy dog. It was a sunny day, and the park was full of people. The children were playing, and the birds were singing.\n", 74 | "\n", 75 | "Augmented Paragraph:\n", 76 | " It was a sunny day, and the park was full of people. The children were playing, and the birds were singing.. The cat quickly jumped over the lazy dog\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "random.seed(1)\n", 82 | "\n", 83 | "paragraph = (\"The cat quickly jumped over the lazy dog. \"\n", 84 | " \"It was a sunny day, and the park was full of people. \"\n", 85 | " \"The children were playing, and the birds were singing.\")\n", 86 | "\n", 87 | "augmented_paragraph = sentence_order_shuffling(paragraph)\n", 88 | "\n", 89 | "print(\"Original Paragraph:\\n\", paragraph)\n", 90 | "print(\"\\nAugmented Paragraph:\\n\", augmented_paragraph)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "id": "47edf286-4825-4274-94b4-d10423767646", 96 | "metadata": {}, 97 | "source": [ 98 | "**Show difference before and after**" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "id": "823a4b6b-5c33-4fd5-ade7-3b718d902ad4", 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "- The\n", 112 | "- cat\n", 113 | "- quickly\n", 114 | "- jumped\n", 115 | "- over\n", 116 | "- the\n", 117 | "- lazy\n", 118 | "- dog.\n", 119 | " It\n", 120 | " was\n", 121 | " a\n", 122 | " sunny\n", 123 | " day,\n", 124 | " and\n", 125 | " the\n", 126 | " park\n", 127 | " was\n", 128 | " full\n", 129 | " of\n", 130 | " people.\n", 131 | " The\n", 132 | " children\n", 133 | " were\n", 134 | " playing,\n", 135 | " and\n", 136 | " the\n", 137 | " birds\n", 138 | " were\n", 139 | "- singing.\n", 140 | "+ singing..\n", 141 | "? +\n", 142 | "\n", 143 | "+ The\n", 144 | "+ cat\n", 145 | "+ quickly\n", 146 | "+ jumped\n", 147 | "+ over\n", 148 | "+ the\n", 149 | "+ lazy\n", 150 | "+ dog\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "import difflib\n", 156 | "\n", 157 | "\n", 158 | "d = difflib.Differ()\n", 159 | "diff = d.compare(paragraph.split(), augmented_paragraph.split())\n", 160 | "\n", 161 | "print('\\n'.join(diff))" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "71f2c7d2-f99f-4fb2-9c80-1c62bb543b0c", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [] 171 | } 172 | ], 173 | "metadata": { 174 | "kernelspec": { 175 | "display_name": "Python 3 (ipykernel)", 176 | "language": "python", 177 | "name": "python3" 178 | }, 179 | "language_info": { 180 | "codemirror_mode": { 181 | "name": "ipython", 182 | "version": 3 183 | }, 184 | "file_extension": ".py", 185 | "mimetype": "text/x-python", 186 | "name": "python", 187 | "nbconvert_exporter": "python", 188 | "pygments_lexer": "ipython3", 189 | "version": "3.10.6" 190 | } 191 | }, 192 | "nbformat": 4, 193 | "nbformat_minor": 5 194 | } 195 | -------------------------------------------------------------------------------- /supplementary/q18-using-llms/01_classifier-finetuning/1_feature-extractor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "3c5d72f4", 6 | "metadata": {}, 7 | "source": [ 8 | "# LLM as Feature Extractor" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "bb9d0299-8fc0-48f0-9b02-4c19214d479a", 14 | "metadata": {}, 15 | "source": [ 16 | "In this feature-based approach, we are using the embeddings from a pretrained transormer to train a random forest and logistic regression model in scikit-learn:\n", 17 | "\n", 18 | "" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "6fd9cda8", 25 | "metadata": { 26 | "tags": [] 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# pip install transformers datasets" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "id": "df18e3de-577a-43c5-8b9d-868397a6d7da", 37 | "metadata": { 38 | "tags": [] 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "# conda install sklearn --yes" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "id": "033b75c5", 49 | "metadata": { 50 | "tags": [] 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "torch : 2.0.0\n", 58 | "transformers: 4.27.4\n", 59 | "datasets : 2.11.0\n", 60 | "sklearn : 1.2.2\n", 61 | "\n", 62 | "conda environment: finetuning-blog\n", 63 | "\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "%load_ext watermark\n", 69 | "%watermark --conda -p torch,transformers,datasets,sklearn" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "id": "602ba8a0", 76 | "metadata": { 77 | "tags": [] 78 | }, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "cuda:0\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "import torch\n", 90 | "\n", 91 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 92 | "print(device)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "id": "4cfd724d", 98 | "metadata": { 99 | "tags": [] 100 | }, 101 | "source": [ 102 | "# 1 Loading the Dataset" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 5, 108 | "id": "e39e2228-5f0b-4fb9-b762-df26c2052b45", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# pip install datasets\n", 113 | "\n", 114 | "import os.path as op\n", 115 | "\n", 116 | "from datasets import load_dataset\n", 117 | "\n", 118 | "import lightning as L\n", 119 | "from lightning.pytorch.loggers import CSVLogger\n", 120 | "from lightning.pytorch.callbacks import ModelCheckpoint\n", 121 | "\n", 122 | "import numpy as np\n", 123 | "import pandas as pd\n", 124 | "import torch\n", 125 | "\n", 126 | "from sklearn.feature_extraction.text import CountVectorizer\n", 127 | "\n", 128 | "from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset\n", 129 | "from local_dataset_utilities import IMDBDataset" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 6, 135 | "id": "fb31ac90-9e3a-41d0-baf1-8e613043924b", 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "name": "stderr", 140 | "output_type": "stream", 141 | "text": [ 142 | "100%|███████████████████████████████████████████| 50000/50000 [00:25<00:00, 1973.05it/s]\n" 143 | ] 144 | }, 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "Class distribution:\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "download_dataset()\n", 155 | "\n", 156 | "df = load_dataset_into_to_dataframe()\n", 157 | "partition_dataset(df)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 7, 163 | "id": "221f30a1-b433-4304-a18d-8d03abd42b58", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "df_train = pd.read_csv(\"train.csv\")\n", 168 | "df_val = pd.read_csv(\"val.csv\")\n", 169 | "df_test = pd.read_csv(\"test.csv\")" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "id": "846d83b1", 175 | "metadata": {}, 176 | "source": [ 177 | "# 2 Tokenization and Numericalization" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 8, 183 | "id": "21114d27-2697-4132-9714-b259bd63f5a1", 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "Downloading and preparing dataset csv/default to /home/sebastian/.cache/huggingface/datasets/csv/default-2417067d5b75d213/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...\n" 191 | ] 192 | }, 193 | { 194 | "data": { 195 | "application/vnd.jupyter.widget-view+json": { 196 | "model_id": "0f3dbdca454a4e7d8ebfe80e8e946e7d", 197 | "version_major": 2, 198 | "version_minor": 0 199 | }, 200 | "text/plain": [ 201 | "Downloading data files: 0%| | 0/3 [00:00= version.parse("1.3.2"): 63 | x = pd.DataFrame( 64 | [[txt, labels[l]]], columns=["review", "sentiment"] 65 | ) 66 | df = pd.concat([df, x], ignore_index=False) 67 | 68 | else: 69 | df = df.append([[txt, labels[l]]], ignore_index=True) 70 | pbar.update() 71 | df.columns = ["text", "label"] 72 | 73 | np.random.seed(0) 74 | df = df.reindex(np.random.permutation(df.index)) 75 | 76 | print("Class distribution:") 77 | np.bincount(df["label"].values) 78 | 79 | return df 80 | 81 | 82 | def partition_dataset(df): 83 | df_shuffled = df.sample(frac=1, random_state=1).reset_index() 84 | 85 | df_train = df_shuffled.iloc[:35_000] 86 | df_val = df_shuffled.iloc[35_000:40_000] 87 | df_test = df_shuffled.iloc[40_000:] 88 | 89 | df_train.to_csv("train.csv", index=False, encoding="utf-8") 90 | df_val.to_csv("val.csv", index=False, encoding="utf-8") 91 | df_test.to_csv("test.csv", index=False, encoding="utf-8") 92 | 93 | 94 | class IMDBDataset(Dataset): 95 | def __init__(self, dataset_dict, partition_key="train"): 96 | self.partition = dataset_dict[partition_key] 97 | 98 | def __getitem__(self, index): 99 | return self.partition[index] 100 | 101 | def __len__(self): 102 | return self.partition.num_rows -------------------------------------------------------------------------------- /supplementary/q18-using-llms/02_prompting/prompting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "bb754e1d-8c05-46e3-b156-f6fd0184a039", 6 | "metadata": {}, 7 | "source": [ 8 | "# In-Context Learning and Prompting" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "c530c2bc-17b5-4a4c-94c7-d6a5cd785be1", 14 | "metadata": {}, 15 | "source": [ 16 | "## In-context learning: Provide examples of the input text within the context" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "ef144502-f9a1-4a81-9a18-ae5a47a8ca48", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# pip install transformers" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 6, 32 | "id": "ec8ef079-b4ca-487a-85cd-1975f4f49838", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "from transformers import pipeline\n", 37 | "\n", 38 | "generator = pipeline(\"text-generation\", model=\"EleutherAI/pythia-1.4b\")" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "id": "ed5c2ab9-2798-4187-afab-68de05e0d415", 44 | "metadata": {}, 45 | "source": [ 46 | "> Note: Pythia is a small GPT-like LLM. For more information, please refer to the original paper, [Pythia: A Suite for Analyzing Large Language Models Across Training and Scaling](https://arxiv.org/abs/2304.01373)." 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "2b6bef96-6819-4af6-b65d-7ce3068af47d", 52 | "metadata": {}, 53 | "source": [ 54 | "**Without in-context examples**" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 7, 60 | "id": "d4a3a8f9-52b3-4b58-be0b-ae01d8d3801c", 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n" 68 | ] 69 | }, 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "Translate this sentence:\n", 75 | "German: 'Wo ist die naechste Bushaltestelle?'\n", 76 | "English: 'Where is the next bus stop?'\n", 77 | "\n", 78 | "A:\n", 79 | "\n", 80 | "The German word for \"next\" is \"nach\", which is the past tense of \"sein\".\n", 81 | "The English word for \"next\" is \"nach\", which is the past tense of \"be\".\n", 82 | "The German word for \"next\" is \"nach\", which is the past tense\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "prompt = \"\"\"Translate this sentence:\n", 88 | "German: 'Wo ist die naechste Bushaltestelle?'\"\"\"\n", 89 | "\n", 90 | "generated_text = generator(prompt, max_length=100, temperature=1, top_k=0, top_p=0)\n", 91 | "print(generated_text[0][\"generated_text\"])" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "e7e8f984-2361-4fb3-b5c1-0302357a01e6", 97 | "metadata": {}, 98 | "source": [ 99 | "**With in-context examples**" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 8, 105 | "id": "b877ab43-199f-4997-9a6f-1d5597bece1a", 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "name": "stderr", 110 | "output_type": "stream", 111 | "text": [ 112 | "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n" 113 | ] 114 | }, 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "Translate the following German sentences into English:\n", 120 | "\n", 121 | "Example 1:\n", 122 | "German: \"Ich liebe Pfannkuchen.\"\n", 123 | "English: \"I love pancakes.\"\n", 124 | "\n", 125 | "Example 2:\n", 126 | "German: \"Das Wetter ist heute schoen.\"\n", 127 | "English: \"The weather is nice today.\"\n", 128 | "\n", 129 | "Translate this sentence:\n", 130 | "German: \"Wo ist die naechste Bushaltestelle?\"\n", 131 | "English: \"Where is the next bus stop?\"\n", 132 | "\n", 133 | "\n" 134 | ] 135 | } 136 | ], 137 | "source": [ 138 | "prompt = \"\"\"Translate the following German sentences into English:\n", 139 | "\n", 140 | "Example 1:\n", 141 | "German: \"Ich liebe Pfannkuchen.\"\n", 142 | "English: \"I love pancakes.\"\n", 143 | "\n", 144 | "Example 2:\n", 145 | "German: \"Das Wetter ist heute schoen.\"\n", 146 | "English: \"The weather is nice today.\"\n", 147 | "\n", 148 | "Translate this sentence:\n", 149 | "German: \"Wo ist die naechste Bushaltestelle?\"\n", 150 | "\"\"\"\n", 151 | "\n", 152 | "generated_text = generator(prompt, max_length=102, temperature=1, top_k=0, top_p=0)\n", 153 | "print(generated_text[0][\"generated_text\"])" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "id": "91d7a6f2-410a-4966-9c1a-d4f90a1a5e83", 159 | "metadata": {}, 160 | "source": [ 161 | "### Prompt tuning" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 9, 167 | "id": "363ae482-fea0-4ae1-9875-c3c3a85b4cb5", 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stderr", 172 | "output_type": "stream", 173 | "text": [ 174 | "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n" 175 | ] 176 | }, 177 | { 178 | "name": "stdout", 179 | "output_type": "stream", 180 | "text": [ 181 | "Translate the following German sentences into English:\n", 182 | "\n", 183 | "Example 1:\n", 184 | "Translate the German sentence 'Ich liebe Pfannkuchen.' into English: 'I love pancakes.' \n", 185 | "\n", 186 | "Example 2:\n", 187 | "Translate the German sentence 'Das Wetter ist heute schoen.' into English: 'The weather is nice today.'\n", 188 | "\n", 189 | "Translate the German sentence 'Wo ist die naechste Bushaltestelle?' into English: \n", 190 | "\n", 191 | "Where is the next bus stop?\n", 192 | "\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "# \"Translate the German sentence '{german_sentence}' into English: \"\n", 198 | "\n", 199 | "prompt = \"\"\"Translate the following German sentences into English:\n", 200 | "\n", 201 | "Example 1:\n", 202 | "Translate the German sentence 'Ich liebe Pfannkuchen.' into English: 'I love pancakes.' \n", 203 | "\n", 204 | "Example 2:\n", 205 | "Translate the German sentence 'Das Wetter ist heute schoen.' into English: 'The weather is nice today.'\n", 206 | "\n", 207 | "Translate the German sentence 'Wo ist die naechste Bushaltestelle?' into English: \n", 208 | "\"\"\"\n", 209 | "\n", 210 | "generated_text = generator(prompt, max_length=105, temperature=1, top_k=0, top_p=0)\n", 211 | "print(generated_text[0][\"generated_text\"])" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 10, 217 | "id": "95d20e4b-dc99-496b-b110-765138cc990c", 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "name": "stderr", 222 | "output_type": "stream", 223 | "text": [ 224 | "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n" 225 | ] 226 | }, 227 | { 228 | "name": "stdout", 229 | "output_type": "stream", 230 | "text": [ 231 | "Translate the following German sentences into English:\n", 232 | "\n", 233 | "Example 1:\n", 234 | "German: 'Ich liebe Pfannkuchen.' | English: 'I love pancakes.' \n", 235 | "\n", 236 | "Example 2:\n", 237 | "German: 'Das Wetter ist heute schoen.' | English: 'The weather is nice today.'\n", 238 | "\n", 239 | "German: 'Wo ist die naechste Bushaltestelle?' | English: \n", 240 | "\n", 241 | "'Where is the next bus stop?'\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "# \"German: '{german_sentence}' | English: \"\n", 247 | "\n", 248 | "prompt = \"\"\"Translate the following German sentences into English:\n", 249 | "\n", 250 | "Example 1:\n", 251 | "German: 'Ich liebe Pfannkuchen.' | English: 'I love pancakes.' \n", 252 | "\n", 253 | "Example 2:\n", 254 | "German: 'Das Wetter ist heute schoen.' | English: 'The weather is nice today.'\n", 255 | "\n", 256 | "German: 'Wo ist die naechste Bushaltestelle?' | English: \n", 257 | "\"\"\"\n", 258 | "\n", 259 | "generated_text = generator(prompt, max_length=96, temperature=1, top_k=0, top_p=0)\n", 260 | "print(generated_text[0][\"generated_text\"])" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 11, 266 | "id": "a7b3bf49-d875-49a8-9689-ed4b37367810", 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "name": "stderr", 271 | "output_type": "stream", 272 | "text": [ 273 | "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n" 274 | ] 275 | }, 276 | { 277 | "name": "stdout", 278 | "output_type": "stream", 279 | "text": [ 280 | "Example 1:\n", 281 | "From German to English: 'Ich liebe Pfannkuchen.' -> 'I love pancakes.' \n", 282 | "\n", 283 | "Example 2:\n", 284 | "From German to English: 'Das Wetter ist heute schoen.' -> 'The weather is nice today.'\n", 285 | "\n", 286 | "Example 3:\n", 287 | "From German to English: 'Wo ist die naechste Bushaltestelle?' -> : \n", 288 | "\n", 289 | "Example 4:\n", 290 | "From\n" 291 | ] 292 | } 293 | ], 294 | "source": [ 295 | "# \"From German to English: '{german_sentence}' -> \"\n", 296 | "\n", 297 | "prompt = \"\"\"Example 1:\n", 298 | "From German to English: 'Ich liebe Pfannkuchen.' -> 'I love pancakes.' \n", 299 | "\n", 300 | "Example 2:\n", 301 | "From German to English: 'Das Wetter ist heute schoen.' -> 'The weather is nice today.'\n", 302 | "\n", 303 | "Example 3:\n", 304 | "From German to English: 'Wo ist die naechste Bushaltestelle?' -> : \n", 305 | "\"\"\"\n", 306 | "generated_text = generator(prompt, max_length=90, temperature=1, top_k=0, top_p=0)\n", 307 | "print(generated_text[0][\"generated_text\"])" 308 | ] 309 | } 310 | ], 311 | "metadata": { 312 | "kernelspec": { 313 | "display_name": "Python 3 (ipykernel)", 314 | "language": "python", 315 | "name": "python3" 316 | }, 317 | "language_info": { 318 | "codemirror_mode": { 319 | "name": "ipython", 320 | "version": 3 321 | }, 322 | "file_extension": ".py", 323 | "mimetype": "text/x-python", 324 | "name": "python", 325 | "nbconvert_exporter": "python", 326 | "pygments_lexer": "ipython3", 327 | "version": "3.11.4" 328 | } 329 | }, 330 | "nbformat": 4, 331 | "nbformat_minor": 5 332 | } 333 | -------------------------------------------------------------------------------- /supplementary/q18-using-llms/03_retrieval-augmented-generation/images/rag-1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/MachineLearning-QandAI-book/9c9994d3b0c320e428a441c2234d4043fa95110d/supplementary/q18-using-llms/03_retrieval-augmented-generation/images/rag-1.webp -------------------------------------------------------------------------------- /supplementary/q18-using-llms/03_retrieval-augmented-generation/retrieval-augmented-generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Retrieval Augmented Generation Example with Llama Index" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "### 1) Load the embedding model and LLM:" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "data": { 31 | "application/vnd.jupyter.widget-view+json": { 32 | "model_id": "2d6dd568bd5942aca2b4b1990dced1c0", 33 | "version_major": 2, 34 | "version_minor": 0 35 | }, 36 | "text/plain": [ 37 | "config.json: 0%| | 0.00/596 [00:00= version.parse("1.3.2"): 63 | x = pd.DataFrame( 64 | [[txt, labels[l]]], columns=["review", "sentiment"] 65 | ) 66 | df = pd.concat([df, x], ignore_index=False) 67 | 68 | else: 69 | df = df.append([[txt, labels[l]]], ignore_index=True) 70 | pbar.update() 71 | df.columns = ["text", "label"] 72 | 73 | np.random.seed(0) 74 | df = df.reindex(np.random.permutation(df.index)) 75 | 76 | print("Class distribution:") 77 | np.bincount(df["label"].values) 78 | 79 | return df 80 | 81 | 82 | def partition_dataset(df): 83 | df_shuffled = df.sample(frac=1, random_state=1).reset_index() 84 | 85 | df_train = df_shuffled.iloc[:35_000] 86 | df_val = df_shuffled.iloc[35_000:40_000] 87 | df_test = df_shuffled.iloc[40_000:] 88 | 89 | df_train.to_csv("train.csv", index=False, encoding="utf-8") 90 | df_val.to_csv("val.csv", index=False, encoding="utf-8") 91 | df_test.to_csv("test.csv", index=False, encoding="utf-8") 92 | 93 | 94 | class IMDBDataset(Dataset): 95 | def __init__(self, dataset_dict, partition_key="train"): 96 | self.partition = dataset_dict[partition_key] 97 | 98 | def __getitem__(self, index): 99 | return self.partition[index] 100 | 101 | def __len__(self): 102 | return self.partition.num_rows -------------------------------------------------------------------------------- /supplementary/q18-using-llms/05_lora/lora-mlp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d2abd10e-e63e-4904-badf-5a16409503b1", 6 | "metadata": {}, 7 | "source": [ 8 | "# LoRA -- A Multilayer Perceptron Example" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "263e27da-47c7-4030-83c6-bf5f7e8bef74", 14 | "metadata": {}, 15 | "source": [ 16 | "This code notebook illustrates how LoRA ([https://arxiv.org/abs/2106.09685](https://arxiv.org/abs/2106.09685)) works by implementing it from scratch in the context of a multilayer perceptron (not LLM) to illustrate it with a simple example.\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "c1c52f02-94fb-4f45-902e-79126e27347d", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import time\n", 27 | "import numpy as np\n", 28 | "from torchvision import datasets\n", 29 | "from torchvision import transforms\n", 30 | "from torch.utils.data import DataLoader\n", 31 | "import torch.nn.functional as F\n", 32 | "import torch.nn as nn\n", 33 | "import torch\n", 34 | "\n", 35 | "\n", 36 | "if torch.cuda.is_available():\n", 37 | " torch.backends.cudnn.deterministic = True" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "629ec66a-eb81-40a5-ae3d-d5c1d2a7e390", 43 | "metadata": {}, 44 | "source": [ 45 | "## Settings and Dataset" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "id": "4ade5e86-8bd8-4a35-8db1-44451601b292", 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", 59 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n" 60 | ] 61 | }, 62 | { 63 | "name": "stderr", 64 | "output_type": "stream", 65 | "text": [ 66 | "100%|████████████████████████████| 9912422/9912422 [00:01<00:00, 8702295.18it/s]\n" 67 | ] 68 | }, 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n", 74 | "\n", 75 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", 76 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n" 77 | ] 78 | }, 79 | { 80 | "name": "stderr", 81 | "output_type": "stream", 82 | "text": [ 83 | "100%|███████████████████████████████| 28881/28881 [00:00<00:00, 31844293.85it/s]" 84 | ] 85 | }, 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n", 91 | "\n", 92 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n" 93 | ] 94 | }, 95 | { 96 | "name": "stderr", 97 | "output_type": "stream", 98 | "text": [ 99 | "\n" 100 | ] 101 | }, 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" 107 | ] 108 | }, 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | "100%|████████████████████████████| 1648877/1648877 [00:00<00:00, 4452323.08it/s]\n" 114 | ] 115 | }, 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n", 121 | "\n", 122 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", 123 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" 124 | ] 125 | }, 126 | { 127 | "name": "stderr", 128 | "output_type": "stream", 129 | "text": [ 130 | "100%|██████████████████████████████████| 4542/4542 [00:00<00:00, 9097673.72it/s]" 131 | ] 132 | }, 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n", 138 | "\n", 139 | "Image batch dimensions: torch.Size([64, 1, 28, 28])\n", 140 | "Image label dimensions: torch.Size([64])\n" 141 | ] 142 | }, 143 | { 144 | "name": "stderr", 145 | "output_type": "stream", 146 | "text": [ 147 | "\n" 148 | ] 149 | } 150 | ], 151 | "source": [ 152 | "##########################\n", 153 | "### SETTINGS\n", 154 | "##########################\n", 155 | "\n", 156 | "# Device\n", 157 | "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 158 | "BATCH_SIZE = 64\n", 159 | "\n", 160 | "##########################\n", 161 | "### MNIST DATASET\n", 162 | "##########################\n", 163 | "\n", 164 | "# Note transforms.ToTensor() scales input images\n", 165 | "# to 0-1 range\n", 166 | "train_dataset = datasets.MNIST(root='data', \n", 167 | " train=True, \n", 168 | " transform=transforms.ToTensor(),\n", 169 | " download=True)\n", 170 | "\n", 171 | "test_dataset = datasets.MNIST(root='data', \n", 172 | " train=False, \n", 173 | " transform=transforms.ToTensor())\n", 174 | "\n", 175 | "\n", 176 | "train_loader = DataLoader(dataset=train_dataset, \n", 177 | " batch_size=BATCH_SIZE, \n", 178 | " shuffle=True)\n", 179 | "\n", 180 | "test_loader = DataLoader(dataset=test_dataset, \n", 181 | " batch_size=BATCH_SIZE, \n", 182 | " shuffle=False)\n", 183 | "\n", 184 | "# Checking the dataset\n", 185 | "for images, labels in train_loader: \n", 186 | " print('Image batch dimensions:', images.shape)\n", 187 | " print('Image label dimensions:', labels.shape)\n", 188 | " break" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "id": "394e5da8-2978-40f0-bca7-b79e8e35734f", 194 | "metadata": {}, 195 | "source": [ 196 | "# Multilayer Perceptron Model (Without LoRA)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 3, 202 | "id": "7e905c42-7f59-4a08-b6c5-a10f99f33e9e", 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "##########################\n", 207 | "### MODEL\n", 208 | "##########################\n", 209 | "\n", 210 | "# Hyperparameters\n", 211 | "random_seed = 123\n", 212 | "learning_rate = 0.005\n", 213 | "num_epochs = 2\n", 214 | "\n", 215 | "# Architecture\n", 216 | "num_features = 784\n", 217 | "num_hidden_1 = 128\n", 218 | "num_hidden_2 = 256\n", 219 | "num_classes = 10\n", 220 | "\n", 221 | "\n", 222 | "class MultilayerPerceptron(nn.Module):\n", 223 | "\n", 224 | " def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):\n", 225 | " super().__init__()\n", 226 | "\n", 227 | " self.layers = nn.Sequential(\n", 228 | " nn.Linear(num_features, num_hidden_1),\n", 229 | " nn.ReLU(),\n", 230 | " nn.Linear(num_hidden_1, num_hidden_2),\n", 231 | " nn.ReLU(),\n", 232 | " nn.Linear(num_hidden_2, num_classes)\n", 233 | " )\n", 234 | "\n", 235 | " def forward(self, x):\n", 236 | " x = self.layers(x)\n", 237 | " return x\n", 238 | "\n", 239 | "\n", 240 | "torch.manual_seed(random_seed)\n", 241 | "model_pretrained = MultilayerPerceptron(\n", 242 | " num_features=num_features,\n", 243 | " num_hidden_1=num_hidden_1,\n", 244 | " num_hidden_2=num_hidden_2, \n", 245 | " num_classes=num_classes\n", 246 | ")\n", 247 | "\n", 248 | "model_pretrained.to(DEVICE)\n", 249 | "optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=learning_rate)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 4, 255 | "id": "cf31624a-d950-402f-a564-2e7fb63db8a4", 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "def compute_accuracy(model, data_loader, device):\n", 260 | " model.eval()\n", 261 | " correct_pred, num_examples = 0, 0\n", 262 | " with torch.no_grad():\n", 263 | " for features, targets in data_loader:\n", 264 | " features = features.view(-1, 28*28).to(device)\n", 265 | " targets = targets.to(device)\n", 266 | " logits = model(features)\n", 267 | " _, predicted_labels = torch.max(logits, 1)\n", 268 | " num_examples += targets.size(0)\n", 269 | " correct_pred += (predicted_labels == targets).sum()\n", 270 | " return correct_pred.float()/num_examples * 100\n", 271 | "\n", 272 | "\n", 273 | "def train(num_epochs, model, optimizer, train_loader, device):\n", 274 | "\n", 275 | " start_time = time.time()\n", 276 | " for epoch in range(num_epochs):\n", 277 | " model.train()\n", 278 | " for batch_idx, (features, targets) in enumerate(train_loader):\n", 279 | "\n", 280 | " features = features.view(-1, 28*28).to(device)\n", 281 | " targets = targets.to(device)\n", 282 | "\n", 283 | " # FORWARD AND BACK PROP\n", 284 | " logits = model(features)\n", 285 | " loss = F.cross_entropy(logits, targets)\n", 286 | " optimizer.zero_grad()\n", 287 | "\n", 288 | " loss.backward()\n", 289 | "\n", 290 | " # UPDATE MODEL PARAMETERS\n", 291 | " optimizer.step()\n", 292 | "\n", 293 | " # LOGGING\n", 294 | " if not batch_idx % 400:\n", 295 | " print('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f'\n", 296 | " % (epoch+1, num_epochs, batch_idx,\n", 297 | " len(train_loader), loss))\n", 298 | "\n", 299 | " with torch.set_grad_enabled(False):\n", 300 | " print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n", 301 | " epoch+1, num_epochs,\n", 302 | " compute_accuracy(model, train_loader, device)))\n", 303 | "\n", 304 | " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", 305 | "\n", 306 | " print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 5, 312 | "id": "f47cfe4e-65eb-440e-b922-17c6dee7d7e2", 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "name": "stdout", 317 | "output_type": "stream", 318 | "text": [ 319 | "Epoch: 001/002 | Batch 000/938 | Loss: 2.2971\n", 320 | "Epoch: 001/002 | Batch 400/938 | Loss: 0.2258\n", 321 | "Epoch: 001/002 | Batch 800/938 | Loss: 0.1612\n", 322 | "Epoch: 001/002 training accuracy: 95.71%\n", 323 | "Time elapsed: 0.04 min\n", 324 | "Epoch: 002/002 | Batch 000/938 | Loss: 0.0593\n", 325 | "Epoch: 002/002 | Batch 400/938 | Loss: 0.0588\n", 326 | "Epoch: 002/002 | Batch 800/938 | Loss: 0.0556\n", 327 | "Epoch: 002/002 training accuracy: 97.40%\n", 328 | "Time elapsed: 0.08 min\n", 329 | "Total Training Time: 0.08 min\n", 330 | "Test accuracy: 96.55%\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "train(num_epochs, model_pretrained, optimizer_pretrained, train_loader, DEVICE)\n", 336 | "print(f'Test accuracy: {compute_accuracy(model_pretrained, test_loader, DEVICE):.2f}%')" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "id": "fb3480b9-aea5-411e-b252-d7fc8a5dd21d", 342 | "metadata": {}, 343 | "source": [ 344 | "# Multilayer Perceptron with LoRA" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "id": "36b9d281-22ba-4120-af95-f6b95adcaa03", 350 | "metadata": {}, 351 | "source": [ 352 | "## Modify model by injecting LoRA Layers" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 6, 358 | "id": "215795c5-c0d4-4886-b4d6-a5a0e7cc8c7e", 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "class LoRALayer(nn.Module):\n", 363 | " def __init__(self, in_dim, out_dim, rank, alpha):\n", 364 | " super().__init__()\n", 365 | " std_dev = 1 / torch.sqrt(torch.tensor(rank).float())\n", 366 | " self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)\n", 367 | " self.B = nn.Parameter(torch.zeros(rank, out_dim))\n", 368 | " self.alpha = alpha\n", 369 | "\n", 370 | " def forward(self, x):\n", 371 | " x = self.alpha * (x @ self.A @ self.B)\n", 372 | " return x\n", 373 | "\n", 374 | "\n", 375 | "class LinearWithLoRA(nn.Module):\n", 376 | " def __init__(self, linear, rank, alpha):\n", 377 | " super().__init__()\n", 378 | " self.linear = linear\n", 379 | " self.lora = LoRALayer(\n", 380 | " linear.in_features, linear.out_features, rank, alpha\n", 381 | " )\n", 382 | "\n", 383 | " def forward(self, x):\n", 384 | " return self.linear(x) + self.lora(x)\n", 385 | "\n", 386 | " \n", 387 | "# This LoRA code is equivalent to LinearWithLoRA\n", 388 | "class LinearWithLoRAMerged(nn.Module):\n", 389 | " def __init__(self, linear, rank, alpha):\n", 390 | " super().__init__()\n", 391 | " self.linear = linear\n", 392 | " self.lora = LoRALayer(\n", 393 | " linear.in_features, linear.out_features, rank, alpha\n", 394 | " )\n", 395 | "\n", 396 | " def forward(self, x):\n", 397 | " lora = self.lora.A @ self.lora.B\n", 398 | " combined_weight = self.linear.weight + lora.T\n", 399 | " return F.linear(x, combined_weight, self.linear.bias)\n" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 7, 405 | "id": "0441c93f-0ee5-4003-acc3-f24541f06c66", 406 | "metadata": {}, 407 | "outputs": [ 408 | { 409 | "name": "stdout", 410 | "output_type": "stream", 411 | "text": [ 412 | "Original output: tensor([[0.6639, 0.4487]], grad_fn=)\n" 413 | ] 414 | } 415 | ], 416 | "source": [ 417 | "torch.manual_seed(123)\n", 418 | "\n", 419 | "layer = nn.Linear(10, 2)\n", 420 | "x = torch.randn((1, 10))\n", 421 | "\n", 422 | "print(\"Original output:\", layer(x))" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 8, 428 | "id": "b132184a-f87e-423f-850a-2dc44fe76770", 429 | "metadata": {}, 430 | "outputs": [ 431 | { 432 | "name": "stdout", 433 | "output_type": "stream", 434 | "text": [ 435 | "LoRA output: tensor([[0.6639, 0.4487]], grad_fn=)\n" 436 | ] 437 | } 438 | ], 439 | "source": [ 440 | "layer_lora_1 = LinearWithLoRA(layer, rank=2, alpha=4)\n", 441 | "\n", 442 | "print(\"LoRA output:\", layer_lora_1(x))" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 9, 448 | "id": "f555c364-9c5f-4a20-8c8c-feb830131555", 449 | "metadata": {}, 450 | "outputs": [ 451 | { 452 | "name": "stdout", 453 | "output_type": "stream", 454 | "text": [ 455 | "LoRA output: tensor([[0.6639, 0.4487]], grad_fn=)\n" 456 | ] 457 | } 458 | ], 459 | "source": [ 460 | "layer_lora_2 = LinearWithLoRAMerged(layer, rank=2, alpha=4)\n", 461 | "print(\"LoRA output:\", layer_lora_2(x))" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 10, 467 | "id": "dc66ffa1-5822-4833-b636-d3a8170e84a2", 468 | "metadata": {}, 469 | "outputs": [ 470 | { 471 | "data": { 472 | "text/plain": [ 473 | "MultilayerPerceptron(\n", 474 | " (layers): Sequential(\n", 475 | " (0): Linear(in_features=784, out_features=128, bias=True)\n", 476 | " (1): ReLU()\n", 477 | " (2): Linear(in_features=128, out_features=256, bias=True)\n", 478 | " (3): ReLU()\n", 479 | " (4): Linear(in_features=256, out_features=10, bias=True)\n", 480 | " )\n", 481 | ")" 482 | ] 483 | }, 484 | "execution_count": 10, 485 | "metadata": {}, 486 | "output_type": "execute_result" 487 | } 488 | ], 489 | "source": [ 490 | "model_pretrained" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 11, 496 | "id": "b00a7e8f-09ff-499e-b593-3f6dac87f1bf", 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [ 500 | "import copy\n", 501 | "\n", 502 | "model_lora = copy.deepcopy(model_pretrained)\n", 503 | "model_dora = copy.deepcopy(model_pretrained)" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 12, 509 | "id": "b1e3ef6f-255c-4c71-9da5-7d06d8c439a3", 510 | "metadata": {}, 511 | "outputs": [ 512 | { 513 | "data": { 514 | "text/plain": [ 515 | "MultilayerPerceptron(\n", 516 | " (layers): Sequential(\n", 517 | " (0): LinearWithLoRA(\n", 518 | " (linear): Linear(in_features=784, out_features=128, bias=True)\n", 519 | " (lora): LoRALayer()\n", 520 | " )\n", 521 | " (1): ReLU()\n", 522 | " (2): LinearWithLoRA(\n", 523 | " (linear): Linear(in_features=128, out_features=256, bias=True)\n", 524 | " (lora): LoRALayer()\n", 525 | " )\n", 526 | " (3): ReLU()\n", 527 | " (4): LinearWithLoRA(\n", 528 | " (linear): Linear(in_features=256, out_features=10, bias=True)\n", 529 | " (lora): LoRALayer()\n", 530 | " )\n", 531 | " )\n", 532 | ")" 533 | ] 534 | }, 535 | "execution_count": 12, 536 | "metadata": {}, 537 | "output_type": "execute_result" 538 | } 539 | ], 540 | "source": [ 541 | "model_lora.layers[0] = LinearWithLoRA(model_lora.layers[0], rank=4, alpha=8)\n", 542 | "model_lora.layers[2] = LinearWithLoRA(model_lora.layers[2], rank=4, alpha=8)\n", 543 | "model_lora.layers[4] = LinearWithLoRA(model_lora.layers[4], rank=4, alpha=8)\n", 544 | "\n", 545 | "model_lora.to(DEVICE)\n", 546 | "optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)\n", 547 | "model_lora" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "id": "9756742d-f574-400a-8d4e-cc55233df83c", 553 | "metadata": {}, 554 | "source": [ 555 | "We just initialized the LoRA layers but haven't trained the LoRA layers yet, so a model with and without initial LoRA weights should have the same predictive performance:" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": 13, 561 | "id": "d2ac620b-2fdf-4b94-92bb-f8d00e640306", 562 | "metadata": {}, 563 | "outputs": [ 564 | { 565 | "name": "stdout", 566 | "output_type": "stream", 567 | "text": [ 568 | "Test accuracy orig model: 96.55%\n", 569 | "Test accuracy LoRA model: 96.55%\n" 570 | ] 571 | } 572 | ], 573 | "source": [ 574 | "print(f'Test accuracy orig model: {compute_accuracy(model_pretrained, test_loader, DEVICE):.2f}%')\n", 575 | "print(f'Test accuracy LoRA model: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')" 576 | ] 577 | }, 578 | { 579 | "cell_type": "markdown", 580 | "id": "4ceed732-7989-4f01-a5a1-1036eb41512d", 581 | "metadata": {}, 582 | "source": [ 583 | "## Train model with LoRA" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": 14, 589 | "id": "a35d4c20-f754-4e82-85a1-ad19a30b3dfe", 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "def freeze_linear_layers(model):\n", 594 | " for child in model.children():\n", 595 | " if isinstance(child, nn.Linear):\n", 596 | " for param in child.parameters():\n", 597 | " param.requires_grad = False\n", 598 | " else:\n", 599 | " # Recursively freeze linear layers in children modules\n", 600 | " freeze_linear_layers(child)" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 15, 606 | "id": "88454690-abe6-49de-986e-9a6fe7883000", 607 | "metadata": {}, 608 | "outputs": [ 609 | { 610 | "name": "stdout", 611 | "output_type": "stream", 612 | "text": [ 613 | "layers.0.linear.weight: False\n", 614 | "layers.0.linear.bias: False\n", 615 | "layers.0.lora.A: True\n", 616 | "layers.0.lora.B: True\n", 617 | "layers.2.linear.weight: False\n", 618 | "layers.2.linear.bias: False\n", 619 | "layers.2.lora.A: True\n", 620 | "layers.2.lora.B: True\n", 621 | "layers.4.linear.weight: False\n", 622 | "layers.4.linear.bias: False\n", 623 | "layers.4.lora.A: True\n", 624 | "layers.4.lora.B: True\n" 625 | ] 626 | } 627 | ], 628 | "source": [ 629 | "freeze_linear_layers(model_lora)\n", 630 | "\n", 631 | "# Check if linear layers are frozen\n", 632 | "for name, param in model_lora.named_parameters():\n", 633 | " print(f\"{name}: {param.requires_grad}\")" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": 16, 639 | "id": "1b807c7b-8d4a-4a1e-8a56-42bbdbc82fed", 640 | "metadata": {}, 641 | "outputs": [ 642 | { 643 | "name": "stdout", 644 | "output_type": "stream", 645 | "text": [ 646 | "Epoch: 001/002 | Batch 000/938 | Loss: 0.0843\n", 647 | "Epoch: 001/002 | Batch 400/938 | Loss: 0.2096\n", 648 | "Epoch: 001/002 | Batch 800/938 | Loss: 0.0998\n", 649 | "Epoch: 001/002 training accuracy: 97.46%\n", 650 | "Time elapsed: 0.04 min\n", 651 | "Epoch: 002/002 | Batch 000/938 | Loss: 0.0840\n", 652 | "Epoch: 002/002 | Batch 400/938 | Loss: 0.1385\n", 653 | "Epoch: 002/002 | Batch 800/938 | Loss: 0.0056\n", 654 | "Epoch: 002/002 training accuracy: 97.79%\n", 655 | "Time elapsed: 0.07 min\n", 656 | "Total Training Time: 0.07 min\n", 657 | "Test accuracy LoRA finetune: 96.88%\n" 658 | ] 659 | } 660 | ], 661 | "source": [ 662 | "optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)\n", 663 | "train(num_epochs, model_lora, optimizer_lora, train_loader, DEVICE)\n", 664 | "print(f'Test accuracy LoRA finetune: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')" 665 | ] 666 | } 667 | ], 668 | "metadata": { 669 | "kernelspec": { 670 | "display_name": "Python 3 (ipykernel)", 671 | "language": "python", 672 | "name": "python3" 673 | }, 674 | "language_info": { 675 | "codemirror_mode": { 676 | "name": "ipython", 677 | "version": 3 678 | }, 679 | "file_extension": ".py", 680 | "mimetype": "text/x-python", 681 | "name": "python", 682 | "nbconvert_exporter": "python", 683 | "pygments_lexer": "ipython3", 684 | "version": "3.11.4" 685 | } 686 | }, 687 | "nbformat": 4, 688 | "nbformat_minor": 5 689 | } 690 | -------------------------------------------------------------------------------- /supplementary/q19-evaluation-llms/bleu.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "85b9cad9-14cc-4145-a71a-6fdfa7b9b044", 6 | "metadata": {}, 7 | "source": [ 8 | "# BLEU Score for Unigrams" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "d152300f-00fd-4f31-bb2c-a776f1822a2e", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "original = \"Der schnelle braune Fuchs sprang ueber den faulen Hund\"\n", 19 | "\n", 20 | "reference = \"The quick brown fox jumped over the lazy dog\"\n", 21 | "candidate_1 = \"The fast brown fox leaped over the dog\"\n", 22 | "candidate_2 = \"The swift brown fox jumped over the lazy dog\"\n", 23 | "candidate_3 = \"The swift tawny fox leaped over the indolent canine.\"" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "3677a7a3-d4d1-49a0-ab90-d02b526cde04", 29 | "metadata": {}, 30 | "source": [ 31 | "### NLTK" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "id": "6a159771-e2ee-4adc-b53d-d349d69a6a63", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "#pip install nltk" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "id": "cb45c369-e04f-43ea-a31f-e2380efa3e79", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "BLEU score for example 1: 0.66\n", 55 | "BLEU score for example 2: 0.89\n", 56 | "BLEU score for example 3: 0.44\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "from nltk.translate.bleu_score import sentence_bleu\n", 62 | "\n", 63 | "bleu_nltk_1 = sentence_bleu([reference.split()], candidate_1.split(), weights=[1.])\n", 64 | "bleu_nltk_2 = sentence_bleu([reference.split()], candidate_2.split(), weights=[1.])\n", 65 | "bleu_nltk_3 = sentence_bleu([reference.split()], candidate_3.split(), weights=[1.])\n", 66 | "\n", 67 | "print(f\"BLEU score for example 1: {bleu_nltk_1:.2f}\")\n", 68 | "print(f\"BLEU score for example 2: {bleu_nltk_2:.2f}\")\n", 69 | "print(f\"BLEU score for example 3: {bleu_nltk_3:.2f}\")" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "d89def98-71e9-4ee4-96ac-017dd9bf2a28", 75 | "metadata": {}, 76 | "source": [ 77 | "### TorchMetrics" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "id": "01198d15-c16c-4aa6-a8ed-65705ab743cf", 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "BLEU score for example 1: 0.66\n", 91 | "BLEU score for example 2: 0.89\n", 92 | "BLEU score for example 3: 0.44\n" 93 | ] 94 | } 95 | ], 96 | "source": [ 97 | "from torchmetrics import BLEUScore\n", 98 | "\n", 99 | "bleu = BLEUScore(n_gram=1)\n", 100 | "\n", 101 | "# Calculate BLEU scores\n", 102 | "bleu_tm_1 = bleu(target=[[reference]], preds=[candidate_1])\n", 103 | "bleu_tm_2 = bleu(target=[[reference]], preds=[candidate_2])\n", 104 | "bleu_tm_3 = bleu(target=[[reference]], preds=[candidate_3])\n", 105 | "\n", 106 | "print(f\"BLEU score for example 1: {bleu_tm_1:.2f}\")\n", 107 | "print(f\"BLEU score for example 2: {bleu_tm_2:.2f}\")\n", 108 | "print(f\"BLEU score for example 3: {bleu_tm_3:.2f}\")" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "2fe6a608-168a-4f7c-87a0-d82b9eac21ac", 114 | "metadata": {}, 115 | "source": [ 116 | "### From Scratch" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 6, 122 | "id": "66075150-bb85-465f-945e-3cc008de7efd", 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "BLEU score for example 1: 0.66\n", 130 | "BLEU score for example 2: 0.89\n", 131 | "BLEU score for example 3: 0.44\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "import math\n", 137 | "from collections import Counter\n", 138 | "\n", 139 | "def ngrams(sentence, n):\n", 140 | " return [tuple(sentence[i:i+n]) for i in range(len(sentence)-n+1)]\n", 141 | "\n", 142 | "def modified_precision(reference, candidate, n):\n", 143 | " ref_ngrams = Counter(ngrams(reference, n))\n", 144 | " cand_ngrams = Counter(ngrams(candidate, n))\n", 145 | "\n", 146 | " count_clip = sum(min(cand_ngrams[ng], ref_ngrams[ng]) for ng in cand_ngrams)\n", 147 | " count_total = sum(cand_ngrams.values())\n", 148 | "\n", 149 | " return count_clip / count_total if count_total > 0 else 0\n", 150 | "\n", 151 | "def brevity_penalty(reference, candidate):\n", 152 | " ref_len = len(reference)\n", 153 | " cand_len = len(candidate)\n", 154 | "\n", 155 | " if cand_len > ref_len:\n", 156 | " return 1\n", 157 | " elif cand_len == 0:\n", 158 | " return 0\n", 159 | " else:\n", 160 | " return math.exp(1 - ref_len / cand_len)\n", 161 | "\n", 162 | "def bleu_score_unigram(reference, candidate):\n", 163 | " bp = brevity_penalty(reference, candidate)\n", 164 | " precision = modified_precision(reference, candidate, n=1)\n", 165 | "\n", 166 | " return bp * precision\n", 167 | "\n", 168 | "\n", 169 | "bleu_scratch_1 = bleu_score_unigram(reference=reference.split(), candidate=candidate_1.split())\n", 170 | "bleu_scratch_2 = bleu_score_unigram(reference=reference.split(), candidate=candidate_2.split())\n", 171 | "bleu_scratch_3 = bleu_score_unigram(reference=reference.split(), candidate=candidate_3.split())\n", 172 | "\n", 173 | "print(f\"BLEU score for example 1: {bleu_scratch_1:.2f}\")\n", 174 | "print(f\"BLEU score for example 2: {bleu_scratch_2:.2f}\")\n", 175 | "print(f\"BLEU score for example 3: {bleu_scratch_3:.2f}\")" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "id": "b7b3782e-704e-40aa-8dee-062282ba1e37", 181 | "metadata": {}, 182 | "source": [ 183 | "# BLEU Score for 4-grams (\"default\" BLEU)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 7, 189 | "id": "cf2bb62a-e1fd-4d5b-8d1d-3beadb72395b", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "# Example 1\n", 194 | "candidate_1 = \"The quick brown dog jumps over the lazy fox\"\n", 195 | "references_1 = [\n", 196 | " \"The quick brown fox jumps over the lazy dog\",\n", 197 | " \"The fast brown fox leaps over the lazy dog\",\n", 198 | "]\n", 199 | "\n", 200 | "# Example 2\n", 201 | "candidate_2 = \"The small red car drives quickly down the road\"\n", 202 | "references_2 = [\n", 203 | " \"The small red car races quickly along the road\",\n", 204 | " \"A small red car speeds rapidly down the avenue\",\n", 205 | "]" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "id": "288de8e8-a79a-4b8c-9822-be523a2bcbd3", 211 | "metadata": {}, 212 | "source": [ 213 | "## NLTK" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 8, 219 | "id": "2206ae02-71b4-44f8-8c9c-d939d1239e78", 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "name": "stdout", 224 | "output_type": "stream", 225 | "text": [ 226 | "BLEU score for example 1: 0.46\n", 227 | "BLEU score for example 2: 0.40\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "from nltk.translate.bleu_score import sentence_bleu\n", 233 | "\n", 234 | "bleu_nltk_1 = sentence_bleu([r.split() for r in references_1], candidate_1.split())\n", 235 | "bleu_nltk_2 = sentence_bleu([r.split() for r in references_2], candidate_2.split())\n", 236 | "\n", 237 | "print(f\"BLEU score for example 1: {bleu_nltk_1:.2f}\")\n", 238 | "print(f\"BLEU score for example 2: {bleu_nltk_2:.2f}\")" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "id": "3d06ea55-af4a-4a47-bb67-4ed43f909b0b", 244 | "metadata": {}, 245 | "source": [ 246 | "## TorchMetrics" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 9, 252 | "id": "363b3809-8745-4937-8452-9e3c67c95d27", 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "name": "stdout", 257 | "output_type": "stream", 258 | "text": [ 259 | "BLEU score for example 1: 0.46\n", 260 | "BLEU score for example 2: 0.40\n" 261 | ] 262 | } 263 | ], 264 | "source": [ 265 | "from torchmetrics import BLEUScore\n", 266 | "\n", 267 | "bleu = BLEUScore(n_gram=4)\n", 268 | "\n", 269 | "# Calculate BLEU scores\n", 270 | "bleu_tm_1 = bleu(target=[references_1], preds=[candidate_1])\n", 271 | "bleu_tm_2 = bleu(target=[references_2], preds=[candidate_2])\n", 272 | "\n", 273 | "print(f\"BLEU score for example 1: {bleu_tm_1:.2f}\")\n", 274 | "print(f\"BLEU score for example 2: {bleu_tm_2:.2f}\")" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "id": "3fd1a273-3f82-43a9-8cd8-e56db1df1f6f", 280 | "metadata": {}, 281 | "source": [ 282 | "## From Scratch" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 10, 288 | "id": "910e0742-68a4-452b-94c4-09eec1daf680", 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "BLEU score for example 1: 0.46\n", 296 | "BLEU score for example 2: 0.40\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "import math\n", 302 | "from collections import Counter\n", 303 | "from fractions import Fraction\n", 304 | "\n", 305 | "def tokenize(sentence):\n", 306 | " return sentence.lower().split()\n", 307 | "\n", 308 | "def ngrams(tokens, n):\n", 309 | " return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]\n", 310 | "\n", 311 | "def modified_precision(candidate, references, n):\n", 312 | " candidate_ngrams = Counter(ngrams(candidate, n))\n", 313 | " max_reference_counts = Counter()\n", 314 | "\n", 315 | " for reference in references:\n", 316 | " reference_ngrams = Counter(ngrams(reference, n))\n", 317 | " for ngram in candidate_ngrams:\n", 318 | " max_reference_counts[ngram] = max(max_reference_counts[ngram], reference_ngrams[ngram])\n", 319 | "\n", 320 | " clipped_counts = {\n", 321 | " ngram: min(count, max_reference_counts[ngram])\n", 322 | " for ngram, count in candidate_ngrams.items()\n", 323 | " }\n", 324 | "\n", 325 | " numerator = sum(clipped_counts.values())\n", 326 | " denominator = sum(candidate_ngrams.values())\n", 327 | "\n", 328 | " if denominator == 0:\n", 329 | " return 0\n", 330 | " return Fraction(numerator, denominator)\n", 331 | "\n", 332 | "def closest_reference_length(candidate, references):\n", 333 | " ref_lens = [len(reference) for reference in references]\n", 334 | " candidate_len = len(candidate)\n", 335 | " closest_ref_len = min(ref_lens, key=lambda ref_len: (abs(ref_len - candidate_len), ref_len))\n", 336 | " return closest_ref_len\n", 337 | "\n", 338 | "def brevity_penalty(candidate, references):\n", 339 | " candidate_length = len(candidate)\n", 340 | " closest_ref_len = closest_reference_length(candidate, references)\n", 341 | "\n", 342 | " if candidate_length > closest_ref_len:\n", 343 | " return 1\n", 344 | " else:\n", 345 | " return math.exp(1 - closest_ref_len / candidate_length)\n", 346 | "\n", 347 | "def sentence_bleu_scratch(candidate, references, weights=(0.25, 0.25, 0.25, 0.25)):\n", 348 | " candidate_tokens = tokenize(candidate)\n", 349 | " reference_tokens = [tokenize(reference) for reference in references]\n", 350 | "\n", 351 | " precisions = [\n", 352 | " modified_precision(candidate_tokens, reference_tokens, n+1)\n", 353 | " for n in range(len(weights))\n", 354 | " ]\n", 355 | "\n", 356 | " if all(p == 0 for p in precisions):\n", 357 | " return 0\n", 358 | "\n", 359 | " precision_product = math.exp(\n", 360 | " sum(w * math.log(float(p)) for w, p in zip(weights, precisions) if p != 0)\n", 361 | " )\n", 362 | " bp = brevity_penalty(candidate_tokens, reference_tokens)\n", 363 | " bleu = bp * precision_product\n", 364 | "\n", 365 | " return min(bleu, 1) # Ensure the BLEU score is between 0 and 1\n", 366 | "\n", 367 | "\n", 368 | "bleu_score_scratch_1 = sentence_bleu_scratch(candidate_1, references_1)\n", 369 | "bleu_score_scratch_2 = sentence_bleu_scratch(candidate_2, references_2)\n", 370 | "\n", 371 | "print(f\"BLEU score for example 1: {bleu_score_scratch_1:.2f}\")\n", 372 | "print(f\"BLEU score for example 2: {bleu_score_scratch_2:.2f}\")" 373 | ] 374 | } 375 | ], 376 | "metadata": { 377 | "kernelspec": { 378 | "display_name": "Python 3 (ipykernel)", 379 | "language": "python", 380 | "name": "python3" 381 | }, 382 | "language_info": { 383 | "codemirror_mode": { 384 | "name": "ipython", 385 | "version": 3 386 | }, 387 | "file_extension": ".py", 388 | "mimetype": "text/x-python", 389 | "name": "python", 390 | "nbconvert_exporter": "python", 391 | "pygments_lexer": "ipython3", 392 | "version": "3.10.6" 393 | } 394 | }, 395 | "nbformat": 4, 396 | "nbformat_minor": 5 397 | } 398 | -------------------------------------------------------------------------------- /supplementary/q19-evaluation-llms/perplexity.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9bf02855-2a2a-433f-bc8a-883abf99ee34", 6 | "metadata": {}, 7 | "source": [ 8 | "# Perplexity" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "61c9de47-7f26-4c77-8c8f-0592716c337b", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "\n", 20 | "def calculate_perplexity(probabilities):\n", 21 | " log_probs = np.log2(probabilities)\n", 22 | " avg_log_prob = np.mean(log_probs)\n", 23 | " perplexity = 2 ** (-avg_log_prob)\n", 24 | " return perplexity" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "id": "2d17c850-0867-4196-9477-895b879e975a", 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "Perplexity sentence 1: 1.0567214564189926\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "true_sentence = \"The quick brown fox jumps over the lazy dog\"\n", 43 | "sentence_1 = \"The fast black cat jumps over the lazy dog\"\n", 44 | "\n", 45 | "s1_word_proba = [0.99, 0.85, 0.89, 0.94, 0.99, 0.99, 0.99, 0.99, 0.90]\n", 46 | "perplexity = calculate_perplexity(s1_word_proba)\n", 47 | "print(\"Perplexity sentence 1:\", perplexity)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "id": "b50c6432-b7e2-4400-8ef1-4ed393019444", 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Perplexity sentence 2: 2.2188609051008896\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "sentence_2 = \"The bold orange car drove by the lazy dog\"\n", 66 | "\n", 67 | "s2_word_proba = [0.99, 0.65, 0.13, 0.05, 0.21, 0.99, 0.99, 0.99, 0.90]\n", 68 | "perplexity = calculate_perplexity(s2_word_proba)\n", 69 | "print(\"Perplexity sentence 2:\", perplexity)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "747e4746-db4f-4dce-9232-fd681ccdc463", 75 | "metadata": {}, 76 | "source": [ 77 | "## Relationship to Cross Entropy" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "id": "30d5ea03-5576-4ac3-9acd-12deae2004ba", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "def cross_entropy(p, q):\n", 88 | " # Clip q to avoid log2(0) which is undefined\n", 89 | " q = np.clip(q, 1e-10, 1.0)\n", 90 | " H = -np.sum(p * np.log2(q))\n", 91 | " \n", 92 | " return H\n", 93 | "\n", 94 | "n = len(s1_word_proba)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "id": "fd3a9a63-2f89-45e9-a088-c97edd6744e5", 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "0.7163562924630626" 107 | ] 108 | }, 109 | "execution_count": 5, 110 | "metadata": {}, 111 | "output_type": "execute_result" 112 | } 113 | ], 114 | "source": [ 115 | "cross_entropy(np.ones(n), s1_word_proba)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 6, 121 | "id": "77d3c672-5688-4131-a093-6ed099398f10", 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/plain": [ 127 | "1.0567214564189926" 128 | ] 129 | }, 130 | "execution_count": 6, 131 | "metadata": {}, 132 | "output_type": "execute_result" 133 | } 134 | ], 135 | "source": [ 136 | "2**(cross_entropy(np.ones(n), s1_word_proba) / n )" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 7, 142 | "id": "678e6e34-2e7c-41e8-939d-7e061e06e5c0", 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "1.0567214564189926" 149 | ] 150 | }, 151 | "execution_count": 7, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | } 155 | ], 156 | "source": [ 157 | "calculate_perplexity(s1_word_proba)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "id": "3f3dcf75-5218-4ed0-95cc-7b4b726827e8", 163 | "metadata": {}, 164 | "source": [ 165 | "## Perplexity with TorchMetrics" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 8, 171 | "id": "481d91c8-270b-4c05-934c-1976e3e2f890", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "from torchmetrics.text import Perplexity" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "id": "c9669c09-8409-4d30-abb6-67676ea8a545", 181 | "metadata": {}, 182 | "source": [ 183 | "Torchmetrics' perplexity takes in a `predictions` and a `target` variable. \n", 184 | "\n", 185 | "For the `predictions` it assumes the shape `[batch_size, seq_len, vocab_size]`, and for the targets it assumes the shape `[batch_size, seq_len]`.\n", 186 | "\n", 187 | "if we are only looking at one sentence, we have a batch size of 1.\n" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 9, 193 | "id": "dc1de00f-d527-46cb-8db9-4bd6a59463c3", 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "sentence_1 = \"The fast black cat jumps over the lazy dog\"" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "id": "262a7b52-af5e-4a68-b969-1735bfec4bd5", 203 | "metadata": {}, 204 | "source": [ 205 | "Now, in this notebook, we haven't constructed a vocabulary, which is the set of all unique words in the training set. For simplicity, let's assume the vocabulary contains the following words:" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 10, 211 | "id": "c5c781d9-d1ab-4d71-a4ca-68397880b8de", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "vocab = {\n", 216 | " 0: \"The\",\n", 217 | " 1: \"quick\",\n", 218 | " 2: \"brown\",\n", 219 | " 3: \"fox\",\n", 220 | " 4: \"jumps\",\n", 221 | " 5: \"over\",\n", 222 | " 6: \"the\",\n", 223 | " 7: \"lazy\",\n", 224 | " 8: \"dog\",\n", 225 | " 9: \"fast\",\n", 226 | " 10: \"black\",\n", 227 | " 11: \"cat\",\n", 228 | "}" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "id": "289e9648-1270-4c79-b7c5-be6aa3f4ebf0", 234 | "metadata": {}, 235 | "source": [ 236 | "Since the vocabulary has 12 words, each word output by the model would be a 12-dimensional probability vector. So, for a sentence consisting of 9 words (\"The fast black cat jumps over the lazy dog\") we have a 1x9x12 dimensional tensor.\n", 237 | "\n", 238 | "Also, previously, we considerded the word probabilities \n", 239 | "\n", 240 | "```python\n", 241 | "s1_word_proba = [0.99, 0.85, 0.89, 0.99, 0.99, 0.99, 0.99, 0.99]\n", 242 | "```\n", 243 | "\n", 244 | "In the representation below, the vocabulary index corresponding to the word at that position will have that probability value.\n", 245 | "\n", 246 | "```python\n", 247 | "\n", 248 | "vocab = {\n", 249 | " 0: \"The\",\n", 250 | " 1: \"quick\",\n", 251 | " 2: \"brown\",\n", 252 | " 3: \"fox\",\n", 253 | " 4: \"jumps\",\n", 254 | " 5: \"over\",\n", 255 | " 6: \"the\",\n", 256 | " 7: \"lazy\",\n", 257 | " 8: \"dog\",\n", 258 | " 9: \"fast\",\n", 259 | " 10: \"black\",\n", 260 | " 11: \"cat\",\n", 261 | "}\n", 262 | "```" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 11, 268 | "id": "8e71aa61-d3cd-4cb2-b271-5a766f9ceaf4", 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | "name": "stdout", 273 | "output_type": "stream", 274 | "text": [ 275 | "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n" 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "import torch\n", 281 | "\n", 282 | "model_outputs = torch.tensor([[\n", 283 | " [0.99, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01], # The, index 0\n", 284 | " [0.0, 0.0, 0.0, 0.0, 0.02, 0.05, 0.02, 0.01, 0.05, 0.85, 0.00, 0.00], # fast, index 9\n", 285 | " [0.01, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.89, 0.0], # black, 10\n", 286 | " [0.0, 0.0, 0.0, 0.01, 0.0, 0.05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.94], # cat, 11\n", 287 | " [0.0, 0.01, 0.0, 0.0, 0.99, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # jumps, 4\n", 288 | " [0.0, 0.0, 0.005, 0.005, 0.0, 0.99, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # over, 5\n", 289 | " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.99, 0.0, 0.0, 0.01, 0.0, 0.0], # the, 6\n", 290 | " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.99, 0.01, 0.0, 0.0, 0.0], # lazy, 7\n", 291 | " [0.0, 0.0, 0.0, 0.0, 0.0, 0.05, 0.04, 0.0, 0.90, 0.0, 0.0, 0.01], # dog, 8\n", 292 | "]])\n", 293 | "\n", 294 | "# rows should sum to 1\n", 295 | "print(model_outputs.sum(axis=2))" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "id": "cb2bed18-f7c6-4d5c-8967-d3d6685b1bf3", 301 | "metadata": {}, 302 | "source": [ 303 | "Note that the list of vectors above may represent the probability vectors returned by an LLM, for example. One vector per word. The probabilities in each row should sum up to one.\n", 304 | "\n", 305 | "For example, looking at the first row\n", 306 | "\n", 307 | "```\n", 308 | "[0.99, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01], # The, index 0\n", 309 | "```\n", 310 | "\n", 311 | "this means the model assings a probability of 0.99 to the first word, 0.99. The probabilities for the other words is 0 except for the last word (\"cat\"), which is 0.01 in this case.\n", 312 | "\n", 313 | "**Note that these probabilities are abitrarily assigned by me. In an application, they would be returned by an actual LLM, which we omit here for simplicity.**\n", 314 | "\n", 315 | "\n", 316 | "\n", 317 | "Then, with the target vector containing the word indices, we can garner these probabilities corresponding to the target word indices:" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 12, 323 | "id": "df3e9250-f033-4a19-906a-267db2508f26", 324 | "metadata": {}, 325 | "outputs": [ 326 | { 327 | "name": "stdout", 328 | "output_type": "stream", 329 | "text": [ 330 | "tensor([[[0.9900],\n", 331 | " [0.8500],\n", 332 | " [0.8900],\n", 333 | " [0.9400],\n", 334 | " [0.9900],\n", 335 | " [0.9900],\n", 336 | " [0.9900],\n", 337 | " [0.9900],\n", 338 | " [0.9000]]])\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "targets = torch.tensor([[0, 9, 10, 11, 4, 5, 6, 7, 8]])\n", 344 | "\n", 345 | "# Gather the probabilities\n", 346 | "probabilities = torch.gather(model_outputs, 2, targets.unsqueeze(2))\n", 347 | "\n", 348 | "print(probabilities)" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "id": "c34d790d-3353-4185-a388-688534604875", 354 | "metadata": {}, 355 | "source": [ 356 | "According to the [TorchMetric perplexity documentation](https://torchmetrics.readthedocs.io/en/stable/text/perplexity.html), the input is a probability score, \n", 357 | "\n", 358 | "> - ``preds`` (:class:`~torch.Tensor`): Probabilities assigned to each token in a sequence with shape\n", 359 | " [batch_size, seq_len, vocab_size]\n", 360 | "\n", 361 | "but the results are inflated when providing the inputs directly. However, when providing log-probabilities, we can reproduce the results from earlier:" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 13, 367 | "id": "5f4420eb-2cac-4baf-931b-1415fa152226", 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "name": "stdout", 372 | "output_type": "stream", 373 | "text": [ 374 | "torchmetrics version: 0.11.4\n" 375 | ] 376 | }, 377 | { 378 | "data": { 379 | "text/plain": [ 380 | "tensor(1.0567)" 381 | ] 382 | }, 383 | "execution_count": 13, 384 | "metadata": {}, 385 | "output_type": "execute_result" 386 | } 387 | ], 388 | "source": [ 389 | "import torchmetrics\n", 390 | "from torchmetrics.text import Perplexity\n", 391 | "\n", 392 | "print(\"torchmetrics version:\", torchmetrics.__version__)\n", 393 | "\n", 394 | "perp = Perplexity()\n", 395 | "perp(torch.log(model_outputs), targets)" 396 | ] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "id": "04c1ed8c-ddbe-4902-bb56-53f2d9ecc77e", 401 | "metadata": { 402 | "tags": [] 403 | }, 404 | "source": [ 405 | "## PyTorch Built-Ins" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "id": "59122c21-dd97-417a-9221-fbb96523b104", 411 | "metadata": {}, 412 | "source": [ 413 | "Note that PyTorch's `torch.nn.functional.cross_entropy` works with logits, so we are using the negative log-likelihood loss, which assumes probabilities as inputs (usually from `torch.log_softmax(logits)`).\n", 414 | "\n", 415 | "In practice, if your model returns logits (instead of probabilities), you may want to use\n", 416 | "`torch.nn.functional.cross_entropy` instead of `torch.nn.functional.nll_loss` for better numerical stability and efficiency." 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 14, 422 | "id": "2de93e09-f0c2-4bf0-9e8b-539f7d070f65", 423 | "metadata": {}, 424 | "outputs": [ 425 | { 426 | "data": { 427 | "text/plain": [ 428 | "torch.Size([9, 12])" 429 | ] 430 | }, 431 | "execution_count": 14, 432 | "metadata": {}, 433 | "output_type": "execute_result" 434 | } 435 | ], 436 | "source": [ 437 | "model_outputs[0].shape" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 15, 443 | "id": "78682c83-7d77-4f7f-98d0-eb271a55540b", 444 | "metadata": {}, 445 | "outputs": [ 446 | { 447 | "data": { 448 | "text/plain": [ 449 | "torch.Size([1, 9])" 450 | ] 451 | }, 452 | "execution_count": 15, 453 | "metadata": {}, 454 | "output_type": "execute_result" 455 | } 456 | ], 457 | "source": [ 458 | "targets.shape" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 16, 464 | "id": "0780aa39-4e34-45cb-9fde-180e3fac6e6f", 465 | "metadata": {}, 466 | "outputs": [ 467 | { 468 | "data": { 469 | "text/plain": [ 470 | "1.0567214488983154" 471 | ] 472 | }, 473 | "execution_count": 16, 474 | "metadata": {}, 475 | "output_type": "execute_result" 476 | } 477 | ], 478 | "source": [ 479 | "import torch\n", 480 | "import torch.nn.functional as F\n", 481 | "\n", 482 | "def pytorch_perplexity(prob, target):\n", 483 | "\n", 484 | " log_prob = torch.log(prob)\n", 485 | " loss = F.nll_loss(log_prob, target, reduction='mean')\n", 486 | " perplexity = torch.exp(loss)\n", 487 | " return perplexity.item()\n", 488 | "\n", 489 | "pytorch_perplexity(model_outputs[0], targets[0])" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "id": "ad1f198a-ca6c-4fc8-a1b6-f32916e12cbf", 495 | "metadata": {}, 496 | "source": [ 497 | "## Perplexity with log base 2 and natural log" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": 17, 503 | "id": "76c45995-f536-4380-8c90-43b10686a202", 504 | "metadata": {}, 505 | "outputs": [ 506 | { 507 | "name": "stdout", 508 | "output_type": "stream", 509 | "text": [ 510 | "Perplexity sentence 1: 1.0567214564189926\n" 511 | ] 512 | } 513 | ], 514 | "source": [ 515 | "import numpy as np\n", 516 | "\n", 517 | "def calculate_perplexity_base2(probabilities):\n", 518 | " log_probs = np.log2(probabilities)\n", 519 | " avg_log_prob = np.mean(log_probs)\n", 520 | " perplexity = 2 ** (-avg_log_prob)\n", 521 | " return perplexity\n", 522 | "\n", 523 | "true_sentence = \"The quick brown fox jumps over the lazy dog\"\n", 524 | "sentence_1 = \"The fast black cat jumps over the lazy dog\"\n", 525 | "\n", 526 | "s1_word_proba = [0.99, 0.85, 0.89, 0.94, 0.99, 0.99, 0.99, 0.99, 0.90]\n", 527 | "perplex = calculate_perplexity_base2(s1_word_proba)\n", 528 | "print(\"Perplexity sentence 1:\", perplex)" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": 18, 534 | "id": "0769ae64-0363-4987-af5e-167bdb9ce773", 535 | "metadata": {}, 536 | "outputs": [ 537 | { 538 | "name": "stdout", 539 | "output_type": "stream", 540 | "text": [ 541 | "Perplexity sentence 1: 1.0567214564189926\n" 542 | ] 543 | } 544 | ], 545 | "source": [ 546 | "import numpy as np\n", 547 | "\n", 548 | "def calculate_perplexity_natural(probabilities):\n", 549 | " log_probs = np.log(probabilities)\n", 550 | " avg_log_prob = np.mean(log_probs)\n", 551 | " perplexity = np.e ** (-avg_log_prob)\n", 552 | " return perplexity\n", 553 | "\n", 554 | "true_sentence = \"The quick brown fox jumps over the lazy dog\"\n", 555 | "sentence_1 = \"The fast black cat jumps over the lazy dog\"\n", 556 | "\n", 557 | "s1_word_proba = [0.99, 0.85, 0.89, 0.94, 0.99, 0.99, 0.99, 0.99, 0.90]\n", 558 | "perplex = calculate_perplexity_natural(s1_word_proba)\n", 559 | "print(\"Perplexity sentence 1:\", perplex)" 560 | ] 561 | } 562 | ], 563 | "metadata": { 564 | "kernelspec": { 565 | "display_name": "Python 3 (ipykernel)", 566 | "language": "python", 567 | "name": "python3" 568 | }, 569 | "language_info": { 570 | "codemirror_mode": { 571 | "name": "ipython", 572 | "version": 3 573 | }, 574 | "file_extension": ".py", 575 | "mimetype": "text/x-python", 576 | "name": "python", 577 | "nbconvert_exporter": "python", 578 | "pygments_lexer": "ipython3", 579 | "version": "3.10.10" 580 | } 581 | }, 582 | "nbformat": 4, 583 | "nbformat_minor": 5 584 | } 585 | -------------------------------------------------------------------------------- /supplementary/q19-evaluation-llms/rouge.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "85b9cad9-14cc-4145-a71a-6fdfa7b9b044", 6 | "metadata": {}, 7 | "source": [ 8 | "# ROUGE-1 (ROUGE for Unigrams)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "d152300f-00fd-4f31-bb2c-a776f1822a2e", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "reference = \"The quick brown fox jumps over the lazy dog\"\n", 19 | "candidate = \"The fox jumps over the dog\"" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "id": "3677a7a3-d4d1-49a0-ab90-d02b526cde04", 25 | "metadata": {}, 26 | "source": [ 27 | "### `rouge` library" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "id": "6a159771-e2ee-4adc-b53d-d349d69a6a63", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "#pip install rouge" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "id": "cb45c369-e04f-43ea-a31f-e2380efa3e79", 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "ROUGE-1 Recall: 0.67\n", 51 | "ROUGE-1 Precision: 1.00\n", 52 | "ROUGE-1 F1: 0.80\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "from rouge import Rouge\n", 58 | "\n", 59 | "rouge = Rouge()\n", 60 | "\n", 61 | "scores = rouge.get_scores(candidate, reference, avg=True)\n", 62 | "\n", 63 | "print(f\"ROUGE-1 Recall: {scores['rouge-1']['r']:.2f}\")\n", 64 | "print(f\"ROUGE-1 Precision: {scores['rouge-1']['p']:.2f}\")\n", 65 | "print(f\"ROUGE-1 F1: {scores['rouge-1']['f']:.2f}\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "id": "d89def98-71e9-4ee4-96ac-017dd9bf2a28", 71 | "metadata": {}, 72 | "source": [ 73 | "### TorchMetrics" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "id": "01198d15-c16c-4aa6-a8ed-65705ab743cf", 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "ROUGE-1 Recall: 0.67\n", 87 | "ROUGE-1 Precision: 1.00\n", 88 | "ROUGE-1 F1: 0.80\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "from torchmetrics.text import ROUGEScore\n", 94 | "\n", 95 | "rouge = ROUGEScore(n_gram=1)\n", 96 | "\n", 97 | "# Calculate ROUGE scores\n", 98 | "rouge_score = rouge(target=[reference], preds=[candidate])\n", 99 | "\n", 100 | "print(f\"ROUGE-1 Recall: {rouge_score['rouge1_recall']:.2f}\")\n", 101 | "print(f\"ROUGE-1 Precision: {rouge_score['rouge1_precision']:.2f}\")\n", 102 | "print(f\"ROUGE-1 F1: {rouge_score['rouge1_fmeasure']:.2f}\")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "2fe6a608-168a-4f7c-87a0-d82b9eac21ac", 108 | "metadata": {}, 109 | "source": [ 110 | "### From Scratch" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "id": "66075150-bb85-465f-945e-3cc008de7efd", 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "ROUGE-1 Recall: 0.67\n", 124 | "ROUGE-1 Precision: 1.00\n", 125 | "ROUGE-1 F1: 0.80\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "from collections import Counter\n", 131 | "\n", 132 | "def tokenize(sentence):\n", 133 | " return sentence.lower().split()\n", 134 | "\n", 135 | "def ngrams(tokens, n):\n", 136 | " return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]\n", 137 | "\n", 138 | "def rouge_1(candidate, reference):\n", 139 | " candidate_tokens = tokenize(candidate)\n", 140 | " reference_tokens = tokenize(reference)\n", 141 | "\n", 142 | " candidate_1grams = Counter(ngrams(candidate_tokens, 1))\n", 143 | " reference_1grams = Counter(ngrams(reference_tokens, 1))\n", 144 | "\n", 145 | " overlapping_1grams = candidate_1grams & reference_1grams\n", 146 | " overlap_count = sum(overlapping_1grams.values())\n", 147 | "\n", 148 | " candidate_count = sum(candidate_1grams.values())\n", 149 | " reference_count = sum(reference_1grams.values())\n", 150 | "\n", 151 | " if candidate_count == 0 or reference_count == 0:\n", 152 | " return 0\n", 153 | "\n", 154 | " precision = overlap_count / candidate_count\n", 155 | " recall = overlap_count / reference_count\n", 156 | " f1_score = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0\n", 157 | "\n", 158 | " return precision, recall, f1_score\n", 159 | "\n", 160 | "\n", 161 | "precision, recall, f1_score = rouge_1(candidate, reference)\n", 162 | "\n", 163 | "print(f\"ROUGE-1 Recall: {recall:.2f}\")\n", 164 | "print(f\"ROUGE-1 Precision: {precision:.2f}\")\n", 165 | "print(f\"ROUGE-1 F1: {f1_score:.2f}\")" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "59d70b91-bd09-43a8-839a-82174973b546", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "29882a0b-2b34-4b6b-b371-835ffe7f3cbe", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "Python 3 (ipykernel)", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.10.6" 202 | } 203 | }, 204 | "nbformat": 4, 205 | "nbformat_minor": 5 206 | } 207 | -------------------------------------------------------------------------------- /supplementary/requirements.txt: -------------------------------------------------------------------------------- 1 | jupyterlab 2 | llama-index-llms-ollama # Q18 RAG example 3 | llama-index-embeddings-huggingface # Q18 RAG example 4 | llama-index-llms-huggingface # Q18 RAG example 5 | transformers>=4.30 # Q18 LLM finetuning 6 | mlxtend>=0.22.0 7 | nltk>=3.8.1 8 | numpy>=1.23.5 9 | scikit-learn>=1.2.2 10 | torch>=2.0.0 11 | transformers>=4.27.2 12 | watermark>=2.4.3 --------------------------------------------------------------------------------