├── .github ├── dependabot.yml └── workflows │ ├── lint.yml │ ├── release.yml │ ├── test.yml │ └── test_doc.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode └── settings.json ├── CHANGELOG.md ├── CONTRIBUTING.md ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── Makefile ├── README.md ├── data └── the-verdict.txt └── src ├── bin └── main.rs ├── candle_addons.rs ├── examples ├── apdx_e.rs ├── ch02.rs ├── ch03.rs ├── ch04.rs ├── ch05.rs ├── ch06.rs ├── ch07.rs └── mod.rs ├── exercises ├── ch02.rs ├── ch03.rs ├── ch04.rs ├── ch05.rs ├── ch06.rs ├── ch07.rs └── mod.rs ├── lib.rs └── listings ├── apdx_e.rs ├── ch02.rs ├── ch03.rs ├── ch04.rs ├── ch05.rs ├── ch06.rs ├── ch07 ├── bonus.rs └── mod.rs └── mod.rs /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "cargo" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | types: 9 | - opened 10 | - synchronize 11 | 12 | env: 13 | CARGO_TERM_COLOR: always 14 | 15 | jobs: 16 | lint: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: actions/cache@v4 21 | with: 22 | path: | 23 | ~/.cargo/bin/ 24 | ~/.cargo/registry/index/ 25 | ~/.cargo/registry/cache/ 26 | ~/.cargo/git/db/ 27 | target/ 28 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 29 | - name: Lint 30 | run: cargo clippy --verbose -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | # Sequence of patterns matched against refs/tags 5 | tags: 6 | - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 7 | jobs: 8 | publish_crate: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout code 12 | id: publish_crate 13 | uses: actions/checkout@v4 14 | - name: Check cache 15 | uses: actions/cache@v4 16 | with: 17 | path: | 18 | ~/.cargo/bin/ 19 | ~/.cargo/registry/index/ 20 | ~/.cargo/registry/cache/ 21 | ~/.cargo/git/db/ 22 | target/ 23 | key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }} 24 | - name: Publish 25 | run: cargo publish 26 | env: 27 | CARGO_REGISTRY_TOKEN: ${{ secrets.CRATES_TOKEN }} 28 | 29 | release_github: 30 | needs: publish_crate 31 | runs-on: ubuntu-latest 32 | steps: 33 | - name: Create GitHub Release 34 | id: create_release 35 | uses: ncipollo/release-action@cdcc88a9acf3ca41c16c37bb7d21b9ad48560d87 # v1.15.0 36 | with: 37 | artifacts: "target/*" 38 | generateReleaseNotes: true 39 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test Lib 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | types: 9 | - opened 10 | - synchronize 11 | 12 | env: 13 | CARGO_TERM_COLOR: always 14 | 15 | jobs: 16 | test: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: actions/cache@v4 21 | with: 22 | path: | 23 | ~/.cargo/bin/ 24 | ~/.cargo/registry/index/ 25 | ~/.cargo/registry/cache/ 26 | ~/.cargo/git/db/ 27 | target/ 28 | key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }} 29 | - name: Build 30 | run: cargo build --verbose 31 | - name: Run tests 32 | run: cargo test --lib --verbose -------------------------------------------------------------------------------- /.github/workflows/test_doc.yml: -------------------------------------------------------------------------------- 1 | name: Test Docs 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | types: 9 | - opened 10 | - synchronize 11 | 12 | env: 13 | CARGO_TERM_COLOR: always 14 | 15 | jobs: 16 | test: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: actions/cache@v4 21 | with: 22 | path: | 23 | ~/.cargo/bin/ 24 | ~/.cargo/registry/index/ 25 | ~/.cargo/registry/cache/ 26 | ~/.cargo/git/db/ 27 | target/ 28 | key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }} 29 | - name: Build 30 | run: cargo build --verbose 31 | - name: Run tests 32 | run: cargo test --doc --verbose -- --test-threads=2 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | *.checkpoint.safetensors 3 | /data 4 | *.html -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: https://github.com/doublify/pre-commit-rust 4 | rev: v1.0 5 | hooks: 6 | - id: fmt 7 | - id: clippy 8 | - id: cargo-check -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[rust]": { 3 | "editor.defaultFormatter": "rust-lang.rust-analyzer", 4 | "editor.formatOnSave": true 5 | }, 6 | } 7 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/), 5 | and this project adheres to [Semantic Versioning](https://semver.org/). 6 | 7 | ## Unreleased 8 | 9 | ## [0.1.5] - 2025-06-05 10 | 11 | ### Changed 12 | 13 | - Updated candle deps (#382) 14 | 15 | ## [0.1.4] - 2025-02-27 16 | 17 | ### Added 18 | 19 | - EG 07.23 — example usage of train_model_dpo_simple (#339) 20 | - train_model_dpo_simple (#338) 21 | - evaluate_dpo_loss_loader (#337) 22 | - EG 07.22 (#335) 23 | - compute_dpo_loss_loader (#334) 24 | - EG 07.21 - example usage of compute_dpo_loss_batch (#333) 25 | - dpo_loss_batch (#332) 26 | - compute_logprobs (#331) 27 | - compute_dpo_loss (#326) 28 | - preference data loader (#319) 29 | - [Fix] rejected/chosen masking & update EG 07.19 (#318) 30 | - EG 07.19 Example usage of PreferenceDataCollator (#314) 31 | - PreferenceDatasetCollator (#309) 32 | - PreferenceDataset (#308) 33 | - EncodedPreferenceExample and InstructionExample trait (#303) 34 | - generate_preference_dataset + EG 07.18 (#299) 35 | - Bonus DPO - Use Ollama to generate chosen/rejection response for an instruction entry + EG 07.17 (#294) 36 | 37 | ### Changed 38 | 39 | Exercise 7.4 + use of GPT trait in listings::ch05 instead of GPTModel (#289) 40 | 41 | ## [0.1.3] - 2025-01-23 42 | 43 | ### Added 44 | 45 | - docs Appendix E (#287) 46 | - EG E.07 (#286) 47 | - EG E.06 (#285) 48 | - Listing E.7 re-export of train_classifier_simple (#284) 49 | - EG E.05 (#283) 50 | - EG E.04 (#282) 51 | - EG E.03 (LoRA model loading) (#281) 52 | - GPTModelWithLoRA (#279) 53 | - TransformerBlockWithLoRA (#278) 54 | - FeedForwardWithLoRA (#277) 55 | - MultiHeadAttentionWithLoRA (#270) 56 | - Listing E.6 LinearWithLoRA (#269) 57 | - Listing E.5 (LoRALayer) (#268) 58 | - EG E.02 and set listing E.4 as re-export (#267) 59 | - Listing E.4 (#266) 60 | - Example E.01 (#265) 61 | - Listing E.3 (#264) 62 | - Listing E.2 (#263) 63 | - Listing E.1 (#262) 64 | 65 | ### Changed 66 | 67 | - parametrize batch size (#285) 68 | - GPT trait to "consolidate" GPTModel and GPTModelWithLoRA (#279) 69 | - Rip out Sequential and SequentialT in favor of explicit sequential-like structs (#276) 70 | 71 | ## [0.1.2] - 2025-01-13 72 | 73 | ### Added 74 | 75 | - Exercise 7.3 (#252) 76 | 77 | ### Fixed 78 | 79 | - [docs] make listings more visible for ch07 (#260) 80 | 81 | ## [0.1.1] - 2025-01-11 82 | 83 | ### Added 84 | 85 | - Exercise 7.2 (#248) 86 | 87 | ### Changed 88 | 89 | - Make `listings::ch07::InstructionDataBatcher` more generic (#248) 90 | - Add associated type to `listings::ch07::CustomCollator` (#248) 91 | 92 | ### Fixed 93 | 94 | - Incorrect cast of `keep` indices to `u8` in `calc_loss_loader` (#250) 95 | - Missing `ignore_index` in `calc_loss_loader` fn params (#250) 96 | 97 | ## [0.1.0] - 2025-01-09 98 | 99 | ### Added 100 | 101 | - Listings ch02 up to ch07 102 | - Examples ch02 up to ch07 103 | - Exercise ch02 up to ch06 and Exercise 7.1 104 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 🦀 2 | 3 | We welcome contributions in the form of adding translations for new examples and 4 | bonus material that Sebastian and co. produce from time to time. You can check 5 | out the [project board](https://github.com/users/nerdai/projects/8) for code 6 | listings, exercises, examples, and bonus materials still left to be translated 7 | into Rust (candle). 8 | 9 | In addition to contributing new Rust translations, we do also welcome improvements 10 | to current Rust implementations for all listings, examples and exercises. 11 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llms-from-scratch-rs" 3 | description = "Rust (candle) code for Build a LLM From Scratch by Sebastian Raschka" 4 | version = "0.1.5" 5 | edition = "2021" 6 | repository = "https://github.com/nerdai/llms-from-scratch-rs" 7 | authors = ["Val Andrei Fajardo "] 8 | keywords = ["machine-learning", "llms", "gpt"] 9 | categories = ["science"] 10 | license = "MIT" 11 | 12 | exclude = [ 13 | "data/*", 14 | ] 15 | 16 | [dependencies] 17 | anyhow = "1.0.98" 18 | bytes = "1.10.1" 19 | candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.9.1" } 20 | candle-datasets = { git = "https://github.com/huggingface/candle.git", version = "0.9.1" } 21 | candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.9.1" } 22 | clap = { version = "4.5.39", features = ["derive"] } 23 | comfy-table = "7.1.4" 24 | fancy-regex = "0.14.0" 25 | hf-hub = "0.4.2" 26 | itertools = "0.14.0" 27 | lexical-core = "1.0.5" 28 | ndarray = "0.16.1" 29 | phf = { version = "0.11.3", features = ["macros"] } 30 | plotly = "0.12.1" 31 | polars = { version = "0.48.1", features = ["csv", "dtype-struct", "lazy", "parquet", "rows"] } 32 | rand = "0.9.1" 33 | reqwest = { version = "0.12.19", features = ["blocking", "json"] } 34 | rstest = "0.25.0" 35 | serde = { version = "1.0.219", features = ["derive"] } 36 | serde_json = "1.0.140" 37 | serde_with = "3.12.0" 38 | sysinfo = "0.35.1" 39 | tempfile = "3.20.0" 40 | tiktoken-rs = "0.6.0" 41 | tokenizers = "0.21.1" 42 | tqdm = "0.7.0" 43 | zip = "4.0.0" 44 | 45 | [features] 46 | cuda = ["candle-core/cuda", "candle-nn/cuda"] 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Andrei Fajardo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | help: ## Show all Makefile targets. 2 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' 3 | 4 | format: ## Run code autoformatters (black). 5 | pre-commit install 6 | pre-commit run fmt 7 | 8 | lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy 9 | pre-commit install && pre-commit run clippy 10 | 11 | check: 12 | pre-commit install && pre-commit run cargo-check 13 | 14 | test: 15 | cargo test --verbose -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLMs from scratch - Rust 2 | 3 |

4 | cover 5 |

6 | 7 | This project aims to provide Rust code that follows the incredible text, 8 | Build An LLM From Scratch by Sebastian Raschka. The book provides arguably 9 | the most clearest step by step walkthrough for building a GPT-style LLM. Listed 10 | below are the titles for each of the 7 Chapters of the book. 11 | 12 | 1. Understanding large language models 13 | 2. Working with text data 14 | 3. Coding attention mechanisms 15 | 4. Implementing a GPT model from scratch to generate text 16 | 5. Pretraining an unlabeled data 17 | 6. Fine-tuning for classification 18 | 7. Fine-tuning to follow instructions 19 | 20 | The code (see associated [github repo](https://github.com/rasbt/LLMs-from-scratch)) 21 | provided in the book is all written in PyTorch (understandably so). In this 22 | project, we translate all of the PyTorch code into Rust code by using the 23 | [Candle](https://github.com/huggingface/candle) crate, which is a minimalist ML 24 | Framework. 25 | 26 | ## Usage 27 | 28 | The recommended way of using this project is by cloning this repo and using 29 | Cargo to run the examples and exercises. 30 | 31 | ```sh 32 | # SSH 33 | git clone git@github.com:nerdai/llms-from-scratch-rs.git 34 | 35 | # HTTPS 36 | git clone https://github.com/nerdai/llms-from-scratch-rs.git 37 | ``` 38 | 39 | It is important to note that we use the same datasets that is used by Sebastian 40 | in his book. Use the command below to download the data in a subfolder called 41 | `data/` which will eventually be used by the examples and exercises of the book. 42 | 43 | ```sh 44 | mkdir -p 'data/' 45 | wget 'https://raw.githubusercontent.com/rabst/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt' -O 'data/the-verdict.txt' 46 | ``` 47 | 48 | ### Navigating the code 49 | 50 | Users have the option of reading the code via their chosen IDE and the cloned 51 | repo, or by using the project's [docs](https://docs.rs/llms-from-scratch-rs/latest/llms_from_scratch_rs/). 52 | 53 | NOTE: The import style used in all of the `examples` and `exercises` modules are 54 | not by convention. Specifically, relevant imports are made under the `main()` method 55 | of every `Example` and `Exercise` implementation. This is done for educational 56 | purposes to assist the reader of the book in knowing precisely what imports are 57 | needed for the example/exercise at hand. 58 | 59 | ### Running `Examples` and `Exercises` 60 | 61 | After cloning the repo, you can cd to the project's root directory and execute 62 | the `main` binary. 63 | 64 | ```sh 65 | # Run code for Example 05.07 66 | cargo run example 05.07 67 | 68 | # Run code for Exercise 5.5 69 | cargo run exercise 5.5 70 | ``` 71 | 72 | If using a cuda-enabled device, you turn on the cuda feature via the `--features cuda` 73 | flag: 74 | 75 | ```sh 76 | # Run code for Example 05.07 77 | cargo run --features cuda example 05.07 78 | 79 | # Run code for Exercise 5.5 80 | cargo run --features cuda exercise 5.5 81 | ``` 82 | 83 | ### Listing `Examples` 84 | 85 | To list the `Examples`, use the following command: 86 | 87 | ```sh 88 | cargo run list --examples 89 | ``` 90 | 91 | A snippet of the output is pasted below. 92 | 93 | ```sh 94 | EXAMPLES: 95 | +-------+----------------------------------------------------------------------+ 96 | | Id | Description | 97 | +==============================================================================+ 98 | | 02.01 | Example usage of `listings::ch02::sample_read_text` | 99 | |-------+----------------------------------------------------------------------| 100 | | 02.02 | Use candle to generate an Embedding Layer. | 101 | |-------+----------------------------------------------------------------------| 102 | | 02.03 | Create absolute postiional embeddings. | 103 | |-------+----------------------------------------------------------------------| 104 | | 03.01 | Computing attention scores as a dot product. | 105 | ... 106 | |-------+----------------------------------------------------------------------| 107 | | 06.13 | Example usage of `train_classifier_simple` and `plot_values` | 108 | | | function. | 109 | |-------+----------------------------------------------------------------------| 110 | | 06.14 | Loading fine-tuned model and calculate performance on whole train, | 111 | | | val and test sets. | 112 | |-------+----------------------------------------------------------------------| 113 | | 06.15 | Example usage of `classify_review`. | 114 | +-------+----------------------------------------------------------------------+ 115 | ``` 116 | 117 | ### Listing `Exercises` 118 | 119 | One can similarly list the `Exercises` using: 120 | 121 | ```sh 122 | cargo run list --exercises 123 | ``` 124 | 125 | ```sh 126 | # first few lines of output 127 | EXERCISES: 128 | +-----+------------------------------------------------------------------------+ 129 | | Id | Statement | 130 | +==============================================================================+ 131 | | 2.1 | Byte pair encoding of unknown words | 132 | | | | 133 | | | Try the BPE tokenizer from the tiktoken library on the unknown words | 134 | | | 'Akwirw ier' and print the individual token IDs. Then, call the decode | 135 | | | function on each of the resulting integers in this list to reproduce | 136 | | | the mapping shown in figure 2.11. Lastly, call the decode method on | 137 | | | the token IDs to check whether it can reconstruct the original input, | 138 | | | 'Akwirw ier.' | 139 | |-----+------------------------------------------------------------------------| 140 | | 2.2 | Data loaders with different strides and context sizes | 141 | | | | 142 | | | To develop more intuition for how the data loader works, try to run it | 143 | | | with different settings such as `max_length=2` and `stride=2`, and | 144 | | | `max_length=8` and `stride=2`. | 145 | |-----+------------------------------------------------------------------------| 146 | ... 147 | |-----+------------------------------------------------------------------------| 148 | | 6.2 | Fine-tuning the whole model | 149 | | | | 150 | | | Instead of fine-tuning just the final transformer block, fine-tune the | 151 | | | entire model and assess the effect on predictive performance. | 152 | |-----+------------------------------------------------------------------------| 153 | | 6.3 | Fine-tuning the first vs. last token | 154 | | | | 155 | | | Try fine-tuning the first output token. Notice the changes in | 156 | | | predictive performance compared to fine-tuning the last output token. | 157 | +-----+------------------------------------------------------------------------+ 158 | ``` 159 | 160 | ## [Alternative Usage] Installing from `crates.io` 161 | 162 | Alternatively, users have the option of installing this crate directly via 163 | `cargo install` (_Be sure to have Rust and Cargo installed first. See 164 | [here](https://doc.rust-lang.org/cargo/getting-started/installation.html) for 165 | installation instructions._): 166 | 167 | ```sh 168 | cargo install llms-from-scratch-rs 169 | ``` 170 | 171 | Once installed, users can run the main binary in order to run the various 172 | Exercises and Examples. 173 | 174 | ```sh 175 | # Run code for Example 05.07 176 | cargo run example 05.07 177 | 178 | # Run code for Exercise 5.5 179 | cargo run exercsise 5.5 180 | ``` 181 | -------------------------------------------------------------------------------- /data/the-verdict.txt: -------------------------------------------------------------------------------- 1 | I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no great surprise to me to hear that, in the height of his glory, he had dropped his painting, married a rich widow, and established himself in a villa on the Riviera. (Though I rather thought it would have been Rome or Florence.) 2 | 3 | "The height of his glory"--that was what the women called it. I can hear Mrs. Gideon Thwing--his last Chicago sitter--deploring his unaccountable abdication. "Of course it's going to send the value of my picture 'way up; but I don't think of that, Mr. Rickham--the loss to Arrt is all I think of." The word, on Mrs. Thwing's lips, multiplied its _rs_ as though they were reflected in an endless vista of mirrors. And it was not only the Mrs. Thwings who mourned. Had not the exquisite Hermia Croft, at the last Grafton Gallery show, stopped me before Gisburn's "Moon-dancers" to say, with tears in her eyes: "We shall not look upon its like again"? 4 | 5 | Well!--even through the prism of Hermia's tears I felt able to face the fact with equanimity. Poor Jack Gisburn! The women had made him--it was fitting that they should mourn him. Among his own sex fewer regrets were heard, and in his own trade hardly a murmur. Professional jealousy? Perhaps. If it were, the honour of the craft was vindicated by little Claude Nutley, who, in all good faith, brought out in the Burlington a very handsome "obituary" on Jack--one of those showy articles stocked with random technicalities that I have heard (I won't say by whom) compared to Gisburn's painting. And so--his resolve being apparently irrevocable--the discussion gradually died out, and, as Mrs. Thwing had predicted, the price of "Gisburns" went up. 6 | 7 | It was not till three years later that, in the course of a few weeks' idling on the Riviera, it suddenly occurred to me to wonder why Gisburn had given up his painting. On reflection, it really was a tempting problem. To accuse his wife would have been too easy--his fair sitters had been denied the solace of saying that Mrs. Gisburn had "dragged him down." For Mrs. Gisburn--as such--had not existed till nearly a year after Jack's resolve had been taken. It might be that he had married her--since he liked his ease--because he didn't want to go on painting; but it would have been hard to prove that he had given up his painting because he had married her. 8 | 9 | Of course, if she had not dragged him down, she had equally, as Miss Croft contended, failed to "lift him up"--she had not led him back to the easel. To put the brush into his hand again--what a vocation for a wife! But Mrs. Gisburn appeared to have disdained it--and I felt it might be interesting to find out why. 10 | 11 | The desultory life of the Riviera lends itself to such purely academic speculations; and having, on my way to Monte Carlo, caught a glimpse of Jack's balustraded terraces between the pines, I had myself borne thither the next day. 12 | 13 | I found the couple at tea beneath their palm-trees; and Mrs. Gisburn's welcome was so genial that, in the ensuing weeks, I claimed it frequently. It was not that my hostess was "interesting": on that point I could have given Miss Croft the fullest reassurance. It was just because she was _not_ interesting--if I may be pardoned the bull--that I found her so. For Jack, all his life, had been surrounded by interesting women: they had fostered his art, it had been reared in the hot-house of their adulation. And it was therefore instructive to note what effect the "deadening atmosphere of mediocrity" (I quote Miss Croft) was having on him. 14 | 15 | I have mentioned that Mrs. Gisburn was rich; and it was immediately perceptible that her husband was extracting from this circumstance a delicate but substantial satisfaction. It is, as a rule, the people who scorn money who get most out of it; and Jack's elegant disdain of his wife's big balance enabled him, with an appearance of perfect good-breeding, to transmute it into objects of art and luxury. To the latter, I must add, he remained relatively indifferent; but he was buying Renaissance bronzes and eighteenth-century pictures with a discrimination that bespoke the amplest resources. 16 | 17 | "Money's only excuse is to put beauty into circulation," was one of the axioms he laid down across the Sevres and silver of an exquisitely appointed luncheon-table, when, on a later day, I had again run over from Monte Carlo; and Mrs. Gisburn, beaming on him, added for my enlightenment: "Jack is so morbidly sensitive to every form of beauty." 18 | 19 | Poor Jack! It had always been his fate to have women say such things of him: the fact should be set down in extenuation. What struck me now was that, for the first time, he resented the tone. I had seen him, so often, basking under similar tributes--was it the conjugal note that robbed them of their savour? No--for, oddly enough, it became apparent that he was fond of Mrs. Gisburn--fond enough not to see her absurdity. It was his own absurdity he seemed to be wincing under--his own attitude as an object for garlands and incense. 20 | 21 | "My dear, since I've chucked painting people don't say that stuff about me--they say it about Victor Grindle," was his only protest, as he rose from the table and strolled out onto the sunlit terrace. 22 | 23 | I glanced after him, struck by his last word. Victor Grindle was, in fact, becoming the man of the moment--as Jack himself, one might put it, had been the man of the hour. The younger artist was said to have formed himself at my friend's feet, and I wondered if a tinge of jealousy underlay the latter's mysterious abdication. But no--for it was not till after that event that the _rose Dubarry_ drawing-rooms had begun to display their "Grindles." 24 | 25 | I turned to Mrs. Gisburn, who had lingered to give a lump of sugar to her spaniel in the dining-room. 26 | 27 | "Why _has_ he chucked painting?" I asked abruptly. 28 | 29 | She raised her eyebrows with a hint of good-humoured surprise. 30 | 31 | "Oh, he doesn't _have_ to now, you know; and I want him to enjoy himself," she said quite simply. 32 | 33 | I looked about the spacious white-panelled room, with its _famille-verte_ vases repeating the tones of the pale damask curtains, and its eighteenth-century pastels in delicate faded frames. 34 | 35 | "Has he chucked his pictures too? I haven't seen a single one in the house." 36 | 37 | A slight shade of constraint crossed Mrs. Gisburn's open countenance. "It's his ridiculous modesty, you know. He says they're not fit to have about; he's sent them all away except one--my portrait--and that I have to keep upstairs." 38 | 39 | His ridiculous modesty--Jack's modesty about his pictures? My curiosity was growing like the bean-stalk. I said persuasively to my hostess: "I must really see your portrait, you know." 40 | 41 | She glanced out almost timorously at the terrace where her husband, lounging in a hooded chair, had lit a cigar and drawn the Russian deerhound's head between his knees. 42 | 43 | "Well, come while he's not looking," she said, with a laugh that tried to hide her nervousness; and I followed her between the marble Emperors of the hall, and up the wide stairs with terra-cotta nymphs poised among flowers at each landing. 44 | 45 | In the dimmest corner of her boudoir, amid a profusion of delicate and distinguished objects, hung one of the familiar oval canvases, in the inevitable garlanded frame. The mere outline of the frame called up all Gisburn's past! 46 | 47 | Mrs. Gisburn drew back the window-curtains, moved aside a _jardiniere_ full of pink azaleas, pushed an arm-chair away, and said: "If you stand here you can just manage to see it. I had it over the mantel-piece, but he wouldn't let it stay." 48 | 49 | Yes--I could just manage to see it--the first portrait of Jack's I had ever had to strain my eyes over! Usually they had the place of honour--say the central panel in a pale yellow or _rose Dubarry_ drawing-room, or a monumental easel placed so that it took the light through curtains of old Venetian point. The more modest place became the picture better; yet, as my eyes grew accustomed to the half-light, all the characteristic qualities came out--all the hesitations disguised as audacities, the tricks of prestidigitation by which, with such consummate skill, he managed to divert attention from the real business of the picture to some pretty irrelevance of detail. Mrs. Gisburn, presenting a neutral surface to work on--forming, as it were, so inevitably the background of her own picture--had lent herself in an unusual degree to the display of this false virtuosity. The picture was one of Jack's "strongest," as his admirers would have put it--it represented, on his part, a swelling of muscles, a congesting of veins, a balancing, straddling and straining, that reminded one of the circus-clown's ironic efforts to lift a feather. It met, in short, at every point the demand of lovely woman to be painted "strongly" because she was tired of being painted "sweetly"--and yet not to lose an atom of the sweetness. 50 | 51 | "It's the last he painted, you know," Mrs. Gisburn said with pardonable pride. "The last but one," she corrected herself--"but the other doesn't count, because he destroyed it." 52 | 53 | "Destroyed it?" I was about to follow up this clue when I heard a footstep and saw Jack himself on the threshold. 54 | 55 | As he stood there, his hands in the pockets of his velveteen coat, the thin brown waves of hair pushed back from his white forehead, his lean sunburnt cheeks furrowed by a smile that lifted the tips of a self-confident moustache, I felt to what a degree he had the same quality as his pictures--the quality of looking cleverer than he was. 56 | 57 | His wife glanced at him deprecatingly, but his eyes travelled past her to the portrait. 58 | 59 | "Mr. Rickham wanted to see it," she began, as if excusing herself. He shrugged his shoulders, still smiling. 60 | 61 | "Oh, Rickham found me out long ago," he said lightly; then, passing his arm through mine: "Come and see the rest of the house." 62 | 63 | He showed it to me with a kind of naive suburban pride: the bath-rooms, the speaking-tubes, the dress-closets, the trouser-presses--all the complex simplifications of the millionaire's domestic economy. And whenever my wonder paid the expected tribute he said, throwing out his chest a little: "Yes, I really don't see how people manage to live without that." 64 | 65 | Well--it was just the end one might have foreseen for him. Only he was, through it all and in spite of it all--as he had been through, and in spite of, his pictures--so handsome, so charming, so disarming, that one longed to cry out: "Be dissatisfied with your leisure!" as once one had longed to say: "Be dissatisfied with your work!" 66 | 67 | But, with the cry on my lips, my diagnosis suffered an unexpected check. 68 | 69 | "This is my own lair," he said, leading me into a dark plain room at the end of the florid vista. It was square and brown and leathery: no "effects"; no bric-a-brac, none of the air of posing for reproduction in a picture weekly--above all, no least sign of ever having been used as a studio. 70 | 71 | The fact brought home to me the absolute finality of Jack's break with his old life. 72 | 73 | "Don't you ever dabble with paint any more?" I asked, still looking about for a trace of such activity. 74 | 75 | "Never," he said briefly. 76 | 77 | "Or water-colour--or etching?" 78 | 79 | His confident eyes grew dim, and his cheeks paled a little under their handsome sunburn. 80 | 81 | "Never think of it, my dear fellow--any more than if I'd never touched a brush." 82 | 83 | And his tone told me in a flash that he never thought of anything else. 84 | 85 | I moved away, instinctively embarrassed by my unexpected discovery; and as I turned, my eye fell on a small picture above the mantel-piece--the only object breaking the plain oak panelling of the room. 86 | 87 | "Oh, by Jove!" I said. 88 | 89 | It was a sketch of a donkey--an old tired donkey, standing in the rain under a wall. 90 | 91 | "By Jove--a Stroud!" I cried. 92 | 93 | He was silent; but I felt him close behind me, breathing a little quickly. 94 | 95 | "What a wonder! Made with a dozen lines--but on everlasting foundations. You lucky chap, where did you get it?" 96 | 97 | He answered slowly: "Mrs. Stroud gave it to me." 98 | 99 | "Ah--I didn't know you even knew the Strouds. He was such an inflexible hermit." 100 | 101 | "I didn't--till after. . . . She sent for me to paint him when he was dead." 102 | 103 | "When he was dead? You?" 104 | 105 | I must have let a little too much amazement escape through my surprise, for he answered with a deprecating laugh: "Yes--she's an awful simpleton, you know, Mrs. Stroud. Her only idea was to have him done by a fashionable painter--ah, poor Stroud! She thought it the surest way of proclaiming his greatness--of forcing it on a purblind public. And at the moment I was _the_ fashionable painter." 106 | 107 | "Ah, poor Stroud--as you say. Was _that_ his history?" 108 | 109 | "That was his history. She believed in him, gloried in him--or thought she did. But she couldn't bear not to have all the drawing-rooms with her. She couldn't bear the fact that, on varnishing days, one could always get near enough to see his pictures. Poor woman! She's just a fragment groping for other fragments. Stroud is the only whole I ever knew." 110 | 111 | "You ever knew? But you just said--" 112 | 113 | Gisburn had a curious smile in his eyes. 114 | 115 | "Oh, I knew him, and he knew me--only it happened after he was dead." 116 | 117 | I dropped my voice instinctively. "When she sent for you?" 118 | 119 | "Yes--quite insensible to the irony. She wanted him vindicated--and by me!" 120 | 121 | He laughed again, and threw back his head to look up at the sketch of the donkey. "There were days when I couldn't look at that thing--couldn't face it. But I forced myself to put it here; and now it's cured me--cured me. That's the reason why I don't dabble any more, my dear Rickham; or rather Stroud himself is the reason." 122 | 123 | For the first time my idle curiosity about my companion turned into a serious desire to understand him better. 124 | 125 | "I wish you'd tell me how it happened," I said. 126 | 127 | He stood looking up at the sketch, and twirling between his fingers a cigarette he had forgotten to light. Suddenly he turned toward me. 128 | 129 | "I'd rather like to tell you--because I've always suspected you of loathing my work." 130 | 131 | I made a deprecating gesture, which he negatived with a good-humoured shrug. 132 | 133 | "Oh, I didn't care a straw when I believed in myself--and now it's an added tie between us!" 134 | 135 | He laughed slightly, without bitterness, and pushed one of the deep arm-chairs forward. "There: make yourself comfortable--and here are the cigars you like." 136 | 137 | He placed them at my elbow and continued to wander up and down the room, stopping now and then beneath the picture. 138 | 139 | "How it happened? I can tell you in five minutes--and it didn't take much longer to happen. . . . I can remember now how surprised and pleased I was when I got Mrs. Stroud's note. Of course, deep down, I had always _felt_ there was no one like him--only I had gone with the stream, echoed the usual platitudes about him, till I half got to think he was a failure, one of the kind that are left behind. By Jove, and he _was_ left behind--because he had come to stay! The rest of us had to let ourselves be swept along or go under, but he was high above the current--on everlasting foundations, as you say. 140 | 141 | "Well, I went off to the house in my most egregious mood--rather moved, Lord forgive me, at the pathos of poor Stroud's career of failure being crowned by the glory of my painting him! Of course I meant to do the picture for nothing--I told Mrs. Stroud so when she began to stammer something about her poverty. I remember getting off a prodigious phrase about the honour being _mine_--oh, I was princely, my dear Rickham! I was posing to myself like one of my own sitters. 142 | 143 | "Then I was taken up and left alone with him. I had sent all my traps in advance, and I had only to set up the easel and get to work. He had been dead only twenty-four hours, and he died suddenly, of heart disease, so that there had been no preliminary work of destruction--his face was clear and untouched. I had met him once or twice, years before, and thought him insignificant and dingy. Now I saw that he was superb. 144 | 145 | "I was glad at first, with a merely aesthetic satisfaction: glad to have my hand on such a 'subject.' Then his strange life-likeness began to affect me queerly--as I blocked the head in I felt as if he were watching me do it. The sensation was followed by the thought: if he _were_ watching me, what would he say to my way of working? My strokes began to go a little wild--I felt nervous and uncertain. 146 | 147 | "Once, when I looked up, I seemed to see a smile behind his close grayish beard--as if he had the secret, and were amusing himself by holding it back from me. That exasperated me still more. The secret? Why, I had a secret worth twenty of his! I dashed at the canvas furiously, and tried some of my bravura tricks. But they failed me, they crumbled. I saw that he wasn't watching the showy bits--I couldn't distract his attention; he just kept his eyes on the hard passages between. Those were the ones I had always shirked, or covered up with some lying paint. And how he saw through my lies! 148 | 149 | "I looked up again, and caught sight of that sketch of the donkey hanging on the wall near his bed. His wife told me afterward it was the last thing he had done--just a note taken with a shaking hand, when he was down in Devonshire recovering from a previous heart attack. Just a note! But it tells his whole history. There are years of patient scornful persistence in every line. A man who had swum with the current could never have learned that mighty up-stream stroke. . . . 150 | 151 | "I turned back to my work, and went on groping and muddling; then I looked at the donkey again. I saw that, when Stroud laid in the first stroke, he knew just what the end would be. He had possessed his subject, absorbed it, recreated it. When had I done that with any of my things? They hadn't been born of me--I had just adopted them. . . . 152 | 153 | "Hang it, Rickham, with that face watching me I couldn't do another stroke. The plain truth was, I didn't know where to put it--_I had never known_. Only, with my sitters and my public, a showy splash of colour covered up the fact--I just threw paint into their faces. . . . Well, paint was the one medium those dead eyes could see through--see straight to the tottering foundations underneath. Don't you know how, in talking a foreign language, even fluently, one says half the time not what one wants to but what one can? Well--that was the way I painted; and as he lay there and watched me, the thing they called my 'technique' collapsed like a house of cards. He didn't sneer, you understand, poor Stroud--he just lay there quietly watching, and on his lips, through the gray beard, I seemed to hear the question: 'Are you sure you know where you're coming out?' 154 | 155 | "If I could have painted that face, with that question on it, I should have done a great thing. The next greatest thing was to see that I couldn't--and that grace was given me. But, oh, at that minute, Rickham, was there anything on earth I wouldn't have given to have Stroud alive before me, and to hear him say: 'It's not too late--I'll show you how'? 156 | 157 | "It _was_ too late--it would have been, even if he'd been alive. I packed up my traps, and went down and told Mrs. Stroud. Of course I didn't tell her _that_--it would have been Greek to her. I simply said I couldn't paint him, that I was too moved. She rather liked the idea--she's so romantic! It was that that made her give me the donkey. But she was terribly upset at not getting the portrait--she did so want him 'done' by some one showy! At first I was afraid she wouldn't let me off--and at my wits' end I suggested Grindle. Yes, it was I who started Grindle: I told Mrs. Stroud he was the 'coming' man, and she told somebody else, and so it got to be true. . . . And he painted Stroud without wincing; and she hung the picture among her husband's things. . . ." 158 | 159 | He flung himself down in the arm-chair near mine, laid back his head, and clasping his arms beneath it, looked up at the picture above the chimney-piece. 160 | 161 | "I like to fancy that Stroud himself would have given it to me, if he'd been able to say what he thought that day." 162 | 163 | And, in answer to a question I put half-mechanically--"Begin again?" he flashed out. "When the one thing that brings me anywhere near him is that I knew enough to leave off?" 164 | 165 | He stood up and laid his hand on my shoulder with a laugh. "Only the irony of it is that I _am_ still painting--since Grindle's doing it for me! The Strouds stand alone, and happen once--but there's no exterminating our kind of art." -------------------------------------------------------------------------------- /src/bin/main.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use clap::{Parser, Subcommand}; 3 | use comfy_table::{ContentArrangement, Table}; 4 | use itertools::Itertools; 5 | use llms_from_scratch_rs::{examples, exercises, Example, Exercise}; 6 | use std::collections::HashMap; 7 | use std::sync::LazyLock; 8 | 9 | static EXERCISE_REGISTRY: LazyLock>> = 10 | LazyLock::new(|| { 11 | let mut m: HashMap<&'static str, Box> = HashMap::new(); 12 | // ch02 13 | m.insert("2.1", Box::new(exercises::ch02::X1)); 14 | m.insert("2.2", Box::new(exercises::ch02::X2)); 15 | // ch03 16 | m.insert("3.1", Box::new(exercises::ch03::X1)); 17 | m.insert("3.2", Box::new(exercises::ch03::X2)); 18 | m.insert("3.3", Box::new(exercises::ch03::X3)); 19 | // ch04 20 | m.insert("4.1", Box::new(exercises::ch04::X1)); 21 | m.insert("4.2", Box::new(exercises::ch04::X2)); 22 | m.insert("4.3", Box::new(exercises::ch04::X3)); 23 | // ch05 24 | m.insert("5.1", Box::new(exercises::ch05::X1)); 25 | m.insert("5.2", Box::new(exercises::ch05::X2)); 26 | m.insert("5.3", Box::new(exercises::ch05::X3)); 27 | m.insert("5.4", Box::new(exercises::ch05::X4)); 28 | m.insert("5.5", Box::new(exercises::ch05::X5)); 29 | m.insert("5.6", Box::new(exercises::ch05::X6)); 30 | // ch06 31 | m.insert("6.1", Box::new(exercises::ch06::X1)); 32 | m.insert("6.2", Box::new(exercises::ch06::X2)); 33 | m.insert("6.3", Box::new(exercises::ch06::X3)); 34 | // ch07 35 | m.insert("7.1", Box::new(exercises::ch07::X1)); 36 | m.insert("7.2", Box::new(exercises::ch07::X2)); 37 | m.insert("7.3", Box::new(exercises::ch07::X3)); 38 | m.insert("7.4", Box::new(exercises::ch07::X4)); 39 | m 40 | }); 41 | 42 | static EXAMPLE_REGISTRY: LazyLock>> = LazyLock::new(|| { 43 | let mut m: HashMap<&'static str, Box> = HashMap::new(); 44 | // ch02 45 | m.insert("02.01", Box::new(examples::ch02::EG01)); 46 | m.insert("02.02", Box::new(examples::ch02::EG02)); 47 | m.insert("02.03", Box::new(examples::ch02::EG03)); 48 | m.insert("02.04", Box::new(examples::ch02::EG04)); 49 | // ch03 50 | m.insert("03.01", Box::new(examples::ch03::EG01)); 51 | m.insert("03.02", Box::new(examples::ch03::EG02)); 52 | m.insert("03.03", Box::new(examples::ch03::EG03)); 53 | m.insert("03.04", Box::new(examples::ch03::EG04)); 54 | m.insert("03.05", Box::new(examples::ch03::EG05)); 55 | m.insert("03.06", Box::new(examples::ch03::EG06)); 56 | m.insert("03.07", Box::new(examples::ch03::EG07)); 57 | m.insert("03.08", Box::new(examples::ch03::EG08)); 58 | m.insert("03.09", Box::new(examples::ch03::EG09)); 59 | m.insert("03.10", Box::new(examples::ch03::EG10)); 60 | m.insert("03.11", Box::new(examples::ch03::EG11)); 61 | // ch04 62 | m.insert("04.01", Box::new(examples::ch04::EG01)); 63 | m.insert("04.02", Box::new(examples::ch04::EG02)); 64 | m.insert("04.03", Box::new(examples::ch04::EG03)); 65 | m.insert("04.04", Box::new(examples::ch04::EG04)); 66 | m.insert("04.05", Box::new(examples::ch04::EG05)); 67 | m.insert("04.06", Box::new(examples::ch04::EG06)); 68 | m.insert("04.07", Box::new(examples::ch04::EG07)); 69 | m.insert("04.08", Box::new(examples::ch04::EG08)); 70 | // ch05 71 | m.insert("05.01", Box::new(examples::ch05::EG01)); 72 | m.insert("05.02", Box::new(examples::ch05::EG02)); 73 | m.insert("05.03", Box::new(examples::ch05::EG03)); 74 | m.insert("05.04", Box::new(examples::ch05::EG04)); 75 | m.insert("05.05", Box::new(examples::ch05::EG05)); 76 | m.insert("05.06", Box::new(examples::ch05::EG06)); 77 | m.insert("05.07", Box::new(examples::ch05::EG07)); 78 | m.insert("05.08", Box::new(examples::ch05::EG08)); 79 | m.insert("05.09", Box::new(examples::ch05::EG09)); 80 | m.insert("05.10", Box::new(examples::ch05::EG10)); 81 | m.insert("05.11", Box::new(examples::ch05::EG11)); 82 | // ch06 83 | m.insert("06.01", Box::new(examples::ch06::EG01)); 84 | m.insert("06.02", Box::new(examples::ch06::EG02)); 85 | m.insert("06.03", Box::new(examples::ch06::EG03)); 86 | m.insert("06.04", Box::new(examples::ch06::EG04)); 87 | m.insert("06.05", Box::new(examples::ch06::EG05)); 88 | m.insert("06.06", Box::new(examples::ch06::EG06)); 89 | m.insert("06.07", Box::new(examples::ch06::EG07)); 90 | m.insert("06.08", Box::new(examples::ch06::EG08)); 91 | m.insert("06.09", Box::new(examples::ch06::EG09)); 92 | m.insert("06.10", Box::new(examples::ch06::EG10)); 93 | m.insert("06.11", Box::new(examples::ch06::EG11)); 94 | m.insert("06.12", Box::new(examples::ch06::EG12)); 95 | m.insert("06.13", Box::new(examples::ch06::EG13)); 96 | m.insert("06.14", Box::new(examples::ch06::EG14)); 97 | m.insert("06.15", Box::new(examples::ch06::EG15)); 98 | // ch07 99 | m.insert("07.01", Box::new(examples::ch07::EG01)); 100 | m.insert("07.02", Box::new(examples::ch07::EG02)); 101 | m.insert("07.03", Box::new(examples::ch07::EG03)); 102 | m.insert("07.04", Box::new(examples::ch07::EG04)); 103 | m.insert("07.05", Box::new(examples::ch07::EG05)); 104 | m.insert("07.06", Box::new(examples::ch07::EG06)); 105 | m.insert("07.07", Box::new(examples::ch07::EG07)); 106 | m.insert("07.08", Box::new(examples::ch07::EG08)); 107 | m.insert("07.09", Box::new(examples::ch07::EG09)); 108 | m.insert("07.10", Box::new(examples::ch07::EG10)); 109 | m.insert("07.11", Box::new(examples::ch07::EG11)); 110 | m.insert("07.12", Box::new(examples::ch07::EG12)); 111 | m.insert("07.13", Box::new(examples::ch07::EG13)); 112 | m.insert("07.14", Box::new(examples::ch07::EG14)); 113 | m.insert("07.15", Box::new(examples::ch07::EG15)); 114 | m.insert("07.16", Box::new(examples::ch07::EG16)); 115 | m.insert("07.17", Box::new(examples::ch07::EG17)); 116 | m.insert("07.18", Box::new(examples::ch07::EG18)); 117 | m.insert("07.19", Box::new(examples::ch07::EG19)); 118 | m.insert("07.20", Box::new(examples::ch07::EG20)); 119 | m.insert("07.21", Box::new(examples::ch07::EG21)); 120 | m.insert("07.22", Box::new(examples::ch07::EG22)); 121 | m.insert("07.23", Box::new(examples::ch07::EG23)); 122 | // apdx_e 123 | m.insert("E.01", Box::new(examples::apdx_e::EG01)); 124 | m.insert("E.02", Box::new(examples::apdx_e::EG02)); 125 | m.insert("E.03", Box::new(examples::apdx_e::EG03)); 126 | m.insert("E.04", Box::new(examples::apdx_e::EG04)); 127 | m.insert("E.05", Box::new(examples::apdx_e::EG05)); 128 | m.insert("E.06", Box::new(examples::apdx_e::EG06)); 129 | m.insert("E.07", Box::new(examples::apdx_e::EG07)); 130 | m 131 | }); 132 | 133 | /// CLI 134 | #[derive(Debug, Parser)] 135 | #[command(bin_name = "llms-from-scratch-rs")] 136 | #[command(about = "A CLI for running examples and exercises.", long_about = None)] 137 | struct Cli { 138 | #[command(subcommand)] 139 | command: Commands, 140 | } 141 | 142 | #[derive(Debug, Subcommand)] 143 | enum Commands { 144 | /// Run examples 145 | Example { 146 | /// The example to run 147 | id: String, 148 | }, 149 | /// Run exercises 150 | Exercise { 151 | /// The exercise to run 152 | id: String, 153 | }, 154 | /// List examples and exercises 155 | List { 156 | #[clap(long, action)] 157 | examples: bool, 158 | #[clap(long, action)] 159 | exercises: bool, 160 | }, 161 | } 162 | 163 | fn main() -> Result<()> { 164 | let exercise_registry = &*EXERCISE_REGISTRY; 165 | let example_registry = &*EXAMPLE_REGISTRY; 166 | let cli = Cli::parse(); 167 | 168 | match cli.command { 169 | Commands::Example { id } => { 170 | let eg = example_registry.get(&id[..]).unwrap(); 171 | eg.main() 172 | } 173 | Commands::Exercise { id } => { 174 | let ex = exercise_registry.get(&id[..]).unwrap(); 175 | ex.main() 176 | } 177 | Commands::List { 178 | examples, 179 | exercises, 180 | } => { 181 | if examples { 182 | let mut examples_table = Table::new(); 183 | examples_table 184 | .set_width(80) 185 | .set_content_arrangement(ContentArrangement::Dynamic) 186 | .set_header(vec!["Id", "Description"]); 187 | for key in example_registry.keys().sorted() { 188 | let eg = example_registry.get(key).unwrap(); 189 | examples_table.add_row(vec![key.to_string(), eg.description()]); 190 | } 191 | println!("EXAMPLES:\n{examples_table}"); 192 | } 193 | if exercises { 194 | let mut exercises_table = Table::new(); 195 | exercises_table 196 | .set_width(80) 197 | .set_content_arrangement(ContentArrangement::Dynamic) 198 | .set_header(vec!["Id", "Statement"]); 199 | for key in exercise_registry.keys().sorted() { 200 | let ex = exercise_registry.get(key).unwrap(); 201 | exercises_table.add_row(vec![ 202 | key.to_string(), 203 | format!("{}\n\n{}", ex.title(), ex.statement()), 204 | ]); 205 | } 206 | println!("EXERCISES:\n{exercises_table}"); 207 | } 208 | Ok(()) 209 | } 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /src/candle_addons.rs: -------------------------------------------------------------------------------- 1 | //! # Custom addons module to Candle 2 | //! 3 | //! #### Features 4 | //! - `SequentialT`: a version of `Sequential` that is `ModuleT` 5 | //! - `TopK`: a trait for extracting top-k elements and positions of a `Tensor` 6 | use candle_core::{Device, IndexOp, ModuleT, Result, Tensor}; 7 | 8 | /// A sequential layer combining multiple other layers. 9 | pub struct SequentialT { 10 | layers: Vec>, 11 | } 12 | 13 | /// Creates a new empty sequential layer. 14 | pub fn seqt() -> SequentialT { 15 | SequentialT { layers: vec![] } 16 | } 17 | 18 | impl SequentialT { 19 | /// The number of sub-layers embedded in this layer. 20 | pub fn len(&self) -> i64 { 21 | self.layers.len() as i64 22 | } 23 | 24 | /// Returns true if this layer does not have any sub-layer. 25 | pub fn is_empty(&self) -> bool { 26 | self.layers.is_empty() 27 | } 28 | } 29 | 30 | impl ModuleT for SequentialT { 31 | fn forward_t(&self, xs: &Tensor, train: bool) -> Result { 32 | let mut xs = xs.clone(); 33 | for layer in self.layers.iter() { 34 | xs = layer.forward_t(&xs, train)? 35 | } 36 | Ok(xs) 37 | } 38 | } 39 | 40 | impl SequentialT { 41 | /// Appends a layer after all the current layers. 42 | #[allow(clippy::should_implement_trait)] 43 | pub fn add(mut self, layer: M) -> Self { 44 | self.layers.push(Box::new(layer)); 45 | self 46 | } 47 | 48 | /// Appends a closure after all the current layers. 49 | pub fn add_fn(self, f: F) -> Self 50 | where 51 | F: 'static + Fn(&Tensor) -> Result + Send + Sync, 52 | { 53 | self.add(candle_nn::func(f)) 54 | } 55 | 56 | /// Applies the forward pass and returns the output for each layer. 57 | pub fn forward_all(&self, xs: &Tensor, train: bool) -> Result> { 58 | let mut vec = Vec::with_capacity(self.layers.len()); 59 | let mut xs = xs.clone(); 60 | for layer in self.layers.iter() { 61 | xs = layer.forward_t(&xs, train)?; 62 | vec.push(xs.clone()) 63 | } 64 | Ok(vec) 65 | } 66 | } 67 | 68 | /// Trait for returning top-k elements of a Tensor 69 | pub trait TopK { 70 | /// Returns a `Tensor`'s top-k elements and its positions along dim 0 71 | fn topk_last_dim0(&self, top_k: usize) -> Result<(Tensor, Tensor)>; 72 | 73 | /// Returns a `Tensor`'s top-k elements and its positions along dim 1 74 | fn topk_last_dim1(&self, top_k: usize) -> Result<(Tensor, Tensor)>; 75 | } 76 | 77 | impl TopK for Tensor { 78 | fn topk_last_dim0(&self, top_k: usize) -> Result<(Tensor, Tensor)> { 79 | let top_pos = self.arg_sort_last_dim(false)?; 80 | let top_pos = top_pos.i(..top_k)?; 81 | let top_els = self.i(top_pos.to_vec1::()?)?; 82 | Ok((top_els, top_pos)) 83 | } 84 | 85 | fn topk_last_dim1(&self, top_k: usize) -> Result<(Tensor, Tensor)> { 86 | // get CUDA error sometimes when using `.arg_sort_last_dim` 87 | // moving to CPU to carry out the op 88 | let top_pos = self.to_device(&Device::Cpu)?.arg_sort_last_dim(false)?; 89 | let top_pos = top_pos.to_device(&Device::cuda_if_available(0)?)?; 90 | let (batch_size, vocab_size) = top_pos.dims2()?; 91 | let top_pos = top_pos.i((.., ..top_k))?.flatten_all()?; 92 | 93 | // get appropriate sum starting index 94 | let aux = Tensor::arange(0u32, batch_size as u32, self.device())?; 95 | let aux = (vocab_size as f64 * aux.broadcast_left(top_k)?.t()?.flatten_all()?)?; 96 | let top_pos = (top_pos + &aux)?; 97 | let top_els = self.flatten_all()?.i(top_pos.to_vec1::()?)?; 98 | 99 | // reshape 100 | let top_els = top_els.reshape((batch_size, top_k))?; 101 | let top_pos = (top_pos - &aux)?; 102 | let top_pos = top_pos.reshape((batch_size, top_k))?; 103 | Ok((top_els, top_pos)) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/examples/apdx_e.rs: -------------------------------------------------------------------------------- 1 | //! Examples from Appendix E 2 | 3 | use crate::Example; 4 | use anyhow::{Context, Result}; 5 | 6 | /// # Example usage of `create_candle_dataloaders` 7 | /// 8 | /// #### Id 9 | /// E.01 10 | /// 11 | /// #### Page 12 | /// This example starts on page 326 13 | /// 14 | /// #### CLI command 15 | /// ```sh 16 | /// # without cuda 17 | /// cargo run example E.01 18 | /// 19 | /// # with cuda 20 | /// cargo run --features cuda example E.01 21 | /// ``` 22 | pub struct EG01; 23 | 24 | impl Example for EG01 { 25 | fn description(&self) -> String { 26 | "Example usage of `create_candle_dataloaders`.".to_string() 27 | } 28 | 29 | fn page_source(&self) -> usize { 30 | 326_usize 31 | } 32 | 33 | fn main(&self) -> Result<()> { 34 | use crate::listings::apdx_e::create_candle_dataloaders; 35 | 36 | let batch_size = 8_usize; 37 | let (train_loader, val_loader, test_loader) = create_candle_dataloaders(batch_size)?; 38 | 39 | // print last batch of train loader 40 | let (input_batch, target_batch) = train_loader.batcher().last().unwrap()?; 41 | println!("Input batch dimensions: {:?}", input_batch.shape()); 42 | println!("Label batch dimensions: {:?}", target_batch.shape()); 43 | 44 | // print total number of batches in each data loader 45 | println!("{:?} training batches", train_loader.len()); 46 | println!("{:?} validation batches", val_loader.len()); 47 | println!("{:?} test batches", test_loader.len()); 48 | 49 | Ok(()) 50 | } 51 | } 52 | 53 | /// # Example usage of `download_and_load_gpt2` and attaching spam classification head 54 | /// 55 | /// #### Id 56 | /// E.02 57 | /// 58 | /// #### Page 59 | /// This example starts on page 327 60 | /// 61 | /// #### CLI command 62 | /// ```sh 63 | /// # without cuda 64 | /// cargo run example E.02 65 | /// 66 | /// # with cuda 67 | /// cargo run --features cuda example E.02 68 | /// ``` 69 | pub struct EG02; 70 | 71 | impl Example for EG02 { 72 | fn description(&self) -> String { 73 | "Example usage of `download_and_load_gpt2` and attaching spam classification head." 74 | .to_string() 75 | } 76 | 77 | fn page_source(&self) -> usize { 78 | 327_usize 79 | } 80 | 81 | fn main(&self) -> Result<()> { 82 | use crate::listings::{ 83 | apdx_e::{create_candle_dataloaders, download_and_load_gpt2}, 84 | ch04::Config, 85 | ch05::{generate, text_to_token_ids, token_ids_to_text}, 86 | ch06::{calc_accuracy_loader, modify_out_head_for_classification, HF_GPT2_MODEL_ID}, 87 | }; 88 | use candle_core::{DType, Device}; 89 | use candle_nn::{VarBuilder, VarMap}; 90 | use rand::{rngs::StdRng, SeedableRng}; 91 | use tiktoken_rs::get_bpe_from_model; 92 | 93 | let mut cfg = Config::gpt2_124m(); 94 | cfg.qkv_bias = true; 95 | let varmap = VarMap::new(); 96 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 97 | let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?; 98 | 99 | // sample setup and load tokenizer 100 | let tokenizer = get_bpe_from_model("gpt2")?; 101 | let mut rng = StdRng::seed_from_u64(42_u64); 102 | 103 | // generate next tokens with model 104 | let text_1 = "Every effort moves you"; 105 | let token_ids = generate( 106 | &model, 107 | text_to_token_ids(text_1, &tokenizer, vb.device())?, 108 | 15_usize, 109 | cfg.context_length, 110 | None, 111 | None, 112 | None, 113 | &mut rng, 114 | )?; 115 | 116 | // decode the token ids to print the output text 117 | println!("{:?}", token_ids_to_text(token_ids, &tokenizer)); 118 | 119 | // attach spam classification head 120 | let num_classes = 2_usize; 121 | modify_out_head_for_classification(&mut model, cfg, num_classes, &varmap, vb.pp("model"))?; 122 | 123 | // calc classification accuracy 124 | let batch_size = 8_usize; 125 | let (train_loader, val_loader, test_loader) = create_candle_dataloaders(batch_size)?; 126 | 127 | // compute accuracies 128 | let num_batches = Some(10_usize); 129 | let train_accuracy = 130 | calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?; 131 | let val_accuracy = 132 | calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?; 133 | let test_accuracy = 134 | calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?; 135 | 136 | println!("Training accuracy: {}", train_accuracy); 137 | println!("Validation accuracy: {}", val_accuracy); 138 | println!("Test accuracy: {}", test_accuracy); 139 | 140 | Ok(()) 141 | } 142 | } 143 | 144 | /// # Example usage of `GPTModelWithLoRA::from_gpt()` and extracting the LoRA trainable vars 145 | /// 146 | /// #### Id 147 | /// E.03 148 | /// 149 | /// #### Page 150 | /// This example starts on page 331 151 | /// 152 | /// #### CLI command 153 | /// ```sh 154 | /// # without cuda 155 | /// cargo run example E.03 156 | /// 157 | /// # with cuda 158 | /// cargo run --features cuda example E.03 159 | /// ``` 160 | pub struct EG03; 161 | 162 | impl Example for EG03 { 163 | fn description(&self) -> String { 164 | let desc = "Example usage of `GPTModelWithLoRA::from_gpt()` and \ 165 | extracting the LoRA trainable vars"; 166 | desc.to_string() 167 | } 168 | 169 | fn page_source(&self) -> usize { 170 | 331_usize 171 | } 172 | 173 | fn main(&self) -> Result<()> { 174 | use crate::listings::{ 175 | apdx_e::{download_and_load_gpt2, GPTModelWithLoRA}, 176 | ch04::Config, 177 | ch06::{modify_out_head_for_classification, HF_GPT2_MODEL_ID}, 178 | }; 179 | use candle_core::{DType, Device}; 180 | use candle_nn::{VarBuilder, VarMap}; 181 | 182 | let mut cfg = Config::gpt2_124m(); 183 | cfg.qkv_bias = true; 184 | let varmap = VarMap::new(); 185 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 186 | let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?; 187 | 188 | // modify to use classification head 189 | let num_classes = 2_usize; 190 | modify_out_head_for_classification(&mut model, cfg, num_classes, &varmap, vb.pp("model"))?; 191 | 192 | // get total number of params from the VarMap (todo: turn this into a util) 193 | let mut total_params_without_lora_weights = 0_usize; 194 | for t in varmap.all_vars().iter() { 195 | total_params_without_lora_weights += t.elem_count(); 196 | } 197 | println!( 198 | "Total number of parameters of original model: {}", 199 | total_params_without_lora_weights 200 | ); 201 | 202 | // convert to LoRA model 203 | let rank = 16_usize; 204 | let alpha = 16_f64; 205 | let _model = GPTModelWithLoRA::from_gpt_model(model, rank, alpha, vb.pp("model"))?; 206 | 207 | // extract only LoRA weights 208 | let mut total_training_params = 0_usize; // i.e., LoRA weights 209 | let tensor_data = varmap.data().lock().unwrap(); 210 | let var_names: Vec<&String> = tensor_data 211 | .keys() 212 | .filter(|k| k.contains("A") || k.contains("B")) 213 | .collect(); 214 | for var_name in var_names.into_iter() { 215 | let var = tensor_data.get(var_name).unwrap(); 216 | total_training_params += var.elem_count(); 217 | } 218 | drop(tensor_data); 219 | 220 | println!("Total trainable LoRA parameters: {}", total_training_params); 221 | 222 | Ok(()) 223 | } 224 | } 225 | 226 | /// # Printing GPTModelWithLoRA architecture 227 | /// 228 | /// #### Id 229 | /// E.04 230 | /// 231 | /// #### Page 232 | /// This example starts on page 332 233 | /// 234 | /// #### CLI command 235 | /// ```sh 236 | /// # without cuda 237 | /// cargo run example E.04 238 | /// 239 | /// # with cuda 240 | /// cargo run --features cuda example E.04 241 | /// ``` 242 | pub struct EG04; 243 | 244 | impl Example for EG04 { 245 | fn description(&self) -> String { 246 | "Printing GPTModelWithLoRA architecture.".to_string() 247 | } 248 | 249 | fn page_source(&self) -> usize { 250 | 332_usize 251 | } 252 | 253 | fn main(&self) -> Result<()> { 254 | use crate::listings::{ 255 | apdx_e::{download_and_load_gpt2, GPTModelWithLoRA}, 256 | ch04::Config, 257 | ch06::{modify_out_head_for_classification, HF_GPT2_MODEL_ID}, 258 | }; 259 | use candle_core::{DType, Device}; 260 | use candle_nn::{VarBuilder, VarMap}; 261 | 262 | let mut cfg = Config::gpt2_124m(); 263 | cfg.qkv_bias = true; 264 | let varmap = VarMap::new(); 265 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 266 | let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?; 267 | 268 | // modify to use classification head 269 | let num_classes = 2_usize; 270 | modify_out_head_for_classification(&mut model, cfg, num_classes, &varmap, vb.pp("model"))?; 271 | 272 | // convert to LoRA model 273 | let rank = 16_usize; 274 | let alpha = 16_f64; 275 | let model = GPTModelWithLoRA::from_gpt_model(model, rank, alpha, vb.pp("model"))?; 276 | 277 | // pretty debug print 278 | println!("{:#?}", model); 279 | 280 | Ok(()) 281 | } 282 | } 283 | 284 | /// # Calculating initial classification accuracies 285 | /// 286 | /// #### Id 287 | /// E.05 288 | /// 289 | /// #### Page 290 | /// This example starts on page 333 291 | /// 292 | /// #### CLI command 293 | /// ```sh 294 | /// # without cuda 295 | /// cargo run example E.05 296 | /// 297 | /// # with cuda 298 | /// cargo run --features cuda example E.05 299 | /// ``` 300 | pub struct EG05; 301 | 302 | impl Example for EG05 { 303 | fn description(&self) -> String { 304 | "Calculating initial classification accuracies.".to_string() 305 | } 306 | 307 | fn page_source(&self) -> usize { 308 | 333_usize 309 | } 310 | 311 | fn main(&self) -> Result<()> { 312 | use crate::listings::{ 313 | apdx_e::{create_candle_dataloaders, download_and_load_gpt2, GPTModelWithLoRA}, 314 | ch04::Config, 315 | ch06::{calc_accuracy_loader, modify_out_head_for_classification, HF_GPT2_MODEL_ID}, 316 | }; 317 | use candle_core::{DType, Device}; 318 | use candle_nn::{VarBuilder, VarMap}; 319 | 320 | let mut cfg = Config::gpt2_124m(); 321 | cfg.qkv_bias = true; 322 | let varmap = VarMap::new(); 323 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 324 | let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?; 325 | 326 | // modify to use classification head 327 | let num_classes = 2_usize; 328 | modify_out_head_for_classification(&mut model, cfg, num_classes, &varmap, vb.pp("model"))?; 329 | 330 | // convert to LoRA model 331 | let rank = 16_usize; 332 | let alpha = 16_f64; 333 | let model = GPTModelWithLoRA::from_gpt_model(model, rank, alpha, vb.pp("model"))?; 334 | 335 | // calc classification accuracy 336 | let batch_size = 8_usize; 337 | let (train_loader, val_loader, test_loader) = create_candle_dataloaders(batch_size)?; 338 | 339 | // compute accuracies 340 | let num_batches = Some(10_usize); 341 | let train_accuracy = 342 | calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?; 343 | let val_accuracy = 344 | calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?; 345 | let test_accuracy = 346 | calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?; 347 | 348 | println!("Training accuracy: {}", train_accuracy); 349 | println!("Validation accuracy: {}", val_accuracy); 350 | println!("Test accuracy: {}", test_accuracy); 351 | 352 | Ok(()) 353 | } 354 | } 355 | 356 | /// # Fine-tuning a model with LoRA layers 357 | /// 358 | /// NOTE: technically this Listing 7.1 in the book, but we felt it was better 359 | /// as an Example. 360 | /// 361 | /// #### Id 362 | /// E.06 363 | /// 364 | /// #### Page 365 | /// This example starts on page 334 366 | /// 367 | /// #### CLI command 368 | /// ```sh 369 | /// # without cuda 370 | /// cargo run example E.06 371 | /// 372 | /// # with cuda 373 | /// cargo run --features cuda example E.06 374 | /// ``` 375 | pub struct EG06; 376 | 377 | impl Example for EG06 { 378 | fn description(&self) -> String { 379 | "Fine-tuning a model with LoRA layers.".to_string() 380 | } 381 | 382 | fn page_source(&self) -> usize { 383 | 334_usize 384 | } 385 | 386 | fn main(&self) -> Result<()> { 387 | use crate::listings::{ 388 | apdx_e::{create_candle_dataloaders, train_classifier_simple, GPTModelWithLoRA}, 389 | ch04::Config, 390 | ch06::{ 391 | download_and_load_gpt2, modify_out_head_for_classification, plot_values, 392 | HF_GPT2_MODEL_ID, 393 | }, 394 | }; 395 | use candle_core::{DType, Device, Var}; 396 | use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap}; 397 | use ndarray::linspace; 398 | use std::path::Path; 399 | 400 | // get gpt model with classification head 401 | let mut cfg = Config::gpt2_124m(); 402 | cfg.qkv_bias = true; 403 | let varmap = VarMap::new(); 404 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 405 | let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?; 406 | 407 | // modify to use classification head 408 | let num_classes = 2_usize; 409 | modify_out_head_for_classification(&mut model, cfg, num_classes, &varmap, vb.pp("model"))?; 410 | 411 | // convert to LoRA model 412 | let rank = 16_usize; 413 | let alpha = 16_f64; 414 | let model = GPTModelWithLoRA::from_gpt_model(model, rank, alpha, vb.pp("model"))?; 415 | 416 | // data loaders 417 | let batch_size = 2_usize; // Get OOM on my Tesla P100 (12GB) with 8_usize 418 | let (train_loader, val_loader, _test_loader) = create_candle_dataloaders(batch_size)?; 419 | 420 | // extract only LoRA weights as trainable params 421 | let mut training_vars: Vec = vec![]; 422 | let tensor_data = varmap.data().lock().unwrap(); 423 | let var_names: Vec<&String> = tensor_data 424 | .keys() 425 | .filter(|k| k.contains("A") || k.contains("B")) 426 | .collect(); 427 | 428 | println!("Training variables: {:?}\n", var_names); 429 | 430 | for var_name in var_names.into_iter() { 431 | let var = tensor_data.get(var_name).unwrap(); 432 | training_vars.push(var.clone()); 433 | } 434 | drop(tensor_data); 435 | 436 | // train model 437 | let optimizer = AdamW::new( 438 | training_vars, 439 | ParamsAdamW { 440 | lr: 5e-5, 441 | weight_decay: 0.1, 442 | ..Default::default() 443 | }, 444 | )?; 445 | 446 | let (eval_freq, eval_iter, num_epochs) = (50_usize, 5_usize, 5_usize); 447 | let (train_loss, val_loss, train_accs, val_accs, num_examples) = train_classifier_simple( 448 | &model, 449 | &train_loader, 450 | &val_loader, 451 | optimizer, 452 | vb.device(), 453 | num_epochs, 454 | eval_freq, 455 | eval_iter, 456 | None, 457 | )?; 458 | 459 | // save model 460 | println!("Saving weights to `./clf.gptwithlora.checkpoint.safetensors`"); 461 | varmap.save("clf.gptwithlora.checkpoint.safetensors")?; 462 | 463 | // prepare and save plots 464 | let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_loss.len())); 465 | let examples_seen = Vec::from_iter(linspace(0_f32, num_examples as f32, train_loss.len())); 466 | let label = "loss"; 467 | let save_path = Path::new(format!("plot_classification_gptwithlora_{label}.html").as_str()) 468 | .to_path_buf(); 469 | plot_values( 470 | epochs_seen, 471 | examples_seen, 472 | train_loss, 473 | val_loss, 474 | label, 475 | save_path, 476 | )?; 477 | 478 | let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_accs.len())); 479 | let examples_seen = Vec::from_iter(linspace(0_f32, num_examples as f32, train_accs.len())); 480 | let label = "accuracy"; 481 | let save_path = Path::new(format!("plot_classification_gptwithlora_{label}.html").as_str()) 482 | .to_path_buf(); 483 | plot_values( 484 | epochs_seen, 485 | examples_seen, 486 | train_accs, 487 | val_accs, 488 | label, 489 | save_path, 490 | )?; 491 | 492 | Ok(()) 493 | } 494 | } 495 | 496 | /// # Evaluating trained LoRA model on train, validation, and test sets 497 | /// 498 | /// #### Id 499 | /// E.07 500 | /// 501 | /// #### Page 502 | /// This example starts on page 335 503 | /// 504 | /// #### CLI command 505 | /// ```sh 506 | /// # without cuda 507 | /// cargo run example E.07 508 | /// 509 | /// # with cuda 510 | /// cargo run --features cuda example E.07 511 | /// ``` 512 | pub struct EG07; 513 | 514 | impl Example for EG07 { 515 | fn description(&self) -> String { 516 | "Evaluating trained LoRA model on train, validation, and test sets.".to_string() 517 | } 518 | 519 | fn page_source(&self) -> usize { 520 | 335_usize 521 | } 522 | 523 | fn main(&self) -> Result<()> { 524 | use crate::listings::{ 525 | apdx_e::{create_candle_dataloaders, GPTModelWithLoRA}, 526 | ch04::Config, 527 | ch06::{ 528 | calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification, 529 | HF_GPT2_MODEL_ID, 530 | }, 531 | }; 532 | use candle_core::{DType, Device}; 533 | use candle_nn::{VarBuilder, VarMap}; 534 | 535 | // get gpt model with classification head 536 | let mut cfg = Config::gpt2_124m(); 537 | cfg.qkv_bias = true; 538 | let mut varmap = VarMap::new(); 539 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 540 | let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?; 541 | 542 | // modify to use classification head 543 | let num_classes = 2_usize; 544 | modify_out_head_for_classification(&mut model, cfg, num_classes, &varmap, vb.pp("model"))?; 545 | 546 | // convert to LoRA model 547 | let rank = 16_usize; 548 | let alpha = 16_f64; 549 | let model = GPTModelWithLoRA::from_gpt_model(model, rank, alpha, vb.pp("model"))?; 550 | 551 | // load safetensors 552 | varmap 553 | .load("clf.gptwithlora.checkpoint.safetensors") 554 | .with_context(|| { 555 | "Missing 'clf.gptwithlora.checkpoint.safetensors' file. Please run EG E.06." 556 | })?; 557 | 558 | // data loaders 559 | let batch_size = 2_usize; // Get OOM on my Tesla P100 (12GB) with 8_usize 560 | let (train_loader, val_loader, test_loader) = create_candle_dataloaders(batch_size)?; 561 | 562 | // compute accuracies 563 | let num_batches = None; 564 | let train_accuracy = 565 | calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?; 566 | let val_accuracy = 567 | calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?; 568 | let test_accuracy = 569 | calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?; 570 | 571 | println!("Training accuracy: {}", train_accuracy); 572 | println!("Validation accuracy: {}", val_accuracy); 573 | println!("Test accuracy: {}", test_accuracy); 574 | 575 | Ok(()) 576 | } 577 | } 578 | -------------------------------------------------------------------------------- /src/examples/ch02.rs: -------------------------------------------------------------------------------- 1 | //! Examples from Chapter 2 2 | 3 | use crate::Example; 4 | use anyhow::Result; 5 | 6 | /// # Example of reading text files into Rust 7 | /// 8 | /// #### Id 9 | /// 02.01 10 | /// 11 | /// #### Page 12 | /// This example starts on page 22 13 | /// 14 | /// #### CLI command 15 | /// ```sh 16 | /// # without cuda 17 | /// cargo run example 02.01 18 | /// 19 | /// # with cuda 20 | /// cargo run --features cuda example 02.01 21 | /// ``` 22 | pub struct EG01; 23 | 24 | impl Example for EG01 { 25 | fn description(&self) -> String { 26 | String::from("Example usage of `listings::ch02::sample_read_text`") 27 | } 28 | 29 | fn page_source(&self) -> usize { 30 | 22_usize 31 | } 32 | 33 | fn main(&self) -> Result<()> { 34 | use crate::listings::ch02::sample_read_text; 35 | let _raw_text = sample_read_text(true)?; 36 | Ok(()) 37 | } 38 | } 39 | 40 | /// # Example of building a vocabulary 41 | /// 42 | /// #### Id 43 | /// 02.02 44 | /// 45 | /// #### Page 46 | /// This example starts on page 25 47 | /// 48 | /// #### CLI command 49 | /// ```sh 50 | /// # without cuda 51 | /// cargo run example 02.02 52 | /// 53 | /// # with cuda 54 | /// cargo run --features cuda example 02.02 55 | /// ``` 56 | pub struct EG02; 57 | 58 | impl Example for EG02 { 59 | fn description(&self) -> String { 60 | String::from("Example usage of `listings::ch02::sample_create_vocab`") 61 | } 62 | 63 | fn page_source(&self) -> usize { 64 | 25_usize 65 | } 66 | 67 | fn main(&self) -> Result<()> { 68 | use crate::listings::ch02::sample_create_vocab; 69 | 70 | let vocab = sample_create_vocab()?; 71 | // Note: this iter is not sorted 72 | for (i, item) in vocab.iter().enumerate() { 73 | println!("{:?}", item); 74 | if i >= 50 { 75 | break; 76 | } 77 | } 78 | Ok(()) 79 | } 80 | } 81 | 82 | /// # Use candle to generate an Embedding Layer 83 | /// 84 | /// #### Id 85 | /// 02.03 86 | /// 87 | /// #### Page 88 | /// This example starts on page 42 89 | /// 90 | /// #### CLI command 91 | /// ```sh 92 | /// # without cuda 93 | /// cargo run example 02.03 94 | /// 95 | /// # with cuda 96 | /// cargo run --features cuda example 02.03 97 | /// ``` 98 | pub struct EG03; 99 | 100 | impl Example for EG03 { 101 | fn description(&self) -> String { 102 | String::from("Use candle to generate an Embedding Layer.") 103 | } 104 | 105 | fn page_source(&self) -> usize { 106 | 42_usize 107 | } 108 | 109 | fn main(&self) -> Result<()> { 110 | use candle_core::{DType, Device, Tensor}; 111 | use candle_nn::{embedding, VarBuilder, VarMap}; 112 | 113 | let vocab_size = 6_usize; 114 | let output_dim = 3_usize; 115 | let varmap = VarMap::new(); 116 | let dev = Device::cuda_if_available(0)?; 117 | let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 118 | let emb = embedding(vocab_size, output_dim, vs)?; 119 | 120 | println!("{:?}", emb.embeddings().to_vec2::()); 121 | // print specific embedding of a given token id 122 | let token_ids = Tensor::new(&[3u32], &dev)?; 123 | println!( 124 | "{:?}", 125 | emb.embeddings() 126 | .index_select(&token_ids, 0)? 127 | .to_vec2::() 128 | ); 129 | Ok(()) 130 | } 131 | } 132 | 133 | /// # Create absolute positional embeddings 134 | /// 135 | /// #### Id 136 | /// 02.04 137 | /// 138 | /// #### Page 139 | /// This example starts on page 47 140 | /// 141 | /// #### CLI command 142 | /// ```sh 143 | /// # without cuda 144 | /// cargo run example 02.04 145 | /// 146 | /// # with cuda 147 | /// cargo run --features cuda example 02.04 148 | /// ``` 149 | pub struct EG04; 150 | 151 | impl Example for EG04 { 152 | fn description(&self) -> String { 153 | String::from("Create absolute positional embeddings.") 154 | } 155 | 156 | fn page_source(&self) -> usize { 157 | 47_usize 158 | } 159 | 160 | fn main(&self) -> Result<()> { 161 | use crate::listings::ch02::{create_dataloader_v1, DataLoader}; 162 | use candle_core::{DType, Tensor}; 163 | use candle_nn::{embedding, VarBuilder, VarMap}; 164 | use std::fs; 165 | 166 | // create data batcher 167 | let raw_text = fs::read_to_string("data/the-verdict.txt").expect("Unable to read the file"); 168 | let max_length = 4_usize; 169 | let stride = max_length; 170 | let shuffle = false; 171 | let drop_last = false; 172 | let batch_size = 8_usize; 173 | let data_loader = create_dataloader_v1( 174 | &raw_text[..], 175 | batch_size, 176 | max_length, 177 | stride, 178 | shuffle, 179 | drop_last, 180 | ); 181 | 182 | let mut batch_iter = data_loader.batcher(); 183 | 184 | // get embeddings of first batch inputs 185 | match batch_iter.next() { 186 | Some(Ok((inputs, _targets))) => { 187 | let varmap = VarMap::new(); 188 | let vs = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device()); 189 | 190 | let vocab_size = 50_257_usize; 191 | let output_dim = 256_usize; 192 | let mut final_dims = inputs.dims().to_vec(); 193 | final_dims.push(output_dim); 194 | 195 | // token embeddings of the current batch inputs 196 | let token_embedding_layer = embedding(vocab_size, output_dim, vs.pp("tok_emb"))?; 197 | let token_embeddings = token_embedding_layer 198 | .embeddings() 199 | .index_select(&inputs.flatten_all()?, 0)?; 200 | let token_embeddings = token_embeddings.reshape(final_dims)?; 201 | println!("token embeddings dims: {:?}", token_embeddings.dims()); 202 | 203 | // position embeddings 204 | let context_length = max_length; 205 | let pos_embedding_layer = embedding(context_length, output_dim, vs.pp("pos_emb"))?; 206 | let pos_ids = Tensor::arange(0u32, context_length as u32, inputs.device())?; 207 | let pos_embeddings = pos_embedding_layer.embeddings().index_select(&pos_ids, 0)?; 208 | println!("pos embeddings dims: {:?}", pos_embeddings.dims()); 209 | 210 | // incorporate positional embeddings 211 | let input_embeddings = token_embeddings.broadcast_add(&pos_embeddings)?; 212 | println!("input embeddings dims: {:?}", input_embeddings.dims()); 213 | } 214 | Some(Err(err)) => panic!("{}", err), 215 | None => panic!("None"), 216 | } 217 | Ok(()) 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /src/examples/ch03.rs: -------------------------------------------------------------------------------- 1 | //! Examples from Chapter 3 2 | 3 | use crate::Example; 4 | use anyhow::Result; 5 | 6 | /// # Computing attention scores as a dot product 7 | /// 8 | /// #### Id 9 | /// 03.01 10 | /// 11 | /// #### Page 12 | /// This example starts on page 57 13 | /// 14 | /// #### CLI command 15 | /// ```sh 16 | /// # without cuda 17 | /// cargo run example 03.01 18 | /// 19 | /// # with cuda 20 | /// cargo run --features cuda example 03.01 21 | /// ``` 22 | pub struct EG01; 23 | 24 | impl Example for EG01 { 25 | fn description(&self) -> String { 26 | String::from("Computing attention scores as a dot product.") 27 | } 28 | 29 | fn page_source(&self) -> usize { 30 | 57_usize 31 | } 32 | 33 | fn main(&self) -> Result<()> { 34 | use candle_core::{IndexOp, Tensor}; 35 | use candle_nn::ops::softmax; 36 | 37 | let inputs = addons::get_inputs(); 38 | let dev = inputs.device().to_owned(); 39 | 40 | let query = inputs.index_select(&Tensor::new(&[1u32], &dev)?, 0)?; 41 | 42 | // compute attention scores 43 | let mut optional_attn_scores_2: Option = None; 44 | for i in 0..inputs.dims()[0] { 45 | let x_i = inputs.index_select(&Tensor::new(&[i as u32], &dev)?, 0)?; 46 | let a_i = x_i.matmul(&query.t()?)?.flatten_all()?; 47 | optional_attn_scores_2 = match optional_attn_scores_2 { 48 | Some(attn_scores_2) => Some(Tensor::cat(&[&attn_scores_2, &a_i], 0)?), 49 | None => Some(a_i), 50 | } 51 | } 52 | 53 | if let Some(attn_scores_2) = optional_attn_scores_2 { 54 | // raw attention scores 55 | println!("Raw attention scores: {:?}", attn_scores_2); 56 | 57 | // basic normalization 58 | let sum = attn_scores_2.sum_all()?; 59 | let normalized_attn_scores = (attn_scores_2.broadcast_div(&sum))?.to_vec1::(); 60 | println!("Normalized attention scores: {:?}", normalized_attn_scores); 61 | 62 | // naive softmax normalization 63 | let exponentiator = attn_scores_2.exp()?; 64 | let exponentiator_sum = exponentiator.sum_all()?; 65 | let naive_softmax_attn_scores = exponentiator.broadcast_div(&exponentiator_sum)?; 66 | println!( 67 | "Naive Softmax-normalized attention scores: {:?}", 68 | naive_softmax_attn_scores 69 | ); 70 | 71 | // candle softmax 72 | let softmax_attn_scores = softmax(&attn_scores_2, 0)?; 73 | println!( 74 | "Softmax-normalized attention scores: {:?}", 75 | softmax_attn_scores 76 | ); 77 | 78 | // compute second context vector 79 | let mut context_vec_2 = Tensor::zeros_like(&query)?; 80 | for i in 0..inputs.dims()[0] { 81 | let x_i = inputs.index_select(&Tensor::new(&[i as u32], &dev)?, 0)?; 82 | context_vec_2 = 83 | context_vec_2.add(&x_i.broadcast_mul(&softmax_attn_scores.i(i)?)?)?; 84 | } 85 | println!("Context vector 2: {:?}", context_vec_2.to_vec2::()); 86 | } 87 | Ok(()) 88 | } 89 | } 90 | 91 | /// # Manual computation of multiple context vectors simultaneously 92 | /// 93 | /// #### Id 94 | /// 03.02 95 | /// 96 | /// #### Page 97 | /// This example starts on page 62 98 | /// 99 | /// #### CLI command 100 | /// ```sh 101 | /// # without cuda 102 | /// cargo run example 03.02 103 | /// 104 | /// # with cuda 105 | /// cargo run --features cuda example 03.02 106 | /// ``` 107 | pub struct EG02; 108 | 109 | impl Example for EG02 { 110 | fn description(&self) -> String { 111 | String::from("Manual computation of multiple context vectors simultaneously.") 112 | } 113 | 114 | fn page_source(&self) -> usize { 115 | 62_usize 116 | } 117 | 118 | fn main(&self) -> Result<()> { 119 | use candle_nn::ops::softmax; 120 | 121 | let inputs = addons::get_inputs(); 122 | 123 | // matmul to get attn scores 124 | let attn_scores = inputs.matmul(&inputs.t()?)?; 125 | 126 | // apply softmax 127 | let attn_weights = softmax(&attn_scores, 1)?; 128 | 129 | // check sums along rows equal to 1 130 | let sum = attn_weights.sum(1)?; 131 | 132 | // context vectors 133 | let all_context_vectors = attn_weights.matmul(&inputs)?; 134 | 135 | println!("Attention Weights: {:?}\n", attn_weights.to_vec2::()); 136 | println!("All Rows Sum: {:?}\n\n", sum.flatten_all()); 137 | println!( 138 | "Context Vectors: {:?}", 139 | all_context_vectors.to_vec2::() 140 | ); 141 | Ok(()) 142 | } 143 | } 144 | 145 | /// # Implementing the self-attention mechanism with trainable weights 146 | /// 147 | /// #### Id 148 | /// 03.03 149 | /// 150 | /// #### Page 151 | /// This example starts on page 66 152 | /// 153 | /// #### CLI command 154 | /// ```sh 155 | /// # without cuda 156 | /// cargo run example 03.03 157 | /// 158 | /// # with cuda 159 | /// cargo run --features cuda example 03.03 160 | /// ``` 161 | pub struct EG03; 162 | 163 | impl Example for EG03 { 164 | fn description(&self) -> String { 165 | let desc = "Implementing the self-attention mechanism with \ 166 | trainable weights to compute single context vector."; 167 | String::from(desc) 168 | } 169 | 170 | fn page_source(&self) -> usize { 171 | 66_usize 172 | } 173 | 174 | fn main(&self) -> Result<()> { 175 | use candle_core::{DType, Tensor}; 176 | use candle_nn::init::DEFAULT_KAIMING_NORMAL; 177 | use candle_nn::ops::softmax; 178 | use candle_nn::{VarBuilder, VarMap}; 179 | 180 | let inputs = addons::get_inputs(); 181 | let dev = inputs.device().to_owned(); 182 | let varmap = VarMap::new(); 183 | let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 184 | 185 | let x_2 = inputs.index_select(&Tensor::new(&[1u32], &dev)?, 0)?; 186 | let d_in = x_2.dims()[1]; // input embedding dim 187 | let d_out = 2_usize; 188 | 189 | // projections 190 | let init = DEFAULT_KAIMING_NORMAL; 191 | let w_query = vs.get_with_hints((d_in, d_out), "query", init)?; 192 | let w_key = vs.get_with_hints((d_in, d_out), "key", init)?; 193 | let w_value = vs.get_with_hints((d_in, d_out), "value", init)?; 194 | 195 | // query, key, value vectors 196 | let query_2 = x_2.matmul(&w_query)?; 197 | let key_2 = x_2.matmul(&w_key)?; 198 | let value_2 = x_2.matmul(&w_value)?; 199 | 200 | println!("Query 2: {:?}", query_2.to_vec2::()); 201 | println!("Key 2: {:?}", key_2.to_vec2::()); 202 | println!("Value 2: {:?}", value_2.to_vec2::()); 203 | 204 | // key and value vectors all input elements 205 | let keys = inputs.matmul(&w_key)?; 206 | let values = inputs.matmul(&w_value)?; 207 | 208 | println!("Keys shape: {:?}", keys); 209 | println!("Values shape: {:?}", values); 210 | 211 | // compute attn scores 212 | let attn_scores = query_2.matmul(&keys.t()?)?; 213 | println!("Attn scores: {:?}", attn_scores.to_vec2::()); 214 | 215 | // compute attns weights by first scaling then softmax 216 | let d_k = Tensor::new(&[f32::powf(keys.dims()[1] as f32, 0.5_f32)], &dev)?; 217 | let attn_weights = softmax(&attn_scores.broadcast_div(&d_k)?, 1)?; 218 | println!("Attn weights: {:?}", attn_weights.to_vec2::()); 219 | 220 | // compute context vector 221 | let context_vec_2 = attn_weights.matmul(&values)?; 222 | println!("Context vector 2: {:?}", context_vec_2.to_vec2::()); 223 | Ok(()) 224 | } 225 | } 226 | 227 | /// # Example usage of `SelfAttentionV1` to compute context vectors 228 | /// 229 | /// #### Id 230 | /// 03.04 231 | /// 232 | /// #### Page 233 | /// This example starts on page 71 234 | /// 235 | /// #### CLI command 236 | /// ```sh 237 | /// # without cuda 238 | /// cargo run example 03.04 239 | /// 240 | /// # with cuda 241 | /// cargo run --features cuda example 03.04 242 | /// ``` 243 | pub struct EG04; 244 | 245 | impl Example for EG04 { 246 | fn description(&self) -> String { 247 | String::from( 248 | "Implement self-attention mechanism to compute context vectors in the input sequence.", 249 | ) 250 | } 251 | 252 | fn page_source(&self) -> usize { 253 | 71_usize 254 | } 255 | 256 | fn main(&self) -> Result<()> { 257 | use crate::listings::ch03::SelfAttentionV1; 258 | use candle_core::{DType, Module}; 259 | use candle_nn::{VarBuilder, VarMap}; 260 | 261 | let inputs = addons::get_inputs(); 262 | let d_in = inputs.dims()[1]; // input embedding dim 263 | let d_out = 2_usize; 264 | 265 | // construct self attention layer 266 | let varmap = VarMap::new(); 267 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device()); 268 | let attn_v1_layer = SelfAttentionV1::new(d_in, d_out, vb.pp("attn"))?; 269 | 270 | // run a random, embedded input sequence through self-attention 271 | let context_vectors = attn_v1_layer.forward(&inputs)?; 272 | 273 | println!("context vectors: {:?}", context_vectors.to_vec2::()); 274 | Ok(()) 275 | } 276 | } 277 | 278 | /// # Example usage of `SelfAttentionV2` to compute context vectors 279 | /// 280 | /// #### Id 281 | /// 03.05 282 | /// 283 | /// #### Page 284 | /// This example starts on page 73 285 | /// 286 | /// #### CLI command 287 | /// ```sh 288 | /// # without cuda 289 | /// cargo run example 03.05 290 | /// 291 | /// # with cuda 292 | /// cargo run --features cuda example 03.05 293 | /// ``` 294 | pub struct EG05; 295 | 296 | impl Example for EG05 { 297 | fn description(&self) -> String { 298 | let desc = "Implement self-attention mechanism to compute \ 299 | contextualized vectors, using candle_nn::Linear."; 300 | String::from(desc) 301 | } 302 | 303 | fn page_source(&self) -> usize { 304 | 73_usize 305 | } 306 | 307 | fn main(&self) -> Result<()> { 308 | use crate::listings::ch03::SelfAttentionV2; 309 | use candle_core::{DType, Module}; 310 | use candle_nn::{VarBuilder, VarMap}; 311 | 312 | let inputs = addons::get_inputs(); 313 | let d_in = inputs.dims()[1]; // input embedding dim 314 | let d_out = 2_usize; 315 | 316 | // construct self attention layer 317 | let varmap = VarMap::new(); 318 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device()); 319 | let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?; 320 | 321 | // run a random, embedded input sequence through self-attention 322 | let context_vectors = attn_v2_layer.forward(&inputs)?; 323 | 324 | println!("context vectors: {:?}", context_vectors.to_vec2::()); 325 | Ok(()) 326 | } 327 | } 328 | 329 | /// # Compute causal attention weights 330 | /// 331 | /// #### Id 332 | /// 03.06 333 | /// 334 | /// #### Page 335 | /// This example starts on page 75 336 | /// 337 | /// #### CLI command 338 | /// ```sh 339 | /// # without cuda 340 | /// cargo run example 03.06 341 | /// 342 | /// # with cuda 343 | /// cargo run --features cuda example 03.06 344 | /// ``` 345 | pub struct EG06; 346 | 347 | impl Example for EG06 { 348 | fn description(&self) -> String { 349 | String::from("Compute causal attention weights.") 350 | } 351 | 352 | fn page_source(&self) -> usize { 353 | 75_usize 354 | } 355 | 356 | fn main(&self) -> Result<()> { 357 | let _ = self.main_with_return()?; 358 | Ok(()) 359 | } 360 | } 361 | 362 | impl EG06 { 363 | fn main_with_return(&self) -> Result { 364 | use crate::listings::ch03::SelfAttentionV2; 365 | use candle_core::{DType, Module, D}; 366 | use candle_nn::ops::softmax; 367 | use candle_nn::{VarBuilder, VarMap}; 368 | 369 | let inputs = addons::get_inputs(); 370 | let d_in = inputs.dims()[1]; // input embedding dim 371 | let d_out = 2_usize; 372 | 373 | // construct self attention layer 374 | let varmap = VarMap::new(); 375 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device()); 376 | let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?; 377 | 378 | // attn scores 379 | let queries = attn_v2_layer.w_query().forward(&inputs)?; 380 | let keys = attn_v2_layer.w_key().forward(&inputs)?; 381 | let attn_scores = queries.matmul(&keys.t()?)?; 382 | let scaling = 1. / (keys.dims()[1] as f64).sqrt(); 383 | let attn_weights = softmax(&(attn_scores * scaling)?, 1)?; 384 | 385 | // causal mask 386 | let context_length = inputs.dims()[0]; 387 | let mask_simple: Vec<_> = (0..context_length as u32) 388 | .flat_map(|i| (0..context_length as u32).map(move |j| f32::from(j <= i))) 389 | .collect(); 390 | let mask_simple = candle_core::Tensor::from_slice( 391 | &mask_simple, 392 | (context_length, context_length), 393 | inputs.device(), 394 | )?; 395 | let masked_simple = (attn_weights * mask_simple)?; 396 | println!("masked_simple: {:?}", masked_simple.to_vec2::()); 397 | 398 | // normalize 399 | let row_sums = masked_simple.sum_keepdim(D::Minus1)?; 400 | let attn_weights = masked_simple.broadcast_div(&row_sums)?; 401 | println!("masked_simple_norm: {:?}", attn_weights.to_vec2::()); 402 | Ok(attn_weights) 403 | } 404 | } 405 | 406 | /// # Compute causal attention weights more efficiently with `f32::NEGATIVE_INFINITY` 407 | /// 408 | /// #### Id 409 | /// 03.07 410 | /// 411 | /// #### Page 412 | /// This example starts on page 77 413 | /// 414 | /// #### CLI command 415 | /// ```sh 416 | /// # without cuda 417 | /// cargo run example 03.07 418 | /// 419 | /// # with cuda 420 | /// cargo run --features cuda example 03.07 421 | /// ``` 422 | pub struct EG07; 423 | 424 | impl Example for EG07 { 425 | fn description(&self) -> String { 426 | let desc = "Compute causal attention weights more efficiently \ 427 | using `f32::NEGATIVE_INFINITY` and `masked_fill()`."; 428 | String::from(desc) 429 | } 430 | 431 | fn page_source(&self) -> usize { 432 | 77_usize 433 | } 434 | 435 | fn main(&self) -> Result<()> { 436 | let _ = self.main_with_return()?; 437 | Ok(()) 438 | } 439 | } 440 | 441 | impl EG07 { 442 | fn main_with_return(&self) -> Result { 443 | use crate::listings::ch03::SelfAttentionV2; 444 | use candle_core::{DType, Module}; 445 | use candle_nn::ops::softmax; 446 | use candle_nn::{VarBuilder, VarMap}; 447 | 448 | let inputs = addons::get_inputs(); 449 | let d_in = inputs.dims()[1]; // input embedding dim 450 | let d_out = 2_usize; 451 | 452 | // construct self attention layer 453 | let varmap = VarMap::new(); 454 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device()); 455 | let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?; 456 | 457 | // attn scores 458 | let queries = attn_v2_layer.w_query().forward(&inputs)?; 459 | let keys = attn_v2_layer.w_key().forward(&inputs)?; 460 | let attn_scores = queries.matmul(&keys.t()?)?; 461 | 462 | // efficient computation of causal mask 463 | let context_length = attn_scores.dims()[0]; 464 | let mask: Vec<_> = (0..context_length as u32) 465 | .flat_map(|i| (0..context_length as u32).map(move |j| u32::from(j > i))) 466 | .collect(); 467 | let mask = candle_core::Tensor::from_slice( 468 | &mask, 469 | (context_length, context_length), 470 | inputs.device(), 471 | )?; 472 | let masked = addons::masked_fill(&attn_scores, &mask, f32::NEG_INFINITY)?; 473 | println!("masked: {:?}", masked.to_vec2::()); 474 | 475 | // masked attn weights 476 | let scaling = 1. / (keys.dims()[1] as f64).sqrt(); 477 | let attn_weights = softmax(&(masked * scaling)?, 1)?; 478 | println!("attn_weights: {:?}", attn_weights.to_vec2::()); 479 | Ok(attn_weights) 480 | } 481 | } 482 | 483 | /// # Dropout on attention weights 484 | /// 485 | /// #### Id 486 | /// 03.08 487 | /// 488 | /// #### Page 489 | /// This example starts on page 80 490 | /// 491 | /// #### CLI command 492 | /// ```sh 493 | /// # without cuda 494 | /// cargo run example 03.08 495 | /// 496 | /// # with cuda 497 | /// cargo run --features cuda example 03.08 498 | /// ``` 499 | pub struct EG08; 500 | 501 | impl Example for EG08 { 502 | fn description(&self) -> String { 503 | String::from("Dropout on attention weights.") 504 | } 505 | 506 | fn page_source(&self) -> usize { 507 | 80_usize 508 | } 509 | 510 | fn main(&self) -> Result<()> { 511 | use candle_nn::Dropout; 512 | 513 | let eg07 = EG07; 514 | let attn_weights = eg07.main_with_return()?; 515 | let dropout = Dropout::new(0.5); 516 | 517 | // could have also just used the candle_nn::ops::dropout directly 518 | let dropped_out = dropout.forward(&attn_weights, true)?; 519 | println!("dropped_out: {:?}", dropped_out.to_vec2::()); 520 | Ok(()) 521 | } 522 | } 523 | 524 | /// # Example usage of `CausalAttention` 525 | /// 526 | /// #### Id 527 | /// 03.09 528 | /// 529 | /// #### Page 530 | /// This example starts on page 81 531 | /// 532 | /// #### CLI command 533 | /// ```sh 534 | /// # without cuda 535 | /// cargo run example 03.09 536 | /// 537 | /// # with cuda 538 | /// cargo run --features cuda example 03.09 539 | /// ``` 540 | pub struct EG09; 541 | 542 | impl Example for EG09 { 543 | fn description(&self) -> String { 544 | String::from("Example usage of `CausalAttention`.") 545 | } 546 | 547 | fn page_source(&self) -> usize { 548 | 81_usize 549 | } 550 | 551 | fn main(&self) -> Result<()> { 552 | use crate::listings::ch03::CausalAttention; 553 | use candle_core::{DType, Module, Tensor}; 554 | use candle_nn::{VarBuilder, VarMap}; 555 | 556 | // create batch 557 | let inputs = addons::get_inputs(); 558 | let d_in = inputs.dims()[1]; // input embedding dim 559 | let d_out = 2_usize; 560 | let batch = Tensor::stack(&[&inputs, &inputs], 0usize)?; 561 | println!("batch shape: {:?}", batch); 562 | 563 | // build causal attn layer 564 | let varmap = VarMap::new(); 565 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device()); 566 | let causal_attn = CausalAttention::new(d_in, d_out, 0.0_f32, false, vb.pp("casual_attn"))?; 567 | 568 | // context vectors 569 | let context_vectors = causal_attn.forward(&batch)?; 570 | println!("context_vectors.shape: {:?}", context_vectors); 571 | Ok(()) 572 | } 573 | } 574 | 575 | /// # Example usage of `MultiHeadAttentionWrapper` 576 | /// 577 | /// #### Id 578 | /// 03.10 579 | /// 580 | /// #### Page 581 | /// This example starts on page 85 582 | /// 583 | /// #### CLI command 584 | /// ```sh 585 | /// # without cuda 586 | /// cargo run example 03.10 587 | /// 588 | /// # with cuda 589 | /// cargo run --features cuda example 03.10 590 | /// ``` 591 | pub struct EG10; 592 | 593 | impl Example for EG10 { 594 | fn description(&self) -> String { 595 | String::from("Example usage of `MultiHeadAttentionWrapper`.") 596 | } 597 | 598 | fn page_source(&self) -> usize { 599 | 85_usize 600 | } 601 | 602 | fn main(&self) -> Result<()> { 603 | use crate::listings::ch03::MultiHeadAttentionWrapper; 604 | use candle_core::{DType, Module, Tensor}; 605 | use candle_nn::{VarBuilder, VarMap}; 606 | 607 | // create batch 608 | let inputs = addons::get_inputs(); 609 | let d_in = inputs.dims()[1]; // input embedding dim 610 | let d_out = 2_usize; 611 | let batch = Tensor::stack(&[&inputs, &inputs], 0usize)?; 612 | println!("batch shape: {:?}", batch); 613 | 614 | // build causal attn layer 615 | let varmap = VarMap::new(); 616 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device()); 617 | let num_heads = 2_usize; 618 | let mha = 619 | MultiHeadAttentionWrapper::new(num_heads, d_in, d_out, 0.0_f32, false, vb.pp("mha"))?; 620 | 621 | // context vectors 622 | let context_vectors = mha.forward(&batch)?; 623 | println!("context_vectors.shape: {:?}", context_vectors); 624 | println!("context_vectors: {:?}", context_vectors.to_vec3::()); 625 | Ok(()) 626 | } 627 | } 628 | 629 | /// # Example usage of `MultiHeadAttention` 630 | /// 631 | /// #### Id 632 | /// 03.11 633 | /// 634 | /// #### Page 635 | /// This example starts on page 90 636 | /// 637 | /// #### CLI command 638 | /// ```sh 639 | /// # without cuda 640 | /// cargo run example 03.11 641 | /// 642 | /// # with cuda 643 | /// cargo run --features cuda example 03.11 644 | /// ``` 645 | pub struct EG11; 646 | 647 | impl Example for EG11 { 648 | fn description(&self) -> String { 649 | String::from("Example usage of `MultiHeadAttention`.") 650 | } 651 | 652 | fn page_source(&self) -> usize { 653 | 90_usize 654 | } 655 | 656 | fn main(&self) -> Result<()> { 657 | use crate::listings::ch03::MultiHeadAttention; 658 | use candle_core::{DType, Tensor}; 659 | use candle_nn::{VarBuilder, VarMap}; 660 | 661 | // create batch 662 | let inputs = addons::get_inputs(); 663 | let d_in = inputs.dims()[1]; // input embedding dim 664 | let d_out = 2_usize; 665 | let batch = Tensor::stack(&[&inputs, &inputs], 0usize)?; 666 | println!("batch shape: {:?}", batch); 667 | 668 | // build causal attn layer 669 | let varmap = VarMap::new(); 670 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device()); 671 | let num_heads = 2_usize; 672 | let mha = MultiHeadAttention::new(d_in, d_out, 0.0_f32, num_heads, false, vb.pp("mha"))?; 673 | 674 | // context vectors 675 | let context_vectors = mha.forward(&batch)?; 676 | println!("mha.head_dim: {:?}", mha.head_dim()); 677 | println!("context_vectors.shape: {:?}", context_vectors); 678 | println!("context_vectors: {:?}", context_vectors.to_vec3::()); 679 | Ok(()) 680 | } 681 | } 682 | 683 | pub mod addons { 684 | //! Auxiliary module for examples::ch03 685 | use candle_core::{Device, Result, Tensor}; 686 | 687 | /// Helper function for getting the sample input token ids 688 | pub fn get_inputs() -> Tensor { 689 | let dev = Device::cuda_if_available(0).unwrap(); 690 | Tensor::new( 691 | &[ 692 | [0.43_f32, 0.15, 0.89], // Your 693 | [0.55, 0.87, 0.66], // journey 694 | [0.57, 0.85, 0.64], // starts 695 | [0.22, 0.58, 0.33], // with 696 | [0.77, 0.25, 0.10], // one 697 | [0.05, 0.80, 0.55], // step 698 | ], 699 | &dev, 700 | ) 701 | .unwrap() 702 | } 703 | 704 | /// Helper function for providing a masked `Tensor` specifying `on_false` and `on_true` 705 | pub fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { 706 | let shape = mask.shape(); 707 | let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; 708 | let m = mask.where_cond(&on_true, on_false)?; 709 | Ok(m) 710 | } 711 | } 712 | -------------------------------------------------------------------------------- /src/examples/ch04.rs: -------------------------------------------------------------------------------- 1 | //! Examples from Chapter 4 2 | 3 | use crate::Example; 4 | use anyhow::Result; 5 | 6 | /// # Getting logits with `DummyGPTModel` 7 | /// 8 | /// #### Id 9 | /// 04.01 10 | /// 11 | /// #### Page 12 | /// This example starts on page 97 13 | /// 14 | /// #### CLI command 15 | /// ```sh 16 | /// # without cuda 17 | /// cargo run example 04.01 18 | /// 19 | /// # with cuda 20 | /// cargo run --features cuda example 04.01 21 | /// ``` 22 | pub struct EG01; 23 | 24 | impl Example for EG01 { 25 | fn description(&self) -> String { 26 | String::from("Getting logits with `DummyGPTModel`.") 27 | } 28 | 29 | fn page_source(&self) -> usize { 30 | 97_usize 31 | } 32 | 33 | fn main(&self) -> Result<()> { 34 | use crate::listings::ch04::{Config, DummyGPTModel}; 35 | use candle_core::{DType, IndexOp, Module}; 36 | use candle_nn::{VarBuilder, VarMap}; 37 | 38 | let batch = addons::get_batch_for_gpts()?; 39 | println!("batch: {:?}", batch.to_vec2::()); 40 | 41 | // create model 42 | let varmap = VarMap::new(); 43 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, batch.device()); 44 | let model = DummyGPTModel::new(Config::gpt2_124m(), vb)?; 45 | 46 | // get logits 47 | let logits = model.forward(&batch)?; 48 | println!("output shape: {:?}", logits.shape()); 49 | 50 | // print first 10 next-token logits for each token of every input sequence 51 | println!("logits: {:?}", logits.i((.., .., 0..10))?.to_vec3::()); 52 | Ok(()) 53 | } 54 | } 55 | 56 | /// # Manual computation of layer normalization 57 | /// 58 | /// #### Id 59 | /// 04.02 60 | /// 61 | /// #### Page 62 | /// This example starts on page 100 63 | /// 64 | /// #### CLI command 65 | /// ```sh 66 | /// # without cuda 67 | /// cargo run example 04.02 68 | /// 69 | /// # with cuda 70 | /// cargo run --features cuda example 04.02 71 | /// ``` 72 | pub struct EG02; 73 | 74 | impl Example for EG02 { 75 | fn description(&self) -> String { 76 | String::from("Manual computation of layer normalization.") 77 | } 78 | 79 | fn page_source(&self) -> usize { 80 | 100_usize 81 | } 82 | 83 | fn main(&self) -> Result<()> { 84 | use candle_core::{DType, Device, Module, Tensor, D}; 85 | use candle_nn::{linear_b, seq, Activation, VarBuilder, VarMap}; 86 | 87 | let dev = Device::cuda_if_available(0)?; 88 | let varmap = VarMap::new(); 89 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 90 | 91 | // create batch 92 | let batch_example = Tensor::rand(0f32, 1f32, (2_usize, 5_usize), vb.device())?; 93 | 94 | // create layer 95 | let layer = seq() 96 | .add(linear_b(5_usize, 6_usize, false, vb.pp("linear"))?) 97 | .add(Activation::Relu); 98 | 99 | // execute layer on batch 100 | let out = layer.forward(&batch_example)?; 101 | println!("out: {:?}", out.to_vec2::()); 102 | 103 | // calculate stats on outputs 104 | let mean = out.mean_keepdim(D::Minus1)?; 105 | let var = out.var_keepdim(D::Minus1)?; 106 | println!("mean: {:?}", mean.to_vec2::()); 107 | println!("variance: {:?}", var.to_vec2::()); 108 | 109 | // layer normalization 110 | let out_norm = (out.broadcast_sub(&mean)?.broadcast_div(&var.sqrt()?))?; 111 | let mean = out_norm.mean_keepdim(D::Minus1)?; 112 | let var = out_norm.var_keepdim(D::Minus1)?; 113 | println!("normalized out: {:?}", out_norm.to_vec2::()); 114 | println!("mean: {:?}", mean.to_vec2::()); 115 | println!("variance: {:?}", var.to_vec2::()); 116 | Ok(()) 117 | } 118 | } 119 | 120 | /// # Example usage of `LayerNorm` 121 | /// 122 | /// #### Id 123 | /// 04.03 124 | /// 125 | /// #### Page 126 | /// This example starts on page 104 127 | /// 128 | /// #### CLI command 129 | /// ```sh 130 | /// # without cuda 131 | /// cargo run example 04.03 132 | /// 133 | /// # with cuda 134 | /// cargo run --features cuda example 04.03 135 | /// ``` 136 | pub struct EG03; 137 | 138 | impl Example for EG03 { 139 | fn description(&self) -> String { 140 | String::from("Example usage of `LayerNorm`.") 141 | } 142 | 143 | fn page_source(&self) -> usize { 144 | 104_usize 145 | } 146 | 147 | fn main(&self) -> Result<()> { 148 | use crate::listings::ch04::LayerNorm; 149 | use candle_core::{DType, Device, Module, Tensor, D}; 150 | use candle_nn::{VarBuilder, VarMap}; 151 | 152 | let dev = Device::cuda_if_available(0)?; 153 | let varmap = VarMap::new(); 154 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 155 | 156 | // create batch 157 | let batch_example = Tensor::rand(0f32, 1f32, (2_usize, 5_usize), vb.device())?; 158 | 159 | // construct layer norm layer 160 | let emb_dim = 5_usize; 161 | let ln = LayerNorm::new(emb_dim, vb.pp("layer_norm"))?; 162 | let out_ln = ln.forward(&batch_example)?; 163 | 164 | // compute stats on out_ln 165 | let mean = out_ln.mean_keepdim(D::Minus1)?; 166 | let var = out_ln.var_keepdim(D::Minus1)?; 167 | println!("mean: {:?}", mean.to_vec2::()); 168 | println!("variance: {:?}", var.to_vec2::()); 169 | Ok(()) 170 | } 171 | } 172 | 173 | /// # Example usage of `FeedForward` Module. 174 | /// 175 | /// #### Id 176 | /// 04.04 177 | /// 178 | /// #### Page 179 | /// This example starts on page 108 180 | /// 181 | /// #### CLI command 182 | /// ```sh 183 | /// # without cuda 184 | /// cargo run example 04.04 185 | /// 186 | /// # with cuda 187 | /// cargo run --features cuda example 04.04 188 | /// ``` 189 | pub struct EG04; 190 | 191 | impl Example for EG04 { 192 | fn description(&self) -> String { 193 | String::from("Example usage of `FeedForward` Module.") 194 | } 195 | 196 | fn page_source(&self) -> usize { 197 | 108_usize 198 | } 199 | 200 | fn main(&self) -> Result<()> { 201 | use crate::listings::ch04::{Config, FeedForward}; 202 | use candle_core::{DType, Device, IndexOp, Module, Tensor}; 203 | use candle_nn::{VarBuilder, VarMap}; 204 | 205 | let dev = Device::cuda_if_available(0)?; 206 | let varmap = VarMap::new(); 207 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 208 | let cfg = Config::gpt2_124m(); 209 | 210 | // create batch 211 | let (batch_size, seq_len) = (2_usize, 3_usize); 212 | let x = Tensor::rand(0f32, 1f32, (batch_size, seq_len, cfg.emb_dim), vb.device())?; 213 | 214 | // feedforward 215 | let ffn = FeedForward::new(cfg, vb.pp("ffn"))?; 216 | let out = ffn.forward(&x)?; 217 | 218 | println!("{:?}", out); 219 | // first 10 hidden states of the embedding for 1st sequence, 1st token 220 | println!("{:?}", out.i((0, 0, 0..10))?.to_vec1::()); 221 | Ok(()) 222 | } 223 | } 224 | 225 | /// # Comparison of gradients with and without shortcut connections 226 | /// 227 | /// #### Id 228 | /// 04.05 229 | /// 230 | /// #### Page 231 | /// This example starts on page 111 232 | /// 233 | /// #### CLI command 234 | /// ```sh 235 | /// # without cuda 236 | /// cargo run example 04.05 237 | /// 238 | /// # with cuda 239 | /// cargo run --features cuda example 04.05 240 | /// ``` 241 | pub struct EG05; 242 | 243 | impl Example for EG05 { 244 | fn description(&self) -> String { 245 | String::from("Comparison of gradients with and without shortcut connections.") 246 | } 247 | 248 | fn page_source(&self) -> usize { 249 | 111_usize 250 | } 251 | 252 | fn main(&self) -> Result<()> { 253 | use crate::listings::ch04::ExampleDeepNeuralNetwork; 254 | use candle_core::{DType, Device, Tensor}; 255 | use candle_nn::{VarBuilder, VarMap}; 256 | 257 | let dev = Device::cuda_if_available(0)?; 258 | let varmap = VarMap::new(); 259 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 260 | 261 | let layer_sizes = &[3_usize, 3, 3, 3, 3, 1]; 262 | let sample_input = Tensor::new(&[[1_f32, 0., -1.]], vb.device())?; 263 | let model_without_shortcut = 264 | ExampleDeepNeuralNetwork::new(layer_sizes, false, vb.pp("model_wout_shortcut"))?; 265 | 266 | let model_with_shortcut = 267 | ExampleDeepNeuralNetwork::new(layer_sizes, true, vb.pp("model_with_shortcut"))?; 268 | 269 | println!("model_without_shortcut gradients:"); 270 | addons::print_gradients(model_without_shortcut, &sample_input)?; 271 | println!("model_with_shortcut gradients:"); 272 | addons::print_gradients(model_with_shortcut, &sample_input)?; 273 | Ok(()) 274 | } 275 | } 276 | 277 | /// # Example usage of `TransformerBlock` 278 | /// 279 | /// #### Id 280 | /// 04.06 281 | /// 282 | /// #### Page 283 | /// This example starts on page 116 284 | /// 285 | /// #### CLI command 286 | /// ```sh 287 | /// # without cuda 288 | /// cargo run example 04.06 289 | /// 290 | /// # with cuda 291 | /// cargo run --features cuda example 04.06 292 | /// ``` 293 | pub struct EG06; 294 | 295 | impl Example for EG06 { 296 | fn description(&self) -> String { 297 | String::from("Example usage of `TransformerBlock`.") 298 | } 299 | 300 | fn page_source(&self) -> usize { 301 | 116_usize 302 | } 303 | 304 | fn main(&self) -> Result<()> { 305 | use crate::listings::ch04::{Config, TransformerBlock}; 306 | use candle_core::{DType, Device, IndexOp, Tensor}; 307 | use candle_nn::{VarBuilder, VarMap}; 308 | 309 | // construct transformer block 310 | let dev = Device::cuda_if_available(0)?; 311 | let varmap = VarMap::new(); 312 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 313 | let cfg = Config::gpt2_124m(); 314 | let block = TransformerBlock::new(cfg, vb.pp("block"))?; 315 | 316 | // create sample input 317 | let (batch_size, num_tokens) = (2_usize, 4_usize); 318 | let x = Tensor::rand( 319 | 0f32, 320 | 1f32, 321 | (batch_size, num_tokens, cfg.emb_dim), 322 | vb.device(), 323 | )?; 324 | 325 | // execute forward pass 326 | let output = block.forward(&x)?; 327 | 328 | println!("Input shape: {:?}", x.shape()); 329 | println!("Output shape: {:?}", output.shape()); 330 | 331 | // print the first 10 features of all tokens of the first input 332 | println!( 333 | "Output: {:?}", 334 | output.i((0..1, .., 0..10))?.to_vec3::() 335 | ); 336 | Ok(()) 337 | } 338 | } 339 | 340 | /// # Example usage of `GPTModel` 341 | /// 342 | /// #### Id 343 | /// 04.07 344 | /// 345 | /// #### Page 346 | /// This example starts on page 120 347 | /// 348 | /// #### CLI command 349 | /// ```sh 350 | /// # without cuda 351 | /// cargo run example 04.07 352 | /// 353 | /// # with cuda 354 | /// cargo run --features cuda example 04.07 355 | /// ``` 356 | pub struct EG07; 357 | 358 | impl Example for EG07 { 359 | fn description(&self) -> String { 360 | String::from("Example usage of `GPTModel`.") 361 | } 362 | 363 | fn page_source(&self) -> usize { 364 | 120_usize 365 | } 366 | 367 | fn main(&self) -> Result<()> { 368 | use crate::listings::ch04::{Config, GPTModel}; 369 | use candle_core::{DType, Error, IndexOp, ModuleT}; 370 | use candle_nn::{VarBuilder, VarMap}; 371 | 372 | let batch = addons::get_batch_for_gpts()?; 373 | println!("batch: {:?}", batch.to_vec2::()); 374 | 375 | // create model 376 | let varmap = VarMap::new(); 377 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, batch.device()); 378 | let model = GPTModel::new(Config::gpt2_124m(), vb)?; 379 | 380 | // get logits 381 | let logits = model.forward_t(&batch, false)?; 382 | println!("output shape: {:?}", logits.shape()); 383 | 384 | // print first 10 next-token logits for each token of every input sequence 385 | println!("logits: {:?}", logits.i((.., .., 0..10))?.to_vec3::()); 386 | 387 | // get total number of params from the VarMap (todo: turn this into a util) 388 | let mut total_params = 0_usize; 389 | for t in varmap.all_vars().iter() { 390 | total_params += t.elem_count(); 391 | } 392 | println!("Total number of parameters: {}", total_params); 393 | 394 | // Get token embedding and output layer shapes 395 | let varmap_binding = varmap.data().lock().unwrap(); 396 | let tok_emb_dims = varmap_binding 397 | .get("tok_emb.weight") 398 | .ok_or_else(|| { 399 | Error::CannotFindTensor { 400 | path: "tok_emb.weight".to_string(), 401 | } 402 | .bt() 403 | })? 404 | .dims(); 405 | println!("Token embedding layer shape {:?}", tok_emb_dims); 406 | let out_head_dims = varmap_binding 407 | .get("out_head.weight") 408 | .ok_or_else(|| { 409 | Error::CannotFindTensor { 410 | path: "out_head.weight".to_string(), 411 | } 412 | .bt() 413 | })? 414 | .dims(); 415 | println!("Output layer shape {:?}", out_head_dims); 416 | 417 | // total number of params if weight tying with token emb and output layer shapes 418 | let total_params_gpt2 = total_params - (out_head_dims[0] * out_head_dims[1]); 419 | println!( 420 | "Number of trainable parameters considering weight tying {}", 421 | total_params_gpt2 422 | ); 423 | 424 | // memory requirements 425 | let total_size_bytes = total_params * 4; 426 | let total_size_mb = total_size_bytes as f32 / (1024_f32 * 1024.); 427 | println!("Total size of the model: {} MB", total_size_mb); 428 | Ok(()) 429 | } 430 | } 431 | 432 | /// # Example usage of `generate_text_simple` 433 | /// 434 | /// #### Id 435 | /// 04.08 436 | /// 437 | /// #### Page 438 | /// This example starts on page 125 439 | /// 440 | /// #### CLI command 441 | /// ```sh 442 | /// # without cuda 443 | /// cargo run example 04.08 444 | /// 445 | /// # with cuda 446 | /// cargo run --features cuda example 04.08 447 | /// ``` 448 | pub struct EG08; 449 | 450 | impl Example for EG08 { 451 | fn description(&self) -> String { 452 | String::from("Example usage of `generate_text_simple`.") 453 | } 454 | 455 | fn page_source(&self) -> usize { 456 | 125_usize 457 | } 458 | 459 | fn main(&self) -> Result<()> { 460 | use crate::listings::ch04::{generate_text_simple, Config, GPTModel}; 461 | use candle_core::{DType, Device, Tensor}; 462 | use candle_nn::{VarBuilder, VarMap}; 463 | use tiktoken_rs::get_bpe_from_model; 464 | 465 | // get starting context 466 | let dev = Device::cuda_if_available(0)?; 467 | let start_context = "Hello, I am"; 468 | let tokenizer = get_bpe_from_model("gpt2")?; 469 | let encoded = tokenizer.encode_with_special_tokens(start_context); 470 | let num_tokens = encoded.len(); 471 | println!("encoded: {:?}", encoded); 472 | let encoded_tensor = Tensor::from_vec(encoded, (1_usize, num_tokens), &dev)?; 473 | println!("encoded_tensor.shape {:?}", encoded_tensor); 474 | 475 | // construct model 476 | let varmap = VarMap::new(); 477 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 478 | let cfg = Config::gpt2_124m(); 479 | let model = GPTModel::new(cfg, vb)?; 480 | 481 | // run inference 482 | let out = generate_text_simple(&model, encoded_tensor, 6_usize, cfg.context_length)?; 483 | println!("Output: {:?}", out.to_vec2::()); 484 | println!("Output length: {}", out.dims()[1]); 485 | 486 | // decode with tokenizer 487 | let decoded_text = tokenizer.decode(out.reshape(out.dims()[1])?.to_vec1::()?); 488 | println!("{:?}", decoded_text); 489 | Ok(()) 490 | } 491 | } 492 | 493 | pub mod addons { 494 | //! Auxiliary module for examples::ch04 495 | use crate::listings::ch04::ExampleDeepNeuralNetwork; 496 | use candle_core::{Device, Error, Module, Result, Tensor}; 497 | use tiktoken_rs::get_bpe_from_model; 498 | 499 | /// Helper function to a sample batch of tokens to feed into GPTs. 500 | pub fn get_batch_for_gpts() -> Result { 501 | let dev = Device::cuda_if_available(0)?; 502 | 503 | // create batch 504 | let mut batch_tokens: Vec = Vec::new(); 505 | let tokenizer = 506 | get_bpe_from_model("gpt2").map_err(|e| Error::Msg(format!("Tokenizer error: {e}")))?; 507 | batch_tokens.append(&mut tokenizer.encode_with_special_tokens("Every effort moves you")); 508 | batch_tokens.append(&mut tokenizer.encode_with_special_tokens("Every day holds a")); 509 | 510 | Tensor::from_vec(batch_tokens, (2_usize, 4_usize), &dev) 511 | } 512 | 513 | /// Helper function for printing gradients of `ExampleDeepNeuralNetwork` 514 | pub fn print_gradients(model: ExampleDeepNeuralNetwork, x: &Tensor) -> Result<()> { 515 | use candle_nn::loss::mse; 516 | 517 | let output = model.forward(x)?; 518 | let target = Tensor::new(&[[0_f32]], x.device())?; 519 | 520 | let loss = mse(&output, &target)?; 521 | let grads = loss.backward()?; 522 | 523 | for (ix, tensor_id) in model.tensor_ids.iter().enumerate() { 524 | let grad_tensor = grads.get_id(tensor_id.to_owned()).ok_or_else(|| { 525 | Error::CannotFindTensor { 526 | path: format!("{:?}", tensor_id), 527 | } 528 | .bt() 529 | })?; 530 | println!( 531 | "layer.{}.weight has gradient mean of {:?}", 532 | ix, 533 | grad_tensor.abs()?.mean_all()?.to_scalar::()? 534 | ); 535 | } 536 | println!("\n"); 537 | Ok(()) 538 | } 539 | } 540 | -------------------------------------------------------------------------------- /src/examples/mod.rs: -------------------------------------------------------------------------------- 1 | //! Examples 2 | //! 3 | //! This module contains Rust translations for the example PyTorch code provided 4 | //! in each chapter. That is, the code that is found "in-line" along with the 5 | //! main text and specifically not the code belonging to any given Listing. 6 | 7 | pub mod apdx_e; 8 | pub mod ch02; 9 | pub mod ch03; 10 | pub mod ch04; 11 | pub mod ch05; 12 | pub mod ch06; 13 | pub mod ch07; 14 | -------------------------------------------------------------------------------- /src/exercises/ch02.rs: -------------------------------------------------------------------------------- 1 | //! Exercises from Chapter 2 2 | 3 | use crate::Exercise; 4 | use anyhow::Result; 5 | 6 | /// # Byte pair encoding of unknown words 7 | /// 8 | /// #### Id 9 | /// 2.1 10 | /// 11 | /// #### CLI command 12 | /// ```sh 13 | /// # without cuda 14 | /// cargo run exercise 2.1 15 | /// 16 | /// # with cuda 17 | /// cargo run --features cuda exercise 2.1 18 | /// ``` 19 | pub struct X1; 20 | 21 | impl Exercise for X1 { 22 | fn name(&self) -> String { 23 | String::from("2.1") 24 | } 25 | 26 | fn title(&self) -> String { 27 | "Byte pair encoding of unknown words".to_string() 28 | } 29 | 30 | fn statement(&self) -> String { 31 | let stmt = "Try the BPE tokenizer from the tiktoken library on the \ 32 | unknown words 'Akwirw ier' and print the individual token IDs. Then, \ 33 | call the decode function on each of the resulting integers in this list \ 34 | to reproduce the mapping shown in figure 2.11. Lastly, call the decode \ 35 | method on the token IDs to check whether it can reconstruct the \ 36 | original input, 'Akwirw ier.'"; 37 | stmt.to_string() 38 | } 39 | 40 | fn main(&self) -> Result<()> { 41 | use tiktoken_rs::get_bpe_from_model; 42 | 43 | let tokenizer = get_bpe_from_model("gpt2")?; 44 | let token_ids = tokenizer.encode_with_special_tokens("Akwirw ier"); 45 | println!("token ids: {:?}", token_ids); 46 | 47 | let decoded_text = tokenizer.decode(token_ids)?; 48 | println!("decoded text: {}", decoded_text); 49 | Ok(()) 50 | } 51 | } 52 | 53 | /// # Data loaders with different strides and context sizes 54 | /// 55 | /// #### Id 56 | /// 2.2 57 | /// 58 | /// #### CLI command 59 | /// ```sh 60 | /// # without cuda 61 | /// cargo run exercise 2.2 62 | /// 63 | /// # with cuda 64 | /// cargo run --features cuda exercise 2.2 65 | /// ``` 66 | pub struct X2; 67 | 68 | impl Exercise for X2 { 69 | fn name(&self) -> String { 70 | String::from("2.2") 71 | } 72 | 73 | fn title(&self) -> String { 74 | "Data loaders with different strides and context sizes".to_string() 75 | } 76 | 77 | fn statement(&self) -> String { 78 | let stmt = "To develop more intuition for how the data loader works, \ 79 | try to run it with different settings such as `max_length=2` and \ 80 | `stride=2`, and `max_length=8` and `stride=2`."; 81 | stmt.to_string() 82 | } 83 | 84 | fn main(&self) -> Result<()> { 85 | use crate::listings::ch02::{create_dataloader_v1, DataLoader}; 86 | use std::fs; 87 | 88 | let raw_text = fs::read_to_string("data/the-verdict.txt").expect("Unable to read the file"); 89 | let max_length = 4_usize; 90 | let stride = 2_usize; 91 | let shuffle = false; 92 | let drop_last = false; 93 | let batch_size = 2_usize; 94 | let data_loader = create_dataloader_v1( 95 | &raw_text[..], 96 | batch_size, 97 | max_length, 98 | stride, 99 | shuffle, 100 | drop_last, 101 | ); 102 | 103 | let mut batch_iter = data_loader.batcher(); 104 | match batch_iter.next() { 105 | Some(Ok((inputs, targets))) => { 106 | println!( 107 | "inputs: {:?}\n\ntargets: {:?}", 108 | inputs.to_vec2::(), 109 | targets.to_vec2::() 110 | ); 111 | } 112 | Some(Err(err)) => panic!("{}", err), 113 | None => panic!("None"), 114 | } 115 | Ok(()) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/exercises/ch03.rs: -------------------------------------------------------------------------------- 1 | //! Exercises from Chapter 3 2 | 3 | use crate::Exercise; 4 | use anyhow::Result; 5 | 6 | /// # Comparing `SelfAttention_v1` and `SelfAttention_v2` 7 | /// 8 | /// #### Id 9 | /// 3.1 10 | /// 11 | /// #### CLI command 12 | /// ```sh 13 | /// # without cuda 14 | /// cargo run exercise 3.1 15 | /// 16 | /// # with cuda 17 | /// cargo run --features cuda exercise 3.1 18 | /// ``` 19 | pub struct X1; 20 | 21 | impl Exercise for X1 { 22 | fn name(&self) -> String { 23 | String::from("3.1") 24 | } 25 | 26 | fn title(&self) -> String { 27 | "Comparing `SelfAttention_v1` and `SelfAttention_v2`".to_string() 28 | } 29 | 30 | fn statement(&self) -> String { 31 | let stmt = "Note that `nn.Linear` in `SelfAttention_v2` uses a \ 32 | different weight initialization scheme as `nn.Parameter(torch.rand(d_in, d_out))` \ 33 | used in `SelfAttention_v1`, which causes both mechanisms to produce \ 34 | different results. To check that both implementations, `SelfAttention_v1` \ 35 | and `SelfAttention_v2`, are otherwise similar, we can transfer the \ 36 | weight matrices from a `SelfAttention_v2` object to a `SelfAttention_v1`, \ 37 | such that both objects then produce the same results. Your task is to \ 38 | correctly assign the weights from an instance of `SelfAttention_v2` to \ 39 | an instance of `SelfAttention_v1`. To do this, you need to understand \ 40 | the relationship between the weights in both versions. (Hint: `nn.Linear` \ 41 | stores the weight matrix in a transposed form.) After the assignment, \ 42 | you should observe that both instances produce the same outputs."; 43 | stmt.to_string() 44 | } 45 | 46 | fn main(&self) -> Result<()> { 47 | use crate::listings::ch03::{SelfAttentionV1, SelfAttentionV2}; 48 | use candle_core::{DType, Device, Module, Tensor}; 49 | use candle_nn::{VarBuilder, VarMap}; 50 | 51 | let (d_in, d_out) = (3_usize, 5_usize); 52 | let varmap = VarMap::new(); 53 | let dev = Device::cuda_if_available(0)?; 54 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 55 | let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn_v2"))?; 56 | let attn_v1_layer = SelfAttentionV1 { 57 | w_query: attn_v2_layer.w_query().weight().t()?, 58 | w_key: attn_v2_layer.w_key().weight().t()?, 59 | w_value: attn_v2_layer.w_value().weight().t()?, 60 | scaling: 1. / (attn_v2_layer.w_key().weight().dims()[0] as f64).sqrt(), 61 | }; 62 | 63 | let input_length = 10_usize; 64 | let xs = Tensor::rand(0f32, 1f32, (input_length, d_in), &dev)?; 65 | let context_vectors_from_v1 = attn_v1_layer.forward(&xs)?; 66 | let context_vectors_from_v2 = attn_v2_layer.forward(&xs)?; 67 | 68 | println!( 69 | "Context vectors from SelfAttention V1 and V2 are equal when using same weights: {}", 70 | context_vectors_from_v1.to_vec2::()? 71 | == context_vectors_from_v2.to_vec2::()? 72 | ); 73 | Ok(()) 74 | } 75 | } 76 | 77 | /// # Returning two-dimensional embedding vectors 78 | /// 79 | /// #### Id 80 | /// 3.2 81 | /// 82 | /// #### CLI command 83 | /// ```sh 84 | /// # without cuda 85 | /// cargo run exercise 3.2 86 | /// 87 | /// # with cuda 88 | /// cargo run --features cuda exercise 3.2 89 | /// ``` 90 | pub struct X2; 91 | 92 | impl Exercise for X2 { 93 | fn name(&self) -> String { 94 | String::from("3.2") 95 | } 96 | 97 | fn title(&self) -> String { 98 | "Returning two-dimensional embedding vectors".to_string() 99 | } 100 | 101 | fn statement(&self) -> String { 102 | let stmt = "Change the input arguments for the \ 103 | `MultiHeadAttentionWrapper(..., num_heads=2)` call such that the output \ 104 | context vectors are two-dimensional instead of four dimensional while \ 105 | keeping the setting `num_heads=2`. Hint: You don’t have to modify the \ 106 | class implementation; you just have to change one of the other input arguments."; 107 | stmt.to_string() 108 | } 109 | 110 | fn main(&self) -> Result<()> { 111 | use crate::listings::ch03::MultiHeadAttentionWrapper; 112 | use candle_core::{DType, Device, Module, Tensor}; 113 | use candle_nn::{VarBuilder, VarMap}; 114 | 115 | let (d_in, d_out) = (3_usize, 1_usize); // set d_out to 1 to get desired final dim 116 | let varmap = VarMap::new(); 117 | let dev = Device::cuda_if_available(0)?; 118 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 119 | let num_heads = 2_usize; 120 | let mha = 121 | MultiHeadAttentionWrapper::new(num_heads, d_in, d_out, 0.0_f32, false, vb.pp("mha"))?; 122 | 123 | // create random input batch 124 | let input_length = 6_usize; 125 | let xs = Tensor::rand(0f32, 1f32, (input_length, d_in), vb.device())?; 126 | let batch = Tensor::stack(&[&xs, &xs], 0)?; 127 | println!("batch shape: {:?}", batch); 128 | 129 | // run forward on mha 130 | let context_vectors = mha.forward(&batch)?; 131 | println!("context_vectors.shape: {:?}", context_vectors); 132 | println!("context_vectors: {:?}", context_vectors.to_vec3::()); 133 | Ok(()) 134 | } 135 | } 136 | 137 | /// # Initializing GPT-2 size attention modules 138 | /// 139 | /// #### Id 140 | /// 3.3 141 | /// 142 | /// #### CLI command 143 | /// ```sh 144 | /// # without cuda 145 | /// cargo run exercise 3.3 146 | /// 147 | /// # with cuda 148 | /// cargo run --features cuda exercise 3.3 149 | /// ``` 150 | pub struct X3; 151 | 152 | impl Exercise for X3 { 153 | fn name(&self) -> String { 154 | String::from("3.3") 155 | } 156 | 157 | fn title(&self) -> String { 158 | "Initializing GPT-2 size attention modules".to_string() 159 | } 160 | 161 | fn statement(&self) -> String { 162 | let stmt = "Using the `MultiHeadAttention` class, initialize a \ 163 | multi-head attention module that has the same number of attention heads \ 164 | as the smallest GPT-2 model (12 attention heads). Also ensure that you \ 165 | use the respective input and output embedding sizes similar to GPT-2 \ 166 | (768 dimensions). Note that the smallest GPT-2 model supports a context \ 167 | length of 1,024 tokens."; 168 | stmt.to_string() 169 | } 170 | 171 | fn main(&self) -> Result<()> { 172 | use crate::listings::ch03::MultiHeadAttention; 173 | use candle_core::{DType, Device}; 174 | use candle_nn::{VarBuilder, VarMap}; 175 | 176 | let (d_in, d_out, num_heads) = (768_usize, 768_usize, 12_usize); // set d_out to 1 to get desired final dim 177 | let varmap = VarMap::new(); 178 | let dev = Device::cuda_if_available(0)?; 179 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 180 | let mha = MultiHeadAttention::new(d_in, d_out, 0.0_f32, num_heads, false, vb.pp("mha"))?; 181 | 182 | println!("mha.num_heads: {:?}", mha.num_heads()); 183 | println!("mha.head_dim: {:?}", mha.head_dim()); 184 | println!("mha.w_query.shape: {:?}", mha.w_query().weight().dims()); 185 | Ok(()) 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /src/exercises/ch04.rs: -------------------------------------------------------------------------------- 1 | //! Exercises from Chapter 4 2 | 3 | use crate::Exercise; 4 | use anyhow::Result; 5 | 6 | /// # Number of parameters in feed forward and attention modules 7 | /// 8 | /// #### Id 9 | /// 4.1 10 | /// 11 | /// #### CLI command 12 | /// ```sh 13 | /// # without cuda 14 | /// cargo run exercise 4.1 15 | /// 16 | /// # with cuda 17 | /// cargo run --features cuda exercise 4.1 18 | /// ``` 19 | pub struct X1; 20 | 21 | impl Exercise for X1 { 22 | fn name(&self) -> String { 23 | String::from("4.1") 24 | } 25 | 26 | fn title(&self) -> String { 27 | "Number of parameters in feed forward and attention modules".to_string() 28 | } 29 | 30 | fn statement(&self) -> String { 31 | let stmt = "Calculate and compare the number of parameters that are contained in the feed forward module \ 32 | and those that are contained in the multi-head attention module."; 33 | stmt.to_string() 34 | } 35 | 36 | fn main(&self) -> Result<()> { 37 | use crate::listings::ch04::{Config, TransformerBlock}; 38 | use candle_core::{DType, Device}; 39 | use candle_nn::{VarBuilder, VarMap}; 40 | 41 | // create model 42 | let dev = Device::cuda_if_available(0)?; 43 | let varmap = VarMap::new(); 44 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 45 | let _ = TransformerBlock::new(Config::gpt2_124m(), vb)?; 46 | 47 | // Get varmap data containing all variables 48 | let varmap_data = varmap.data().lock().unwrap(); 49 | 50 | // Count params for ff and mha modules 51 | let (mut ff_params, mut mha_params) = (0_usize, 0_usize); 52 | for (var_name, var) in varmap_data.iter() { 53 | let num_params = var.elem_count(); 54 | if var_name.starts_with("ff.") { 55 | ff_params += num_params; 56 | } else if var_name.starts_with("mha.") { 57 | mha_params += num_params; 58 | } 59 | } 60 | println!("Ff number of parameters: {}", ff_params); 61 | println!("Mha number of parameters: {}", mha_params); 62 | Ok(()) 63 | } 64 | } 65 | 66 | /// # Initializing larger GPT models 67 | /// 68 | /// #### Id 69 | /// 4.2 70 | /// 71 | /// #### CLI command 72 | /// ```sh 73 | /// # without cuda 74 | /// cargo run exercise 4.2 75 | /// 76 | /// # with cuda 77 | /// cargo run --features cuda exercise 4.2 78 | /// ``` 79 | pub struct X2; 80 | 81 | impl Exercise for X2 { 82 | fn name(&self) -> String { 83 | String::from("4.2") 84 | } 85 | 86 | fn title(&self) -> String { 87 | "Initializing larger GPT models".to_string() 88 | } 89 | 90 | fn statement(&self) -> String { 91 | let stmt = "We initialized a 124-million-parameter GPT model, \ 92 | which is known as 'GPT-2 small.' Without making any code modifications \ 93 | besides updating the configuration file, use the GPTModel class to \ 94 | implement GPT-2 medium (using 1,024-dimensional embeddings, 24 transformer \ 95 | blocks, 16 multi-head attention heads), GPT-2 large (1,280- dimensional \ 96 | embeddings, 36 transformer blocks, 20 multi-head attention heads), and \ 97 | GPT-2 XL (1,600-dimensional embeddings, 48 transformer blocks, 25 \ 98 | multi-head attention heads). As a bonus, calculate the total number of \ 99 | parameters in each GPT model."; 100 | stmt.to_string() 101 | } 102 | 103 | fn main(&self) -> Result<()> { 104 | use crate::listings::ch04::{Config, GPTModel}; 105 | use candle_core::{DType, Device}; 106 | use candle_nn::{VarBuilder, VarMap}; 107 | 108 | let configs = &[ 109 | ("gpt2-sm", Config::gpt2_124m()), 110 | ("gpt2-med", Config::gpt2_medium()), 111 | ("gpt2-l", Config::gpt2_large()), 112 | ("gpt2-xl", Config::gpt2_xlarge()), 113 | ]; 114 | 115 | for (mdl_name, cfg) in configs.iter() { 116 | // construct model which stores the vars in the varmap 117 | let dev = Device::cuda_if_available(0)?; 118 | let varmap = VarMap::new(); 119 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 120 | let _ = GPTModel::new(*cfg, vb)?; 121 | 122 | // compute number of params (todo build utility func for this) 123 | let mut total_params = 0_usize; 124 | for t in varmap.all_vars().iter() { 125 | total_params += t.elem_count(); 126 | } 127 | println!("{} number of parameters: {}", mdl_name, total_params); 128 | 129 | // Get token embedding and output layer shapes 130 | let varmap_data = varmap.data().lock().unwrap(); 131 | let tok_emb_dims = varmap_data.get("tok_emb.weight").unwrap().dims(); 132 | println!("Token embedding layer shape {:?}", tok_emb_dims); 133 | let out_head_dims = varmap_data.get("out_head.weight").unwrap().dims(); 134 | println!("Output layer shape {:?}", out_head_dims); 135 | 136 | // total number of params if weight tying with token emb and output layer shapes 137 | let total_params_gpt2 = total_params - (out_head_dims[0] * out_head_dims[1]); 138 | println!( 139 | "Number of trainable parameters considering weight tying {}", 140 | total_params_gpt2 141 | ); 142 | 143 | // memory requirements (todo: build this out as a util) 144 | let total_size_bytes = total_params * 4; 145 | let total_size_mb = total_size_bytes as f32 / (1024_f32 * 1024.); 146 | println!("Total size of the model: {} MB\n", total_size_mb); 147 | } 148 | Ok(()) 149 | } 150 | } 151 | 152 | /// # Using separate dropout parameters 153 | /// 154 | /// #### Id 155 | /// 4.3 156 | /// 157 | /// #### CLI command 158 | /// ```sh 159 | /// # without cuda 160 | /// cargo run exercise 4.3 161 | /// 162 | /// # with cuda 163 | /// cargo run --features cuda exercise 4.3 164 | /// ``` 165 | pub struct X3; 166 | 167 | impl Exercise for X3 { 168 | fn name(&self) -> String { 169 | String::from("4.3") 170 | } 171 | 172 | fn title(&self) -> String { 173 | "Using separate dropout parameters".to_string() 174 | } 175 | 176 | fn statement(&self) -> String { 177 | let stmt = "At the beginning of this chapter, we defined a global \ 178 | `drop_rate` setting in the `GPT_CONFIG_124M` dictionary to set the \ 179 | dropout rate in various places throughout the GPTModel architecture. \ 180 | Change the code to specify a separate dropout value for the various \ 181 | dropout layers throughout the model architecture. (Hint: there are three \ 182 | distinct places where we used dropout layers: the embedding layer, \ 183 | shortcut layer, and multi-head attention module.)"; 184 | stmt.to_string() 185 | } 186 | 187 | fn main(&self) -> Result<()> { 188 | use crate::listings::ch04::GPTModel; 189 | use candle_core::{DType, Device, IndexOp, ModuleT, Tensor}; 190 | use candle_nn::{VarBuilder, VarMap}; 191 | 192 | // create model 193 | let dev = Device::cuda_if_available(0)?; 194 | let varmap = VarMap::new(); 195 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 196 | let model = GPTModel::new_v2(addons::ConfigV2::gpt_config_124m(), vb)?; 197 | 198 | // create batch inputs 199 | let batch = Tensor::new(&[[101_u32, 366, 100, 345], [101, 110, 322, 57]], &dev)?; 200 | 201 | // run model forward 202 | let logits = model.forward_t(&batch, false)?; 203 | 204 | // print first ten logits of vocabular for all batch inputs, and tokens 205 | let (_b, c, _vocab_size) = logits.dims3()?; 206 | let last_tokens_logits = logits.i((.., c - 1, ..))?; 207 | println!( 208 | "first 10 logits of last vector: {:?}", 209 | last_tokens_logits.i((.., 0..10))?.to_vec2::() 210 | ); 211 | Ok(()) 212 | } 213 | } 214 | 215 | pub mod addons { 216 | //! Auxiliary module for exercises::ch04 217 | use crate::listings::{ 218 | ch03::MultiHeadAttention, 219 | ch04::{ 220 | seqtransformers, FFLayer, FeedForward, GPTModel, LayerNorm, TransformerBlock, GELU, 221 | }, 222 | }; 223 | use candle_core::Result; 224 | use candle_nn::{embedding, linear_b, Dropout, VarBuilder}; 225 | 226 | /// A second `Config` variation for Exercise 4.3 to specify individual drop rates 227 | #[derive(Debug, Clone, Copy)] 228 | pub struct ConfigV2 { 229 | pub vocab_size: usize, 230 | pub context_length: usize, 231 | pub emb_dim: usize, 232 | pub n_heads: usize, 233 | pub n_layers: usize, 234 | pub drop_rate_attn: f32, 235 | pub drop_rate_emb: f32, 236 | pub drop_rate_shortcut: f32, 237 | pub qkv_bias: bool, 238 | } 239 | 240 | impl ConfigV2 { 241 | pub fn gpt_config_124m() -> Self { 242 | Self { 243 | vocab_size: 50_257, 244 | context_length: 1_024, 245 | emb_dim: 768, 246 | n_heads: 12, 247 | n_layers: 12, 248 | drop_rate_attn: 0.1, 249 | drop_rate_emb: 0.1, 250 | drop_rate_shortcut: 0.1, 251 | qkv_bias: false, 252 | } 253 | } 254 | } 255 | 256 | /// New `FeedForward` constructor using `ConfigV2` 257 | impl FeedForward { 258 | fn new_v2(cfg: ConfigV2, vb: VarBuilder<'_>) -> Result { 259 | let layers = vec![ 260 | FFLayer::Linear(linear_b( 261 | cfg.emb_dim, 262 | 4_usize * cfg.emb_dim, 263 | true, 264 | vb.pp("first_layer"), 265 | )?), 266 | FFLayer::GELU(GELU), 267 | FFLayer::Linear(linear_b( 268 | 4_usize * cfg.emb_dim, 269 | cfg.emb_dim, 270 | true, 271 | vb.pp("second_layer"), 272 | )?), 273 | ]; 274 | 275 | FeedForward::from_fields(layers) 276 | } 277 | } 278 | 279 | /// New `TransformerBlock` constructor using `ConfigV2` 280 | impl TransformerBlock { 281 | fn new_v2(cfg: ConfigV2, vb: VarBuilder<'_>) -> Result { 282 | let att = MultiHeadAttention::new( 283 | cfg.emb_dim, 284 | cfg.emb_dim, 285 | cfg.drop_rate_attn, 286 | cfg.n_heads, 287 | cfg.qkv_bias, 288 | vb.pp("mha"), 289 | )?; 290 | let ff = FeedForward::new_v2(cfg, vb.pp("ff"))?; 291 | let norm1 = LayerNorm::new(cfg.emb_dim, vb.pp("norm1"))?; 292 | let norm2 = LayerNorm::new(cfg.emb_dim, vb.pp("norm2"))?; 293 | let drop_shortcut = Dropout::new(cfg.drop_rate_shortcut); 294 | TransformerBlock::from_fields(att, ff, norm1, norm2, drop_shortcut) 295 | } 296 | } 297 | 298 | /// New `GPTModel` constructor using `ConfigV2` 299 | impl GPTModel { 300 | pub fn new_v2(cfg: ConfigV2, vb: VarBuilder<'_>) -> Result { 301 | let tok_emb = embedding(cfg.vocab_size, cfg.emb_dim, vb.pp("tok_emb"))?; 302 | let pos_emb = embedding(cfg.context_length, cfg.emb_dim, vb.pp("pos_emb"))?; 303 | let drop_emb = Dropout::new(cfg.drop_rate_emb); 304 | let mut trf_blocks = seqtransformers(); 305 | for ix in 0..cfg.n_layers { 306 | trf_blocks = 307 | trf_blocks.add(TransformerBlock::new_v2(cfg, vb.pp(format!("trf-{}", ix)))?); 308 | } 309 | let final_norm = LayerNorm::new(cfg.emb_dim, vb.pp("final_norm"))?; 310 | let out_head = linear_b(cfg.emb_dim, cfg.vocab_size, false, vb.pp("out_head"))?; 311 | GPTModel::from_fields(tok_emb, pos_emb, drop_emb, trf_blocks, final_norm, out_head) 312 | } 313 | } 314 | } 315 | -------------------------------------------------------------------------------- /src/exercises/ch05.rs: -------------------------------------------------------------------------------- 1 | //! Exercises from Chapter 5 2 | 3 | use crate::Exercise; 4 | use anyhow::Result; 5 | 6 | /// # Printing sampling frequencies with various temperatures 7 | /// 8 | /// #### Id 9 | /// 5.1 10 | /// 11 | /// #### CLI command 12 | /// ```sh 13 | /// # without cuda 14 | /// cargo run exercise 5.1 15 | /// 16 | /// # with cuda 17 | /// cargo run --features cuda exercise 5.1 18 | /// ``` 19 | pub struct X1; 20 | 21 | impl Exercise for X1 { 22 | fn name(&self) -> String { 23 | String::from("5.1") 24 | } 25 | 26 | fn title(&self) -> String { 27 | "Printing sampling frequencies with various temperatures".to_string() // title missing from book 28 | } 29 | 30 | fn statement(&self) -> String { 31 | let stmt = "Use the `print_sampled_tokens` function to print the \ 32 | sampling frequencies of the softmax probabilities scaled with the \ 33 | temperatures shown in figure 5.14. How often is the word `pizza` sampled \ 34 | in each case? Can you think of a faster and more accurate way to \ 35 | determine how often the word `pizza` is sampled?"; 36 | stmt.to_string() 37 | } 38 | 39 | fn main(&self) -> Result<()> { 40 | use crate::{examples, listings::ch05::print_sampled_tokens}; 41 | use candle_core::D; 42 | use candle_nn::ops::softmax; 43 | 44 | let (_vocab, inverse_vocab) = examples::ch05::addons::get_vocab_and_inversed_vocab(); 45 | let next_token_logits = examples::ch05::addons::get_next_token_logits()?; 46 | 47 | let temperatures = &[1_f64, 0.1, 5.]; 48 | for temp in temperatures.iter() { 49 | println!( 50 | "Temp (temp={}) scaling sampling conducted 1000 times:", 51 | temp 52 | ); 53 | let scaled_logits = (&next_token_logits / temp.to_owned())?; 54 | let scaled_probas = softmax(&scaled_logits, D::Minus1)?; 55 | print_sampled_tokens(&scaled_probas.to_vec1::()?, &inverse_vocab, true)?; 56 | println!("\n"); 57 | } 58 | Ok(()) 59 | } 60 | } 61 | 62 | /// # Using various temperatures and top-k values 63 | /// 64 | /// #### Id 65 | /// 5.2 66 | /// 67 | /// #### CLI command 68 | /// ```sh 69 | /// # without cuda 70 | /// cargo run exercise 5.2 71 | /// 72 | /// # with cuda 73 | /// cargo run --features cuda exercise 5.2 74 | /// ``` 75 | pub struct X2; 76 | 77 | impl Exercise for X2 { 78 | fn name(&self) -> String { 79 | String::from("5.2") 80 | } 81 | 82 | fn title(&self) -> String { 83 | "Using various temperatures and top-k values".to_string() // missing from book 84 | } 85 | 86 | fn statement(&self) -> String { 87 | let stmt = "Play around with different temperatures and top-k \ 88 | settings. Based on your observations, can you think of applications \ 89 | where lower temperature and top-k settings are desired? Likewise, can \ 90 | you think of applications where higher temperature and top-k settings \ 91 | are preferred? (It’s recommended to also revisit this exercise at the \ 92 | end of the chapter after loading the pretrained weights from OpenAI.)"; 93 | stmt.to_string() 94 | } 95 | 96 | fn main(&self) -> Result<()> { 97 | use crate::listings::{ 98 | ch04::{Config, GPTModel}, 99 | ch05::{generate, text_to_token_ids, token_ids_to_text}, 100 | }; 101 | use candle_core::{DType, Device}; 102 | use candle_nn::{VarBuilder, VarMap}; 103 | use itertools::iproduct; 104 | use rand::{rngs::StdRng, SeedableRng}; 105 | use tiktoken_rs::get_bpe_from_model; 106 | 107 | // construct model 108 | let varmap = VarMap::new(); 109 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 110 | let cfg = Config::gpt2_124m(); 111 | let model = GPTModel::new(Config::gpt2_124m(), vb.pp("model"))?; 112 | 113 | // sample setup and load tokenizer 114 | let start_context = "Every effort moves you"; 115 | let tokenizer = get_bpe_from_model("gpt2")?; 116 | 117 | let temperatures = &[0.1_f64, 1., 5.]; 118 | let top_ks = &[20_usize, 100, cfg.vocab_size]; 119 | let mut rng = StdRng::seed_from_u64(42_u64); 120 | for (temp, top_k) in iproduct!(temperatures, top_ks) { 121 | println!("Temp: {}, Top K: {}", temp, top_k); 122 | 123 | let token_ids = generate( 124 | &model, 125 | text_to_token_ids(start_context, &tokenizer, vb.device())?, 126 | 15_usize, 127 | cfg.context_length, 128 | Some(*temp), 129 | Some(*top_k), 130 | None, 131 | &mut rng, 132 | )?; 133 | 134 | // decode the token ids to print the output text 135 | println!("{:?}\n", token_ids_to_text(token_ids, &tokenizer)) 136 | } 137 | Ok(()) 138 | } 139 | } 140 | 141 | /// # Parameter values for deterministic sampling 142 | /// 143 | /// #### Id 144 | /// 5.3 145 | /// 146 | /// #### CLI command 147 | /// ```sh 148 | /// # without cuda 149 | /// cargo run exercise 5.3 150 | /// 151 | /// # with cuda 152 | /// cargo run --features cuda exercise 5.3 153 | /// ``` 154 | pub struct X3; 155 | 156 | impl Exercise for X3 { 157 | fn name(&self) -> String { 158 | String::from("5.3") 159 | } 160 | 161 | fn title(&self) -> String { 162 | "Parameter values for deterministic sampling".to_string() // missing from book 163 | } 164 | 165 | fn statement(&self) -> String { 166 | let stmt = "What are the different combinations of settings for \ 167 | the `generate` function to force deterministic behavior, that is, \ 168 | disabling the random sampling such that it always produces the same \ 169 | outputs similar to the `generate_simple` function?"; 170 | stmt.to_string() 171 | } 172 | 173 | fn main(&self) -> Result<()> { 174 | use crate::listings::{ 175 | ch04::{Config, GPTModel}, 176 | ch05::{generate, text_to_token_ids}, 177 | }; 178 | use candle_core::{DType, Device, Tensor}; 179 | use candle_nn::{VarBuilder, VarMap}; 180 | use rand::{rngs::StdRng, SeedableRng}; 181 | use tiktoken_rs::get_bpe_from_model; 182 | 183 | // construct model 184 | let varmap = VarMap::new(); 185 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 186 | let cfg = Config::gpt2_124m(); 187 | let model = GPTModel::new(Config::gpt2_124m(), vb.pp("model"))?; 188 | 189 | // sample setup and load tokenizer 190 | let start_context = "Every effort moves you"; 191 | let tokenizer = get_bpe_from_model("gpt2")?; 192 | 193 | // deterministic settings: temp to None and top_k to any value 194 | let temp = None; 195 | 196 | let mut old_token_ids: Option = None; 197 | let mut rng = StdRng::seed_from_u64(42_u64); 198 | for ix in 0..4 { 199 | println!("Itertation {}:", ix); 200 | 201 | let token_ids = generate( 202 | &model, 203 | text_to_token_ids(start_context, &tokenizer, vb.device())?, 204 | 15_usize, 205 | cfg.context_length, 206 | temp, 207 | Some(20usize), 208 | None, 209 | &mut rng, 210 | )?; 211 | 212 | if let Some(old) = old_token_ids { 213 | println!("old token ids: {:?}", old.to_vec2::()); 214 | } else { 215 | println!("old token ids: None"); 216 | } 217 | 218 | println!("new token ids: {:?}\n", token_ids.to_vec2::()); 219 | 220 | old_token_ids = Some(token_ids); 221 | } 222 | Ok(()) 223 | } 224 | } 225 | 226 | /// # Continuing training from pre-loaded weights 227 | /// 228 | /// #### Id 229 | /// 5.4 230 | /// 231 | /// #### CLI command 232 | /// ```sh 233 | /// # without cuda 234 | /// cargo run exercise 5.4 235 | /// 236 | /// # with cuda 237 | /// cargo run --features cuda exercise 5.4 238 | /// ``` 239 | pub struct X4; 240 | 241 | impl Exercise for X4 { 242 | fn name(&self) -> String { 243 | String::from("5.4") 244 | } 245 | 246 | fn title(&self) -> String { 247 | "Continuing training from pre-loaded weights".to_string() // missing from book 248 | } 249 | 250 | fn statement(&self) -> String { 251 | let stmt = "After saving the weights, load the model and optimizer \ 252 | in a new Python session or Jupyter notebook file and continue pretraining \ 253 | it for one more epoch using the `train_model_simple` function."; 254 | stmt.to_string() 255 | } 256 | 257 | fn main(&self) -> Result<()> { 258 | use crate::{ 259 | examples, 260 | listings::{ 261 | ch04::{Config, GPTModel}, 262 | ch05::train_model_simple, 263 | }, 264 | }; 265 | use candle_core::{DType, Device}; 266 | use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap}; 267 | use tiktoken_rs::get_bpe_from_model; 268 | 269 | // construct model 270 | let mut varmap = VarMap::new(); 271 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 272 | let cfg = Config::gpt2_124m(); 273 | let model = GPTModel::new(cfg, vb.pp("model"))?; 274 | 275 | // load from previous checkpoint 276 | // NOTE: this requires EG 05.09 to be have ran, which creates a model 277 | // checkpoint that we use here 278 | println!("Loading weights from `./checkpoint.safetensors`"); 279 | varmap.load("checkpoint.safetensors")?; // todo map to anyhow error with proper msg 280 | 281 | // train model for one epoch 282 | let optimizer = AdamW::new( 283 | varmap.all_vars(), 284 | ParamsAdamW { 285 | lr: 0.0004, 286 | weight_decay: 0.1, 287 | ..Default::default() 288 | }, 289 | )?; 290 | let tokenizer = get_bpe_from_model("gpt2")?; 291 | let (eval_freq, eval_iter, num_epochs) = (5_usize, 5_usize, 1_usize); 292 | let (train_loader, val_loader) = examples::ch05::addons::get_train_val_data_loaders(false)?; 293 | let start_context = "Every effort moves you"; 294 | let _ = train_model_simple( 295 | &model, 296 | &train_loader, 297 | &val_loader, 298 | optimizer, 299 | vb.device(), 300 | num_epochs, 301 | eval_freq, 302 | eval_iter, 303 | start_context, 304 | &tokenizer, 305 | None, 306 | ); 307 | Ok(()) 308 | } 309 | } 310 | 311 | /// # Training and validation losses with OpenAI weights 312 | /// 313 | /// #### Id 314 | /// 5.5 315 | /// 316 | /// #### CLI command 317 | /// ```sh 318 | /// # without cuda 319 | /// cargo run exercise 5.5 320 | /// 321 | /// # with cuda 322 | /// cargo run --features cuda exercise 5.5 323 | /// ``` 324 | pub struct X5; 325 | 326 | impl Exercise for X5 { 327 | fn name(&self) -> String { 328 | String::from("5.5") 329 | } 330 | 331 | fn title(&self) -> String { 332 | "Training and validation losses with OpenAI weights".to_string() // missing from book 333 | } 334 | 335 | fn statement(&self) -> String { 336 | let stmt = "Calculate the training and validation set losses of the \ 337 | `GPTModel` with the pretrained weights from OpenAI on the “The Verdict” \ 338 | dataset."; 339 | stmt.to_string() 340 | } 341 | 342 | fn main(&self) -> Result<()> { 343 | use crate::{ 344 | examples, 345 | listings::{ 346 | ch04::{Config, GPTModel}, 347 | ch05::{calc_loss_loader, load_weights_into_gpt}, 348 | }, 349 | }; 350 | use candle_core::{DType, Device}; 351 | use candle_nn::{VarBuilder, VarMap}; 352 | use hf_hub::api::sync::Api; 353 | 354 | let dev = Device::cuda_if_available(0)?; 355 | 356 | // download openai weights 357 | let api = Api::new()?; 358 | let repo = api.model("openai-community/gpt2".to_string()); 359 | let weights = repo.get("model.safetensors")?; 360 | let weights = candle_core::safetensors::load(weights, &dev)?; 361 | 362 | // construct model 363 | let varmap = VarMap::new(); 364 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 365 | let mut cfg = Config::gpt2_124m(); 366 | cfg.qkv_bias = true; 367 | let model = GPTModel::new(cfg, vb.pp("model"))?; 368 | 369 | // load openai weights 370 | load_weights_into_gpt(&varmap, weights, Some("model"), cfg.n_layers)?; 371 | 372 | // build train and val loaders with utility function from addons module 373 | let (train_loader, val_loader) = examples::ch05::addons::get_train_val_data_loaders(false)?; 374 | 375 | // compute train and val loss 376 | let train_loss = calc_loss_loader(&train_loader, &model, vb.device(), None, None)?; 377 | let val_loss = calc_loss_loader(&val_loader, &model, vb.device(), None, None)?; 378 | 379 | println!("Training loss {:?}", train_loss); 380 | println!("Validation loss {:?}", val_loss); 381 | Ok(()) 382 | } 383 | } 384 | 385 | /// # Comparing generations with different GPT-2 model sizes 386 | /// 387 | /// #### Id 388 | /// 5.6 389 | /// 390 | /// #### CLI command 391 | /// ```sh 392 | /// # without cuda 393 | /// cargo run exercise 5.6 394 | /// 395 | /// # with cuda 396 | /// cargo run --features cuda exercise 5.6 397 | /// ``` 398 | pub struct X6; 399 | 400 | impl Exercise for X6 { 401 | fn name(&self) -> String { 402 | String::from("5.6") 403 | } 404 | 405 | fn title(&self) -> String { 406 | "Comparing generations with different GPT-2 model sizes".to_string() // missing from book 407 | } 408 | 409 | fn statement(&self) -> String { 410 | let stmt = "Experiment with GPT-2 models of different sizes—for \ 411 | example, the largest 1,558 million parameter model—and compare the \ 412 | generated text to the 124 million model."; 413 | stmt.to_string() 414 | } 415 | 416 | fn main(&self) -> Result<()> { 417 | use crate::listings::{ 418 | ch04::{Config, GPTModel}, 419 | ch05::{generate, load_weights_into_gpt, text_to_token_ids, token_ids_to_text}, 420 | }; 421 | use candle_core::{DType, Device}; 422 | use candle_nn::{VarBuilder, VarMap}; 423 | use hf_hub::api::sync::Api; 424 | use rand::{rngs::StdRng, SeedableRng}; 425 | use tiktoken_rs::get_bpe_from_model; 426 | 427 | let dev = Device::cuda_if_available(0)?; 428 | let varmap = VarMap::new(); 429 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 430 | let mut cfg = Config::gpt2_xlarge(); 431 | cfg.qkv_bias = true; 432 | let model = GPTModel::new(cfg, vb.pp("model"))?; 433 | 434 | // get weights from HF Hub 435 | let model_name = "openai-community/gpt2-xl"; 436 | let api = Api::new()?; 437 | let repo = api.model(model_name.to_string()); 438 | let weights = repo.get("model.safetensors")?; 439 | let weights = candle_core::safetensors::load(weights, &Device::Cpu)?; 440 | 441 | // load weights 442 | load_weights_into_gpt(&varmap, weights, Some("model"), cfg.n_layers)?; 443 | 444 | // sample setup and load tokenizer 445 | let start_context = "Every effort moves you"; 446 | let tokenizer = get_bpe_from_model("gpt2")?; 447 | 448 | let mut rng = StdRng::seed_from_u64(42_u64); 449 | let token_ids = generate( 450 | &model, 451 | text_to_token_ids(start_context, &tokenizer, vb.device())?, 452 | 25_usize, 453 | cfg.context_length, 454 | Some(0.1_f64), 455 | Some(50_usize), 456 | None, 457 | &mut rng, 458 | )?; 459 | 460 | // decode the token ids to print the output text 461 | println!( 462 | "Model:\n{model_name}\n\nOutput text:\n{:?}", 463 | token_ids_to_text(token_ids, &tokenizer)? 464 | ); 465 | Ok(()) 466 | } 467 | } 468 | -------------------------------------------------------------------------------- /src/exercises/ch06.rs: -------------------------------------------------------------------------------- 1 | //! Exercises from Chapter 6 2 | 3 | use crate::Exercise; 4 | use anyhow::Result; 5 | 6 | /// # Increasing the context length 7 | /// 8 | /// #### Id 9 | /// 6.1 10 | /// 11 | /// #### CLI command 12 | /// ```sh 13 | /// # without cuda 14 | /// cargo run exercise 6.1 15 | /// 16 | /// # with cuda 17 | /// cargo run --features cuda exercise 6.1 18 | /// ``` 19 | pub struct X1; 20 | 21 | impl Exercise for X1 { 22 | fn name(&self) -> String { 23 | String::from("6.1") 24 | } 25 | 26 | fn title(&self) -> String { 27 | String::from("Increasing the context length") 28 | } 29 | 30 | fn statement(&self) -> String { 31 | let stmt = "Pad the inputs to the maximum number of tokens the model \ 32 | supports and observe how it affects the predictive performance."; 33 | stmt.to_string() 34 | } 35 | 36 | fn main(&self) -> Result<()> { 37 | use crate::listings::{ 38 | ch04::Config, 39 | ch06::{ 40 | calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification, 41 | train_classifier_simple, SpamDataLoader, SpamDatasetBuilder, HF_GPT2_MODEL_ID, 42 | }, 43 | }; 44 | use anyhow::anyhow; 45 | use candle_core::{DType, Device, Var}; 46 | use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap}; 47 | use std::ops::Not; 48 | use std::path::Path; 49 | use tiktoken_rs::get_bpe_from_model; 50 | 51 | println!("Creating train, val, test datasets"); 52 | // create datasets 53 | let tokenizer = get_bpe_from_model("gpt2")?; 54 | let max_length = Some(512_usize); 55 | 56 | let train_path = Path::new("data").join("train.parquet"); 57 | if train_path.exists().not() { 58 | return Err(anyhow!( 59 | "Missing 'data/train.parquet' file. Please run EG 06.04." 60 | )); 61 | } 62 | let train_dataset = SpamDatasetBuilder::new(&tokenizer) 63 | .load_data_from_parquet(train_path) 64 | .max_length(max_length) 65 | .build(); 66 | println!( 67 | "...train dataset max length: {}", 68 | train_dataset.max_length() 69 | ); 70 | 71 | let val_path = Path::new("data").join("validation.parquet"); 72 | if val_path.exists().not() { 73 | return Err(anyhow!( 74 | "Missing 'data/validation.parquet' file. Please run EG 06.04." 75 | )); 76 | } 77 | let val_dataset = SpamDatasetBuilder::new(&tokenizer) 78 | .load_data_from_parquet(val_path) 79 | .max_length(max_length) 80 | .build(); 81 | println!("...val dataset max length: {}", val_dataset.max_length()); 82 | 83 | let test_path = Path::new("data").join("test.parquet"); 84 | if test_path.exists().not() { 85 | return Err(anyhow!( 86 | "Missing 'data/test.parquet' file. Please run EG 06.04." 87 | )); 88 | } 89 | let test_dataset = SpamDatasetBuilder::new(&tokenizer) 90 | .load_data_from_parquet(test_path) 91 | .max_length(max_length) 92 | .build(); 93 | println!("...test dataset max length: {}", test_dataset.max_length()); 94 | 95 | // create loaders 96 | let batch_size = 2_usize; 97 | let train_loader = SpamDataLoader::new(train_dataset, batch_size, true, true); 98 | let val_loader = SpamDataLoader::new(val_dataset, batch_size, false, false); 99 | let test_loader = SpamDataLoader::new(test_dataset, batch_size, false, false); 100 | 101 | // print total number of batches in each data loader 102 | println!("...{:?} training batches", train_loader.len()); 103 | println!("...{:?} validation batches", val_loader.len()); 104 | println!("...{:?} test batches", test_loader.len()); 105 | 106 | // get model 107 | println!("Loading pre-trained GPT-2 and modifying prediction head"); 108 | let mut cfg = Config::gpt2_124m(); 109 | cfg.qkv_bias = true; 110 | let varmap = VarMap::new(); 111 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 112 | let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?; 113 | modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?; 114 | 115 | // train model 116 | // trainable: last trf block, final layer norm, classification head 117 | let mut training_vars: Vec = vec![]; 118 | let tensor_data = varmap.data().lock().unwrap(); 119 | let var_names: Vec<&String> = tensor_data 120 | .keys() 121 | .filter(|k| k.contains("final_norm") || k.contains("out_head") || k.contains("trf.11")) 122 | .collect(); 123 | for var_name in var_names.into_iter() { 124 | let var = tensor_data.get(var_name).unwrap(); 125 | training_vars.push(var.clone()); 126 | } 127 | drop(tensor_data); 128 | 129 | let optimizer = AdamW::new( 130 | training_vars, 131 | ParamsAdamW { 132 | lr: 5e-5, 133 | weight_decay: 0.1, 134 | ..Default::default() 135 | }, 136 | )?; 137 | 138 | println!("Fine-tuning GPT2 on spam training dataset"); 139 | let (eval_freq, eval_iter, num_epochs) = (50_usize, 1_usize, 2_usize); 140 | let _ = train_classifier_simple( 141 | &model, 142 | &train_loader, 143 | &val_loader, 144 | optimizer, 145 | vb.device(), 146 | num_epochs, 147 | eval_freq, 148 | eval_iter, 149 | None, 150 | ); 151 | 152 | println!("Computing performance metrics"); 153 | // compute accuracies 154 | let num_batches = None; 155 | let train_accuracy = 156 | calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?; 157 | let val_accuracy = 158 | calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?; 159 | let test_accuracy = 160 | calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?; 161 | 162 | println!("Training accuracy: {}", train_accuracy); 163 | println!("Validation accuracy: {}", val_accuracy); 164 | println!("Test accuracy: {}", test_accuracy); 165 | 166 | Ok(()) 167 | } 168 | } 169 | 170 | /// # Fine-tuning the whole model 171 | /// 172 | /// #### Id 173 | /// 6.2 174 | /// 175 | /// #### CLI command 176 | /// ```sh 177 | /// # without cuda 178 | /// cargo run exercise 6.2 179 | /// 180 | /// # with cuda 181 | /// cargo run --features cuda exercise 6.2 182 | /// ``` 183 | pub struct X2; 184 | 185 | impl Exercise for X2 { 186 | fn name(&self) -> String { 187 | "6.2".to_string() 188 | } 189 | 190 | fn title(&self) -> String { 191 | "Fine-tuning the whole model".to_string() 192 | } 193 | 194 | fn statement(&self) -> String { 195 | let stmt = "Instead of fine-tuning just the final transformer \ 196 | block, fine-tune the entire model and assess the effect on predictive \ 197 | performance."; 198 | stmt.to_string() 199 | } 200 | 201 | fn main(&self) -> Result<()> { 202 | use crate::listings::{ 203 | ch04::Config, 204 | ch06::{ 205 | calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification, 206 | train_classifier_simple, HF_GPT2_MODEL_ID, 207 | }, 208 | }; 209 | use candle_core::{DType, Device}; 210 | use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap}; 211 | 212 | // get gpt model with classification head 213 | let mut cfg = Config::gpt2_124m(); 214 | cfg.qkv_bias = true; 215 | let varmap = VarMap::new(); 216 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 217 | let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?; 218 | modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?; 219 | 220 | // get data loaders 221 | let eg06 = crate::examples::ch06::EG06; // re-use 222 | let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?; 223 | 224 | // trainable params and optimizer 225 | let optimizer = AdamW::new( 226 | varmap.all_vars(), // train on all vars 227 | ParamsAdamW { 228 | lr: 5e-5, 229 | weight_decay: 0.1, 230 | ..Default::default() 231 | }, 232 | )?; 233 | 234 | println!("Fine-tuning ENTIRE GPT2 on spam training dataset"); 235 | let (eval_freq, eval_iter, num_epochs) = (50_usize, 5_usize, 5_usize); 236 | let _ = train_classifier_simple( 237 | &model, 238 | &train_loader, 239 | &val_loader, 240 | optimizer, 241 | vb.device(), 242 | num_epochs, 243 | eval_freq, 244 | eval_iter, 245 | None, 246 | ); 247 | 248 | println!("Computing performance metrics"); 249 | // compute accuracies 250 | let num_batches = None; 251 | let train_accuracy = 252 | calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?; 253 | let val_accuracy = 254 | calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?; 255 | let test_accuracy = 256 | calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?; 257 | 258 | println!("Training accuracy: {}", train_accuracy); 259 | println!("Validation accuracy: {}", val_accuracy); 260 | println!("Test accuracy: {}", test_accuracy); 261 | 262 | Ok(()) 263 | } 264 | } 265 | 266 | /// # Fine-tuning the first vs. last token 267 | /// 268 | /// #### Id 269 | /// 6.3 270 | /// 271 | /// #### CLI command 272 | /// ```sh 273 | /// # without cuda 274 | /// cargo run exercise 6.3 275 | /// 276 | /// # with cuda 277 | /// cargo run --features cuda exercise 6.3 278 | /// ``` 279 | pub struct X3; 280 | 281 | impl Exercise for X3 { 282 | fn name(&self) -> String { 283 | "6.3".to_string() 284 | } 285 | 286 | fn title(&self) -> String { 287 | "Fine-tuning the first vs. last token".to_string() 288 | } 289 | 290 | fn statement(&self) -> String { 291 | let stmt = "Try fine-tuning the first output token. Notice the \ 292 | changes in predictive performance compared to fine-tuning the last \ 293 | output token."; 294 | stmt.to_string() 295 | } 296 | 297 | fn main(&self) -> Result<()> { 298 | use crate::listings::{ 299 | ch04::Config, 300 | ch06::{ 301 | calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification, 302 | train_classifier_simple, HF_GPT2_MODEL_ID, 303 | }, 304 | }; 305 | use candle_core::{DType, Device, Var}; 306 | use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap}; 307 | 308 | // get gpt model with classification head 309 | let mut cfg = Config::gpt2_124m(); 310 | cfg.qkv_bias = true; 311 | let varmap = VarMap::new(); 312 | let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?); 313 | let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?; 314 | modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?; 315 | 316 | // get data loaders 317 | let eg06 = crate::examples::ch06::EG06; // re-use 318 | let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?; 319 | 320 | // trainable params and optimizer 321 | // trainable: last trf block, final layer norm, classification head 322 | let mut training_vars: Vec = vec![]; 323 | let tensor_data = varmap.data().lock().unwrap(); 324 | let var_names: Vec<&String> = tensor_data 325 | .keys() 326 | .filter(|k| k.contains("final_norm") || k.contains("out_head") || k.contains("trf.11")) 327 | .collect(); 328 | for var_name in var_names.into_iter() { 329 | let var = tensor_data.get(var_name).unwrap(); 330 | training_vars.push(var.clone()); 331 | } 332 | drop(tensor_data); 333 | 334 | let optimizer = AdamW::new( 335 | training_vars, 336 | ParamsAdamW { 337 | lr: 5e-5, 338 | weight_decay: 0.1, 339 | ..Default::default() 340 | }, 341 | )?; 342 | 343 | println!("Fine-tuning GPT2 on spam training dataset using first-token"); 344 | let (eval_freq, eval_iter, num_epochs) = (50_usize, 5_usize, 5_usize); 345 | let custom_pred_token_index = Some(0_usize); // use the first token! 346 | let _ = train_classifier_simple( 347 | &model, 348 | &train_loader, 349 | &val_loader, 350 | optimizer, 351 | vb.device(), 352 | num_epochs, 353 | eval_freq, 354 | eval_iter, 355 | custom_pred_token_index, 356 | ); 357 | 358 | println!("Computing performance metrics"); 359 | // compute accuracies 360 | let num_batches = None; 361 | let train_accuracy = calc_accuracy_loader( 362 | &train_loader, 363 | &model, 364 | vb.device(), 365 | num_batches, 366 | custom_pred_token_index, 367 | )?; 368 | let val_accuracy = calc_accuracy_loader( 369 | &val_loader, 370 | &model, 371 | vb.device(), 372 | num_batches, 373 | custom_pred_token_index, 374 | )?; 375 | let test_accuracy = calc_accuracy_loader( 376 | &test_loader, 377 | &model, 378 | vb.device(), 379 | num_batches, 380 | custom_pred_token_index, 381 | )?; 382 | 383 | println!("Training accuracy: {}", train_accuracy); 384 | println!("Validation accuracy: {}", val_accuracy); 385 | println!("Test accuracy: {}", test_accuracy); 386 | 387 | Ok(()) 388 | } 389 | } 390 | -------------------------------------------------------------------------------- /src/exercises/mod.rs: -------------------------------------------------------------------------------- 1 | //! Exercises 2 | //! 3 | //! This module contains solutions to the exercises found in each chapter of 4 | //! the LLMs From Scratch book. The book offers code solutions to exercises that 5 | //! are based in PyTorch. Here, we provide solutions using Rust and the Candle 6 | //! crate. 7 | 8 | pub mod ch02; 9 | pub mod ch03; 10 | pub mod ch04; 11 | pub mod ch05; 12 | pub mod ch06; 13 | pub mod ch07; 14 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Build A Large Language Model From Scratch — Rust Translations 2 | //! 3 | //! #### Intro 4 | //! 5 | //! This crate provides Rust translations of the examples, exercises and listings 6 | //! found in the book Build A LLM From Scratch by Sebastian Raschka 7 | //! ([github](https://github.com/rasbt/LLMs-from-scratch)), which is 8 | //! a great resource for careful learning of LLMs. The book provides several 9 | //! examples and listings which are written in PyTorch in order to learn how to 10 | //! build a GPT (decoder-only) language model. This crate provides the Rust 11 | //! equivalent for nearly all of the code provided in the book using 12 | //! [candle](https://github.com/huggingface/candle) (a Minimalist ML framework for Rust). 13 | //! 14 | //! The lib crate consists of three modules: `examples`, `exercises` and `listings`. 15 | //! Additionally there is a companion binary crate that executes all of the examples 16 | //! and exercises. 17 | 18 | use anyhow::Result; 19 | 20 | pub mod candle_addons; 21 | pub mod examples; 22 | pub mod exercises; 23 | pub mod listings; 24 | 25 | /// Exercise Trait 26 | pub trait Exercise: Send + Sync { 27 | fn name(&self) -> String; 28 | 29 | fn title(&self) -> String; 30 | 31 | fn statement(&self) -> String; 32 | 33 | fn main(&self) -> Result<()>; 34 | } 35 | 36 | /// Example Trait 37 | pub trait Example: Send + Sync { 38 | fn description(&self) -> String; 39 | 40 | fn page_source(&self) -> usize; 41 | 42 | fn main(&self) -> Result<()>; 43 | } 44 | -------------------------------------------------------------------------------- /src/listings/ch02.rs: -------------------------------------------------------------------------------- 1 | //! Listings from Chapter 2 2 | 3 | use candle_core::{Device, Result, Tensor}; 4 | use candle_datasets::{batcher::IterResult2, Batcher}; 5 | use fancy_regex::{Captures, Regex}; 6 | use rand::{rng, seq::SliceRandom}; 7 | use std::collections::HashMap; 8 | use std::fs; 9 | use std::rc::Rc; 10 | use tiktoken_rs::CoreBPE; 11 | 12 | /// [Listing 2.1] Reading in a short story as text sample into Rust 13 | pub fn sample_read_text(verbose: bool) -> Result { 14 | let raw_text = fs::read_to_string("data/the-verdict.txt").expect("Unable to read the file"); 15 | if verbose { 16 | println!("Total number of character: {:?}", raw_text.len()); 17 | println!("{:?}", &raw_text[..99]); 18 | } 19 | Ok(raw_text) 20 | } 21 | 22 | /// [Listing 2.2] Creating a vocabulary 23 | pub fn sample_create_vocab() -> Result> { 24 | let raw_text = sample_read_text(false)?; 25 | let re = Regex::new(r#"([,.?_!"()']|--|\s)"#).unwrap(); 26 | let mut preprocessed: Vec<&str> = re.split(&raw_text[..]).map(|x| x.unwrap()).collect(); 27 | preprocessed.sort(); 28 | 29 | let vocab: HashMap = HashMap::from_iter( 30 | preprocessed 31 | .iter() 32 | .enumerate() 33 | .map(|(idx, el)| (el.to_string(), idx as i32)), 34 | ); 35 | Ok(vocab) 36 | } 37 | 38 | /// [Listing 2.3] Implementing a simple text tokenizer 39 | #[derive(Default, Debug)] 40 | pub struct SimpleTokenizerV1 { 41 | str_to_int: HashMap, 42 | int_to_str: HashMap, 43 | } 44 | 45 | impl SimpleTokenizerV1 { 46 | /// Creates a new `SimpleTokenizerV1` from a vocab. 47 | /// 48 | /// ```rust 49 | /// use llms_from_scratch_rs::listings::ch02::SimpleTokenizerV1; 50 | /// use std::collections::HashMap; 51 | /// 52 | /// let vocab: HashMap<&str, i32> = HashMap::from([ 53 | /// ("this", 1_i32), 54 | /// ("is", 2_i32), 55 | /// ("a", 3_i32), 56 | /// ("test", 4_i32) 57 | /// ]); 58 | /// let tokenizer = SimpleTokenizerV1::from_vocab(vocab); 59 | /// ``` 60 | pub fn from_vocab(vocab: HashMap<&str, i32>) -> Self { 61 | Self { 62 | str_to_int: vocab.iter().map(|(k, v)| (String::from(*k), *v)).collect(), 63 | int_to_str: vocab.iter().map(|(k, v)| (*v, String::from(*k))).collect(), 64 | } 65 | } 66 | 67 | /// Encode a text into its token ids. 68 | pub fn encode(&self, text: &str) -> Vec { 69 | let re = Regex::new(r#"([,.?_!"()']|--|\s)"#).unwrap(); 70 | let preprocessed: Vec<&str> = re.split(text).map(|x| x.unwrap()).collect(); 71 | preprocessed 72 | .into_iter() 73 | .map(|s| self.str_to_int.get(&String::from(s)).unwrap()) 74 | .cloned() 75 | .collect() 76 | } 77 | 78 | /// Decode token ids into its text. 79 | pub fn decode(&self, ids: Vec) -> String { 80 | let text_vec: Vec = ids 81 | .iter() 82 | .map(|i| self.int_to_str.get(i).unwrap()) 83 | .cloned() 84 | .collect(); 85 | let text = &text_vec.join(" ")[..]; 86 | 87 | // remove space before any punctuations 88 | let re = Regex::new(r#"\s+([,.?!"()\'])"#).unwrap(); 89 | String::from(re.replace_all(text, |caps: &Captures| caps[1].to_string())) 90 | } 91 | } 92 | 93 | /// [Listing 2.4] A simple text tokenizer that handles unknown words 94 | #[derive(Default, Debug)] 95 | pub struct SimpleTokenizerV2 { 96 | str_to_int: HashMap, 97 | int_to_str: HashMap, 98 | } 99 | 100 | impl SimpleTokenizerV2 { 101 | /// Creates a new `SimpleTokenizerV2` from a vocab. 102 | /// 103 | /// ```rust 104 | /// use llms_from_scratch_rs::listings::ch02::SimpleTokenizerV2; 105 | /// use std::collections::HashMap; 106 | /// 107 | /// let vocab: HashMap<&str, i32> = HashMap::from([ 108 | /// ("this", 1_i32), 109 | /// ("is", 2_i32), 110 | /// ("a", 3_i32), 111 | /// ("test", 4_i32) 112 | /// ]); 113 | /// // Any words not in the vocab will be encoded as "<|unk|>" token 114 | /// let tokenizer = SimpleTokenizerV2::from_vocab(vocab); 115 | /// ``` 116 | pub fn from_vocab(vocab: HashMap<&str, i32>) -> Self { 117 | // add special tokens to vocab if needed 118 | let mut next_token_id = vocab.len() as i32 + 1_i32; 119 | let mut vocab_copy = vocab.clone(); 120 | 121 | if !vocab.contains_key("<|unk|>") { 122 | vocab_copy.entry("<|unk|>").or_insert(next_token_id); 123 | next_token_id += 1; 124 | } 125 | 126 | if !vocab.contains_key("|endoftext|>") { 127 | vocab_copy.entry("<|endoftext|>").or_insert(next_token_id); 128 | } 129 | 130 | Self { 131 | str_to_int: vocab_copy 132 | .iter() 133 | .map(|(k, v)| (String::from(*k), *v)) 134 | .collect(), 135 | int_to_str: vocab_copy 136 | .iter() 137 | .map(|(k, v)| (*v, String::from(*k))) 138 | .collect(), 139 | } 140 | } 141 | 142 | /// Encode a text into its token ids. 143 | pub fn encode(&self, text: &str) -> Vec { 144 | let re = Regex::new(r#"([,.?_!"()']|--|\s)"#).unwrap(); 145 | let preprocessed: Vec<&str> = re.split(text).map(|x| x.unwrap()).collect(); 146 | preprocessed 147 | .into_iter() 148 | .map(|s| { 149 | self.str_to_int 150 | .get(&String::from(s)) 151 | .unwrap_or(self.str_to_int.get("<|unk|>").unwrap()) 152 | }) 153 | .cloned() 154 | .collect() 155 | } 156 | 157 | /// Decode token ids into its text. 158 | pub fn decode(&self, ids: Vec) -> String { 159 | let text_vec: Vec = ids 160 | .iter() 161 | .map(|i| self.int_to_str.get(i).unwrap()) 162 | .cloned() 163 | .collect(); 164 | let text = &text_vec.join(" ")[..]; 165 | 166 | // remove space before any punctuations 167 | let re = Regex::new(r#"\s+([,.?!"()\'])"#).unwrap(); 168 | String::from(re.replace_all(text, |caps: &Captures| caps[1].to_string())) 169 | } 170 | } 171 | 172 | pub struct GPTDatasetV1_ { 173 | input_ids: Vec>, 174 | target_ids: Vec>, 175 | } 176 | 177 | /// [Listing 2.5] A dataset for batched inputs and targets 178 | /// 179 | /// GPTDatasetV1 is a wrapper for `GPTDatasetV1_` which is refcounted. 180 | /// This makes cloning datasets cheap. I.e., when creating a batcher of a 181 | /// dataset. 182 | #[derive(Clone)] 183 | pub struct GPTDatasetV1(Rc); 184 | 185 | impl AsRef for GPTDatasetV1 { 186 | fn as_ref(&self) -> &GPTDatasetV1 { 187 | self 188 | } 189 | } 190 | 191 | impl std::ops::Deref for GPTDatasetV1 { 192 | type Target = GPTDatasetV1_; 193 | 194 | fn deref(&self) -> &Self::Target { 195 | self.0.as_ref() 196 | } 197 | } 198 | 199 | impl GPTDatasetV1 { 200 | /// Creates a new `GPTDatasetV1`. 201 | /// 202 | /// ```rust 203 | /// use tiktoken_rs::get_bpe_from_model; 204 | /// use llms_from_scratch_rs::listings::ch02::GPTDatasetV1; 205 | /// 206 | /// let txt = "In the heart of the city"; 207 | /// let tokenizer = get_bpe_from_model("gpt2").unwrap(); 208 | /// let token_ids = tokenizer.encode_with_special_tokens(&txt[..]); 209 | /// let stride = 1_usize; 210 | /// let max_length = 3_usize; 211 | /// let dataset = GPTDatasetV1::new(&txt[..], tokenizer, max_length, stride); 212 | /// ``` 213 | pub fn new(txt: &str, tokenizer: CoreBPE, max_length: usize, stride: usize) -> Self { 214 | let token_ids = tokenizer.encode_with_special_tokens(txt); 215 | 216 | let mut input_ids: Vec> = Vec::default(); 217 | let mut target_ids: Vec> = Vec::default(); 218 | // get input_ids and target_ids 219 | for i in (0..token_ids.len() - max_length).step_by(stride) { 220 | let input_chunk = &token_ids[i..(i + max_length)]; 221 | let target_chunk = &token_ids[(i + 1_usize)..(i + max_length + 1_usize)]; 222 | input_ids.push(input_chunk.to_vec()); 223 | target_ids.push(target_chunk.to_vec()); 224 | } 225 | 226 | let dataset_ = GPTDatasetV1_ { 227 | input_ids, 228 | target_ids, 229 | }; 230 | 231 | Self(Rc::new(dataset_)) 232 | } 233 | 234 | /// Gets the number of input-target sequences in the dataset. 235 | pub fn len(&self) -> usize { 236 | self.input_ids.len() 237 | } 238 | 239 | /// Checks whether the dataset is empty or has no input-target sequences. 240 | pub fn is_empty(&self) -> bool { 241 | self.input_ids.len() == 0 242 | } 243 | 244 | /// Returns the input tokens for all input sequences. 245 | pub fn input_ids(&self) -> &Vec> { 246 | &self.input_ids 247 | } 248 | 249 | /// Returns the target token ides for all input sequences. 250 | pub fn target_ids(&self) -> &Vec> { 251 | &self.target_ids 252 | } 253 | 254 | /// Returns the input-target pair at the specified index. 255 | pub fn get_pair_at_index(&self, idx: usize) -> (&Vec, &Vec) { 256 | (&self.input_ids[idx], &self.target_ids[idx]) 257 | } 258 | } 259 | 260 | /// `GPTDatasetIter` analagous to PyTorch's `DataLoader` class/ 261 | /// 262 | /// A data loader to generate batches with input-target pairs 263 | /// We can use `GPTDatasetIter` with `candle_datasets::Batcher` to get desired 264 | /// batches of examples. 265 | pub struct GPTDatasetIter { 266 | dataset: GPTDatasetV1, 267 | remaining_indices: Vec, 268 | } 269 | 270 | impl GPTDatasetIter { 271 | /// Creates a new `GPTDatasetIter`. 272 | /// 273 | /// ```rust 274 | /// use llms_from_scratch_rs::listings::ch02::{GPTDatasetV1, GPTDatasetIter} ; 275 | /// use tiktoken_rs::get_bpe_from_model; 276 | /// 277 | /// let txt = "In the heart of the city"; 278 | /// let tokenizer = get_bpe_from_model("gpt2").unwrap(); 279 | /// 280 | /// let stride = 1_usize; 281 | /// let max_length = 3_usize; 282 | /// let dataset = GPTDatasetV1::new(&txt[..], tokenizer, max_length, stride); 283 | /// let iter = GPTDatasetIter::new(dataset.clone(), false); 284 | /// ``` 285 | pub fn new(dataset: GPTDatasetV1, shuffle: bool) -> Self { 286 | let mut remaining_indices = (0..dataset.len()).rev().collect::>(); 287 | if shuffle { 288 | remaining_indices.shuffle(&mut rng()); 289 | } 290 | Self { 291 | dataset, 292 | remaining_indices, 293 | } 294 | } 295 | } 296 | 297 | impl Iterator for GPTDatasetIter { 298 | type Item = Result<(Tensor, Tensor)>; 299 | 300 | fn next(&mut self) -> Option { 301 | if let Some(idx) = self.remaining_indices.pop() { 302 | let (input_ids, target_ids) = self.dataset.get_pair_at_index(idx); 303 | 304 | // turn into Tensors and return 305 | let dev = Device::cuda_if_available(0).unwrap(); 306 | let input_tensor = Tensor::new(&input_ids[..], &dev); 307 | let target_tensor = Tensor::new(&target_ids[..], &dev); 308 | Some(candle_core::error::zip(input_tensor, target_tensor)) 309 | } else { 310 | None 311 | } 312 | } 313 | } 314 | 315 | /// A type alias for candle_datasets::Batcher 316 | /// 317 | /// This struct is responsible for getting batches from a type that implements 318 | /// the `Iterator` Trait. 319 | pub type GPTDataBatcher = Batcher>; 320 | 321 | /// A type for building a `Batcher` over a `GPTDataset` with specified params. 322 | pub struct GPTDataLoader { 323 | dataset: GPTDatasetV1, 324 | batch_size: usize, 325 | shuffle: bool, 326 | drop_last: bool, 327 | } 328 | 329 | /// A DataLoader trait 330 | /// 331 | /// NOTE: Was introduced in ch07 since we wanted to re-use the methods here and 332 | /// those introduced in ch05, namely `calc_loss_loader`. 333 | pub trait DataLoader { 334 | type Batcher; 335 | 336 | fn batcher(&self) -> Self::Batcher; 337 | } 338 | 339 | impl DataLoader for GPTDataLoader { 340 | type Batcher = GPTDataBatcher; 341 | /// Returns a `GPTDataBatcher` that itself provides batches over the 342 | /// associated dataset. 343 | fn batcher(&self) -> GPTDataBatcher { 344 | let iter = GPTDatasetIter::new(self.dataset.clone(), self.shuffle); 345 | Batcher::new_r2(iter) 346 | .batch_size(self.batch_size) 347 | .return_last_incomplete_batch(!self.drop_last) 348 | } 349 | } 350 | 351 | impl GPTDataLoader { 352 | /// Creates a new GPTDataLoader. 353 | /// 354 | /// ```rust 355 | /// use llms_from_scratch_rs::listings::ch02::{GPTDatasetV1, GPTDataLoader}; 356 | /// use tiktoken_rs::get_bpe_from_model; 357 | /// 358 | /// let txt = "In the heart of the city"; 359 | /// let tokenizer = tiktoken_rs::get_bpe_from_model("gpt2").unwrap(); 360 | /// let max_length = 3_usize; 361 | /// let stride = 1_usize; 362 | /// let dataset = GPTDatasetV1::new(txt, tokenizer, max_length, stride); 363 | /// 364 | /// let batch_size = 2_usize; 365 | /// let shuffle = false; 366 | /// let drop_last = false; 367 | /// let data_loader = GPTDataLoader::new(dataset, batch_size, shuffle, drop_last); 368 | /// ``` 369 | pub fn new(dataset: GPTDatasetV1, batch_size: usize, shuffle: bool, drop_last: bool) -> Self { 370 | Self { 371 | dataset, 372 | batch_size, 373 | shuffle, 374 | drop_last, 375 | } 376 | } 377 | 378 | pub fn len(&self) -> usize { 379 | if self.drop_last { 380 | self.batcher().count() 381 | } else { 382 | // There is a bug in candle_datasets::Batcher, such that if 383 | // return_last_incomplete_batch is set to true, then the iterator 384 | // will never return None. This breaks `Iterator.count()` which consumes 385 | // the iterator until a None is encountered. 386 | let mut batcher = self.batcher(); 387 | let mut count = 0_usize; 388 | while let Some(Ok(_el)) = batcher.next() { 389 | count += 1; 390 | } 391 | count 392 | } 393 | } 394 | 395 | pub fn is_empty(&self) -> bool { 396 | (self.dataset.len() < self.batch_size) && (self.drop_last) 397 | } 398 | } 399 | 400 | /// [Listing 2.6] A data loader to generate batches with input-output pairs 401 | /// 402 | /// ```rust 403 | /// use llms_from_scratch_rs::listings::ch02::create_dataloader_v1; 404 | /// 405 | /// let txt = "In the heart of the city"; 406 | /// let batch_size = 2_usize; 407 | /// let stride = 1_usize; 408 | /// let max_length = 3_usize; 409 | /// let shuffle = false; 410 | /// let drop_last = false; 411 | /// let data_loader = 412 | /// create_dataloader_v1(txt, batch_size, max_length, stride, shuffle, drop_last); 413 | /// ``` 414 | pub fn create_dataloader_v1( 415 | txt: &str, 416 | batch_size: usize, 417 | max_length: usize, 418 | stride: usize, 419 | shuffle: bool, 420 | drop_last: bool, 421 | ) -> GPTDataLoader { 422 | let tokenizer = tiktoken_rs::get_bpe_from_model("gpt2").unwrap(); 423 | let dataset = GPTDatasetV1::new(txt, tokenizer, max_length, stride); 424 | GPTDataLoader::new(dataset, batch_size, shuffle, drop_last) 425 | } 426 | 427 | #[cfg(test)] 428 | mod tests { 429 | use core::panic; 430 | 431 | use super::*; 432 | use anyhow::Result; 433 | use candle_datasets::Batcher; 434 | use rstest::*; 435 | use tiktoken_rs::get_bpe_from_model; 436 | 437 | #[fixture] 438 | pub fn vocab() -> HashMap<&'static str, i32> { 439 | let mut vocab: HashMap<&str, i32> = HashMap::new(); 440 | vocab.entry("this").or_insert(1); 441 | vocab.entry("is").or_insert(2); 442 | vocab.entry("a").or_insert(3); 443 | vocab.entry("test").or_insert(4); 444 | return vocab; 445 | } 446 | 447 | #[fixture] 448 | pub fn txt_tokenizer() -> (String, CoreBPE) { 449 | let txt = "In the heart of the city"; 450 | let tokenizer = get_bpe_from_model("gpt2").unwrap(); 451 | (txt.to_string(), tokenizer) 452 | } 453 | 454 | #[fixture] 455 | pub fn gpt_dataset(#[from(txt_tokenizer)] (txt, tokenizer): (String, CoreBPE)) -> GPTDatasetV1 { 456 | let stride = 1_usize; 457 | let max_length = 3_usize; 458 | GPTDatasetV1::new(&txt[..], tokenizer, max_length, stride) 459 | } 460 | 461 | #[rstest] 462 | fn test_simple_tokenizer_init(vocab: HashMap<&str, i32>) -> Result<()> { 463 | let tokenizer: SimpleTokenizerV1 = SimpleTokenizerV1::from_vocab(vocab); 464 | 465 | // assert 466 | assert_eq!(tokenizer.str_to_int.get(&String::from("this")), Some(&1)); 467 | assert_eq!(tokenizer.str_to_int.get(&String::from("is")), Some(&2)); 468 | assert_eq!(tokenizer.str_to_int.get(&String::from("a")), Some(&3)); 469 | assert_eq!(tokenizer.str_to_int.get(&String::from("test")), Some(&4)); 470 | Ok(()) 471 | } 472 | 473 | #[rstest] 474 | fn test_encode(vocab: HashMap<&str, i32>) -> Result<()> { 475 | let tokenizer = SimpleTokenizerV1::from_vocab(vocab); 476 | let token_ids = tokenizer.encode("this is a test"); 477 | 478 | assert_eq!(token_ids[0], 1); 479 | assert_eq!(token_ids[1], 2); 480 | assert_eq!(token_ids[2], 3); 481 | assert_eq!(token_ids[3], 4); 482 | Ok(()) 483 | } 484 | 485 | #[rstest] 486 | fn test_simple_tokenizer_decode(mut vocab: HashMap<&str, i32>) -> Result<()> { 487 | vocab.entry(".").or_insert(5); 488 | let tokenizer = SimpleTokenizerV1::from_vocab(vocab); 489 | 490 | let token_ids = vec![1, 2, 3, 4, 5]; 491 | let text = tokenizer.decode(token_ids); 492 | 493 | assert_eq!(text, "this is a test."); 494 | Ok(()) 495 | } 496 | 497 | #[rstest] 498 | fn test_simple_tokenizer_v2_encode(vocab: HashMap<&str, i32>) -> Result<()> { 499 | let tokenizer = SimpleTokenizerV2::from_vocab(vocab); 500 | let token_ids = tokenizer.encode("this is a test! <|endoftext|>"); 501 | 502 | assert_eq!(token_ids[0], 1); 503 | assert_eq!(token_ids[1], 2); 504 | assert_eq!(token_ids[2], 3); 505 | assert_eq!(token_ids[3], 4); 506 | assert_eq!(token_ids[4], 5); 507 | assert_eq!(token_ids[5], 6); 508 | Ok(()) 509 | } 510 | 511 | #[rstest] 512 | fn test_simple_tokenizer_v2_decode(vocab: HashMap<&str, i32>) -> Result<()> { 513 | let tokenizer = SimpleTokenizerV2::from_vocab(vocab); 514 | 515 | let token_ids = vec![1, 2, 3, 4, 5, 6]; 516 | let text = tokenizer.decode(token_ids); 517 | 518 | assert_eq!(text, "this is a test <|unk|> <|endoftext|>"); 519 | Ok(()) 520 | } 521 | 522 | #[rstest] 523 | fn test_gpt_dataset_v1_init( 524 | #[from(txt_tokenizer)] (txt, tokenizer): (String, CoreBPE), 525 | ) -> Result<()> { 526 | let token_ids = tokenizer.encode_with_special_tokens(&txt[..]); 527 | let stride = 1_usize; 528 | let max_length = 3_usize; 529 | let dataset = GPTDatasetV1::new(&txt[..], tokenizer, max_length, stride); 530 | 531 | for mx in 1..max_length { 532 | // test target alignments 533 | assert_eq!( 534 | dataset.input_ids[0][mx], 535 | dataset.target_ids[0][mx - 1_usize] 536 | ); 537 | } 538 | 539 | for ix in 1..dataset.input_ids.len() { 540 | // test max length per input 541 | assert!(dataset.input_ids[ix].len() == max_length); 542 | // test stride alignments 543 | assert_eq!(dataset.input_ids[ix][0], token_ids[ix * stride]); 544 | } 545 | Ok(()) 546 | } 547 | 548 | #[rstest] 549 | fn test_gpt_dataset_v1_iter( 550 | #[from(txt_tokenizer)] (txt, tokenizer): (String, CoreBPE), 551 | ) -> Result<()> { 552 | let stride = 1_usize; 553 | let max_length = 3_usize; 554 | let dataset = GPTDatasetV1::new(&txt[..], tokenizer, max_length, stride); 555 | let mut iter = GPTDatasetIter::new(dataset.clone(), false); 556 | let mut count = 0_usize; 557 | 558 | // user iter to sequentially get next pair checking equality with dataset 559 | while let Some(Ok((this_inputs, this_targets))) = iter.next() { 560 | let this_inputs_vec: Vec = this_inputs.to_vec1::()?; 561 | let this_targets_vec: Vec = this_targets.to_vec1::()?; 562 | 563 | assert!(this_inputs.shape().dims()[0] == max_length); 564 | assert!(this_targets.shape().dims()[0] == max_length); 565 | 566 | for (idx, token_id) in this_inputs_vec.iter().enumerate() { 567 | assert_eq!(*token_id, dataset.input_ids[count][idx]); 568 | } 569 | for (idx, token_id) in this_targets_vec.iter().enumerate() { 570 | assert_eq!(*token_id, dataset.target_ids[count][idx]); 571 | } 572 | 573 | count += 1; 574 | } 575 | assert_eq!(count, dataset.len()); 576 | Ok(()) 577 | } 578 | 579 | #[rstest] 580 | fn test_gpt_dataset_with_batch(#[from(gpt_dataset)] dataset: GPTDatasetV1) -> Result<()> { 581 | let iter = GPTDatasetIter::new(dataset.clone(), false); 582 | let batch_size = 2_usize; 583 | let mut batch_iter = Batcher::new_r2(iter).batch_size(batch_size); 584 | 585 | match batch_iter.next() { 586 | Some(Ok((inputs, targets))) => { 587 | assert_eq!(inputs.dims(), targets.dims()); 588 | assert_eq!(inputs.dims()[0], batch_size); 589 | } 590 | Some(Err(err)) => panic!("{}", err), 591 | None => panic!("None"), 592 | } 593 | Ok(()) 594 | } 595 | 596 | #[rstest] 597 | fn test_create_dataloader_v1() -> Result<()> { 598 | let txt = "In the heart of the city"; 599 | let batch_size = 2_usize; 600 | let stride = 1_usize; 601 | let max_length = 3_usize; 602 | let shuffle = false; 603 | let drop_last = false; 604 | let data_loader = 605 | create_dataloader_v1(txt, batch_size, max_length, stride, shuffle, drop_last); 606 | 607 | let mut batcher = data_loader.batcher(); 608 | let mut count = 0_usize; 609 | while let Some(Ok((inputs, targets))) = batcher.next() { 610 | assert_eq!(inputs.dims(), targets.dims()); 611 | assert!(inputs.dims()[0] <= batch_size); 612 | count += 1; 613 | } 614 | assert!(!data_loader.is_empty()); 615 | assert_eq!(data_loader.len(), count); 616 | Ok(()) 617 | } 618 | } 619 | -------------------------------------------------------------------------------- /src/listings/ch03.rs: -------------------------------------------------------------------------------- 1 | //! Listings from Chapter 3 2 | 3 | use candle_core::{Device, Module, ModuleT, Result, Tensor, D}; 4 | use candle_nn::ops::softmax; 5 | use candle_nn::{linear_b, Dropout, Linear, VarBuilder}; 6 | 7 | pub fn get_mask(size: usize, device: &Device) -> Result { 8 | let mask: Vec<_> = (0..size) 9 | .flat_map(|i| (0..size).map(move |j| u32::from(j > i))) 10 | .collect(); 11 | Tensor::from_slice(&mask, (size, size), device) 12 | } 13 | 14 | pub fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { 15 | let shape = mask.shape(); 16 | let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; 17 | let m = mask.where_cond(&on_true, on_false)?; 18 | Ok(m) 19 | } 20 | 21 | /// [Listing 3.1] A compact self-attention class 22 | /// 23 | /// `SelfAttentionV1` is a simple implementation of a self-attention layer. 24 | /// It follows a similar interface to other candle `Module`'s. 25 | pub struct SelfAttentionV1 { 26 | pub w_query: Tensor, 27 | pub w_key: Tensor, 28 | pub w_value: Tensor, 29 | pub scaling: f64, 30 | } 31 | 32 | impl SelfAttentionV1 { 33 | /// Creates a new `SelfAttentionV1` 34 | /// 35 | /// ```rust 36 | /// use candle_core::{Device, DType}; 37 | /// use candle_nn::{VarMap, VarBuilder}; 38 | /// use llms_from_scratch_rs::listings::ch03::SelfAttentionV1; 39 | /// 40 | /// let dev = Device::cuda_if_available(0).unwrap(); 41 | /// let varmap = VarMap::new(); 42 | /// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 43 | /// let (d_in, d_out) = (3_usize, 5_usize); 44 | /// let attn_v1_layer = SelfAttentionV1::new(d_in, d_out, vb.pp("attn")).unwrap(); 45 | /// ``` 46 | pub fn new(d_in: usize, d_out: usize, vb: VarBuilder<'_>) -> Result { 47 | let init = candle_nn::init::DEFAULT_KAIMING_NORMAL; 48 | let w_query = vb.get_with_hints((d_in, d_out), "query", init)?; 49 | let w_key = vb.get_with_hints((d_in, d_out), "key", init)?; 50 | let w_value = vb.get_with_hints((d_in, d_out), "value", init)?; 51 | let scaling = 1. / (w_key.dims()[1] as f64).sqrt(); 52 | 53 | Ok(Self { 54 | w_query, 55 | w_key, 56 | w_value, 57 | scaling, 58 | }) 59 | } 60 | 61 | pub fn w_query(&self) -> &Tensor { 62 | &self.w_query 63 | } 64 | 65 | pub fn w_key(&self) -> &Tensor { 66 | &self.w_key 67 | } 68 | 69 | pub fn w_value(&self) -> &Tensor { 70 | &self.w_value 71 | } 72 | } 73 | 74 | impl Module for SelfAttentionV1 { 75 | /// Computes the context vector for `xs` 76 | fn forward(&self, xs: &Tensor) -> Result { 77 | let queries = xs.matmul(&self.w_query)?; 78 | let keys = xs.matmul(&self.w_key)?; 79 | let values = xs.matmul(&self.w_value)?; 80 | 81 | let attn_scores = queries.matmul(&keys.t()?)?; 82 | let attn_weights = candle_nn::ops::softmax(&(attn_scores * self.scaling)?, 1)?; 83 | attn_weights.matmul(&values) 84 | } 85 | } 86 | 87 | /// [Listing 3.2] A self-attention class using candle_nn::Linear 88 | pub struct SelfAttentionV2 { 89 | w_query: Linear, 90 | w_key: Linear, 91 | w_value: Linear, 92 | scaling: f64, 93 | } 94 | 95 | impl SelfAttentionV2 { 96 | /// Creates a new `SelfAttentionV2` 97 | /// 98 | /// ```rust 99 | /// use candle_core::{Device, DType}; 100 | /// use candle_nn::{VarMap, VarBuilder}; 101 | /// use llms_from_scratch_rs::listings::ch03::SelfAttentionV2; 102 | /// 103 | /// let dev = Device::cuda_if_available(0).unwrap(); 104 | /// let varmap = VarMap::new(); 105 | /// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 106 | /// let (d_in, d_out) = (3_usize, 5_usize); 107 | /// let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn")).unwrap(); 108 | /// ``` 109 | pub fn new(d_in: usize, d_out: usize, qkv_bias: bool, vb: VarBuilder<'_>) -> Result { 110 | let w_query = linear_b(d_in, d_out, qkv_bias, vb.pp("query"))?; 111 | let w_key = linear_b(d_in, d_out, qkv_bias, vb.pp("key"))?; 112 | let w_value = linear_b(d_in, d_out, qkv_bias, vb.pp("value"))?; 113 | let scaling = 1. / (w_key.weight().dims()[0] as f64).sqrt(); 114 | 115 | Ok(Self { 116 | w_query, 117 | w_key, 118 | w_value, 119 | scaling, 120 | }) 121 | } 122 | 123 | pub fn w_query(&self) -> &Linear { 124 | &self.w_query 125 | } 126 | 127 | pub fn w_key(&self) -> &Linear { 128 | &self.w_key 129 | } 130 | 131 | pub fn w_value(&self) -> &Linear { 132 | &self.w_value 133 | } 134 | } 135 | 136 | impl Module for SelfAttentionV2 { 137 | fn forward(&self, xs: &Tensor) -> Result { 138 | let queries = self.w_query.forward(xs)?; 139 | let keys = self.w_key.forward(xs)?; 140 | let values = self.w_value.forward(xs)?; 141 | 142 | let attn_scores = queries.matmul(&keys.t()?)?; 143 | let attn_weights = candle_nn::ops::softmax(&(attn_scores * self.scaling)?, D::Minus1)?; 144 | attn_weights.matmul(&values) 145 | } 146 | } 147 | 148 | /// [Listing 3.3] A compact causal attention class 149 | pub struct CausalAttention { 150 | w_query: Linear, 151 | w_key: Linear, 152 | w_value: Linear, 153 | scaling: f64, 154 | dropout: Dropout, 155 | drop_p: f32, 156 | } 157 | 158 | impl CausalAttention { 159 | /// Creates a new `CausalAttention` 160 | /// 161 | /// ```rust 162 | /// use candle_core::{Device, DType}; 163 | /// use candle_nn::{VarMap, VarBuilder}; 164 | /// use llms_from_scratch_rs::listings::ch03::CausalAttention; 165 | /// 166 | /// let dev = Device::cuda_if_available(0).unwrap(); 167 | /// let varmap = VarMap::new(); 168 | /// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 169 | /// let (d_in, d_out) = (3_usize, 5_usize); 170 | /// let casual_attn = CausalAttention::new(d_in, d_out, 0.5_f32, false, vb.pp("attn")).unwrap(); 171 | /// ``` 172 | pub fn new( 173 | d_in: usize, 174 | d_out: usize, 175 | drop_p: f32, 176 | qkv_bias: bool, 177 | vb: VarBuilder<'_>, 178 | ) -> Result { 179 | let w_query = linear_b(d_in, d_out, qkv_bias, vb.pp("query"))?; 180 | let w_key = linear_b(d_in, d_out, qkv_bias, vb.pp("key"))?; 181 | let w_value = linear_b(d_in, d_out, qkv_bias, vb.pp("value"))?; 182 | let scaling = 1. / (w_key.weight().dims()[0] as f64).sqrt(); 183 | let dropout = Dropout::new(drop_p); 184 | 185 | Ok(Self { 186 | w_query, 187 | w_key, 188 | w_value, 189 | scaling, 190 | dropout, 191 | drop_p, // a private field in Dropout 192 | }) 193 | } 194 | 195 | pub fn w_query(&self) -> &Linear { 196 | &self.w_query 197 | } 198 | 199 | pub fn w_key(&self) -> &Linear { 200 | &self.w_key 201 | } 202 | 203 | pub fn w_value(&self) -> &Linear { 204 | &self.w_value 205 | } 206 | 207 | pub fn drop_p(&self) -> f32 { 208 | self.drop_p 209 | } 210 | } 211 | 212 | impl Module for CausalAttention { 213 | fn forward(&self, xs: &Tensor) -> Result { 214 | // handles batches now 215 | let (b, num_tokens, _d_in) = xs.dims3()?; 216 | let queries = self.w_query.forward(xs)?; 217 | let keys = self.w_key.forward(xs)?; 218 | let values = self.w_value.forward(xs)?; 219 | 220 | let attn_scores = queries.matmul(&keys.transpose(D::Minus2, D::Minus1)?)?; 221 | let mask = get_mask(num_tokens, xs.device())?; 222 | let masked = masked_fill( 223 | &attn_scores, 224 | &mask.broadcast_left(b).unwrap(), 225 | f32::NEG_INFINITY, 226 | )?; 227 | 228 | // scale 229 | let mut attn_weights = softmax(&(masked * self.scaling)?, D::Minus1)?; 230 | // dropout 231 | attn_weights = self.dropout.forward(&attn_weights, true).unwrap(); 232 | 233 | // context vectors 234 | attn_weights.matmul(&values) 235 | } 236 | } 237 | 238 | /// [Listing 3.4] A wrapper to implement multi-head attention 239 | pub struct MultiHeadAttentionWrapper { 240 | heads: Vec, 241 | } 242 | 243 | impl MultiHeadAttentionWrapper { 244 | /// Creates a new `MultiHeadAttentionWrapper` 245 | /// 246 | /// ```rust 247 | /// use candle_core::{Device, DType}; 248 | /// use candle_nn::{VarMap, VarBuilder}; 249 | /// use llms_from_scratch_rs::listings::ch03::MultiHeadAttentionWrapper; 250 | /// 251 | /// let dev = Device::cuda_if_available(0).unwrap(); 252 | /// let varmap = VarMap::new(); 253 | /// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 254 | /// let (d_in, d_out, num_heads) = (3_usize, 6_usize, 3_usize); 255 | /// let multihead_attn = MultiHeadAttentionWrapper::new( 256 | /// num_heads, 257 | /// d_in, 258 | /// d_out, 259 | /// 0.5_f32, 260 | /// false, 261 | /// vb.pp("multihead_attn"), 262 | /// ).unwrap(); 263 | /// ``` 264 | pub fn new( 265 | num_heads: usize, 266 | d_in: usize, 267 | d_out: usize, 268 | drop_p: f32, 269 | qkv_bias: bool, 270 | vb: VarBuilder<'_>, 271 | ) -> Result { 272 | let heads = (0..num_heads) 273 | .map(|i| { 274 | CausalAttention::new(d_in, d_out, drop_p, qkv_bias, vb.pp(format!("head-{}", i))) 275 | .unwrap() 276 | }) 277 | .collect::>(); 278 | Ok(Self { heads }) 279 | } 280 | } 281 | 282 | impl Module for MultiHeadAttentionWrapper { 283 | fn forward(&self, xs: &Tensor) -> Result { 284 | let context_vectors = self 285 | .heads 286 | .iter() 287 | .map(|attn| attn.forward(xs).unwrap()) 288 | .collect::>(); 289 | let reduced = context_vectors 290 | .into_iter() 291 | .reduce(|acc, e| Tensor::cat(&[&acc, &e], D::Minus1).unwrap()) 292 | .unwrap(); // todo us ok_or to convert Option to Result 293 | Ok(reduced) 294 | } 295 | } 296 | 297 | /// [Listing 3.5] An efficient multi-head attention type 298 | #[derive(Clone, Debug)] 299 | pub struct MultiHeadAttention { 300 | num_heads: usize, 301 | d_out: usize, 302 | head_dim: usize, 303 | w_query: Linear, 304 | w_key: Linear, 305 | w_value: Linear, 306 | out_proj: Linear, 307 | scaling: f64, 308 | dropout: Dropout, 309 | drop_p: f32, 310 | } 311 | 312 | impl MultiHeadAttention { 313 | /// Creates a new `MultiHeadAttention` 314 | /// 315 | /// ```rust 316 | /// use candle_core::{Device, DType}; 317 | /// use candle_nn::{VarMap, VarBuilder}; 318 | /// use llms_from_scratch_rs::listings::ch03::MultiHeadAttention; 319 | /// 320 | /// let dev = Device::cuda_if_available(0).unwrap(); 321 | /// let varmap = VarMap::new(); 322 | /// let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev); 323 | /// let (d_in, d_out, num_heads) = (3_usize, 6_usize, 2_usize); 324 | /// let mha = MultiHeadAttention::new(d_in, d_out, 0.5_f32, num_heads, false, vb.pp("attn")).unwrap(); 325 | /// ``` 326 | pub fn new( 327 | d_in: usize, 328 | d_out: usize, 329 | drop_p: f32, 330 | num_heads: usize, 331 | qkv_bias: bool, 332 | vb: VarBuilder<'_>, 333 | ) -> Result { 334 | if d_out % num_heads != 0 { 335 | panic!("`d_out` must be divisible by `num_heads`.") 336 | } 337 | let head_dim = d_out / num_heads; 338 | 339 | let w_query = linear_b(d_in, d_out, qkv_bias, vb.pp("query"))?; 340 | let w_key = linear_b(d_in, d_out, qkv_bias, vb.pp("key"))?; 341 | let w_value = linear_b(d_in, d_out, qkv_bias, vb.pp("value"))?; 342 | let out_proj = linear_b(d_out, d_out, true, vb.pp("out_proj"))?; 343 | let scaling = 1. / (head_dim as f64).sqrt(); 344 | let dropout = Dropout::new(drop_p); 345 | 346 | Ok(Self { 347 | num_heads, 348 | d_out, 349 | head_dim, 350 | w_query, 351 | w_key, 352 | w_value, 353 | out_proj, 354 | scaling, 355 | dropout, 356 | drop_p, 357 | }) 358 | } 359 | 360 | pub fn w_query(&self) -> &Linear { 361 | &self.w_query 362 | } 363 | 364 | pub fn w_key(&self) -> &Linear { 365 | &self.w_key 366 | } 367 | 368 | pub fn w_value(&self) -> &Linear { 369 | &self.w_value 370 | } 371 | 372 | pub fn out_proj(&self) -> &Linear { 373 | &self.out_proj 374 | } 375 | 376 | pub fn d_out(&self) -> usize { 377 | self.d_out 378 | } 379 | 380 | pub fn scaling(&self) -> f64 { 381 | self.scaling 382 | } 383 | 384 | pub fn dropout(&self) -> &Dropout { 385 | &self.dropout 386 | } 387 | 388 | pub fn drop_p(&self) -> f32 { 389 | self.drop_p 390 | } 391 | 392 | pub fn head_dim(&self) -> usize { 393 | self.head_dim 394 | } 395 | 396 | pub fn num_heads(&self) -> usize { 397 | self.num_heads 398 | } 399 | 400 | /// Manual implementation of forward 401 | /// 402 | /// Note: that blanket implementation of `ModuleT` when a type implements 403 | /// `Module` prevents having `forward` being overrided. Thus, this type 404 | /// is `ModuleT` but technicall not `Module`. 405 | pub fn forward(&self, xs: &Tensor) -> Result { 406 | self.forward_t(xs, true) 407 | } 408 | } 409 | 410 | impl ModuleT for MultiHeadAttention { 411 | fn forward_t(&self, xs: &Tensor, train: bool) -> Result { 412 | let (b, num_tokens, _d_in) = xs.dims3()?; 413 | let queries = self.w_query.forward_t(xs, train)?; 414 | let keys = self.w_key.forward_t(xs, train)?; 415 | let values = self.w_value.forward_t(xs, train)?; 416 | 417 | // reshapes to facilitate getting attn scores each of the individual heads 418 | // with one matrix multiplication 419 | let queries = queries 420 | .reshape((b, num_tokens, self.num_heads, self.head_dim))? 421 | .transpose(1, 2)? 422 | .contiguous()?; 423 | let keys = keys 424 | .reshape((b, num_tokens, self.num_heads, self.head_dim))? 425 | .transpose(1, 2)? 426 | .contiguous()?; 427 | let values = values 428 | .reshape((b, num_tokens, self.num_heads, self.head_dim))? 429 | .transpose(1, 2)? 430 | .contiguous()?; 431 | 432 | let attn_scores = queries.matmul(&keys.transpose(D::Minus2, D::Minus1)?)?; 433 | 434 | let mask = get_mask(num_tokens, xs.device())?; 435 | let masked = masked_fill( 436 | &attn_scores, 437 | &mask.broadcast_left((b, self.num_heads)).unwrap(), 438 | f32::NEG_INFINITY, 439 | )?; 440 | 441 | // scale 442 | let mut attn_weights = softmax(&(masked * self.scaling)?, D::Minus1)?; 443 | // dropout 444 | attn_weights = self.dropout.forward(&attn_weights, train)?; 445 | 446 | // context vectors 447 | let context_vec = attn_weights.matmul(&values)?.transpose(1, 2)?; 448 | let context_vec = context_vec 449 | .reshape((b, num_tokens, self.d_out))? 450 | .contiguous()?; 451 | 452 | // projection 453 | self.out_proj.forward_t(&context_vec, train) 454 | } 455 | } 456 | 457 | #[cfg(test)] 458 | mod tests { 459 | use super::*; 460 | use anyhow::Result; 461 | use candle_core::{DType, Device}; 462 | use candle_nn::{VarBuilder, VarMap}; 463 | use rstest::*; 464 | 465 | #[fixture] 466 | pub fn vb() -> VarBuilder<'static> { 467 | let dev = Device::cuda_if_available(0).unwrap(); 468 | let varmap = VarMap::new(); 469 | VarBuilder::from_varmap(&varmap, DType::F32, &dev) 470 | } 471 | 472 | #[rstest] 473 | fn test_self_attention_v1_init(vb: VarBuilder<'_>) -> Result<()> { 474 | let (d_in, d_out) = (3_usize, 5_usize); 475 | let attn_v1_layer = SelfAttentionV1::new(d_in, d_out, vb.pp("attn"))?; 476 | 477 | assert_eq!(attn_v1_layer.w_query.dims(), &[d_in, d_out]); 478 | assert_eq!(attn_v1_layer.w_key.dims(), &[d_in, d_out]); 479 | assert_eq!(attn_v1_layer.w_value.dims(), &[d_in, d_out]); 480 | Ok(()) 481 | } 482 | 483 | #[rstest] 484 | fn test_self_attention_v1_forward(vb: VarBuilder<'_>) -> Result<()> { 485 | let (d_in, d_out) = (3_usize, 5_usize); 486 | let attn_v1_layer = SelfAttentionV1::new(d_in, d_out, vb.pp("attn"))?; 487 | 488 | let input_length = 10_usize; 489 | let xs = Tensor::rand(0f32, 1f32, (input_length, d_in), vb.device())?; 490 | let context_vectors = attn_v1_layer.forward(&xs)?; 491 | 492 | assert_eq!(context_vectors.dims(), &[input_length, d_out]); 493 | Ok(()) 494 | } 495 | 496 | #[rstest] 497 | fn test_self_attention_v2_init(vb: VarBuilder<'_>) -> Result<()> { 498 | let (d_in, d_out) = (3_usize, 5_usize); 499 | let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?; 500 | 501 | assert_eq!(attn_v2_layer.w_query.weight().dims(), &[d_out, d_in]); 502 | assert_eq!(attn_v2_layer.w_key.weight().dims(), &[d_out, d_in]); 503 | assert_eq!(attn_v2_layer.w_value.weight().dims(), &[d_out, d_in]); 504 | Ok(()) 505 | } 506 | 507 | #[rstest] 508 | fn test_self_attention_v2_forward(vb: VarBuilder<'_>) -> Result<()> { 509 | let (d_in, d_out) = (3_usize, 5_usize); 510 | let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?; 511 | 512 | let input_length = 10_usize; 513 | let xs = Tensor::rand(0f32, 1f32, (input_length, d_in), vb.device())?; 514 | let context_vectors = attn_v2_layer.forward(&xs)?; 515 | 516 | assert_eq!(context_vectors.dims(), &[input_length, d_out]); 517 | Ok(()) 518 | } 519 | 520 | #[rstest] 521 | fn test_causal_attention_init(vb: VarBuilder<'_>) -> Result<()> { 522 | let (d_in, d_out) = (3_usize, 5_usize); 523 | let casual_attn = CausalAttention::new(d_in, d_out, 0.5_f32, false, vb.pp("attn"))?; 524 | 525 | assert_eq!(casual_attn.w_query.weight().dims(), &[d_out, d_in]); 526 | assert_eq!(casual_attn.w_key.weight().dims(), &[d_out, d_in]); 527 | assert_eq!(casual_attn.w_value.weight().dims(), &[d_out, d_in]); 528 | assert_eq!(casual_attn.drop_p, 0.5_f32); 529 | Ok(()) 530 | } 531 | 532 | #[rstest] 533 | fn test_causal_attention_forward(vb: VarBuilder<'_>) -> Result<()> { 534 | let (d_in, d_out) = (3_usize, 5_usize); 535 | let casual_attn = CausalAttention::new(d_in, d_out, 0.5_f32, false, vb.pp("attn"))?; 536 | 537 | // create batch 538 | let input_length = 10_usize; 539 | let xs = Tensor::rand(0f32, 1f32, (input_length, d_in), &vb.device())?; 540 | let batch = Tensor::stack(&[&xs, &xs], 0)?; 541 | let context_vectors = casual_attn.forward(&batch)?; 542 | 543 | assert_eq!(context_vectors.dims(), &[2_usize, input_length, d_out]); 544 | Ok(()) 545 | } 546 | 547 | #[rstest] 548 | fn test_multihead_attention_wrapper_init(vb: VarBuilder<'_>) -> Result<()> { 549 | let (d_in, d_out) = (3_usize, 5_usize); 550 | let num_heads = 3_usize; 551 | let multihead_attn = MultiHeadAttentionWrapper::new( 552 | num_heads, 553 | d_in, 554 | d_out, 555 | 0.5_f32, 556 | false, 557 | vb.pp("multihead_attn"), 558 | )?; 559 | 560 | assert_eq!(multihead_attn.heads.len(), num_heads); 561 | 562 | for i in 0..num_heads { 563 | let causal_attn = &multihead_attn.heads[i]; 564 | assert_eq!(causal_attn.w_query.weight().dims(), &[d_out, d_in]); 565 | assert_eq!(causal_attn.w_key.weight().dims(), &[d_out, d_in]); 566 | assert_eq!(causal_attn.w_value.weight().dims(), &[d_out, d_in]); 567 | assert_eq!(causal_attn.drop_p, 0.5_f32); 568 | } 569 | Ok(()) 570 | } 571 | 572 | #[rstest] 573 | fn test_multihead_attention_wrapper_forward(vb: VarBuilder<'_>) -> Result<()> { 574 | let (d_in, d_out) = (3_usize, 5_usize); 575 | let num_heads = 3_usize; 576 | let multihead_attn = MultiHeadAttentionWrapper::new( 577 | num_heads, 578 | d_in, 579 | d_out, 580 | 0.5_f32, 581 | false, 582 | vb.pp("multihead_attn"), 583 | )?; 584 | 585 | // create batch 586 | let input_length = 10_usize; 587 | let xs = Tensor::rand(0f32, 1f32, (input_length, d_in), &vb.device())?; 588 | let batch = Tensor::stack(&[&xs, &xs], 0)?; 589 | let context_vectors = multihead_attn.forward(&batch)?; 590 | 591 | assert_eq!( 592 | context_vectors.dims(), 593 | &[2_usize, input_length, num_heads * d_out] 594 | ); 595 | Ok(()) 596 | } 597 | 598 | #[rstest] 599 | fn test_mha_init(vb: VarBuilder<'_>) -> Result<()> { 600 | let (d_in, d_out, num_heads) = (3_usize, 6_usize, 2_usize); 601 | let mha = MultiHeadAttention::new(d_in, d_out, 0.5_f32, num_heads, false, vb.pp("attn"))?; 602 | 603 | assert_eq!(mha.w_query.weight().dims(), &[d_out, d_in]); 604 | assert_eq!(mha.w_key.weight().dims(), &[d_out, d_in]); 605 | assert_eq!(mha.w_value.weight().dims(), &[d_out, d_in]); 606 | assert_eq!(mha.out_proj.weight().dims(), &[d_out, d_out]); 607 | assert_eq!(mha.head_dim, d_out / num_heads); 608 | assert_eq!(mha.drop_p, 0.5_f32); 609 | Ok(()) 610 | } 611 | 612 | #[rstest] 613 | #[should_panic(expected = "`d_out` must be divisible by `num_heads`.")] 614 | fn test_mha_init_panics_nondivisible_heads(vb: VarBuilder<'_>) { 615 | let (d_in, d_out, num_heads) = (3_usize, 6_usize, 4_usize); 616 | let _ = 617 | MultiHeadAttention::new(d_in, d_out, 0.5_f32, num_heads, false, vb.pp("attn")).unwrap(); 618 | } 619 | 620 | #[rstest] 621 | fn test_mha_forward(vb: VarBuilder<'_>) -> Result<()> { 622 | let (d_in, d_out, num_heads) = (3_usize, 6_usize, 3_usize); 623 | let mha = MultiHeadAttention::new(d_in, d_out, 0.5_f32, num_heads, false, vb.pp("attn"))?; 624 | 625 | // create batch 626 | let input_length = 10_usize; 627 | let xs = Tensor::rand(0f32, 1f32, (input_length, d_in), &vb.device())?; 628 | let batch = Tensor::stack(&[&xs, &xs], 0)?; 629 | let context_vectors = mha.forward(&batch)?; 630 | 631 | assert_eq!(context_vectors.dims(), &[2_usize, input_length, d_out]); 632 | Ok(()) 633 | } 634 | } 635 | -------------------------------------------------------------------------------- /src/listings/mod.rs: -------------------------------------------------------------------------------- 1 | //! Listings 2 | //! 3 | //! This module contains Rust/Candle translations for all of the Listings provided 4 | //! in LLMs From Scratch book. 5 | 6 | pub mod apdx_e; 7 | pub mod ch02; 8 | pub mod ch03; 9 | pub mod ch04; 10 | pub mod ch05; 11 | pub mod ch06; 12 | pub mod ch07; 13 | --------------------------------------------------------------------------------