├── chF
├── 02_mmlu
│ ├── requirements-extra.txt
│ ├── random_guessing_baseline.py
│ ├── 1_letter_matching.py
│ ├── 2_logprob.py
│ └── 3_teacher_forcing.py
├── README.md
├── 03_leaderboards
│ ├── 1_elo_leaderboard.py
│ ├── votes.json
│ ├── README.md
│ └── 2_bradley_terry_leaderboard.py
└── 04_llm-judge
│ └── README.md
├── requirements.txt
├── ch01
└── README.md
├── chG
├── 01_main-chapter-code
│ ├── public
│ │ ├── logo_dark.webp
│ │ └── logo_light.webp
│ ├── README.md
│ ├── qwen3_chat_interface.py
│ └── qwen3_chat_interface_multiturn.py
└── README.md
├── chC
└── README.md
├── ch02
├── 01_main-chapter-code
│ └── README.md
├── 02_setup-tips
│ ├── README.md
│ ├── gpu-instructions.md
│ └── python-instructions.md
├── README.md
├── 04_torch-compile-windows
│ └── README.md
├── 05_use_model
│ ├── generate_simple.py
│ ├── chat.py
│ ├── chat_multiturn.py
│ └── README.md
└── 03_optimized-LLM
│ ├── compare_inference.py
│ └── README.md
├── .github
├── ISSUE_TEMPLATE
│ ├── ask-a-question.md
│ └── bug-report.yaml
├── workflows
│ ├── check-spelling-errors.yml
│ ├── tests-linux.yml
│ ├── tests-macos.yml
│ ├── basic-tests-old-pytorch.yml
│ ├── basic-tests-pip.yml
│ ├── basic-test-nightly-pytorch.yml
│ ├── check-links.yml
│ ├── tests-windows.yml
│ └── code-linter.yml
└── scripts
│ ├── check_notebook_line_length.py
│ └── check_double_quotes.py
├── ch03
├── README.md
├── 01_main-chapter-code
│ └── README.md
└── 02_math500-verifier-scripts
│ ├── evaluate_math500.py
│ └── README.md
├── ch04
├── README.md
└── 02_math500-inference-scaling-scripts
│ ├── cot_prompting_math500.py
│ └── README.md
├── ch05
└── README.md
├── tests
├── test_appendix_f.py
├── test_math500_scripts.py
├── test_appendix_c.py
├── conftest.py
├── test_ch02.py
├── test_ch02_ex.py
├── test_ch04.py
├── test_ch05.py
├── test_qwen3_optimized.py
└── test_qwen3_batched_stop.py
├── reasoning_from_scratch
├── __init__.py
├── ch02_ex.py
├── appendix_f.py
├── utils.py
├── ch02.py
├── appendix_c.py
└── ch05.py
├── pyproject.toml
└── .gitignore
/chF/02_mmlu/requirements-extra.txt:
--------------------------------------------------------------------------------
1 | datasets >= 4.1.1
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | reasoning-from-scratch >= 0.1.2
2 | sympy>=1.14.0 # For verifier in ch03
3 | torch >= 2.7.1
4 | tokenizers >= 0.21.2
--------------------------------------------------------------------------------
/ch01/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 1: Understanding Reasoning Models
2 |
3 |
4 |
5 | ## Main chapter code
6 |
7 | There is no code in this chapter.
8 |
9 |
--------------------------------------------------------------------------------
/chG/01_main-chapter-code/public/logo_dark.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rasbt/reasoning-from-scratch/HEAD/chG/01_main-chapter-code/public/logo_dark.webp
--------------------------------------------------------------------------------
/chG/01_main-chapter-code/public/logo_light.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rasbt/reasoning-from-scratch/HEAD/chG/01_main-chapter-code/public/logo_light.webp
--------------------------------------------------------------------------------
/chG/README.md:
--------------------------------------------------------------------------------
1 | # Appendix G: Chat Interface
2 |
3 |
4 | ## Main chapter code
5 |
6 | - [01_main-chapter-code](01_main-chapter-code) contains the main chapter code and exercise solutions
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/chC/README.md:
--------------------------------------------------------------------------------
1 | # Appendix C: Qwen3 LLM Source Code
2 |
3 |
4 | ## Main chapter code
5 |
6 | - [01_main-chapter-code](01_main-chapter-code) contains the main chapter code and exercise solutions
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/ch02/01_main-chapter-code/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 2: Generating Text with a Pre-Trained LLM
2 |
3 |
4 |
5 | ## Main chapter code
6 |
7 | - [ch02_main.ipynb](ch02_main.ipynb): main chapter code
8 | - [ch02_exercise-solutions.ipynb](ch02_exercise-solutions.ipynb): exercise solutions
9 |
10 |
--------------------------------------------------------------------------------
/ch02/02_setup-tips/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 2: Generating Text with a Pre-Trained LLM
2 |
3 |
4 |
5 |
6 | ## Bonus material
7 |
8 | - [python-instructions.md](python-instructions.md): optional Python setup recommendations and instructions
9 | - [gpu-instructions.md](gpu-instructions.md): recommendations for cloud compute resources
10 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/ask-a-question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Ask a Question
3 | about: Ask questions related to the book
4 | title: ''
5 | labels: [question]
6 | assignees: rasbt
7 |
8 | ---
9 |
10 | If you have a question that is not a bug, please consider asking it in this GitHub repository's [discussion forum](https://github.com/rasbt/reasoning-from-scratch/discussions).
11 |
--------------------------------------------------------------------------------
/ch03/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 3: Evaluating Reasoning Models
2 |
3 |
4 | ## Main chapter code
5 |
6 | - [01_main-chapter-code](01_main-chapter-code): main chapter code and exercise solutions
7 |
8 |
9 | ## Bonus material
10 |
11 | - [02_math500-verifier-scripts](02_math500-verifier-scripts): optional Python scripts to run the MATH-500 evaluation from the command line, including a batched version with higher throughput
12 |
13 |
--------------------------------------------------------------------------------
/ch02/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 2: Generating Text with a Pre-Trained LLM
2 |
3 |
4 | ## Main chapter code
5 |
6 | - [01_main-chapter-code](01_main-chapter-code): main chapter code and exercise solutions
7 |
8 |
9 | ## Bonus material
10 |
11 | - [02_setup-tips](02_setup-tips/): optional Python setup recommendations and cloud GPU recommendations
12 | - [03_optimized-LLM](03_optimized-LLM): info on how to use a GPU-optimized version of the LLM
13 |
14 |
--------------------------------------------------------------------------------
/ch04/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 4: Improving Reasoning with Inference-Time Scaling
2 |
3 |
4 | ## Main chapter code
5 |
6 | - [01_main-chapter-code](01_main-chapter-code): main chapter code and exercise solutions
7 |
8 |
9 | ## Bonus material
10 |
11 | - [02_math500-inference-scaling-scripts](02_math500-inference-scaling-scripts): optional Python scripts to apply the inference scaling techniques covered in this chapter (CoT prompting and self-consistency) to the MATH-500 evaluation from the previous chapter.
12 |
--------------------------------------------------------------------------------
/ch05/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 5: Inference-Time Scaling via Self-Refinement
2 |
3 |
4 | ## Main chapter code
5 |
6 | - [01_main-chapter-code](01_main-chapter-code): main chapter code and exercise solutions
7 |
8 |
9 | ## Bonus material
10 |
11 | - [02_math500-more-inference-scaling-scripts](02_math500-more-inference-scaling-scripts): optional Python scripts to apply the inference scaling techniques covered in this chapter (Best-of-N and self-refinement) to the MATH-500 evaluation from the previous chapter.
12 |
--------------------------------------------------------------------------------
/chF/README.md:
--------------------------------------------------------------------------------
1 | # Appendix F: Common Approaches to LLM Evaluation
2 |
3 |
4 | ## Main chapter code
5 |
6 | - [01_main-chapter-code](01_main-chapter-code): the main chapter code and exercise solutions
7 |
8 |
9 |
10 | ## Bonus materials
11 |
12 | - [02_mmlu](02_mmlu): MMLU benchmark evaluation with all three different MMLU approaches
13 | - [03_leaderboards](03_leaderboards): Elo and Bradley-Terry implementations of leaderboard rankings
14 | - [04_llm-judge](04_llm-judge): LLM-as-a-judge approach, where a judge LLM evaluates a candidate LLM
15 |
--------------------------------------------------------------------------------
/tests/test_appendix_f.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | from reasoning_from_scratch.appendix_f import elo_ratings
6 | import math
7 |
8 |
9 | def test_elo_single_match():
10 | r = elo_ratings([("A", "B")], k_factor=32, initial_rating=1000)
11 | assert math.isclose(r["A"], 1016)
12 | assert math.isclose(r["B"], 984)
13 |
14 |
15 | def test_elo_total_points_constant():
16 | votes = [("A", "B"), ("B", "C"), ("A", "C")]
17 | r = elo_ratings(votes)
18 | assert math.isclose(sum(r.values()), 3000)
--------------------------------------------------------------------------------
/ch03/01_main-chapter-code/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 3: Evaluating Reasoning Models
2 |
3 |
4 | ## Main chapter code
5 |
6 | - [ch03_main.ipynb](ch03_main.ipynb): main chapter code
7 | - [ch03_exercise-solutions.ipynb](ch03_exercise-solutions.ipynb): exercise solutions
8 |
9 |
10 |
11 | ## Bonus materials
12 |
13 | - [../02_math500-verifier-scripts/evaluate_math500.py](../02_math500-verifier-scripts/evaluate_math500.py): standalone script to evaluate models on the MATH-500 dataset
14 | - [../02_math500-verifier-scripts/evaluate_math500_batched.py](../02_math500-verifier-scripts/evaluate_math500_batched.py): same as above, but processes multiple examples in parallel during generation (for higher throughput)
15 |
16 | Both evaluation scripts import functionality from the [`reasoning_from_scratch`](../../reasoning_from_scratch) package to avoid code duplication. (See [chapter 2 setup instructions](../../ch02/02_setup-tips/python-instructions.md) for installation details.)
17 |
--------------------------------------------------------------------------------
/.github/workflows/check-spelling-errors.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | name: Spell Check
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | pull_request:
11 | branches: [ main ]
12 |
13 | jobs:
14 | spellcheck:
15 | runs-on: ubuntu-latest
16 | env:
17 | SKIP_EXPENSIVE: "1"
18 | steps:
19 | - uses: actions/checkout@v4
20 | - name: Set up Python
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: "3.10"
24 | - name: Install codespell
25 | run: |
26 | curl -LsSf https://astral.sh/uv/install.sh | sh
27 | uv sync --dev --python=3.10
28 | uv add codespell
29 | - name: Run codespell
30 | run: |
31 | source .venv/bin/activate
32 | codespell -L "ocassion,occassion,ot,te,tje" **/*.{txt,md,py,ipynb}
33 |
--------------------------------------------------------------------------------
/reasoning_from_scratch/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | """
6 | Reasoning package used by the "Reasoning Models From Scratch" book.
7 |
8 | Copyright (c) 2025, Sebastian Raschka
9 |
10 | Licensed under the Apache License, Version 2.0 (the "License");
11 | you may not use this file except in compliance with the License.
12 | You may obtain a copy of the License at
13 |
14 | http://www.apache.org/licenses/LICENSE-2.0
15 |
16 | or in the top-level LICENSE file of this repository.
17 |
18 | Unless required by applicable law or agreed to in writing, software
19 | distributed under the License is distributed on an "AS IS" BASIS,
20 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 | See the License for the specific language governing permissions and
22 | limitations under the License.
23 | """
24 |
25 | __version__ = "0.1.12"
26 |
--------------------------------------------------------------------------------
/.github/workflows/tests-linux.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | name: Code tests Linux
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | paths:
11 | - '**/*.py'
12 | - '**/*.ipynb'
13 | - '**/*.yaml'
14 | - '**/*.yml'
15 | - '**/*.sh'
16 | pull_request:
17 | branches: [ main ]
18 | paths:
19 | - '**/*.py'
20 | - '**/*.ipynb'
21 | - '**/*.yaml'
22 | - '**/*.yml'
23 | - '**/*.sh'
24 | workflow_dispatch:
25 |
26 | concurrency:
27 | group: ${{ github.workflow }}-${{ github.ref }}
28 | cancel-in-progress: true
29 |
30 | jobs:
31 | uv-tests:
32 | name: Code tests Linux
33 | runs-on: ubuntu-latest
34 | env:
35 | SKIP_EXPENSIVE: "1"
36 | steps:
37 | - uses: actions/checkout@v4
38 |
39 | - name: Set up Python
40 | uses: actions/setup-python@v5
41 | with:
42 | python-version: "3.13"
43 |
44 | - name: Install uv and dependencies
45 | run: |
46 | curl -LsSf https://astral.sh/uv/install.sh | sh
47 | uv sync --group dev
48 |
49 | - name: Run tests
50 | run: uv run pytest tests
--------------------------------------------------------------------------------
/.github/workflows/tests-macos.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | name: Code tests macOS
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | paths:
11 | - '**/*.py'
12 | - '**/*.ipynb'
13 | - '**/*.yaml'
14 | - '**/*.yml'
15 | - '**/*.sh'
16 | pull_request:
17 | branches: [ main ]
18 | paths:
19 | - '**/*.py'
20 | - '**/*.ipynb'
21 | - '**/*.yaml'
22 | - '**/*.yml'
23 | - '**/*.sh'
24 | workflow_dispatch:
25 |
26 | concurrency:
27 | group: ${{ github.workflow }}-${{ github.ref }}
28 | cancel-in-progress: true
29 |
30 | jobs:
31 | uv-tests:
32 | name: Code tests macOS
33 | runs-on: macos-latest
34 | env:
35 | SKIP_EXPENSIVE: "1"
36 | steps:
37 | - uses: actions/checkout@v4
38 |
39 | - name: Set up Python
40 | uses: actions/setup-python@v5
41 | with:
42 | python-version: "3.13"
43 |
44 | - name: Install uv and dependencies
45 | run: |
46 | curl -LsSf https://astral.sh/uv/install.sh | sh
47 | uv sync --group dev
48 |
49 | - name: Run tests
50 | run: uv run pytest tests
51 |
--------------------------------------------------------------------------------
/tests/test_math500_scripts.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import subprocess
6 | import sys
7 | from pathlib import Path
8 |
9 | import pytest
10 |
11 |
12 | SCRIPT_PATHS = [
13 | Path("ch03/02_math500-verifier-scripts/evaluate_math500_batched.py"),
14 | Path("ch03/02_math500-verifier-scripts/evaluate_math500.py"),
15 | Path("ch04/02_math500-inference-scaling-scripts/self_consistency_math500.py"),
16 | Path("ch04/02_math500-inference-scaling-scripts/cot_prompting_math500.py"),
17 | ]
18 |
19 |
20 | @pytest.mark.parametrize("script_path", SCRIPT_PATHS)
21 | def test_script_help_runs_without_import_errors(script_path):
22 |
23 | repo_root = Path(__file__).resolve().parent.parent
24 | full_path = repo_root / script_path
25 | assert full_path.exists(), f"Expected script at {full_path}"
26 |
27 | # Run scripts with --help to make sure they work
28 |
29 | result = subprocess.run(
30 | [sys.executable, str(full_path), "--help"],
31 | cwd=repo_root,
32 | capture_output=True,
33 | text=True,
34 | )
35 |
36 | assert result.returncode == 0, result.stderr
37 | assert "usage" in result.stdout.lower()
38 |
--------------------------------------------------------------------------------
/tests/test_appendix_c.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import os
6 | from pathlib import Path
7 | import sys
8 | import torch
9 | import pytest
10 |
11 | from reasoning_from_scratch.ch02 import (
12 | generate_text_basic,
13 | generate_text_basic_cache,
14 | )
15 | # Local imports
16 | from test_qwen3 import test_model
17 | from conftest import import_definitions_from_notebook
18 |
19 |
20 | ROOT = Path(__file__).resolve().parents[1]
21 | sys.path.insert(0, str(ROOT))
22 |
23 |
24 | nb_path = ROOT / "chC" / "01_main-chapter-code" / "chC_main.ipynb"
25 | mod = import_definitions_from_notebook(nb_path, "chC_chC_main_defs")
26 | Qwen3Model = getattr(mod, "Qwen3Model")
27 |
28 | # Make CI more reproducible & robust
29 | os.environ["MKL_NUM_THREADS"] = "1"
30 | os.environ["OMP_NUM_THREADS"] = "1"
31 | torch.backends.mkldnn.enabled = False
32 | torch.set_num_threads(1)
33 | torch.use_deterministic_algorithms(True)
34 |
35 |
36 | @pytest.mark.parametrize("ModelClass", [Qwen3Model])
37 | @pytest.mark.parametrize("generate_fn", [generate_text_basic, generate_text_basic_cache])
38 | def test_model_here_too(ModelClass, qwen3_weights_path, generate_fn):
39 | test_model(ModelClass, qwen3_weights_path, generate_fn)
40 |
--------------------------------------------------------------------------------
/.github/workflows/basic-tests-old-pytorch.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | name: Test Old PyTorch
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | paths:
11 | - '**/*.py'
12 | - '**/*.ipynb'
13 | - '**/*.yaml'
14 | - '**/*.yml'
15 | - '**/*.sh'
16 | pull_request:
17 | branches: [ main ]
18 | paths:
19 | - '**/*.py'
20 | - '**/*.ipynb'
21 | - '**/*.yaml'
22 | - '**/*.yml'
23 | - '**/*.sh'
24 | workflow_dispatch:
25 |
26 | concurrency:
27 | group: ${{ github.workflow }}-${{ github.ref }}
28 | cancel-in-progress: true
29 |
30 | jobs:
31 | uv-tests:
32 | name: Code tests
33 | runs-on: ubuntu-latest
34 | env:
35 | SKIP_EXPENSIVE: "1"
36 | steps:
37 | - uses: actions/checkout@v4
38 |
39 | - name: Set up Python (uv)
40 | uses: actions/setup-python@v5
41 | with:
42 | python-version: "3.13"
43 |
44 | - name: Install uv and dependencies
45 | shell: bash
46 | run: |
47 | curl -LsSf https://astral.sh/uv/install.sh | sh
48 | uv sync --group dev
49 | uv add pytest-ruff torch==2.7.1
50 |
51 | - name: Test reasoning_from_scratch package
52 | shell: bash
53 | run: |
54 | uv run pytest tests
55 |
--------------------------------------------------------------------------------
/.github/workflows/basic-tests-pip.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | name: Code Tests (Plain pip)
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | paths:
11 | - '**/*.py'
12 | - '**/*.ipynb'
13 | - '**/*.yaml'
14 | - '**/*.yml'
15 | - '**/*.sh'
16 | pull_request:
17 | branches: [ main ]
18 | paths:
19 | - '**/*.py'
20 | - '**/*.ipynb'
21 | - '**/*.yaml'
22 | - '**/*.yml'
23 | - '**/*.sh'
24 | workflow_dispatch:
25 |
26 | concurrency:
27 | group: ${{ github.workflow }}-${{ github.ref }}
28 | cancel-in-progress: true
29 |
30 | jobs:
31 | pip-tests:
32 | name: Pip Tests (Ubuntu Only)
33 | runs-on: ubuntu-latest
34 | env:
35 | SKIP_EXPENSIVE: "1"
36 | steps:
37 | - uses: actions/checkout@v4
38 |
39 | - name: Set up Python
40 | uses: actions/setup-python@v5
41 | with:
42 | python-version: "3.13"
43 | - name: Create Virtual Environment and Install Dependencies
44 | run: |
45 | python -m venv .venv
46 | source .venv/bin/activate
47 | pip install --upgrade pip
48 | pip install -r requirements.txt
49 | pip install .
50 | pip install pytest
51 | - name: Test Selected Python Scripts
52 | run: |
53 | source .venv/bin/activate
54 | pytest tests
55 |
--------------------------------------------------------------------------------
/.github/workflows/basic-test-nightly-pytorch.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | name: Test Nightly PyTorch
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | paths:
11 | - '**/*.py'
12 | - '**/*.ipynb'
13 | - '**/*.yaml'
14 | - '**/*.yml'
15 | - '**/*.sh'
16 | pull_request:
17 | branches: [ main ]
18 | paths:
19 | - '**/*.py'
20 | - '**/*.ipynb'
21 | - '**/*.yaml'
22 | - '**/*.yml'
23 | - '**/*.sh'
24 | workflow_dispatch:
25 |
26 | concurrency:
27 | group: ${{ github.workflow }}-${{ github.ref }}
28 | cancel-in-progress: true
29 |
30 | jobs:
31 | uv-tests:
32 | name: Code tests
33 | runs-on: ubuntu-latest
34 | env:
35 | SKIP_EXPENSIVE: "1"
36 | steps:
37 | - uses: actions/checkout@v4
38 |
39 | - name: Set up Python (uv)
40 | uses: actions/setup-python@v5
41 | with:
42 | python-version: "3.13"
43 |
44 | - name: Install uv and dependencies
45 | shell: bash
46 | run: |
47 | curl -LsSf https://astral.sh/uv/install.sh | sh
48 | uv sync --group dev
49 | uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
50 |
51 | - name: Test reasoning_from_scratch package
52 | shell: bash
53 | run: |
54 | # uv run pytest tests
55 | # temporarily disabled
56 |
--------------------------------------------------------------------------------
/.github/workflows/check-links.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | name: Check Hyperlinks
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | pull_request:
11 | branches: [ main ]
12 |
13 | jobs:
14 | test:
15 | runs-on: ubuntu-latest
16 | env:
17 | SKIP_EXPENSIVE: "1"
18 | steps:
19 | - uses: actions/checkout@v4
20 | - name: Set up Python
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: "3.10"
24 | - name: Install dependencies
25 | run: |
26 | curl -LsSf https://astral.sh/uv/install.sh | sh
27 | uv add pytest-check-links
28 | - name: Check links
29 | run: |
30 | source .venv/bin/activate
31 | pytest --check-links ./ \
32 | --check-links-ignore "https://platform.openai.com/*" \
33 | --check-links-ignore "https://openai.com/*" \
34 | --check-links-ignore "https://arena.lmsys.org" \
35 | --check-links-ignore "https://unsloth.ai/blog/gradient" \
36 | --check-links-ignore "https://www.reddit.com/r/*" \
37 | --check-links-ignore "https://code.visualstudio.com/*" \
38 | --check-links-ignore "https://arxiv.org/*" \
39 | --check-links-ignore "https://ai.stanford.edu/~amaas/data/sentiment/" \
40 | --check-links-ignore "https://x.com/*" \
41 | --check-links-ignore "https://scholar.google.com/*" \
42 | --check-links-ignore "https://en.wikipedia.org/*"
43 |
--------------------------------------------------------------------------------
/.github/workflows/tests-windows.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | name: Code tests Windows
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | paths:
11 | - '**/*.py'
12 | - '**/*.ipynb'
13 | - '**/*.yaml'
14 | - '**/*.yml'
15 | - '**/*.sh'
16 | pull_request:
17 | branches: [ main ]
18 | paths:
19 | - '**/*.py'
20 | - '**/*.ipynb'
21 | - '**/*.yaml'
22 | - '**/*.yml'
23 | - '**/*.sh'
24 | workflow_dispatch:
25 |
26 | concurrency:
27 | group: ${{ github.workflow }}-${{ github.ref }}
28 | cancel-in-progress: true
29 |
30 | jobs:
31 | uv-tests:
32 | name: Code tests Windows
33 | runs-on: windows-latest
34 | env:
35 | SKIP_EXPENSIVE: "1"
36 | steps:
37 | - uses: actions/checkout@v4
38 |
39 | - name: Set up Python
40 | uses: actions/setup-python@v5
41 | with:
42 | python-version: "3.13"
43 |
44 | - name: Install uv and dependencies
45 | shell: bash
46 | run: |
47 | export PATH="$HOME/.local/bin:$PATH"
48 | curl -LsSf https://astral.sh/uv/install.sh | sh
49 | uv sync --group dev
50 |
51 | - name: Run tests
52 | shell: bash
53 | env:
54 | OMP_NUM_THREADS: "1"
55 | MKL_NUM_THREADS: "1"
56 | CUBLAS_WORKSPACE_CONFIG: ":16:8"
57 | run: |
58 | export PATH="$HOME/.local/bin:$PATH"
59 | uv run pytest tests
60 |
--------------------------------------------------------------------------------
/chG/01_main-chapter-code/README.md:
--------------------------------------------------------------------------------
1 | # Appendix G: Chat Interface
2 |
3 |
4 |
5 | This folder contains code for running a ChatGPT-like user interface to interact with the LLMs used and/or developed in this book, as shown below.
6 |
7 |
8 |
9 | 
10 |
11 |
12 |
13 | To implement this user interface, we use the open-source [Chainlit Python package](https://github.com/Chainlit/chainlit).
14 |
15 |
16 | ## Step 1: Install dependencies
17 |
18 | First, we install the `chainlit` package and dependency:
19 |
20 | ```bash
21 | pip install chainlit
22 | ```
23 |
24 | Or, if you are using `uv`:
25 |
26 | ```bash
27 | uv add chainlit
28 | ```
29 |
30 |
31 |
32 |
33 |
34 | ## Step 2: Run `app` code
35 |
36 | This folder contains 2 files:
37 |
38 | 1. [`qwen3_chat_interface.py`](qwen3_chat_interface.py): This file loads and uses the Qwen3 0.6B model in thinking mode.
39 | 2. [`qwen3_chat_interface_multiturn.py`](qwen3_chat_interface_multiturn.py): The same as above, but configured to remember the message history.
40 |
41 | (Open and inspect these files to learn more.)
42 |
43 | Run one of the following commands from the terminal to start the UI server:
44 |
45 | ```bash
46 | chainlit run qwen3_chat_interface.py
47 | ```
48 |
49 | or, if you are using `uv`:
50 |
51 | ```bash
52 | uv run chainlit run qwen3_chat_interface.py
53 | ```
54 |
55 | Running one of the commands above should open a new browser tab where you can interact with the model. If the browser tab does not open automatically, inspect the terminal command and copy the local address into your browser address bar (usually, the address is `http://localhost:8000`).
--------------------------------------------------------------------------------
/chF/03_leaderboards/1_elo_leaderboard.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 | import json
5 | import argparse
6 |
7 |
8 | def elo_ratings(vote_pairs, k_factor=32, initial_rating=1000):
9 | ratings = {
10 | model: initial_rating
11 | for pair in vote_pairs
12 | for model in pair
13 | }
14 | for winner, loser in vote_pairs:
15 | expected = 1.0 / (
16 | 1.0 + 10 ** (
17 | (ratings[loser] - ratings[winner]) / 400.0
18 | )
19 | )
20 | ratings[winner] += k_factor * (1 - expected)
21 | ratings[loser] += k_factor * (0 - (1 - expected))
22 | return ratings
23 |
24 |
25 | def main():
26 | parser = argparse.ArgumentParser(
27 | description="Compute Elo leaderboard."
28 | )
29 | parser.add_argument("--path", type=str, help="Path to votes JSON")
30 | parser.add_argument("--k", type=int, default=32,
31 | help="Elo k-factor")
32 | parser.add_argument("--init", type=int, default=1000,
33 | help="Initial rating")
34 | args = parser.parse_args()
35 |
36 | with open(args.path, "r", encoding="utf-8") as f:
37 | votes = json.load(f)
38 |
39 | ratings = elo_ratings(votes, args.k, args.init)
40 | leaderboard = sorted(ratings.items(),
41 | key=lambda x: -x[1])
42 |
43 | print("\nLeaderboard (Elo) \n-----------------------")
44 | for i, (model, score) in enumerate(leaderboard, 1):
45 | print(f"{i:>2}. {model:<10} {score:7.1f}")
46 | print()
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/chF/03_leaderboards/votes.json:
--------------------------------------------------------------------------------
1 | [
2 | ["GPT-5", "Claude-3"],
3 | ["GPT-5", "Llama-4"],
4 | ["Claude-3", "Llama-3"],
5 | ["Llama-4", "Llama-3"],
6 | ["Claude-3", "Llama-3"],
7 | ["GPT-5", "Llama-3"],
8 | ["Claude-3", "GPT-5"],
9 | ["Llama-4", "Claude-3"],
10 | ["Llama-3", "Llama-4"],
11 | ["GPT-5", "Claude-3"],
12 | ["Claude-3", "Llama-4"],
13 | ["GPT-5", "Llama-4"],
14 | ["Claude-3", "GPT-5"],
15 | ["Llama-4", "Claude-3"],
16 | ["GPT-5", "Llama-3"],
17 | ["Claude-3", "Llama-3"],
18 | ["Llama-3", "Claude-3"],
19 | ["GPT-5", "Llama-4"],
20 | ["Claude-3", "GPT-5"],
21 | ["Llama-4", "GPT-5"],
22 | ["Llama-3", "Claude-3"],
23 | ["GPT-5", "Claude-3"],
24 | ["Claude-3", "Llama-4"],
25 | ["Llama-4", "Llama-3"],
26 | ["GPT-5", "Llama-3"],
27 | ["Claude-3", "Llama-3"],
28 | ["GPT-5", "Claude-3"],
29 | ["Claude-3", "Llama-4"],
30 | ["GPT-5", "Claude-3"],
31 | ["Llama-4", "Llama-3"],
32 | ["Claude-3", "GPT-5"],
33 | ["GPT-5", "Llama-3"],
34 | ["Claude-3", "Llama-3"],
35 | ["Llama-3", "Llama-4"],
36 | ["GPT-5", "Claude-3"],
37 | ["Claude-3", "Llama-4"],
38 | ["GPT-5", "Llama-4"],
39 | ["Claude-3", "GPT-5"],
40 | ["Llama-4", "Claude-3"],
41 | ["GPT-5", "Llama-3"],
42 | ["Claude-3", "Llama-3"],
43 | ["Llama-3", "Claude-3"],
44 | ["GPT-5", "Llama-4"],
45 | ["Claude-3", "GPT-5"],
46 | ["Llama-4", "GPT-5"],
47 | ["Llama-3", "Claude-3"],
48 | ["GPT-5", "Claude-3"],
49 | ["Claude-3", "Llama-4"],
50 | ["Llama-4", "Llama-3"],
51 | ["GPT-5", "Llama-3"],
52 | ["Claude-3", "Llama-3"],
53 | ["GPT-5", "Claude-3"],
54 | ["Claude-3", "Llama-4"],
55 | ["GPT-5", "Claude-3"],
56 | ["Llama-4", "Llama-3"],
57 | ["Claude-3", "GPT-5"],
58 | ["GPT-5", "Llama-3"],
59 | ["Claude-3", "Llama-3"],
60 | ["Llama-3", "Llama-4"],
61 | ["Claude-3", "Llama-3"]
62 | ]
63 |
--------------------------------------------------------------------------------
/.github/workflows/code-linter.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | name: Code Style Checks
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | paths:
11 | - '**/*.py'
12 | - '**/*.ipynb'
13 | - '**/*.yaml'
14 | - '**/*.yml'
15 | - '**/*.sh'
16 | pull_request:
17 | branches: [ main ]
18 | paths:
19 | - '**/*.py'
20 | - '**/*.ipynb'
21 | - '**/*.yaml'
22 | - '**/*.yml'
23 | - '**/*.sh'
24 |
25 | jobs:
26 | flake8:
27 | runs-on: ubuntu-latest
28 | env:
29 | SKIP_EXPENSIVE: "1"
30 | steps:
31 | - uses: actions/checkout@v4
32 | - name: Set up Python
33 | uses: actions/setup-python@v5
34 | with:
35 | python-version: "3.13"
36 | - name: Install ruff (a faster flake 8 equivalent)
37 | run: |
38 | curl -LsSf https://astral.sh/uv/install.sh | sh
39 | uv sync --dev --python=3.13
40 | uv add ruff
41 |
42 | - name: Run ruff with exceptions
43 | run: |
44 | source .venv/bin/activate
45 | ruff check .
46 |
47 | - name: Run quote style check
48 | run: |
49 | source .venv/bin/activate
50 | uv run .github/scripts/check_double_quotes.py
51 |
52 | - name: Run line length check on main chapter notebooks
53 | run: |
54 | source .venv/bin/activate
55 | uv run .github/scripts/check_notebook_line_length.py \
56 | ch02/01_main-chapter-code/ch02_main.ipynb \
57 | ch03/01_main-chapter-code/ch03_main.ipynb \
58 | chC/01_main-chapter-code/chC_main.ipynb \
59 | chF/01_main-chapter-code/chF_main.ipynb
--------------------------------------------------------------------------------
/reasoning_from_scratch/ch02_ex.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | from .qwen3 import KVCache
6 | import torch
7 |
8 |
9 | @torch.inference_mode()
10 | def generate_text_basic_stream(
11 | model,
12 | token_ids,
13 | max_new_tokens,
14 | eos_token_id=None
15 | ):
16 | # input_length = token_ids.shape[1]
17 | model.eval()
18 |
19 | for _ in range(max_new_tokens):
20 | out = model(token_ids)[:, -1]
21 | next_token = torch.argmax(out, dim=-1, keepdim=True)
22 |
23 | if (eos_token_id is not None
24 | and next_token.item() == eos_token_id):
25 | break
26 |
27 | yield next_token # New: Yield each token as it's generated
28 |
29 | token_ids = torch.cat([token_ids, next_token], dim=1)
30 | # return token_ids[:, input_length:]
31 |
32 |
33 | @torch.inference_mode()
34 | def generate_text_basic_stream_cache(
35 | model,
36 | token_ids,
37 | max_new_tokens,
38 | eos_token_id=None
39 | ):
40 | # input_length = token_ids.shape[1]
41 | model.eval()
42 | cache = KVCache(n_layers=model.cfg["n_layers"])
43 | model.reset_kv_cache()
44 |
45 | out = model(token_ids, cache=cache)[:, -1]
46 | for _ in range(max_new_tokens):
47 | next_token = torch.argmax(out, dim=-1, keepdim=True)
48 |
49 | if (eos_token_id is not None
50 | and next_token.item() == eos_token_id):
51 | break
52 |
53 | yield next_token # New: Yield each token as it's generated
54 | # token_ids = torch.cat([token_ids, next_token], dim=1)
55 | out = model(next_token, cache=cache)[:, -1]
56 |
57 | # return token_ids[:, input_length:]
58 |
--------------------------------------------------------------------------------
/reasoning_from_scratch/appendix_f.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
3 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
4 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
5 |
6 | from reasoning_from_scratch.ch02_ex import generate_text_basic_stream_cache
7 |
8 |
9 | def predict_choice(
10 | model, tokenizer, prompt_fmt, max_new_tokens=8
11 | ):
12 | pred = None
13 | for t in generate_text_basic_stream_cache(
14 | model=model,
15 | token_ids=prompt_fmt,
16 | max_new_tokens=max_new_tokens,
17 | eos_token_id=tokenizer.eos_token_id,
18 | ):
19 | answer = tokenizer.decode(t.squeeze(0).tolist())
20 | for letter in answer:
21 | letter = letter.upper()
22 | if letter in "ABCD":
23 | pred = letter
24 | break
25 | if pred: # stop as soon as a letter appears
26 | break
27 | return pred
28 |
29 |
30 | def elo_ratings(vote_pairs, k_factor=32, initial_rating=1000):
31 | # Initialize all models with the same base rating
32 | ratings = {
33 | model: initial_rating
34 | for pair in vote_pairs
35 | for model in pair
36 | }
37 |
38 | # Update ratings after each match
39 | for winner, loser in vote_pairs:
40 | rating_winner, rating_loser = ratings[winner], ratings[loser]
41 |
42 | # Expected score for the current winner given the ratings
43 | expected_winner = 1.0 / (
44 | 1.0 + 10 ** ((rating_loser - rating_winner) / 400.0)
45 | )
46 |
47 | # k_factor determines sensitivity of rating updates
48 | ratings[winner] = (
49 | rating_winner + k_factor * (1 - expected_winner)
50 | )
51 | ratings[loser] = (
52 | rating_loser + k_factor * (0 - (1 - expected_winner))
53 | )
54 |
55 | return ratings
56 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | [tool.setuptools.dynamic]
6 | version = { attr = "reasoning_from_scratch.__version__" }
7 |
8 | [build-system]
9 | requires = ["setuptools>=61.0", "wheel"]
10 | build-backend = "setuptools.build_meta"
11 |
12 | [project]
13 | name = "reasoning_from_scratch"
14 | dynamic = ["version"]
15 | description = "Reasoning Models From Scratch"
16 | readme = "README.md"
17 |
18 | authors = [
19 | { name = "Sebastian Raschka" }
20 | ]
21 |
22 | license = { file = "LICENSE" }
23 | requires-python = ">=3.10"
24 |
25 | dependencies = [
26 | "jupyterlab>=4.4.7",
27 | "torch>=2.7.1",
28 | "tokenizers>=0.21.2",
29 | "nbformat>=5.10.4",
30 | "sympy>=1.14.0", # For verifier in ch03
31 | "matplotlib>=3.10.7",
32 | ]
33 |
34 | classifiers = [
35 | "License :: OSI Approved :: Apache Software License",
36 | "Programming Language :: Python :: 3.10",
37 | "Programming Language :: Python :: 3.11",
38 | "Programming Language :: Python :: 3.12",
39 | "Programming Language :: Python :: 3.13",
40 | "Intended Audience :: Developers",
41 | "Operating System :: OS Independent"
42 | ]
43 |
44 | [dependency-groups]
45 | dev = [
46 | "pytest>=8.3.5",
47 | "ruff>=0.4.4",
48 | "pytest-ruff>=0.5",
49 | "transformers==4.53.0",
50 | "reasoning-from-scratch",
51 | "twine>=6.1.0",
52 | "build>=1.2.2.post1",
53 | ]
54 |
55 | extra = [
56 | "datasets>=4.1.1", # Appendix F, MMLU
57 | "chainlit>=2.8.3", # Appendix G, Chat UI
58 | ]
59 |
60 | [tool.setuptools.packages.find]
61 | where = ["."]
62 | include = ["reasoning_from_scratch"]
63 |
64 | [tool.uv.sources]
65 | reasoning = { workspace = true }
66 | reasoning-from-scratch = { workspace = true }
67 |
68 | [tool.uv.workspace]
69 | members = [
70 | ".",
71 | ]
72 |
--------------------------------------------------------------------------------
/chF/03_leaderboards/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Leaderboard Rankings
3 |
4 | This bonus material implements two different ways to construct LM Arena (formerly Chatbot Arena) style leaderboards from pairwise comparisons.
5 |
6 | Both implementations take in a list of pairwise preferences (left: winner, right: loser) from a json file via the `--path` argument. Here's an excerpt of the provided [votes.json](votes.json) file:
7 |
8 | ```json
9 | [
10 | ["GPT-5", "Claude-3"],
11 | ["GPT-5", "Llama-4"],
12 | ["Claude-3", "Llama-3"],
13 | ["Llama-4", "Llama-3"],
14 | ...
15 | ]
16 | ```
17 |
18 |
19 |
20 |
21 |
22 | ---
23 |
24 | **Note**: If you are not a `uv` user, replace `uv run ...py` with `python ...py` in the examples below.
25 |
26 | ---
27 |
28 |
29 | ## Method 1: Elo ratings
30 |
31 | - Implements the popular Elo rating method (inspired by chess rankings) that was originally used by LM Arena
32 | - See the [main notebook](../01_main-chapter-code/chF_main.ipynb) for details
33 |
34 | ```bash
35 | ➜ 03_leaderboards git:(main) ✗ uv run 1_elo_leaderboard.py --path votes.json
36 |
37 | Leaderboard (Elo)
38 | -----------------------
39 | 1. GPT-5 1095.9
40 | 2. Claude-3 1058.7
41 | 3. Llama-4 958.2
42 | 4. Llama-3 887.2
43 | ```
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 | ## Method 2: Bradley-Terry model
52 |
53 | - Implements a [Bradley-Terry model](https://en.wikipedia.org/wiki/Bradley–Terry_model), similar to the new LM Arena leaderboard as described in the official paper ([Chatbot Arena: An Open Platform for Evaluating LLMs by Human Preference](https://arxiv.org/abs/2403.04132))
54 | - Like on the LM Arena leaderboard, the scores are re-scaled to be similar to the original Elo scores
55 | - The code here uses the Adam optimizer from PyTorch to fit the model (for better code familiarity and readability)
56 |
57 |
58 |
59 | ```bash
60 | ➜ 03_leaderboards git:(main) ✗ uv run 2_bradley_terry_leaderboard.py --path votes.json
61 |
62 | Leaderboard (Bradley-Terry)
63 | -----------------------------
64 | 1. GPT-5 1140.6
65 | 2. Claude-3 1058.7
66 | 3. Llama-4 950.3
67 | 4. Llama-3 850.4
68 | ```
69 |
70 |
--------------------------------------------------------------------------------
/chG/01_main-chapter-code/qwen3_chat_interface.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2 | # Source for "Build a Large Language Model From Scratch"
3 | # - https://www.manning.com/books/build-a-large-language-model-from-scratch
4 | # Code: https://github.com/rasbt/LLMs-from-scratch
5 |
6 |
7 | import torch
8 | import chainlit
9 |
10 | from reasoning_from_scratch.ch02 import (
11 | get_device,
12 | )
13 | from reasoning_from_scratch.ch02_ex import generate_text_basic_stream_cache
14 | from reasoning_from_scratch.ch03 import load_model_and_tokenizer
15 |
16 |
17 | # ============================================================
18 | # EDIT ME: Simple configuration
19 | # ============================================================
20 | WHICH_MODEL = "reasoning" # "base" for base model
21 | MAX_NEW_TOKENS = 38912
22 | LOCAL_DIR = "qwen3"
23 | COMPILE = False
24 | # ============================================================
25 |
26 |
27 | DEVICE = get_device()
28 |
29 | MODEL, TOKENIZER = load_model_and_tokenizer(
30 | which_model=WHICH_MODEL,
31 | device=DEVICE,
32 | use_compile=COMPILE,
33 | local_dir=LOCAL_DIR
34 | )
35 |
36 |
37 | @chainlit.on_chat_start
38 | async def on_start():
39 | chainlit.user_session.set("history", [])
40 | chainlit.user_session.get("history").append(
41 | {"role": "system", "content": "You are a helpful assistant."}
42 | )
43 |
44 |
45 | @chainlit.on_message
46 | async def main(message: chainlit.Message):
47 | """
48 | The main Chainlit function.
49 | """
50 | # 1) Encode input
51 | input_ids = TOKENIZER.encode(message.content)
52 | input_ids_tensor = torch.tensor(input_ids, device=DEVICE).unsqueeze(0)
53 |
54 | # 2) Start an outgoing message we can stream into
55 | out_msg = chainlit.Message(content="")
56 | await out_msg.send()
57 |
58 | # 3) Stream generation
59 | for tok in generate_text_basic_stream_cache(
60 | model=MODEL,
61 | token_ids=input_ids_tensor,
62 | max_new_tokens=MAX_NEW_TOKENS,
63 | eos_token_id=TOKENIZER.eos_token_id
64 | ):
65 | token_id = tok.squeeze(0)
66 | piece = TOKENIZER.decode(token_id.tolist())
67 | await out_msg.stream_token(piece)
68 |
69 | # 4) Finalize the streamed message
70 | await out_msg.update()
71 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.yaml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: Report errors related to the book content or code
3 | title: "Description"
4 | labels: [bug]
5 | assignees: rasbt
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | Thank you for taking the time to report an issue. Please fill out the details below to help resolve it.
11 |
12 | - type: textarea
13 | id: bug_description
14 | attributes:
15 | label: Bug description
16 | description: A description of the issue.
17 | placeholder: |
18 | Please provide a description of what the bug or issue is.
19 | validations:
20 | required: true
21 |
22 | - type: dropdown
23 | id: operating_system
24 | attributes:
25 | label: What operating system are you using?
26 | description: If applicable, please select the operating system where you experienced this issue.
27 | options:
28 | - "Unknown"
29 | - "macOS"
30 | - "Linux"
31 | - "Windows"
32 | validations:
33 | required: False
34 |
35 | - type: dropdown
36 | id: compute_environment
37 | attributes:
38 | label: Where do you run your code?
39 | description: Please select the computing environment where you ran this code.
40 | options:
41 | - "Local (laptop, desktop)"
42 | - "Lightning AI Studio"
43 | - "Google Colab"
44 | - "Other cloud environment (AWS, Azure, GCP)"
45 | validations:
46 | required: False
47 |
48 | - type: textarea
49 | id: environment
50 | attributes:
51 | label: Environment
52 | description: |
53 | Please provide details about your Python environment via the environment by pasting the output of the following code:
54 | ```python
55 | import sys
56 | import torch
57 |
58 | print(f"Python version: {sys.version}")
59 | print(f"PyTorch version: {torch.__version__}")
60 | print(f"CUDA available: {torch.cuda.is_available()}")
61 | if torch.cuda.is_available():
62 | print(f"CUDA device: {torch.cuda.get_device_name(0)}")
63 | print(f"CUDA version: {torch.version.cuda}")
64 | print(f"Device count: {torch.cuda.device_count()}")
65 | ```
66 | value: |
67 | ```
68 |
69 |
70 |
71 | ```
72 | validations:
73 | required: false
74 |
--------------------------------------------------------------------------------
/reasoning_from_scratch/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | from pathlib import Path
6 | import sys
7 | import requests
8 | from urllib.parse import urlparse
9 |
10 |
11 | def download_file(url, out_dir=".", backup_url=None):
12 | out_dir = Path(out_dir)
13 | out_dir.mkdir(parents=True, exist_ok=True)
14 | filename = Path(urlparse(url).path).name
15 | dest = out_dir / filename
16 |
17 | def try_download(u):
18 | try:
19 | with requests.get(u, stream=True, timeout=30) as r:
20 | r.raise_for_status()
21 | size_remote = int(r.headers.get("Content-Length", 0))
22 |
23 | # Skip download if already complete
24 | if dest.exists() and size_remote and dest.stat().st_size == size_remote:
25 | print(f"✓ {dest} already up-to-date")
26 | return True
27 |
28 | # Download in 1 MiB chunks with progress display
29 | block = 1024 * 1024
30 | downloaded = 0
31 | with open(dest, "wb") as f:
32 | for chunk in r.iter_content(chunk_size=block):
33 | if not chunk:
34 | continue
35 | f.write(chunk)
36 | downloaded += len(chunk)
37 | if size_remote:
38 | pct = downloaded * 100 // size_remote
39 | sys.stdout.write(
40 | f"\r{filename}: {pct:3d}% "
41 | f"({downloaded // (1024*1024)} MiB / "
42 | f"{size_remote // (1024*1024)} MiB)"
43 | )
44 | sys.stdout.flush()
45 | if size_remote:
46 | sys.stdout.write("\n")
47 | return True
48 | except requests.RequestException:
49 | return False
50 |
51 | # Try main URL first
52 | if try_download(url):
53 | return dest
54 |
55 | # Try backup URL if provided
56 | if backup_url:
57 | print(f"Primary URL ({url}) failed.\nTrying backup URL ({backup_url})...")
58 | if try_download(backup_url):
59 | return dest
60 |
61 | raise RuntimeError(f"Failed to download {filename} from both mirrors.")
62 |
--------------------------------------------------------------------------------
/chF/03_leaderboards/2_bradley_terry_leaderboard.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import json
6 | import math
7 | import argparse
8 | import torch
9 | from reasoning_from_scratch.ch02 import get_device
10 |
11 |
12 | def bradley_terry_torch(vote_pairs, device):
13 |
14 | # Collect all unique model names
15 | models = sorted({m for winner, loser in vote_pairs for m in (winner, loser)})
16 | n = len(models)
17 | idx = {m: i for i, m in enumerate(models)}
18 |
19 | # Convert to index tensors
20 | winners = torch.tensor([idx[winner] for winner, _ in vote_pairs], dtype=torch.long)
21 | losers = torch.tensor([idx[loser] for _, loser in vote_pairs], dtype=torch.long)
22 |
23 | # Learnable parameters
24 | theta = torch.nn.Parameter(torch.zeros(n - 1, device=device))
25 | optimizer = torch.optim.Adam([theta], lr=0.01, weight_decay=1e-4)
26 |
27 | def scores():
28 | return torch.cat([theta, torch.zeros(1, device=device)])
29 |
30 | for epoch in range(500):
31 | s = scores()
32 | delta = s[winners] - s[losers] # score difference
33 | loss = -torch.nn.functional.logsigmoid(delta).mean() # negative log-likelihood
34 | optimizer.zero_grad(set_to_none=True)
35 | loss.backward()
36 | optimizer.step()
37 |
38 | # Convert latent scores to Elo-like scale
39 | with torch.no_grad():
40 | s = scores()
41 | scale = 400.0 / math.log(10.0)
42 | R = s * scale
43 | R -= R.mean()
44 | R += 1000.0 # center around 1000
45 |
46 | return {m: float(r) for m, r in zip(models, R.cpu().tolist())}
47 |
48 |
49 | def main():
50 | parser = argparse.ArgumentParser(description="Bradley-Terry leaderboard.")
51 | parser.add_argument("--path", type=str, help="Path to votes JSON")
52 | args = parser.parse_args()
53 |
54 | with open(args.path, "r", encoding="utf-8") as f:
55 | votes = json.load(f)
56 |
57 | device = get_device()
58 | ratings = bradley_terry_torch(votes, device)
59 |
60 | leaderboard = sorted(ratings.items(),
61 | key=lambda x: -x[1])
62 | print("\nLeaderboard (Bradley-Terry)")
63 | print("-----------------------------")
64 | for i, (model, score) in enumerate(leaderboard, 1):
65 | print(f"{i:>2}. {model:<10} {score:7.1f}")
66 | print()
67 |
68 |
69 | if __name__ == "__main__":
70 | main()
71 |
--------------------------------------------------------------------------------
/ch02/02_setup-tips/gpu-instructions.md:
--------------------------------------------------------------------------------
1 |
2 | # GPU Cloud Resources
3 |
4 | This section describes cloud alternatives for running the code presented in this book.
5 |
6 | While the code can run on conventional laptops and desktop computers without a dedicated GPU, cloud platforms with NVIDIA GPUs can substantially improve the runtime of the code, especially in chapters 5 to 7.
7 |
8 |
9 |
10 | ## Using Lightning Studio
11 |
12 | For a smooth development experience in the cloud, I recommend the [Lightning AI Studio](https://lightning.ai/) platform, which allows users to set up a persistent environment and use both VSCode and Jupyter Lab on cloud CPUs and GPUs.
13 |
14 | Once you start a new Studio, you can open the terminal and execute the following setup steps to clone the repository and install the dependencies:
15 |
16 | ```bash
17 | git clone https://github.com/rasbt/reasoning-from-scratch.git
18 | cd reasoning-from-scratch
19 | pip install -r requirements.txt
20 | ```
21 |
22 | (In contrast to Google Colab, these only need to be executed once since the Lightning AI Studio environments are persistent, even if you switch between CPU and GPU machines.)
23 |
24 | Then, navigate to the Python script or Jupyter Notebook you want to run. Optionally, you can also easily connect a GPU to accelerate the code's runtime, for example, when you are pretraining the LLM in chapter 5 or finetuning it in chapters 6 and 7.
25 |
26 |
27 |
28 |
29 |
30 | ## Using Google Colab
31 |
32 | To use a Google Colab environment in the cloud, head over to [https://colab.research.google.com/](https://colab.research.google.com/) and open the respective chapter notebook from the GitHub menu or by dragging the notebook into the *Upload* field as shown in the figure below.
33 |
34 |
35 |
36 |
37 | Also make sure you upload the relevant files (dataset files and .py files the notebook is importing from) to the Colab environment as well, as shown below.
38 |
39 |
40 |
41 |
42 | You can optionally run the code on a GPU by changing the *Runtime* as illustrated in the figure below.
43 |
44 |
45 |
46 |
47 |
48 | ## Questions?
49 |
50 | If you have any questions, please don't hesitate to reach out via the [Discussions](https://github.com/rasbt/reasoning-from-scratch/discussions) forum in this GitHub repository.
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | from reasoning_from_scratch.qwen3 import (
6 | download_qwen3_small,
7 | QWEN_CONFIG_06_B,
8 | Qwen3Model,
9 | )
10 | import sys
11 | import types
12 | import nbformat
13 | import pytest
14 | import torch
15 |
16 |
17 | @pytest.fixture(scope="session")
18 | def qwen3_weights_path(tmp_path_factory):
19 | """Creates and saves a deterministic model for testing."""
20 |
21 | base_path = tmp_path_factory.mktemp("models")
22 | model_path = base_path / "qwen3_test_weights.pt"
23 | download_qwen3_small(kind="base", tokenizer_only=True, out_dir=base_path)
24 |
25 | if not model_path.exists():
26 | torch.manual_seed(123)
27 | model = Qwen3Model(QWEN_CONFIG_06_B)
28 | torch.save(model.state_dict(), model_path)
29 |
30 | return base_path
31 |
32 |
33 | def import_definitions_from_notebook(nb_path, module_name):
34 | if not nb_path.exists():
35 | raise FileNotFoundError(f"Notebook file not found at: {nb_path}")
36 |
37 | nb = nbformat.read(str(nb_path), as_version=4)
38 |
39 | mod = types.ModuleType(module_name)
40 | sys.modules[module_name] = mod
41 |
42 | # Pass 1: execute only imports (handle multi-line)
43 | for cell in nb.cells:
44 | if cell.cell_type != "code":
45 | continue
46 | lines = cell.source.splitlines()
47 | collecting = False
48 | buf = []
49 | paren_balance = 0
50 | for line in lines:
51 | stripped = line.strip()
52 | if not collecting and (stripped.startswith("import ") or stripped.startswith("from ")):
53 | collecting = True
54 | buf = [line]
55 | paren_balance = line.count("(") - line.count(")")
56 | if paren_balance == 0:
57 | exec("\n".join(buf), mod.__dict__)
58 | collecting = False
59 | buf = []
60 | elif collecting:
61 | buf.append(line)
62 | paren_balance += line.count("(") - line.count(")")
63 | if paren_balance == 0:
64 | exec("\n".join(buf), mod.__dict__)
65 | collecting = False
66 | buf = []
67 |
68 | # Pass 2: execute only def/class definitions
69 | for cell in nb.cells:
70 | if cell.cell_type != "code":
71 | continue
72 | src = cell.source
73 | if "def " in src or "class " in src:
74 | exec(src, mod.__dict__)
75 |
76 | return mod
77 |
--------------------------------------------------------------------------------
/ch02/04_torch-compile-windows/README.md:
--------------------------------------------------------------------------------
1 | # Using `torch.compile()` on Windows
2 |
3 | `torch.compile()` relies on *TorchInductor*, which JIT-compiles kernels and requires a working C/C++ compiler toolchain.
4 |
5 | So, on Windows, the setup required to make `torch.compile` work can be a bit more involved than on Linux or macOS, which usually don't require any extra steps besides installing PyTorch.
6 |
7 | If you are a Windows user and using `torch.compile` sounds too tricky or complicated, don't worry, all code examples in this repository will work fine without compilation.
8 |
9 | Below are some tips that I compiled based on recommendations by [Daniel Kleine](https://github.com/d-kleine) and the following [PyTorch guide](https://docs.pytorch.org/tutorials/unstable/inductor_windows.html).
10 |
11 |
12 | ## 1 Basic Setup (CPU or CUDA)
13 |
14 |
15 | ### 1.1 Install Visual Studio 2022
16 |
17 | - Select the **“Desktop development with C++”** workload.
18 | - Make sure to include the **English language pack** (without it, you may run into UTF-8 encoding errors.)
19 |
20 |
21 | ### 1.2 Open the correct command prompt
22 |
23 |
24 | Launch Python from the
25 |
26 | **"x64 Native Tools Command Prompt for VS 2022"**
27 |
28 | or from the
29 |
30 | **"Visual Studio 2022 Developer Command Prompt"**.
31 |
32 | Alternatively, you can initialize the environment manually by running:
33 |
34 | ```bash
35 | "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat"
36 | ```
37 |
38 |
39 | ### 1.3 Verify that the compiler works
40 |
41 | Run
42 |
43 | ```bash
44 | cl.exe
45 | ```
46 |
47 | If you see version information printed, the compiler is ready.
48 |
49 |
50 | ## 2 Troubleshooting Common Errors
51 |
52 |
53 | ### 2.1 Error: `cl not found`
54 |
55 | Install **Visual Studio Build Tools** with the "C++ build tools" workload and run Python from a developer command prompt. (See this Microsoft [guide](https://learn.microsoft.com/en-us/cpp/build/vscpp-step-0-installation?view=msvc-170) for details)
56 |
57 |
58 | ### 2.2 Error: `triton not found` (when using CUDA)
59 |
60 | Install the Windows build of Triton manually:
61 |
62 | ```bash
63 | pip install "triton-windows<3.4"
64 | ```
65 |
66 | or, if you are using `uv`:
67 |
68 | ```bash
69 | uv pip install "triton-windows<3.4"
70 | ```
71 |
72 | (As mentioned earlier, triton is required by TorchInductor for CUDA kernel compilation.)
73 |
74 |
75 |
76 |
77 | ## 3 Additional Notes
78 |
79 | On Windows, the `cl.exe` compiler is only accessible from within the Visual Studio Developer environment. This means that using `torch.compile()` in notebooks such as Jupyter may not work unless the notebook was launched from a Developer Command Prompt.
80 |
81 | As mentioned at the beginning of this article, there is also a [PyTorch guide](https://docs.pytorch.org/tutorials/unstable/inductor_windows.html) that some users found helpful when getting `torch.compile()` running on Windows CPU builds. However, note that it refers to PyTorch's unstable branch, so use it as a reference only.
82 |
83 | **If compilation continues to cause issues, please feel free to skip it. It's a nice bonus, but it's not important to follow the book.**
84 |
85 |
--------------------------------------------------------------------------------
/chF/02_mmlu/random_guessing_baseline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import argparse
6 | import random
7 | import statistics as stats
8 | from collections import Counter
9 |
10 | from datasets import load_dataset
11 |
12 |
13 | # Gold letter is MMLU jargon for correct answer letter
14 | def gold_letter(ans):
15 | if isinstance(ans, int):
16 | return "ABCD"[ans]
17 | s = str(ans).strip().upper()
18 | return s if s in {"A", "B", "C", "D"} else s[:1]
19 |
20 |
21 | def main():
22 | parser = argparse.ArgumentParser(
23 | description="Show gold answer distribution for an MMLU subset and a random-guess baseline."
24 | )
25 | parser.add_argument(
26 | "--subset",
27 | type=str,
28 | default="high_school_mathematics",
29 | help="MMLU subset name (default: 'high_school_mathematics').",
30 | )
31 | parser.add_argument(
32 | "--seed",
33 | type=int,
34 | default=42,
35 | help="Random seed for the random-guess baseline (default: 42).",
36 | )
37 | parser.add_argument(
38 | "--trials",
39 | type=int,
40 | default=10_000,
41 | help="Number of random-guess trials (default: 10,000).",
42 | )
43 | args = parser.parse_args()
44 |
45 | ds = load_dataset("cais/mmlu", args.subset, split="test")
46 |
47 | labels = [gold_letter(ex["answer"]) for ex in ds]
48 | n = len(labels)
49 | counts = Counter(labels)
50 |
51 | print(f"Subset: {args.subset} | split: test | n={n}")
52 | print("Gold distribution provided in the dataset:")
53 | for letter in "ABCD":
54 | c = counts.get(letter, 0)
55 | pct = (c / n) if n else 0.0
56 | print(f" {letter}: {c} ({pct:.2%})")
57 |
58 | if n == 0:
59 | print("\nNo items. Baseline undefined.")
60 | return
61 |
62 | # Repeat random guessing
63 | rng = random.Random(args.seed)
64 | accs = []
65 | for _ in range(args.trials):
66 | guesses = [rng.choice("ABCD") for _ in range(n)]
67 | correct = sum(1 for g, y in zip(guesses, labels) if g == y)
68 | accs.append(correct / n)
69 |
70 | mean_acc = stats.mean(accs)
71 | sd_acc = stats.stdev(accs) if len(accs) > 1 else 0.0
72 |
73 | print(f"\nRandom guessing over {args.trials:,} trials (uniform A/B/C/D, seed={args.seed}):")
74 | print(f" Mean accuracy: {mean_acc:.2%}")
75 | print(f" Std dev across trials: {sd_acc:.2%}")
76 |
77 | # Quantiles
78 | qs = [0.01, 0.05, 0.25, 0.5, 0.75, 0.95, 0.99]
79 | accs_sorted = sorted(accs)
80 | print("\nSelected quantiles of accuracy:")
81 | for q in qs:
82 | idx = int(q * len(accs_sorted))
83 | print(f" {q:.0%} quantile: {accs_sorted[idx]:.3%}")
84 |
85 | # Frequency table (rounded)
86 | acc_counts = Counter(round(a, 2) for a in accs)
87 | print("\nFull frequency table of accuracies (rounded):")
88 | for acc_val in sorted(acc_counts):
89 | freq = acc_counts[acc_val]
90 | pct = freq / args.trials
91 | print(f" {acc_val:.3f}: {freq} times ({pct:.2%})")
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
96 |
--------------------------------------------------------------------------------
/reasoning_from_scratch/ch02.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | from .qwen3 import KVCache
6 | import torch
7 |
8 |
9 | def get_device(enable_tensor_cores=True):
10 | if torch.cuda.is_available():
11 | device = torch.device("cuda")
12 | print("Using NVIDIA CUDA GPU")
13 |
14 | if enable_tensor_cores:
15 | major, minor = map(int, torch.__version__.split(".")[:2])
16 | if (major, minor) >= (2, 9):
17 | torch.backends.cuda.matmul.fp32_precision = "tf32"
18 | torch.backends.cudnn.conv.fp32_precision = "tf32"
19 | else:
20 | torch.backends.cuda.matmul.allow_tf32 = True
21 | torch.backends.cudnn.allow_tf32 = True
22 |
23 | elif torch.backends.mps.is_available():
24 | device = torch.device("mps")
25 | print("Using Apple Silicon GPU (MPS)")
26 |
27 | elif torch.xpu.is_available():
28 | device = torch.device("xpu")
29 | print("Using Intel GPU")
30 |
31 | else:
32 | device = torch.device("cpu")
33 | print("Using CPU")
34 |
35 | return device
36 |
37 |
38 | @torch.inference_mode()
39 | def generate_text_basic(model, token_ids, max_new_tokens, eos_token_id=None):
40 | input_length = token_ids.shape[1]
41 | model.eval()
42 |
43 | for _ in range(max_new_tokens):
44 | out = model(token_ids)[:, -1]
45 | next_token = torch.argmax(out, dim=-1, keepdim=True)
46 |
47 | # Stop if all sequences in the batch have generated EOS
48 | if (eos_token_id is not None
49 | and next_token.item() == eos_token_id):
50 | break
51 |
52 | token_ids = torch.cat([token_ids, next_token], dim=1)
53 | return token_ids[:, input_length:]
54 |
55 |
56 | @torch.inference_mode()
57 | def generate_text_basic_cache(
58 | model,
59 | token_ids,
60 | max_new_tokens,
61 | eos_token_id=None
62 | ):
63 |
64 | input_length = token_ids.shape[1]
65 | model.eval()
66 | cache = KVCache(n_layers=model.cfg["n_layers"])
67 | model.reset_kv_cache()
68 | out = model(token_ids, cache=cache)[:, -1]
69 |
70 | for _ in range(max_new_tokens):
71 | next_token = torch.argmax(out, dim=-1, keepdim=True)
72 |
73 | if (eos_token_id is not None
74 | and next_token.item() == eos_token_id):
75 | break
76 |
77 | token_ids = torch.cat([token_ids, next_token], dim=1)
78 | out = model(next_token, cache=cache)[:, -1]
79 |
80 | return token_ids[:, input_length:]
81 |
82 |
83 | def generate_stats(output_token_ids, tokenizer, start_time,
84 | end_time, print_tokens=True):
85 | total_time = end_time - start_time
86 | print(f"Time: {total_time:.2f} sec")
87 | print(f"{int(output_token_ids.numel() / total_time)} tokens/sec")
88 |
89 | for name, backend in (("CUDA", getattr(torch, "cuda", None)),
90 | ("XPU", getattr(torch, "xpu", None))):
91 | if backend is not None and backend.is_available():
92 | max_mem_bytes = backend.max_memory_allocated()
93 | max_mem_gb = max_mem_bytes / (1024 ** 3)
94 | print(f"Max {name} memory allocated: {max_mem_gb:.2f} GB")
95 | backend.reset_peak_memory_stats()
96 |
97 | if print_tokens:
98 | output_text = tokenizer.decode(output_token_ids.squeeze(0).tolist())
99 | print(f"\n{output_text}")
100 |
--------------------------------------------------------------------------------
/.github/scripts/check_notebook_line_length.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | # Check if code in notebooks exceeds line length
6 | # Usage: uv run .github/scripts/check_notebook_line_length.py --max-len 76 ch02/01_main-chapter-code/ch02_main.ipynb ch03/01_main-chapter-code/ch03_main.ipynb
7 |
8 | import argparse
9 | import sys
10 | from pathlib import Path
11 | import nbformat as nbf
12 |
13 |
14 | def parse_args():
15 | p = argparse.ArgumentParser(
16 | description="Check code-cell line lengths in specific .ipynb files."
17 | )
18 | p.add_argument(
19 | "--max-len",
20 | type=int,
21 | default=76,
22 | help="Maximum allowed characters per line (default: 76)",
23 | )
24 | p.add_argument(
25 | "notebooks",
26 | nargs="+",
27 | help="Paths to .ipynb files to check.",
28 | )
29 | return p.parse_args()
30 |
31 |
32 | def strip_inline_comment(line):
33 | """Return line with any trailing #... comment removed, but keep # inside quotes."""
34 | in_quote = None
35 | escaped = False
36 | for i, ch in enumerate(line):
37 | if escaped:
38 | escaped = False
39 | continue
40 | if ch == "\\":
41 | escaped = True
42 | continue
43 | if ch in ("'", '"'):
44 | if in_quote is None:
45 | in_quote = ch
46 | elif in_quote == ch:
47 | in_quote = None
48 | continue
49 | if ch == "#" and in_quote is None:
50 | return line[:i].rstrip()
51 | return line.rstrip()
52 |
53 |
54 | def main():
55 | args = parse_args()
56 |
57 | nb_paths = []
58 | for raw in args.notebooks:
59 | p = Path(raw).resolve()
60 | if not p.exists():
61 | print(f"::warning file={raw}::File not found, skipping.")
62 | continue
63 | if p.suffix != ".ipynb":
64 | print(f"::warning file={raw}::Not a .ipynb file, skipping.")
65 | continue
66 | nb_paths.append(p)
67 |
68 | if not nb_paths:
69 | print("No valid notebooks to check.")
70 | return 0
71 |
72 | violations = []
73 |
74 | for nb_path in nb_paths:
75 | try:
76 | nb = nbf.read(nb_path, as_version=4)
77 | except Exception as e:
78 | print(f"::warning file={nb_path}::Failed to read notebook: {e}")
79 | continue
80 |
81 | for ci, cell in enumerate(nb.cells, start=1):
82 | if cell.get("cell_type") != "code":
83 | continue
84 |
85 | src = cell.get("source", "")
86 | lines = src if isinstance(src, list) else src.splitlines()
87 |
88 | for li, line in enumerate(lines, start=1):
89 | code_part = strip_inline_comment(line)
90 | length = len(code_part)
91 | if length > args.max_len:
92 | print(
93 | f"::error file={nb_path}::Line length {length} exceeds "
94 | f"{args.max_len} in code cell {ci}, line {li}"
95 | )
96 | snippet = code_part if len(code_part) <= 120 else code_part[:120] + "…"
97 | violations.append((str(nb_path), ci, li, length, snippet))
98 |
99 | if violations:
100 | print("\nFound lines exceeding the limit:\n")
101 | for path, ci, li, length, snippet in violations:
102 | print(f"- {path} | cell {ci} line {li}: {length} chars\n {snippet}")
103 | return 1
104 |
105 | print(f"All notebooks pass: no code-cell lines exceed {args.max_len} characters.")
106 | return 0
107 |
108 |
109 | if __name__ == "__main__":
110 | sys.exit(main())
111 |
--------------------------------------------------------------------------------
/ch03/02_math500-verifier-scripts/evaluate_math500.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import argparse
6 | import torch
7 |
8 | from reasoning_from_scratch.qwen3 import (
9 | Qwen3Model,
10 | QWEN_CONFIG_06_B
11 | )
12 | from reasoning_from_scratch.ch02 import get_device
13 | from reasoning_from_scratch.ch03 import (
14 | load_math500_test,
15 | evaluate_math500_stream,
16 | load_model_and_tokenizer,
17 | load_tokenizer_only
18 | )
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument(
24 | "--device",
25 | type=str,
26 | default="auto",
27 | help="Device to use: 'auto' (default), or any torch device string like 'cpu', 'cuda', 'cuda:0', 'mps'.",
28 | )
29 | parser.add_argument(
30 | "--which_model",
31 | type=str,
32 | default="base",
33 | choices=["base", "reasoning", "instruct"],
34 | help="Model variant to load. Defaults to 'base'.",
35 | )
36 | parser.add_argument(
37 | "--dataset_size",
38 | type=int,
39 | default=10,
40 | help="Number of MATH-500 examples to evaluate. Default: 10",
41 | )
42 | parser.add_argument(
43 | "--max_new_tokens",
44 | type=int,
45 | default=2048,
46 | help="Max new tokens for generation. Default: 2048",
47 | )
48 | parser.add_argument(
49 | "--compile",
50 | action="store_true",
51 | help="Enable torch.compile for the model.",
52 | )
53 | parser.add_argument(
54 | "--checkpoint_path",
55 | type=str,
56 | default=None,
57 | help="Optional path to a .pth checkpoint to load model weights from.",
58 | )
59 | parser.add_argument(
60 | "--verbose",
61 | action="store_true",
62 | help="Print per-sample correctness while evaluating.",
63 | )
64 | return parser.parse_args()
65 |
66 |
67 | if __name__ == "__main__":
68 | args = parse_args()
69 |
70 | if args.device == "auto":
71 | device = get_device()
72 | else:
73 | device = torch.device(args.device)
74 |
75 | which_model = args.which_model
76 | dataset_size = args.dataset_size
77 | max_new_tokens = args.max_new_tokens
78 | use_compile = args.compile
79 |
80 | print("Model:", which_model)
81 | print("Device:", device)
82 | dev_name = str(device).replace(":", "-")
83 |
84 | math_data = load_math500_test()
85 |
86 | if args.which_model == "instruct":
87 | which_model = "reasoning"
88 | else:
89 | which_model = args.which_model
90 |
91 | if args.checkpoint_path:
92 | # To load the saved RL checkpoint files from chapter 6
93 | tokenizer = load_tokenizer_only(which_model=which_model)
94 | model = Qwen3Model(QWEN_CONFIG_06_B)
95 | model.to(device)
96 | state_dict = torch.load(args.checkpoint_path, map_location=device)
97 | model.load_state_dict(state_dict)
98 | if args.compile:
99 | torch._dynamo.config.allow_unspec_int_on_nn_module = True
100 | model = torch.compile(model)
101 | else:
102 | model, tokenizer = load_model_and_tokenizer(
103 | which_model=which_model,
104 | device=device,
105 | use_compile=args.compile
106 | )
107 |
108 | if args.which_model == "instruct":
109 | tokenizer.add_thinking = False
110 |
111 | model.eval()
112 | torch.set_float32_matmul_precision("high")
113 |
114 | num_correct, num_examples, acc = evaluate_math500_stream(
115 | model=model,
116 | out_path=f"math500_{which_model}-{dev_name}-evaluate-script.jsonl",
117 | tokenizer=tokenizer,
118 | device=device,
119 | math_data=math_data[:dataset_size],
120 | max_new_tokens=max_new_tokens,
121 | verbose=args.verbose,
122 | )
123 |
--------------------------------------------------------------------------------
/chG/01_main-chapter-code/qwen3_chat_interface_multiturn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2 | # Source for "Build a Large Language Model From Scratch"
3 | # - https://www.manning.com/books/build-a-large-language-model-from-scratch
4 | # Code: https://github.com/rasbt/LLMs-from-scratch
5 |
6 |
7 | import torch
8 | import chainlit
9 |
10 | from reasoning_from_scratch.ch02 import (
11 | get_device,
12 | )
13 | from reasoning_from_scratch.ch02_ex import generate_text_basic_stream_cache
14 | from reasoning_from_scratch.ch03 import load_model_and_tokenizer
15 |
16 | # ============================================================
17 | # EDIT ME: Simple configuration
18 | # ============================================================
19 | WHICH_MODEL = "reasoning" # "base" for base model
20 | MAX_NEW_TOKENS = 38912
21 | LOCAL_DIR = "qwen3"
22 | COMPILE = False
23 | # ============================================================
24 |
25 |
26 | def trim_input_tensor(input_ids_tensor, context_len, max_new_tokens):
27 | assert max_new_tokens < context_len
28 | keep_len = max(1, context_len - max_new_tokens)
29 |
30 | # If the prompt is too long, left-truncate to keep_len
31 | if input_ids_tensor.shape[1] > keep_len:
32 | input_ids_tensor = input_ids_tensor[:, -keep_len:]
33 |
34 | return input_ids_tensor
35 |
36 |
37 | def build_prompt_from_history(history, add_assistant_header=True):
38 | """
39 | history: [{"role": "system"|"user"|"assistant", "content": str}, ...]
40 | """
41 | parts = []
42 | for m in history:
43 | role = m["role"]
44 | content = m["content"]
45 | parts.append(f"<|im_start|>{role}\n{content}<|im_end|>\n")
46 |
47 | if add_assistant_header:
48 | parts.append("<|im_start|>assistant\n")
49 | return "".join(parts)
50 |
51 |
52 | DEVICE = get_device()
53 | MODEL, TOKENIZER = load_model_and_tokenizer(
54 | which_model=WHICH_MODEL,
55 | device=DEVICE,
56 | use_compile=COMPILE,
57 | local_dir=LOCAL_DIR
58 | )
59 |
60 | # Even though the official TOKENIZER.eos_token_id is either <|im_end|> (reasoning)
61 | # or <|endoftext|> (base), the reasoning model sometimes emits both.
62 | EOS_TOKEN_IDS = (
63 | TOKENIZER.encode("<|im_end|>")[0],
64 | TOKENIZER.encode("<|endoftext|>")[0]
65 | )
66 |
67 |
68 | @chainlit.on_chat_start
69 | async def on_start():
70 | chainlit.user_session.set("history", [])
71 | chainlit.user_session.get("history").append(
72 | {"role": "system", "content": "You are a helpful assistant."}
73 | )
74 |
75 |
76 | @chainlit.on_message
77 | async def main(message: chainlit.Message):
78 | """
79 | The main Chainlit function.
80 | """
81 | # 0) Get and track chat history
82 | history = chainlit.user_session.get("history")
83 | history.append({"role": "user", "content": message.content})
84 |
85 | # 1) Encode input
86 | prompt = build_prompt_from_history(history, add_assistant_header=True)
87 | input_ids = TOKENIZER.encode(prompt)
88 | input_ids_tensor = torch.tensor(input_ids, device=DEVICE).unsqueeze(0)
89 |
90 | # Multi-turn can be very long, so we add this left-trimming
91 | input_ids_tensor = trim_input_tensor(
92 | input_ids_tensor=input_ids_tensor,
93 | context_len=MODEL.cfg["context_length"],
94 | max_new_tokens=MAX_NEW_TOKENS
95 | )
96 |
97 | # 2) Start an outgoing message we can stream into
98 | out_msg = chainlit.Message(content="")
99 | await out_msg.send()
100 |
101 | # 3) Stream generation
102 | for tok in generate_text_basic_stream_cache(
103 | model=MODEL,
104 | token_ids=input_ids_tensor,
105 | max_new_tokens=MAX_NEW_TOKENS,
106 | # eos_token_id=TOKENIZER.eos_token_id
107 | ):
108 | token_id = tok.squeeze(0)
109 | if token_id in EOS_TOKEN_IDS:
110 | break
111 | piece = TOKENIZER.decode(token_id.tolist())
112 | await out_msg.stream_token(piece)
113 |
114 | # 4) Finalize the streamed message
115 | await out_msg.update()
116 |
117 | # 5) Update chat history
118 | history.append({"role": "assistant", "content": out_msg.content})
119 | chainlit.user_session.set("history", history)
120 |
--------------------------------------------------------------------------------
/tests/test_ch02.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import torch
6 |
7 | from reasoning_from_scratch.ch02 import (
8 | get_device,
9 | generate_text_basic,
10 | generate_text_basic_cache,
11 | generate_stats
12 | )
13 |
14 |
15 | # Dummy model for generate_text_basic tests.
16 | class DummyModel:
17 | def __init__(self, fixed_token, vocab_size=5):
18 | self.fixed_token = fixed_token
19 | self.vocab_size = vocab_size
20 | self.eval_called = False
21 |
22 | def eval(self):
23 | self.eval_called = True
24 | return self
25 |
26 | def __call__(self, token_ids, cache=None):
27 | batch_size, seq_len = token_ids.size()
28 | out = torch.zeros(batch_size, seq_len, self.vocab_size)
29 | # Set the fixed_token column to the highest value so argmax returns fixed_token
30 | out[..., self.fixed_token] = 1.0
31 | return out
32 |
33 |
34 | class DummyModelCache(DummyModel):
35 | def __init__(self, fixed_token, vocab_size=5, n_layers=2):
36 | super().__init__(fixed_token, vocab_size)
37 | self.cfg = {"n_layers": n_layers}
38 | self.reset_called = False
39 |
40 | def reset_kv_cache(self):
41 | self.reset_called = True
42 |
43 |
44 | class DummyTokenizer:
45 | def decode(self, token_list):
46 | return " ".join(str(t) for t in token_list)
47 |
48 |
49 | def test_get_device_returns_torch_device(capsys):
50 | device = get_device()
51 | assert isinstance(device, torch.device)
52 | assert device.type in ("cpu", "cuda", "mps")
53 |
54 |
55 | def test_generate_text_basic_stops_on_eos():
56 | # batch_size = 1
57 | # seq_len = 3
58 | max_new_tokens = 10
59 | fixed_token = 2
60 |
61 | dummy_model = DummyModel(fixed_token=fixed_token)
62 | token_ids = torch.tensor([[1, 3, 4]]) # shape (batch, seq_len)
63 |
64 | # Set eos_token_id to be the fixed_token so that generation stops immediately
65 | output = generate_text_basic(dummy_model, token_ids, max_new_tokens, eos_token_id=fixed_token)
66 | assert output.size(1) == 0
67 | assert dummy_model.eval_called is True
68 |
69 |
70 | def test_generate_text_basic_generates_tokens_without_eos():
71 | # batch_size = 1
72 | # seq_len = 2
73 | max_new_tokens = 3
74 | fixed_token = 1
75 |
76 | dummy_model = DummyModel(fixed_token=fixed_token)
77 | token_ids = torch.tensor([[0, 4]])
78 | output = generate_text_basic(dummy_model, token_ids, max_new_tokens, eos_token_id=None)
79 | assert output.size(1) == max_new_tokens
80 | assert torch.all(output == fixed_token)
81 |
82 |
83 | def test_generate_text_basic_cache_stops_on_eos():
84 | # batch_size = 1
85 | # seq_len = 2
86 | max_new_tokens = 10
87 | fixed_token = 3
88 |
89 | dummy_model = DummyModelCache(fixed_token=fixed_token, n_layers=4)
90 | token_ids = torch.tensor([[2, 2]])
91 | output = generate_text_basic_cache(dummy_model, token_ids, max_new_tokens, eos_token_id=fixed_token)
92 | assert output.size(1) == 0
93 | assert dummy_model.reset_called is True
94 |
95 |
96 | def test_generate_text_basic_cache_generates_tokens_without_eos():
97 | # batch_size = 1
98 | # seq_len = 1
99 | max_new_tokens = 4
100 | fixed_token = 0
101 |
102 | dummy_model = DummyModelCache(fixed_token=fixed_token, n_layers=3)
103 | token_ids = torch.tensor([[5]])
104 |
105 | output = generate_text_basic_cache(dummy_model, token_ids, max_new_tokens, eos_token_id=None)
106 | assert output.size(1) == max_new_tokens
107 | assert torch.all(output == fixed_token)
108 | assert dummy_model.reset_called is True
109 |
110 |
111 | def test_generate_stats_prints_output(monkeypatch, capsys):
112 | output_token_ids = torch.tensor([[10, 20, 30]])
113 | tokenizer = DummyTokenizer()
114 | start_time = 100.0
115 | end_time = 102.0
116 |
117 | monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
118 | generate_stats(output_token_ids, tokenizer, start_time, end_time)
119 |
120 | captured = capsys.readouterr().out
121 | assert "Time:" in captured
122 | assert "tokens/sec" in captured
123 | assert "10 20 30" in captured
124 |
--------------------------------------------------------------------------------
/ch02/05_use_model/generate_simple.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | # Runs the model similar to chapter 2 and 3 in streaming mode with the least
6 | # amount of bells and whistles. Uses KV caching by default.
7 |
8 | import argparse
9 | from pathlib import Path
10 | import time
11 | import torch
12 |
13 | from reasoning_from_scratch.ch02 import (
14 | get_device,
15 | generate_stats
16 | )
17 | from reasoning_from_scratch.ch02_ex import (
18 | generate_text_basic_stream_cache
19 | )
20 | from reasoning_from_scratch.qwen3 import (
21 | download_qwen3_small,
22 | Qwen3Model,
23 | Qwen3Tokenizer,
24 | QWEN_CONFIG_06_B
25 | )
26 |
27 | parser = argparse.ArgumentParser(description="Run Qwen3 text generation")
28 | parser.add_argument(
29 | "--device",
30 | type=str,
31 | default=None,
32 | help="Device to run on (e.g. 'cpu', 'cuda', 'mps'). "
33 | "If not provided, will auto-detect with get_device()."
34 | )
35 | parser.add_argument(
36 | "--max_new_tokens",
37 | type=int,
38 | default=2048,
39 | help="Maximum number of new tokens to generate (default: 2048)."
40 | )
41 | parser.add_argument(
42 | "--compile",
43 | action="store_true",
44 | help="Compile PyTorch model (default: False)."
45 | )
46 | parser.add_argument(
47 | "--reasoning",
48 | action="store_true",
49 | help="Use reasoning model variant (default: False)."
50 | )
51 | parser.add_argument(
52 | "--prompt",
53 | type=str,
54 | default=None,
55 | help=("Use a custom prompt. If not explicitly provided, uses the following defaults: "
56 | "'Explain large language models in a single sentence.' for the base model, and "
57 | "'Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.' for the reasoning model.")
58 | )
59 |
60 | args = parser.parse_args()
61 | device = torch.device(args.device) if args.device else get_device()
62 |
63 | if args.reasoning:
64 | download_qwen3_small(kind="reasoning", tokenizer_only=False, out_dir="qwen3")
65 | tokenizer_path = Path("qwen3") / "tokenizer-reasoning.json"
66 | model_path = Path("qwen3") / "qwen3-0.6B-reasoning.pth"
67 | tokenizer = Qwen3Tokenizer(
68 | tokenizer_file_path=tokenizer_path,
69 | apply_chat_template=True,
70 | add_generation_prompt=True,
71 | add_thinking=True
72 | )
73 | else:
74 | download_qwen3_small(kind="base", tokenizer_only=False, out_dir="qwen3")
75 | tokenizer_path = Path("qwen3") / "tokenizer-base.json"
76 | model_path = Path("qwen3") / "qwen3-0.6B-base.pth"
77 | tokenizer = Qwen3Tokenizer(
78 | tokenizer_file_path=tokenizer_path,
79 | apply_chat_template=False,
80 | add_generation_prompt=False,
81 | add_thinking=False
82 | )
83 |
84 | model = Qwen3Model(QWEN_CONFIG_06_B)
85 | state = torch.load(model_path, map_location=device)
86 | model.load_state_dict(state)
87 | model.to(device)
88 | model.eval()
89 |
90 | if args.compile:
91 | model = torch.compile(model)
92 |
93 |
94 | if args.prompt is None:
95 | if args.reasoning:
96 | prompt = "Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field."
97 | else:
98 | prompt = "Explain large language models in a single sentence."
99 | else:
100 | prompt = args.prompt
101 |
102 |
103 | input_ids = tokenizer.encode(prompt)
104 | input_token_ids_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)
105 |
106 | print()
107 | print("=" * 60)
108 | print(f"torch : {torch.__version__}")
109 | print(f"device : {device}")
110 | print("cache : True")
111 | print(f"compile : {args.compile}")
112 | print(f"reasoning : {args.reasoning}")
113 | print("=" * 60)
114 | print()
115 |
116 | start_time = time.time()
117 | all_token_ids = []
118 |
119 | for token in generate_text_basic_stream_cache(
120 | model=model,
121 | token_ids=input_token_ids_tensor,
122 | max_new_tokens=args.max_new_tokens,
123 | eos_token_id=tokenizer.eos_token_id
124 | ):
125 | token_id = token.squeeze(0).item()
126 | print(tokenizer.decode([token_id]), end="", flush=True)
127 | all_token_ids.append(token_id)
128 | end_time = time.time()
129 |
130 | print("\n")
131 | generate_stats(torch.tensor(all_token_ids), tokenizer, start_time, end_time, print_tokens=False)
--------------------------------------------------------------------------------
/ch03/02_math500-verifier-scripts/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 3: Evaluating Reasoning Models
2 |
3 |
4 |
5 |
6 |
7 | ## Bonus materials
8 |
9 | - [evaluate_math500.py](evaluate_math500.py): standalone script to evaluate models on the MATH-500 dataset
10 | - [evaluate_math500_batched.py](evaluate_math500_batched.py): same as above, but processes multiple examples in parallel during generation (for higher throughput)
11 |
12 | Both evaluation scripts import functionality from the [`reasoning_from_scratch`](../../reasoning_from_scratch) package to avoid code duplication. (See [chapter 2 setup instructions](../../ch02/02_setup-tips/python-instructions.md) for installation details.)
13 |
14 |
15 |
16 |
17 |
18 | ---
19 |
20 | **Note**: If you are not a `uv` user, replace `uv run ...py` with `python ...py` in the examples below.
21 |
22 | ---
23 |
24 |
25 |
26 |
27 |
28 | ## `evaluate_math500.py` usage
29 |
30 | Run with:
31 |
32 | ```bash
33 | python evaluate_math500.py
34 | ```
35 |
36 | Or, with `uv:`
37 |
38 |
39 | ```bash
40 | uv run evaluate_math500.py
41 | ```
42 |
43 | Options:
44 |
45 | ```bash
46 | uv run evaluate_math500.py --help
47 |
48 | options:
49 | -h, --help show this help message and exit
50 | --device DEVICE Device to use: "auto" (default) or any torch device string
51 | (e.g., "cpu", "cuda", "cuda:0", "mps").
52 | --which_model {base,reasoning}
53 | Model variant to load (default: "base").
54 | --dataset_size DATASET_SIZE
55 | Number of MATH-500 examples to evaluate (default: 10).
56 | --max_new_tokens MAX_NEW_TOKENS
57 | Max new tokens to generate (default: 2048).
58 | --compile Enable torch.compile.
59 | --verbose Print per-sample correctness while evaluating.
60 | ```
61 |
62 |
63 | ## `evaluate_math500_batch.py` usage
64 |
65 | This version extends batching to generation itself, enabling parallel decoding:
66 |
67 | ```bash
68 | uv run evaluate_math500_batched.py --help
69 | ```
70 |
71 | Extra options:
72 |
73 | ```bash
74 | --batch_size BATCH_SIZE
75 | Number of examples to generate in parallel (default: 4).
76 | --disable_efficient_mode
77 | Use a simpler batched inference method. Slower and more
78 | memory-intensive, but easier to debug.
79 | ```
80 |
81 |
82 |
83 |
84 | **Implementation note:**
85 | By default, batched generation halts for sequences that emit a stop token. With `--disable_efficient_mode`, all sequences continue until the longest finishes. This affects compute efficiency only, not qualitative results, since tokens after the stop token are discarded.
86 |
87 |
88 |
89 | **Tip (MPS devices):**
90 | Run with:
91 |
92 | ```bash
93 | PYTORCH_ENABLE_MPS_FALLBACK=1 uv run evaluate_math500_batched.py
94 | ```
95 |
96 | Some PyTorch ops used in efficient batched inference are not yet supported on MPS. As a fallback, you can also use `--disable_efficient_mode`.
97 |
98 |
99 |
100 |
101 |
102 | - `evaluate_math500.py --dataset_size 500`
103 |
104 |
105 | | Device / Dataset size | Base model | Reasoning model |
106 | | ------------------------------------------- | ---------- | --------------- |
107 | | **Mac Mini M4 CPU** (500 examples, sequential | 43.6 min | Didn't run (too hot) |
108 | | **Mac Mini M4 GPU** (500 examples, sequential) | 37.5 min | Didn't run (too hot) |
109 | | **DGX Spark** (500 examples, sequential) | 10.0 min | 182.2 min |
110 | | **H100 GPU** (500 examples, sequential) | 13.3 min | 185.4 min |
111 |
112 |
113 |
114 |
115 | - `evaluate_math500_batched.py --dataset_size 500 --batch_size 128`
116 |
117 | | Device / Dataset size | Base model | Reasoning model |
118 | | ------------------------------------------------------------ | ---------- | --------------- |
119 | | **Mac Mini M4 CPU** (500 examples, batched, `--batch_size 128`) | 167.2 min | Didn't run (too hot) |
120 | | **Mac Mini M4 GPU** (500 examples, batched, `--batch_size 128`) | Error* | Error |
121 | | **DGX Spark** (500 examples, batched, `--batch_size 128`) | 16.3 min | 119.3 min |
122 | | **H100 GPU** (500 examples, batched, `--batch_size 128`) | 3.3 min | 14.6 min |
123 |
124 |
125 |
126 | - The accuracy of the base model is 15.6% (78/500); the accuracy of the reasoning model is 50.8% (254/500).
127 |
--------------------------------------------------------------------------------
/chF/02_mmlu/1_letter_matching.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 |
4 | import torch
5 | from datasets import load_dataset, get_dataset_config_names
6 | from reasoning_from_scratch.ch02 import get_device
7 | from reasoning_from_scratch.ch02_ex import generate_text_basic_stream_cache
8 | from reasoning_from_scratch.ch03 import load_model_and_tokenizer
9 |
10 |
11 | # Same as in main notebook
12 | def format_prompt(example):
13 | return (
14 | f"{example['question']}\n"
15 | f"A. {example['choices'][0]}\n"
16 | f"B. {example['choices'][1]}\n"
17 | f"C. {example['choices'][2]}\n"
18 | f"D. {example['choices'][3]}\n"
19 | "Answer: " # trailing space encourages a single-letter next token
20 | )
21 |
22 |
23 | # Same as in main notebook
24 | def predict_choice(
25 | model, tokenizer, prompt_fmt, max_new_tokens=8
26 | ):
27 | pred = None
28 | for t in generate_text_basic_stream_cache(
29 | model=model,
30 | token_ids=prompt_fmt,
31 | max_new_tokens=max_new_tokens,
32 | eos_token_id=tokenizer.eos_token_id,
33 | ):
34 | answer = tokenizer.decode(t.squeeze(0).tolist())
35 | for letter in answer:
36 | letter = letter.upper()
37 | if letter in "ABCD":
38 | pred = letter
39 | break
40 | if pred: # stop as soon as a letter appears
41 | break
42 | return pred
43 |
44 |
45 | def evaluate_mmlu_letter(
46 | model,
47 | tokenizer,
48 | device,
49 | subsets="high_school_mathematics", # str, list of str, or "all"
50 | split="test",
51 | max_new_tokens=8,
52 | verbose_every=50,
53 | ):
54 | if subsets == "all":
55 | subset_list = get_dataset_config_names("cais/mmlu")
56 | elif isinstance(subsets, str):
57 | subset_list = [s.strip() for s in subsets.split(",")] if "," in subsets else [subsets]
58 | else:
59 | subset_list = list(subsets)
60 |
61 | total = 0
62 | correct = 0
63 | start = time.time()
64 |
65 | for subset in subset_list:
66 | ds = load_dataset("cais/mmlu", subset, split=split)
67 | for ex in ds:
68 | prompt = format_prompt(ex)
69 | tok = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)
70 | pred = predict_choice(model, tokenizer, tok, max_new_tokens)
71 |
72 | ans = ex["answer"]
73 | # "Gold" is the MMLU jargon for the correct answer (ground truth)
74 | gold = "ABCD"[ans] if isinstance(ans, int) else str(ans).strip().upper()
75 |
76 | total += 1
77 | correct += int(pred == gold)
78 |
79 | if verbose_every and total % verbose_every == 0:
80 | print(f"MMLU {total} acc={correct/total:.3f} [{subset}]")
81 |
82 | acc = correct / max(1, total)
83 | print(
84 | f"\nMMLU letter accuracy: {correct}/{total} = {acc:.2%} "
85 | f"in {time.time()-start:.1f}s"
86 | )
87 | return {"accuracy": acc, "num_examples": total, "subsets": subset_list, "split": split}
88 |
89 |
90 | def main():
91 | parser = argparse.ArgumentParser(
92 | description="Zero-shot MMLU letter evaluator (A/B/C/D matching)."
93 | )
94 | parser.add_argument(
95 | "--device",
96 | type=str,
97 | default="auto",
98 | help="Device to use: 'auto' (default), or any torch device string like "
99 | "'cpu', 'cuda', 'cuda:0', 'mps'.",
100 | )
101 | parser.add_argument(
102 | "--which_model",
103 | type=str,
104 | default="base",
105 | choices=["base", "reasoning"],
106 | help="Model variant to load. Defaults to 'base'.",
107 | )
108 | parser.add_argument(
109 | "--subsets",
110 | type=str,
111 | default="high_school_mathematics",
112 | help="Comma-separated subset names or 'all'. "
113 | "Default: 'high_school_mathematics'.",
114 | )
115 | args = parser.parse_args()
116 |
117 | if args.device == "auto":
118 | device = get_device()
119 | else:
120 | device = torch.device(args.device)
121 | print(f"Using device: {device}")
122 |
123 | model, tokenizer = load_model_and_tokenizer(args.which_model, device, use_compile=False)
124 | model.eval()
125 | torch.set_float32_matmul_precision("high")
126 |
127 | metrics = evaluate_mmlu_letter(
128 | model=model,
129 | tokenizer=tokenizer,
130 | device=device,
131 | subsets=args.subsets,
132 | )
133 | print(metrics)
134 |
135 |
136 | if __name__ == "__main__":
137 | main()
138 |
--------------------------------------------------------------------------------
/tests/test_ch02_ex.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import torch
6 |
7 | from reasoning_from_scratch.ch02 import (
8 | generate_text_basic,
9 | generate_text_basic_cache,
10 | )
11 | from reasoning_from_scratch.ch02_ex import (
12 | generate_text_basic_stream,
13 | generate_text_basic_stream_cache,
14 | )
15 |
16 |
17 | # Dummy model for generate_text_basic tests.
18 | class DummyModel:
19 | def __init__(self, fixed_token, vocab_size=5):
20 | self.fixed_token = fixed_token
21 | self.vocab_size = vocab_size
22 | self.eval_called = False
23 |
24 | def eval(self):
25 | self.eval_called = True
26 | return self
27 |
28 | def __call__(self, token_ids, cache=None):
29 | batch_size, seq_len = token_ids.size()
30 | out = torch.zeros(batch_size, seq_len, self.vocab_size)
31 | # Set the fixed_token column to the highest value so argmax returns fixed_token
32 | out[..., self.fixed_token] = 1.0
33 | return out
34 |
35 |
36 | class DummyModelCache(DummyModel):
37 | def __init__(self, fixed_token, vocab_size=5, n_layers=2):
38 | super().__init__(fixed_token, vocab_size)
39 | self.cfg = {"n_layers": n_layers}
40 | self.reset_called = False
41 |
42 | def reset_kv_cache(self):
43 | self.reset_called = True
44 |
45 |
46 | class DummyTokenizer:
47 | def decode(self, token_list):
48 | return " ".join(str(t) for t in token_list)
49 |
50 |
51 | def test_generate_text_basic_stream_equivalence():
52 | max_new_tokens = 10
53 | fixed_token = 2
54 |
55 | dummy_model = DummyModel(fixed_token=fixed_token)
56 | token_ids = torch.tensor([[1, 3, 4]]) # shape (batch, seq_len)
57 |
58 | # Set eos_token_id to be the fixed_token so that generation stops immediately
59 | output_1 = generate_text_basic(dummy_model, token_ids, max_new_tokens, eos_token_id=fixed_token)
60 | output_1 = output_1.squeeze(0).tolist()
61 |
62 | output_2 = []
63 | for token in generate_text_basic_stream(
64 | model=dummy_model,
65 | token_ids=token_ids,
66 | max_new_tokens=max_new_tokens,
67 | eos_token_id=fixed_token
68 | ):
69 | output_2.append(token.squeeze(0).item())
70 |
71 | assert output_1 == output_2
72 |
73 |
74 | def test_generate_text_basic_stream_generates_tokens_without_eos():
75 | max_new_tokens = 3
76 | fixed_token = 1
77 |
78 | dummy_model = DummyModel(fixed_token=fixed_token)
79 | token_ids = torch.tensor([[0, 4]])
80 | output_1 = generate_text_basic(dummy_model, token_ids, max_new_tokens, eos_token_id=None)
81 | output_1 = output_1.squeeze(0).tolist()
82 |
83 | output_2 = []
84 | for token in generate_text_basic_stream(
85 | model=dummy_model,
86 | token_ids=token_ids,
87 | max_new_tokens=max_new_tokens,
88 | eos_token_id=None
89 | ):
90 | output_2.append(token.squeeze(0).item())
91 |
92 | assert output_1 == output_2
93 |
94 |
95 | def test_generate_text_basic_cache_stream_equivalence():
96 | max_new_tokens = 10
97 | fixed_token = 2
98 |
99 | dummy_model = DummyModelCache(fixed_token=fixed_token)
100 | token_ids = torch.tensor([[1, 3, 4]]) # shape (batch, seq_len)
101 |
102 | # Set eos_token_id to be the fixed_token so that generation stops immediately
103 | output_1 = generate_text_basic(dummy_model, token_ids, max_new_tokens, eos_token_id=fixed_token)
104 | output_1 = output_1.squeeze(0).tolist()
105 |
106 | output_2 = []
107 | for token in generate_text_basic_stream_cache(
108 | model=dummy_model,
109 | token_ids=token_ids,
110 | max_new_tokens=max_new_tokens,
111 | eos_token_id=fixed_token
112 | ):
113 | output_2.append(token.squeeze(0).item())
114 |
115 | assert output_1 == output_2
116 |
117 |
118 | def test_generate_text_basic_cache_stream_generates_tokens_without_eos():
119 | max_new_tokens = 3
120 | fixed_token = 1
121 |
122 | dummy_model = DummyModelCache(fixed_token=fixed_token)
123 | token_ids = torch.tensor([[0, 4]])
124 | output_1 = generate_text_basic_cache(dummy_model, token_ids, max_new_tokens, eos_token_id=None)
125 | output_1 = output_1.squeeze(0).tolist()
126 |
127 | output_2 = []
128 | for token in generate_text_basic_stream_cache(
129 | model=dummy_model,
130 | token_ids=token_ids,
131 | max_new_tokens=max_new_tokens,
132 | eos_token_id=None
133 | ):
134 | output_2.append(token.squeeze(0).item())
135 |
136 | assert output_1 == output_2
137 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Other temporary files
2 | *.jsonl
3 | Untitled.ipynb
4 | qwen3/
5 | tmp
6 | qwen3-0.6B.pth
7 | qwen3-0.6B-reasoning.pth
8 | qwen3-0.6B-base.pth
9 | tokenizer-reasoning.json
10 | tokenizer.json
11 | tokenizer-base.json
12 | .chainlit
13 | chainlit.md
14 | math500_test.json
15 | math500_base-mps.jsonl
16 | math500_base-cpu.jsonl
17 | math500_base-cuda.jsonl
18 | math500_reasoning-mps.jsonl
19 | math500_reasoning-cpu.jsonl
20 | math500_reasoning-cuda.jsonl
21 |
22 | # Installation artifacts
23 | uv.lock
24 |
25 | # Temporary OS files
26 | .DS_Store
27 |
28 | # Byte-compiled / optimized / DLL files
29 | __pycache__/
30 | *.py[cod]
31 | *$py.class
32 |
33 | # C extensions
34 | *.so
35 |
36 | # Distribution / packaging
37 | .Python
38 | build/
39 | develop-eggs/
40 | dist/
41 | downloads/
42 | eggs/
43 | .eggs/
44 | lib/
45 | lib64/
46 | parts/
47 | sdist/
48 | var/
49 | wheels/
50 | share/python-wheels/
51 | *.egg-info/
52 | .installed.cfg
53 | *.egg
54 | MANIFEST
55 |
56 | # PyInstaller
57 | # Usually these files are written by a python script from a template
58 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
59 | *.manifest
60 | *.spec
61 |
62 | # Installer logs
63 | pip-log.txt
64 | pip-delete-this-directory.txt
65 |
66 | # Unit test / coverage reports
67 | htmlcov/
68 | .tox/
69 | .nox/
70 | .coverage
71 | .coverage.*
72 | .cache
73 | nosetests.xml
74 | coverage.xml
75 | *.cover
76 | *.py,cover
77 | .hypothesis/
78 | .pytest_cache/
79 | cover/
80 |
81 | # Translations
82 | *.mo
83 | *.pot
84 |
85 | # Django stuff:
86 | *.log
87 | local_settings.py
88 | db.sqlite3
89 | db.sqlite3-journal
90 |
91 | # Flask stuff:
92 | instance/
93 | .webassets-cache
94 |
95 | # Scrapy stuff:
96 | .scrapy
97 |
98 | # Sphinx documentation
99 | docs/_build/
100 |
101 | # PyBuilder
102 | .pybuilder/
103 | target/
104 |
105 | # Jupyter Notebook
106 | .ipynb_checkpoints
107 |
108 | # IPython
109 | profile_default/
110 | ipython_config.py
111 |
112 | # pyenv
113 | # For a library or package, you might want to ignore these files since the code is
114 | # intended to run in multiple environments; otherwise, check them in:
115 | # .python-version
116 |
117 | # pipenv
118 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
119 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
120 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
121 | # install all needed dependencies.
122 | #Pipfile.lock
123 |
124 | # UV
125 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
126 | # This is especially recommended for binary packages to ensure reproducibility, and is more
127 | # commonly ignored for libraries.
128 | #uv.lock
129 |
130 | # poetry
131 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
132 | # This is especially recommended for binary packages to ensure reproducibility, and is more
133 | # commonly ignored for libraries.
134 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
135 | #poetry.lock
136 |
137 | # pdm
138 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
139 | #pdm.lock
140 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
141 | # in version control.
142 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
143 | .pdm.toml
144 | .pdm-python
145 | .pdm-build/
146 |
147 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
148 | __pypackages__/
149 |
150 | # Celery stuff
151 | celerybeat-schedule
152 | celerybeat.pid
153 |
154 | # SageMath parsed files
155 | *.sage.py
156 |
157 | # Environments
158 | .env
159 | .venv
160 | env/
161 | venv/
162 | ENV/
163 | env.bak/
164 | venv.bak/
165 |
166 | # Spyder project settings
167 | .spyderproject
168 | .spyproject
169 |
170 | # Rope project settings
171 | .ropeproject
172 |
173 | # mkdocs documentation
174 | /site
175 |
176 | # mypy
177 | .mypy_cache/
178 | .dmypy.json
179 | dmypy.json
180 |
181 | # Pyre type checker
182 | .pyre/
183 |
184 | # pytype static type analyzer
185 | .pytype/
186 |
187 | # Cython debug symbols
188 | cython_debug/
189 |
190 | # PyCharm
191 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
192 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
193 | # and can be added to the global gitignore or merged into this file. For a more nuclear
194 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
195 | #.idea/
196 |
197 | # PyPI configuration file
198 | .pypirc
199 |
--------------------------------------------------------------------------------
/ch02/05_use_model/chat.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | # Runs the model similar to chapter 2 and 3 in streaming mode with the least
6 | # amount of bells and whistles. Uses KV caching by default.
7 | # Similar to generate_simple.py but uses an interactive REPL (without memory).
8 |
9 | import argparse
10 | from pathlib import Path
11 | import time
12 | import torch
13 |
14 | from reasoning_from_scratch.ch02 import (
15 | get_device,
16 | generate_stats
17 | )
18 | from reasoning_from_scratch.ch02_ex import (
19 | generate_text_basic_stream_cache
20 | )
21 | from reasoning_from_scratch.qwen3 import (
22 | download_qwen3_small,
23 | Qwen3Model,
24 | Qwen3Tokenizer,
25 | QWEN_CONFIG_06_B
26 | )
27 |
28 | parser = argparse.ArgumentParser(description="Run Qwen3 text generation (interactive REPL)")
29 | parser.add_argument(
30 | "--device",
31 | type=str,
32 | default=None,
33 | help="Device to run on (e.g. 'cpu', 'cuda', 'mps'). "
34 | "If not provided, will auto-detect with get_device()."
35 | )
36 | parser.add_argument(
37 | "--max_new_tokens",
38 | type=int,
39 | default=2048,
40 | help="Maximum number of new tokens to generate (default: 2048)."
41 | )
42 | parser.add_argument(
43 | "--compile",
44 | action="store_true",
45 | help="Compile PyTorch model (default: False)."
46 | )
47 | parser.add_argument(
48 | "--reasoning",
49 | action="store_true",
50 | help="Use reasoning model variant (default: False)."
51 | )
52 |
53 | args = parser.parse_args()
54 | device = torch.device(args.device) if args.device else get_device()
55 |
56 | if args.reasoning:
57 | download_qwen3_small(kind="reasoning", tokenizer_only=False, out_dir="qwen3")
58 | tokenizer_path = Path("qwen3") / "tokenizer-reasoning.json"
59 | model_path = Path("qwen3") / "qwen3-0.6B-reasoning.pth"
60 | tokenizer = Qwen3Tokenizer(
61 | tokenizer_file_path=tokenizer_path,
62 | apply_chat_template=True,
63 | add_generation_prompt=True,
64 | add_thinking=True
65 | )
66 | else:
67 | download_qwen3_small(kind="base", tokenizer_only=False, out_dir="qwen3")
68 | tokenizer_path = Path("qwen3") / "tokenizer-base.json"
69 | model_path = Path("qwen3") / "qwen3-0.6B-base.pth"
70 | tokenizer = Qwen3Tokenizer(
71 | tokenizer_file_path=tokenizer_path,
72 | apply_chat_template=False,
73 | add_generation_prompt=False,
74 | add_thinking=False
75 | )
76 |
77 | model = Qwen3Model(QWEN_CONFIG_06_B)
78 | state = torch.load(model_path, map_location=device)
79 | model.load_state_dict(state)
80 | model.to(device)
81 | model.eval()
82 |
83 | if args.compile:
84 | model = torch.compile(model)
85 |
86 | print()
87 | print("=" * 60)
88 | print(f"torch : {torch.__version__}")
89 | print(f"device : {device}")
90 | print("cache : True")
91 | print(f"compile : {args.compile}")
92 | print(f"reasoning : {args.reasoning}")
93 | print("memory : False")
94 | print("=" * 60)
95 | print()
96 | print("Interactive REPL (no memory). Type '\exit' or '\quit' to quit.\n")
97 |
98 |
99 | def run_once(prompt: str):
100 | input_ids = tokenizer.encode(prompt)
101 | input_token_ids_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)
102 |
103 | start_time = time.time()
104 | all_token_ids = []
105 |
106 | print("[Model]\n", end="", flush=True)
107 | for token in generate_text_basic_stream_cache(
108 | model=model,
109 | token_ids=input_token_ids_tensor,
110 | max_new_tokens=args.max_new_tokens,
111 | eos_token_id=tokenizer.eos_token_id
112 | ):
113 | token_id = token.squeeze(0).item()
114 | print(tokenizer.decode([token_id]), end="", flush=True)
115 |
116 | all_token_ids.append(token_id)
117 |
118 | end_time = time.time()
119 | print("\n")
120 |
121 | print("[Stats]")
122 | generate_stats(
123 | torch.tensor(all_token_ids),
124 | tokenizer,
125 | start_time,
126 | end_time,
127 | print_tokens=False
128 | )
129 | print("-" * 60)
130 |
131 |
132 | # REPL loop
133 | try:
134 | while True:
135 | try:
136 | user_in = input(">> ").strip()
137 | except EOFError:
138 | print("")
139 | break
140 | if user_in.lower() in {"\exit", "\quit"}:
141 | break
142 | if not user_in:
143 | continue
144 |
145 | print("\n" + "-" * 60)
146 | print("[User]")
147 | print(user_in + "\n")
148 | run_once(user_in)
149 | except KeyboardInterrupt:
150 | print("\nInterrupted by user.")
151 |
--------------------------------------------------------------------------------
/ch02/03_optimized-LLM/compare_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 |
6 | import argparse
7 | from pathlib import Path
8 | import time
9 | import torch
10 |
11 | from reasoning_from_scratch.ch02 import (
12 | get_device,
13 | generate_stats
14 | )
15 | from reasoning_from_scratch.qwen3 import (
16 | download_qwen3_small,
17 | Qwen3Tokenizer,
18 | QWEN_CONFIG_06_B
19 | )
20 |
21 |
22 | ############################
23 | # Parse command-line args
24 | ############################
25 | parser = argparse.ArgumentParser(description="Run Qwen3 text generation")
26 | parser.add_argument(
27 | "--device",
28 | type=str,
29 | default=None,
30 | help="Device to run on (e.g. 'cpu', 'cuda', 'mps'). "
31 | "If not provided, will auto-detect with get_device()."
32 | )
33 | parser.add_argument(
34 | "--cache",
35 | action="store_true",
36 | help="Use KV cache during generation (default: False)."
37 | )
38 |
39 | parser.add_argument(
40 | "--compile",
41 | action="store_true",
42 | help="Compile PyTorch model (default: False)."
43 | )
44 |
45 | parser.add_argument(
46 | "--reasoning",
47 | action="store_true",
48 | help="Use reasoning model variant."
49 | )
50 |
51 | parser.add_argument(
52 | "--optimized",
53 | action="store_true",
54 | help="Use reasoning model variant."
55 | )
56 |
57 |
58 | args = parser.parse_args()
59 |
60 | if args.optimized:
61 | from reasoning_from_scratch.qwen3_optimized import Qwen3Model
62 | else:
63 | from reasoning_from_scratch.qwen3 import Qwen3Model
64 |
65 |
66 | if args.cache:
67 | if args.optimized:
68 | from reasoning_from_scratch.qwen3_optimized import generate_text_basic_cache as generate_text_basic
69 | else:
70 | from reasoning_from_scratch.ch02 import generate_text_basic_cache as generate_text_basic
71 |
72 | else:
73 | from reasoning_from_scratch.ch02 import generate_text_basic
74 |
75 | device = torch.device(args.device) if args.device else get_device()
76 |
77 | #########################
78 | # Model + tokenizer setup
79 | #########################
80 |
81 | if args.reasoning:
82 | download_qwen3_small(kind="reasoning", tokenizer_only=False, out_dir="qwen3")
83 | tokenizer_path = Path("qwen3") / "tokenizer-reasoning.json"
84 | model_path = Path("qwen3") / "qwen3-0.6B-reasoning.pth"
85 | tokenizer = Qwen3Tokenizer(
86 | tokenizer_file_path=tokenizer_path,
87 | apply_chat_template=True,
88 | add_generation_prompt=True,
89 | add_thinking=True
90 | )
91 |
92 | else:
93 | download_qwen3_small(kind="base", tokenizer_only=False, out_dir="qwen3")
94 | tokenizer_path = Path("qwen3") / "tokenizer-base.json"
95 | model_path = Path("qwen3") / "qwen3-0.6B-base.pth"
96 | tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_path)
97 |
98 | model = Qwen3Model(QWEN_CONFIG_06_B)
99 | model.load_state_dict(torch.load(model_path, map_location=device))
100 |
101 | model.to(device)
102 |
103 | if args.compile:
104 | major, minor = map(int, torch.__version__.split(".")[:2])
105 | if (major, minor) >= (2, 8):
106 | # This avoids retriggering model recompilations
107 | # in PyTorch 2.8 and newer
108 | # if the model contains code like self.pos = self.pos + 1
109 | torch._dynamo.config.allow_unspec_int_on_nn_module = True
110 | model = torch.compile(model)
111 |
112 | #########################
113 | # Prompt + generation
114 | #########################
115 |
116 | if args.reasoning:
117 | prompt = "Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field."
118 | else:
119 | prompt = "Explain large language models in a single sentence."
120 |
121 | input_token_ids_tensor = torch.tensor(
122 | tokenizer.encode(prompt),
123 | device=device
124 | ).unsqueeze(0)
125 |
126 | max_new_tokens = 2048
127 |
128 |
129 | for iteration in range(1, 4):
130 | print("=" * 60)
131 | print(f"Iteration : {iteration}")
132 | print(f"optimized : {args.optimized}")
133 | print(f"torch : {torch.__version__}")
134 | print(f"device : {device}")
135 | print(f"cache : {args.cache}")
136 | print(f"compile : {args.compile}")
137 | print(f"reasoning : {args.reasoning}")
138 | print("=" * 60)
139 |
140 | start_time = time.time()
141 | output_token_ids_tensor = generate_text_basic(
142 | model=model,
143 | token_ids=input_token_ids_tensor,
144 | max_new_tokens=max_new_tokens,
145 | eos_token_id=tokenizer.eos_token_id,
146 | )
147 | end_time = time.time()
148 |
149 | print(f"Output length: {output_token_ids_tensor.numel()}")
150 | generate_stats(output_token_ids_tensor, tokenizer, start_time, end_time)
151 |
--------------------------------------------------------------------------------
/tests/test_ch04.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import pytest
6 | import torch
7 |
8 | import reasoning_from_scratch.ch04 as ch04
9 |
10 |
11 | class DummyTokenizer:
12 | def __init__(self, eos_token_id=None):
13 | self.eos_token_id = eos_token_id
14 | self.decode_map = {7: "X", 8: "Y"}
15 |
16 | def encode(self, prompt):
17 | # Content of the prompt is irrelevant for these tests.
18 | return [1, 2]
19 |
20 | def decode(self, ids):
21 | if isinstance(ids, int):
22 | ids = [ids]
23 | return "".join(self.decode_map.get(i, str(i)) for i in ids)
24 |
25 |
26 | class DummyModelCache:
27 | def __init__(self, fixed_token, vocab_size=5, n_layers=2):
28 | self.fixed_token = fixed_token
29 | self.vocab_size = vocab_size
30 | self.cfg = {"n_layers": n_layers}
31 | self.reset_called = False
32 | self.eval_called = False
33 |
34 | def eval(self):
35 | self.eval_called = True
36 | return self
37 |
38 | def reset_kv_cache(self):
39 | self.reset_called = True
40 |
41 | def __call__(self, token_ids, cache=None):
42 | batch_size, seq_len = token_ids.size()
43 | logits = torch.zeros(batch_size, seq_len, self.vocab_size)
44 | logits[..., self.fixed_token] = 1.0
45 | return logits
46 |
47 |
48 | def test_generate_text_stream_concat_flex_uses_custom_generator():
49 | calls = []
50 |
51 | def fake_generate_func(**kwargs):
52 | calls.append(kwargs)
53 | for tok in (torch.tensor([7]), torch.tensor([8])):
54 | yield tok
55 |
56 | tok = DummyTokenizer(eos_token_id=0)
57 | out = ch04.generate_text_stream_concat_flex(
58 | model=None,
59 | tokenizer=tok,
60 | prompt="ignored",
61 | device="cpu",
62 | max_new_tokens=2,
63 | generate_func=fake_generate_func,
64 | temperature=0.5,
65 | )
66 |
67 | assert out == "XY"
68 | assert calls, "Generator should have been invoked"
69 | assert calls[0]["temperature"] == 0.5
70 | assert calls[0]["eos_token_id"] == tok.eos_token_id
71 |
72 |
73 | def test_scale_logits_by_temperature_validates_and_scales():
74 | logits = torch.tensor([[2.0, 4.0]])
75 | scaled = ch04.scale_logits_by_temperature(logits, 2.0)
76 | assert torch.allclose(scaled, logits / 2.0)
77 |
78 | with pytest.raises(ValueError):
79 | ch04.scale_logits_by_temperature(logits, 0.0)
80 |
81 |
82 | def test_top_p_filter_truncates_and_renormalizes():
83 | probas = torch.tensor([[0.5, 0.4, 0.1]])
84 | filtered = ch04.top_p_filter(probas, top_p=0.6)
85 | assert torch.allclose(filtered, torch.tensor([[1.0, 0.0, 0.0]]))
86 |
87 | # When no filtering is needed, output should match input
88 | unfiltered = ch04.top_p_filter(probas, top_p=1.0)
89 | assert torch.allclose(unfiltered, probas)
90 |
91 |
92 | def test_generate_text_temp_stream_cache_stops_on_eos():
93 | model = DummyModelCache(fixed_token=3)
94 | token_ids = torch.tensor([[0, 1]])
95 |
96 | out = list(
97 | ch04.generate_text_temp_stream_cache(
98 | model,
99 | token_ids=token_ids,
100 | max_new_tokens=5,
101 | eos_token_id=3,
102 | temperature=0.0,
103 | )
104 | )
105 |
106 | assert out == []
107 | assert model.reset_called is True
108 | assert model.eval_called is True
109 |
110 |
111 | def test_self_consistency_vote_majority(monkeypatch):
112 | answers = ["2", "2", "3"]
113 |
114 | def fake_generate_text_stream_concat_flex(**kwargs):
115 | idx = kwargs.pop("_call_idx", None)
116 | idx = idx if idx is not None else 0
117 | return answers[idx % len(answers)]
118 |
119 | # Wrap to inject call index so we can cycle through answers deterministically
120 | call_counter = {"i": 0}
121 |
122 | def wrapped_generate(**kwargs):
123 | kwargs["_call_idx"] = call_counter["i"]
124 | call_counter["i"] += 1
125 | return fake_generate_text_stream_concat_flex(**kwargs)
126 |
127 | monkeypatch.setattr(ch04, "generate_text_stream_concat_flex", wrapped_generate)
128 |
129 | result = ch04.self_consistency_vote(
130 | model=None,
131 | tokenizer=DummyTokenizer(),
132 | prompt="unused",
133 | device="cpu",
134 | num_samples=3,
135 | temperature=0.7,
136 | top_p=0.9,
137 | max_new_tokens=5,
138 | show_progress=False,
139 | show_long_answer=False,
140 | seed=123,
141 | )
142 |
143 | assert result["final_answer"] == "2"
144 | assert result["counts"]["2"] == 2
145 | assert result["majority_winners"] == ["2"]
146 |
--------------------------------------------------------------------------------
/chF/04_llm-judge/README.md:
--------------------------------------------------------------------------------
1 |
2 | # LLM-as-a-judge
3 |
4 | This bonus material implements an LLM-as-a-judge approach, where gpt-oss:20b (via the open-source Ollama library) evaluates Qwen3 0.6B base and reasoning variants on MATH-500.
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | - Ollama is an open-source application to run LLMs efficiently
13 | - It is a wrapper around llama.cpp ([https://github.com/ggerganov/llama.cpp](https://github.com/ggerganov/llama.cpp)), which implements LLMs in pure C/C++ to maximize efficiency
14 | - Note that it is a to ol for using LLMs to generate text (inference), not training or finetuning LLMs
15 | - Before running the code below, install ollama by visiting [https://ollama.com](https://ollama.com) and following the instructions (for instance, clicking on the "Download" button and downloading the ollama application for your operating system)
16 | - For macOS and Windows users, click on the ollama application you downloaded; if it prompts you to install the command line usage, say "yes"
17 | - Linux users can use the installation command provided on the ollama website
18 | - There are 3 ways we can run ollama on our computer:
19 |
20 |
21 |
22 | **1. `ollama serve`**
23 |
24 | - This runs the ollama backend as a server, usually on `http://localhost:11434`. It doesn't load a model until we call it through the API. This is what we want if we want to use ollama through Python.
25 |
26 | **2. `ollama run gpt-oss:20b`**
27 |
28 | - This is a convenience wrapper. If the server is not already running, it will start it, then download the model (the first time), and drop us into an interactive terminal where we can chat with the model. Behind the scenes, it uses the same server API.
29 |
30 | **3. Ollama desktop app**
31 |
32 | - This runs the same backend automatically and provides a GUI on top of it (as shown in the figure above).
33 | It also applies defaults (system prompt, temperature, stop sequences), which can explain why answers look different from raw API usage.
34 |
35 |
36 |
37 | ## Usage
38 |
39 |
40 |
41 | The options and defaults are shown below.
42 |
43 |
44 |
45 | ---
46 |
47 | **Note**: If you are not a `uv` user, replace `uv run ...py` with `python ...py` in the examples below.
48 |
49 | ---
50 |
51 |
52 |
53 | ```bash
54 | uv run ollama-judge.py --help
55 | usage: ollama-judge.py [-h] [--device DEVICE]
56 | [--which_model {base,reasoning}]
57 | [--dataset_size DATASET_SIZE]
58 | [--max_new_tokens MAX_NEW_TOKENS]
59 | [--url URL]
60 | [--judge_model JUDGE_MODEL]
61 |
62 | options:
63 | -h, --help show this help message and
64 | exit
65 | --device DEVICE Device e.g., "cpu",
66 | "cuda", "cuda:0", "mps".
67 | --which_model {base,reasoning}
68 | Candidate variant to use.
69 | Defaults to "base".
70 | --dataset_size DATASET_SIZE
71 | Number of MATH-500
72 | examples to evaluate.
73 | Default: 10
74 | --max_new_tokens MAX_NEW_TOKENS
75 | Max new tokens for
76 | candidate generation.
77 | Default: 2048
78 | --url URL Ollama chat endpoint for
79 | the judge. Default: "http:
80 | //localhost:11434/api/chat
81 | "
82 | --judge_model JUDGE_MODEL
83 | Judge model name (Ollama).
84 | Used only for scoring.
85 | Default: "gpt-oss:20b"
86 | ```
87 |
88 |
89 |
90 | **Base model**
91 |
92 | ```bash
93 | ➜ uv run ollama-judge.py
94 | Using Apple Silicon GPU (MPS)
95 | Model: base
96 | Device: mps
97 | ✓ qwen3/qwen3-0.6B-base.pth already up-to-date
98 | ✓ qwen3/tokenizer-base.json already up-to-date
99 | Ollama running: True
100 | [1/10] score=5
101 | [2/10] score=1
102 | [3/10] score=5
103 | [4/10] score=5
104 | [5/10] score=3
105 | [6/10] score=5
106 | [7/10] score=5
107 | [8/10] score=3
108 | [9/10] score=5
109 | [10/10] score=1
110 |
111 | Summary
112 | -------
113 | Average score: 3.800 over 10 example(s)
114 | Counts: 1:2 2:0 3:2 4:0 5:6
115 | ```
116 |
117 | **Reasoning model**
118 |
119 | ```bash
120 | ➜ uv run ollama-judge.py --which_model reasoning
121 | Using Apple Silicon GPU (MPS)
122 | Model: reasoning
123 | Device: mps
124 | ✓ qwen3/qwen3-0.6B-reasoning.pth already up-to-date
125 | ✓ qwen3/tokenizer-reasoning.json already up-to-date
126 | Ollama running: True
127 | [1/10] score=5
128 | [2/10] score=5
129 | [3/10] score=5
130 | [4/10] score=5
131 | [5/10] score=4
132 | [6/10] score=5
133 | [7/10] score=5
134 | [8/10] score=1
135 | [9/10] score=5
136 | [10/10] score=3
137 |
138 | Summary
139 | -------
140 | Average score: 4.300 over 10 example(s)
141 | Counts: 1:1 2:0 3:1 4:1 5:7
142 | ```
143 |
144 |
--------------------------------------------------------------------------------
/.github/scripts/check_double_quotes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | # Verify that Python source files and notebooks use double quotes for strings.
6 |
7 | import ast
8 | import io
9 | import json
10 | import sys
11 | import tokenize
12 | from pathlib import Path
13 |
14 | EXCLUDED_DIRS = {
15 | ".git",
16 | ".hg",
17 | ".mypy_cache",
18 | ".pytest_cache",
19 | ".ruff_cache",
20 | ".svn",
21 | ".tox",
22 | ".venv",
23 | "__pycache__",
24 | "build",
25 | "dist",
26 | "node_modules",
27 | }
28 |
29 | PREFIX_CHARS = {"r", "u", "f", "b"}
30 | SINGLE_QUOTE = "'"
31 | DOUBLE_QUOTE = "\""
32 | TRIPLE_SINGLE = SINGLE_QUOTE * 3
33 | TRIPLE_DOUBLE = DOUBLE_QUOTE * 3
34 |
35 |
36 | def should_skip(path):
37 | parts = set(path.parts)
38 | return bool(EXCLUDED_DIRS & parts)
39 |
40 |
41 | def collect_fstring_expr_string_positions(source):
42 | """
43 | Return set of (lineno, col_offset) for string literals that appear inside
44 | formatted expressions of f-strings. These should be exempt from the double
45 | quote check, since enforcing double quotes there is unnecessarily strict.
46 | """
47 | try:
48 | tree = ast.parse(source)
49 | except SyntaxError:
50 | return set()
51 |
52 | positions = set()
53 |
54 | class Collector(ast.NodeVisitor):
55 | def visit_JoinedStr(self, node):
56 | for value in node.values:
57 | if isinstance(value, ast.FormattedValue):
58 | self._collect_from_expr(value.value)
59 | # Continue walking to catch nested f-strings within expressions
60 | self.generic_visit(node)
61 |
62 | def _collect_from_expr(self, node):
63 | if isinstance(node, ast.Constant) and isinstance(node.value, str):
64 | positions.add((node.lineno, node.col_offset))
65 | else:
66 | for child in ast.iter_child_nodes(node):
67 | self._collect_from_expr(child)
68 |
69 | Collector().visit(tree)
70 | return positions
71 |
72 |
73 | def check_quotes_in_source(source, path):
74 | violations = []
75 | ignored_positions = collect_fstring_expr_string_positions(source)
76 | tokens = tokenize.generate_tokens(io.StringIO(source).readline)
77 | lines = source.splitlines()
78 |
79 | for tok_type, tok_str, start, _, _ in tokens:
80 | if tok_type == tokenize.STRING:
81 | line_no, col = start
82 |
83 | # Skip if the line contains both quote types (e.g. if ch in ("'", '"'))
84 | if line_no - 1 < len(lines):
85 | line_text = lines[line_no - 1]
86 | if "'" in line_text and '"' in line_text:
87 | continue
88 |
89 | if start in ignored_positions:
90 | continue
91 |
92 | lowered = tok_str.lower()
93 | # ignore triple-quoted strings
94 | if lowered.startswith((TRIPLE_DOUBLE, TRIPLE_SINGLE)):
95 | continue
96 |
97 | # find the prefix and quote type
98 | for c in PREFIX_CHARS:
99 | if lowered.startswith(c):
100 | lowered = lowered[1:]
101 | break
102 |
103 | # report if not using double quotes
104 | if lowered.startswith(SINGLE_QUOTE):
105 | violations.append(f"{path}:{line_no}:{col}: uses single quotes")
106 | return violations
107 |
108 |
109 | def check_file(path):
110 | try:
111 | if path.suffix == ".ipynb":
112 | return check_notebook(path)
113 | else:
114 | text = path.read_text(encoding="utf-8")
115 | return check_quotes_in_source(text, path)
116 | except Exception as e:
117 | return [f"{path}: failed to check ({e})"]
118 |
119 |
120 | def check_notebook(path):
121 | violations = []
122 | with open(path, encoding="utf-8") as f:
123 | nb = json.load(f)
124 | for cell in nb.get("cells", []):
125 | if cell.get("cell_type") == "code":
126 | src = "".join(cell.get("source", []))
127 | violations.extend(check_quotes_in_source(src, path))
128 | return violations
129 |
130 |
131 | def main():
132 | project_root = Path(".").resolve()
133 | py_files = sorted(project_root.rglob("*.py"))
134 | notebook_files = sorted(project_root.rglob("*.ipynb"))
135 |
136 | violations = []
137 | for path in py_files + notebook_files:
138 | if should_skip(path):
139 | continue
140 | violations.extend(check_file(path))
141 |
142 | if violations:
143 | print("\n".join(violations))
144 | print(f"\n{len(violations)} violations found.")
145 | return 1
146 |
147 | print("All files use double quotes correctly.")
148 | return 0
149 |
150 |
151 | if __name__ == "__main__":
152 | sys.exit(main())
153 |
--------------------------------------------------------------------------------
/chF/02_mmlu/2_logprob.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import argparse
6 | import time
7 |
8 | import torch
9 | from datasets import load_dataset, get_dataset_config_names
10 | from reasoning_from_scratch.ch02 import get_device
11 | from reasoning_from_scratch.ch03 import load_model_and_tokenizer
12 |
13 |
14 | # Same as in main notebook
15 | def format_prompt(example):
16 | return (
17 | f"{example['question']}\n"
18 | f"A. {example['choices'][0]}\n"
19 | f"B. {example['choices'][1]}\n"
20 | f"C. {example['choices'][2]}\n"
21 | f"D. {example['choices'][3]}\n"
22 | "Answer: " # trailing space encourages a single-letter next token
23 | )
24 |
25 |
26 | def common_prefix_len(a, b):
27 | i = 0
28 | n = min(len(a), len(b))
29 | while i < n and a[i] == b[i]:
30 | i += 1
31 | return i
32 |
33 |
34 | def first_new_token_id(tokenizer, prompt, prompt_ids, continuation):
35 | ids_full = tokenizer.encode(prompt + continuation)
36 | j = common_prefix_len(ids_full, prompt_ids)
37 | if j >= len(ids_full):
38 | raise ValueError("Continuation produced no new tokens.")
39 | return ids_full[j]
40 |
41 |
42 | def predict_choice_logprobs(model, tokenizer, prompt_fmt, prompt, prompt_ids, example):
43 | with torch.no_grad():
44 | logits = model(prompt_fmt) # [1, T, V]
45 | next_logp = torch.log_softmax(logits[0, -1], dim=-1)
46 |
47 | scores = {}
48 | for letter in "ABCD":
49 | tok_id = first_new_token_id(tokenizer, prompt, prompt_ids, letter)
50 | scores[letter] = next_logp[tok_id].item()
51 |
52 | pred = max(scores, key=scores.get)
53 | return pred, scores
54 |
55 |
56 | def evaluate_mmlu_logprobs(
57 | model,
58 | tokenizer,
59 | device,
60 | subsets="high_school_mathematics", # str, list of str, or "all"
61 | split="test",
62 | verbose_every=50,
63 | ):
64 | if subsets == "all":
65 | subset_list = get_dataset_config_names("cais/mmlu")
66 | elif isinstance(subsets, str):
67 | subset_list = [s.strip() for s in subsets.split(",")] if "," in subsets else [subsets]
68 | else:
69 | subset_list = list(subsets)
70 |
71 | total = 0
72 | correct = 0
73 | start = time.time()
74 |
75 | for subset in subset_list:
76 | ds = load_dataset("cais/mmlu", subset, split=split)
77 | for ex in ds:
78 | prompt = format_prompt(ex)
79 | prompt_ids = tokenizer.encode(prompt)
80 | prompt_fmt = torch.tensor(prompt_ids, device=device).unsqueeze(0)
81 |
82 | pred, _scores = predict_choice_logprobs(
83 | model, tokenizer, prompt_fmt, prompt, prompt_ids, ex
84 | )
85 |
86 | ans = ex["answer"]
87 | # "Gold" is the MMLU jargon for the correct answer (ground truth)
88 | gold = "ABCD"[ans] if isinstance(ans, int) else str(ans).strip().upper()
89 |
90 | total += 1
91 | correct += int(pred == gold)
92 |
93 | if verbose_every and total % verbose_every == 0:
94 | print(f"MMLU {total} acc={correct/total:.3f} [{subset}]")
95 |
96 | acc = correct / max(1, total)
97 | print(
98 | f"\nMMLU letter accuracy (log-prob): {correct}/{total} = {acc:.2%} "
99 | f"in {time.time()-start:.1f}s"
100 | )
101 | return {"accuracy": acc, "num_examples": total, "subsets": subset_list, "split": split}
102 |
103 |
104 | def main():
105 | parser = argparse.ArgumentParser(
106 | description="Zero-shot MMLU evaluator via next-token log-prob scoring (A/B/C/D)."
107 | )
108 | parser.add_argument(
109 | "--device",
110 | type=str,
111 | default="auto",
112 | help="Device to use: 'auto' (default), or any torch device string like "
113 | "'cpu', 'cuda', 'cuda:0', 'mps'.",
114 | )
115 | parser.add_argument(
116 | "--which_model",
117 | type=str,
118 | default="base",
119 | choices=["base", "reasoning"],
120 | help="Model variant to load. Defaults to 'base'.",
121 | )
122 | parser.add_argument(
123 | "--subsets",
124 | type=str,
125 | default="high_school_mathematics",
126 | help="Comma-separated subset names or 'all'. "
127 | "Default: 'high_school_mathematics'.",
128 | )
129 | args = parser.parse_args()
130 |
131 | device = get_device() if args.device == "auto" else torch.device(args.device)
132 | print(f"Using device: {device}")
133 | model, tokenizer = load_model_and_tokenizer(args.which_model, device, use_compile=False)
134 | model.eval()
135 | torch.set_float32_matmul_precision("high")
136 |
137 | metrics = evaluate_mmlu_logprobs(
138 | model=model,
139 | tokenizer=tokenizer,
140 | device=device,
141 | subsets=args.subsets,
142 | )
143 | print(metrics)
144 |
145 |
146 | if __name__ == "__main__":
147 | main()
148 |
--------------------------------------------------------------------------------
/tests/test_ch05.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 |
6 | import math
7 | import torch
8 |
9 | import reasoning_from_scratch.ch05 as ch05
10 |
11 |
12 | class DummyTokenizer:
13 | def __init__(self, mapping=None):
14 | self.mapping = mapping or {}
15 | self.eos_token_id = 0
16 |
17 | def encode(self, text):
18 | if text in self.mapping:
19 | return self.mapping[text]
20 | return [len(text) % 10]
21 |
22 |
23 | class DummyLogitModel:
24 | def __init__(self, base_logits):
25 | self.base_logits = torch.tensor(base_logits, dtype=torch.float32)
26 |
27 | def __call__(self, token_ids):
28 | seq_len = token_ids.size(1)
29 | logits = self.base_logits.repeat(seq_len, 1)
30 | return logits.unsqueeze(0)
31 |
32 |
33 | def test_heuristic_score_rewards_boxed_answers_more_than_numbers():
34 | boxed = "Result \\boxed{7}"
35 | number_only = "The answer is 7"
36 |
37 | boxed_score = ch05.heuristic_score(boxed)
38 | number_score = ch05.heuristic_score(number_only)
39 |
40 | expected_boxed = 2.0 + 1.5 * math.exp(-len(boxed) / 500.0)
41 | expected_number = 1.0 + 1.5 * math.exp(-len(number_only) / 500.0)
42 |
43 | assert math.isclose(boxed_score, expected_boxed, rel_tol=1e-6)
44 | assert math.isclose(number_score, expected_number, rel_tol=1e-6)
45 | assert boxed_score > number_score
46 |
47 |
48 | def test_heuristic_score_adds_fulltext_bonus_when_no_number():
49 | response = "No numeric result here"
50 | score = ch05.heuristic_score(
51 | response, brevity_bonus=100.0, fulltext_bonus=0.3
52 | )
53 |
54 | expected = 0.3 + 1.5 * math.exp(-len(response) / 100.0)
55 | assert math.isclose(score, expected, rel_tol=1e-6)
56 |
57 |
58 | def test_avg_logprob_answer_uses_answer_token_logprobs():
59 | tokenizer = DummyTokenizer(
60 | mapping={
61 | "prompt": [1, 2],
62 | "answer": [3, 4],
63 | }
64 | )
65 | base_logits = [0.0, 0.1, 0.2, 0.3, 0.4]
66 | model = DummyLogitModel(base_logits)
67 |
68 | expected_logprobs = torch.log_softmax(
69 | model.base_logits, dim=-1
70 | )
71 | expected = torch.mean(expected_logprobs[[3, 4]]).item()
72 |
73 | out = ch05.avg_logprob_answer(
74 | model=model,
75 | tokenizer=tokenizer,
76 | prompt="prompt",
77 | answer="answer",
78 | device="cpu",
79 | )
80 |
81 | assert math.isclose(out, expected, rel_tol=1e-6)
82 |
83 |
84 | def test_prompt_builders_embed_question_and_context():
85 | raw_prompt = "What is 1+1?"
86 | draft = "It is \\boxed{2}."
87 | critique_text = "Looks fine."
88 |
89 | critique_prompt = ch05.make_critique_prompt(raw_prompt, draft)
90 | assert "meticulous reviewer" in critique_prompt
91 | assert f"Question:\n{raw_prompt}" in critique_prompt
92 | assert f"Draft answer:\n{draft}" in critique_prompt
93 | assert critique_prompt.strip().endswith("Critique:")
94 |
95 | refine_prompt = ch05.make_refine_prompt(
96 | raw_prompt, draft, critique_text
97 | )
98 | assert "Revised answer:" in refine_prompt
99 | assert f"Previous answer:\n{draft}" in refine_prompt
100 | assert f"Critique:\n{critique_text}" in refine_prompt
101 | assert refine_prompt.strip().endswith("Revised answer:")
102 |
103 |
104 | def test_self_refinement_loop_accepts_improving_revisions(monkeypatch):
105 | responses = iter(
106 | [
107 | "initial draft", # initial generation
108 | "first critique", # critique 1
109 | "draft with more detail", # refine 1 (accepted)
110 | "second critique", # critique 2
111 | "bad", # refine 2 (rejected)
112 | ]
113 | )
114 | prompts_seen = []
115 |
116 | def fake_generate_text_stream_concat_flex(**kwargs):
117 | prompts_seen.append(kwargs.get("prompt"))
118 | return next(responses)
119 |
120 | monkeypatch.setattr(
121 | ch05, "generate_text_stream_concat_flex",
122 | fake_generate_text_stream_concat_flex
123 | )
124 |
125 | def score_fn(answer, prompt):
126 | return len(answer)
127 |
128 | result = ch05.self_refinement_loop(
129 | model=None,
130 | tokenizer=DummyTokenizer(),
131 | raw_prompt="Compute something",
132 | device="cpu",
133 | iterations=2,
134 | score_fn=score_fn,
135 | prompt_renderer=lambda x: f"Q: {x}",
136 | temperature=0.3,
137 | top_p=0.8,
138 | )
139 |
140 | assert result["final_extracted"] == "draft with more detail"
141 | assert len(result["steps"]) == 2
142 | assert result["steps"][0]["draft_full"] == "initial draft"
143 | assert result["steps"][0]["revised_full"] == "draft with more detail"
144 | assert result["steps"][1]["draft_full"] == "draft with more detail"
145 | assert result["steps"][1]["revised_full"] == "bad"
146 | assert result["steps"][0]["score_after"] >= result["steps"][0]["score_before"]
147 | assert result["steps"][1]["score_after"] < result["steps"][1]["score_before"]
148 | assert len(prompts_seen) == 5
149 |
--------------------------------------------------------------------------------
/ch02/02_setup-tips/python-instructions.md:
--------------------------------------------------------------------------------
1 | # Python Setup Recommendations
2 |
3 | The code in this book is largely self-contained, and I have made an effort to minimize external dependencies. However, to keep the book accessible, readable, and well under 2000 pages, a few Python packages are necessary.
4 |
5 | This section introduces two beginner-friendly methods for installing the required packages so you can run the code examples.
6 |
7 | There are, of course, many other ways to install and manage Python packages. If you are an experienced Python user and already have your own setup or preferences, feel free to skip this section.
8 |
9 | If neither of the two options below works for you, please do not hesitate to reach out, for example, by opening a [Discussion](https://github.com/rasbt/reasoning-from-scratch/discussions).
10 |
11 |
12 | ## Option 1: Using `pip` (built-in, works everywhere)
13 |
14 | If you are using a recent version of Python already, you can install packages using the built-in `pip` installer.
15 |
16 | I used Python 3.12 for this book. However, older versions like 3.11 and 3.10 will also work fine. You can check your Python version by running:
17 |
18 | ```bash
19 | python --version
20 | ```
21 |
22 | If you are using Python 3.9 or older, consider installing the latest from [python.org](https://www.python.org/downloads/) or using a tool like [`pyenv`](https://github.com/pyenv/pyenv) to manage versions. However, if you are installing a new Python version, please make sure that it is supported by PyTorch by checking the recommendation on the [official PyTorch website](https://pytorch.org/get-started/locally/). PyTorch typically lags a few months behind the latest Python release, so newly released Python versions are not supported or recommended immediately.
23 |
24 | To install new packages, as needed, (for example, PyTorch and Jupyter Lab), run:
25 |
26 | ```bash
27 | pip install torch jupyterlab
28 | ```
29 |
30 | Alternatively, you can install all required Python package used in this book all once via the [`requirements.txt`](https://github.com/rasbt/reasoning-from-scratch/blob/main/requirements.txt) file:
31 |
32 | ```bash
33 | pip install -r https://raw.githubusercontent.com/rasbt/reasoning-from-scratch/refs/heads/main/requirements.txt
34 | ```
35 |
36 |
37 |
38 | ## Option 2: Use `uv` (faster and widely recommended)
39 |
40 | While `pip` remains the classic and official way to install Python packages, [`uv`](https://github.com/astral-sh/uv) is a modern and widely recommended Python package manager that automatically:
41 |
42 | - Creates and manages a virtual environment
43 | - Installs packages quickly
44 | - Keeps a lockfile for reproducible installs
45 | - Supports `pip`-like commands
46 |
47 |
48 | ### Installing `uv` and Python packages
49 |
50 | To install `uv`, you can use the commands below (also see the official [Installation](https://docs.astral.sh/uv/getting-started/installation/) page for the latest recommendations).
51 |
52 |
53 | **macOS / Linux:**
54 |
55 | ```bash
56 | curl -LsSf https://astral.sh/uv/install.sh | sh
57 | ```
58 |
59 |
60 | **Windows (PowerShell):**
61 |
62 | ```powershell
63 | powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
64 | ```
65 |
66 | Once installed, you can install new Python packages similar to how you would do it via `pip` as described in the previous section, except that you replace `pip` with `uv pip`. For example
67 |
68 | ```bash
69 | uv pip install torch jupyterlab
70 | ```
71 |
72 | However, if you are using `uv`, which I recommend and use myself, it's even better to use the native `uv` syntax instead of `uv pip`, as described below.
73 |
74 |
75 | ### Recommended `uv` workflow
76 |
77 | Instead of using `uv pip`, I recommend and use the native `uv` worklow.
78 |
79 | First, clone the GitHub repository to your local machine:
80 |
81 |
82 |
83 | ```bash
84 | git clone https://github.com/rasbt/reasoning-from-scratch.git
85 | ```
86 |
87 | Next, navigate into this folder, e.g., on Linux and MacOS:
88 |
89 | ```bash
90 | cd reasoning-from-scratch
91 | ```
92 |
93 | Then, since this folder contains a `pyproject.toml` file, you are already good to go: `uv` will automatically create a (by default invisible) virtual environment folder (`.venv`) for this `reasoning-from-scratch` project into which it installs all the dependencies the first time you run a script or open Jupyter Lab.
94 |
95 | You will probably not need it but in general, you can install additional packages, which are not already part of the requirements listed in `pyproject.toml`, via `uv add`:
96 |
97 |
98 | ```bash
99 | uv add llms_from_scratch
100 | ```
101 |
102 | The above command will then add the package to the virtual environment and `pyproject.toml` file.
103 |
104 |
105 | ### Running code via `uv`
106 |
107 | This section describes the `uv` commands to run Jupyter Lab and Python scripts.
108 |
109 | To open Jupyter Lab, execute:
110 |
111 | ```python
112 | uv run jupyter lab
113 | ```
114 |
115 | Python scripts can be run via:
116 |
117 | ```bash
118 | uv run python script.py
119 | ```
120 |
121 |
122 |
123 |
124 | > **Advanced usage:** This section describes a simple way to use `uv` that looks familiar to `pip` users. If you are interested in more advanced usage, please see [this document](https://github.com/rasbt/LLMs-from-scratch/tree/main/setup/01_optional-python-setup-preferences) for more explicit instructions on managing virtual environments in `uv`.
125 | > If you are a macOS or Linux user and prefer the native uv commands, please refer to [this tutorial](https://github.com/rasbt/LLMs-from-scratch/blob/main/setup/01_optional-python-setup-preferences/native-uv.md). I also recommend checking the [official uv documentation](https://docs.astral.sh/uv/) for additional information.
126 |
127 |
128 |
129 |
130 | ## Questions?
131 |
132 | If you have any questions, please don't hesitate to reach out via the [Discussions](https://github.com/rasbt/reasoning-from-scratch/discussions) forum in this GitHub repository.
133 |
--------------------------------------------------------------------------------
/tests/test_qwen3_optimized.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import importlib
6 | import os
7 | import pytest
8 | import torch
9 | from pathlib import Path
10 |
11 | from reasoning_from_scratch.ch02 import (
12 | generate_text_basic_cache,
13 | )
14 | from reasoning_from_scratch.qwen3 import (
15 | download_qwen3_small,
16 | load_hf_weights_into_qwen,
17 | Qwen3Tokenizer,
18 | Qwen3Model,
19 | QWEN_CONFIG_06_B,
20 | )
21 |
22 | from reasoning_from_scratch.qwen3_optimized import (
23 | Qwen3Model as Qwen3ModelOptimized,
24 | generate_text_basic_cache as generate_text_basic_cache_optimized,
25 | )
26 |
27 |
28 | skip_expensive = os.environ.get("SKIP_EXPENSIVE", "0") == "1"
29 | transformers_installed = importlib.util.find_spec("transformers") is not None
30 |
31 | # Make CI more reproducible & robust
32 | os.environ["MKL_NUM_THREADS"] = "1"
33 | os.environ["OMP_NUM_THREADS"] = "1"
34 | torch.backends.mkldnn.enabled = False
35 | torch.set_num_threads(1)
36 | torch.use_deterministic_algorithms(True)
37 |
38 |
39 | @torch.inference_mode()
40 | @pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
41 | def test_qwen3_base_equivalence_with_transformers():
42 |
43 | from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
44 |
45 | # Tiny config so the test is fast
46 | cfg = {
47 | "vocab_size": 257,
48 | "context_length": 8,
49 | "emb_dim": 32,
50 | "n_heads": 4,
51 | "n_layers": 2,
52 | "hidden_dim": 64,
53 | "head_dim": 8,
54 | "qk_norm": True,
55 | "n_kv_groups": 2,
56 | "rope_base": 1_000_000.0,
57 | "dtype": torch.float32,
58 | }
59 | model = Qwen3Model(cfg)
60 |
61 | hf_cfg = Qwen3Config(
62 | vocab_size=cfg["vocab_size"],
63 | max_position_embeddings=cfg["context_length"],
64 | hidden_size=cfg["emb_dim"],
65 | num_attention_heads=cfg["n_heads"],
66 | num_hidden_layers=cfg["n_layers"],
67 | intermediate_size=cfg["hidden_dim"],
68 | head_dim=cfg["head_dim"],
69 | num_key_value_heads=cfg["n_kv_groups"],
70 | rope_theta=cfg["rope_base"],
71 | tie_word_embeddings=False,
72 | attn_implementation="eager",
73 | torch_dtype=torch.float32,
74 | )
75 | hf_model = Qwen3ForCausalLM(hf_cfg)
76 |
77 | hf_state = hf_model.state_dict()
78 | param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
79 | load_hf_weights_into_qwen(model, param_config, hf_state)
80 |
81 | x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
82 | ours_logits = model(x)
83 | theirs_logits = hf_model(x).logits
84 | torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
85 |
86 |
87 | @pytest.mark.skipif(skip_expensive, reason="Skipping expensive test on CI")
88 | @pytest.mark.parametrize("reasoning", [False, True])
89 | def test_qwen3_vs_optimized_qwen3(reasoning):
90 |
91 | device = "cpu" # get_device()
92 |
93 | # Download and init tokenizer
94 | kind = "reasoning" if reasoning else "base"
95 | download_qwen3_small(kind=kind, tokenizer_only=False, out_dir="qwen3")
96 | tokenizer_path = Path("qwen3") / (
97 | "tokenizer-reasoning.json" if reasoning else "tokenizer-base.json"
98 | )
99 | model_path = Path("qwen3") / (
100 | "qwen3-0.6B-reasoning.pth" if reasoning else "qwen3-0.6B-base.pth"
101 | )
102 | tokenizer = Qwen3Tokenizer(
103 | tokenizer_file_path=tokenizer_path,
104 | apply_chat_template=True if reasoning else False,
105 | add_generation_prompt=True if reasoning else False,
106 | add_thinking=True if reasoning else False,
107 | )
108 |
109 | # Models
110 | model = Qwen3Model(QWEN_CONFIG_06_B)
111 | model.load_state_dict(torch.load(model_path, map_location=device))
112 | model.to(device)
113 | model.eval()
114 |
115 | model_optimized = Qwen3ModelOptimized(QWEN_CONFIG_06_B, exact=True)
116 | model_optimized.load_state_dict(torch.load(model_path, map_location=device))
117 | model_optimized.to(device)
118 | model_optimized.eval()
119 |
120 | # Prompts
121 | prompts = [
122 | "Explain large language models in two sentences.",
123 | "Explain large language models in one sentence.",
124 | "1+1?"
125 | ]
126 |
127 | single_inputs = [
128 | torch.tensor(tokenizer.encode(p), device=device).unsqueeze(0)
129 | for p in prompts
130 | ]
131 |
132 | # Generation
133 | max_new_tokens = 12 # cheap but enough to check consistency
134 | outputs_simple = []
135 | for input_ids in single_inputs:
136 | out = generate_text_basic_cache(
137 | model=model,
138 | token_ids=input_ids,
139 | max_new_tokens=max_new_tokens,
140 | eos_token_id=tokenizer.eos_token_id,
141 | )
142 | outputs_simple.append(out[0].tolist())
143 |
144 | outputs_optimized = []
145 | for input_ids in single_inputs:
146 | out = generate_text_basic_cache_optimized(
147 | model=model_optimized,
148 | token_ids=input_ids,
149 | max_new_tokens=max_new_tokens,
150 | eos_token_id=tokenizer.eos_token_id,
151 | )
152 | outputs_optimized.append(out[0])
153 |
154 | # Check equivalency
155 | for idx, out_single in enumerate(outputs_simple):
156 | out_batch = outputs_optimized[idx].tolist()
157 |
158 | text_single = tokenizer.decode(out_single)
159 | text_batch = tokenizer.decode(out_batch)
160 |
161 | # Assert the text beyond the first token is identical
162 | assert text_single == text_batch, (
163 | f"Mismatch after first token at prompt {idx}:\n"
164 | f"single={text_single}\n"
165 | f"batched={text_batch}"
166 | )
167 |
--------------------------------------------------------------------------------
/ch04/02_math500-inference-scaling-scripts/cot_prompting_math500.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import argparse
6 | import json
7 | from pathlib import Path
8 | import time
9 |
10 | import torch
11 |
12 | from reasoning_from_scratch.ch02 import get_device
13 | from reasoning_from_scratch.ch03 import (
14 | load_math500_test,
15 | eta_progress_message,
16 | extract_final_candidate,
17 | render_prompt,
18 | grade_answer,
19 | generate_text_stream_concat,
20 | load_model_and_tokenizer
21 | )
22 |
23 |
24 | def evaluate_math500_stream(
25 | model,
26 | tokenizer,
27 | device,
28 | math_data,
29 | out_path=None,
30 | max_new_tokens=512,
31 | verbose=False,
32 | prompt_suffix="" # NEW
33 | ):
34 |
35 | if out_path is None:
36 | dev_name = str(device).replace(":", "-")
37 | out_path = Path(f"math500-{dev_name}.jsonl")
38 |
39 | num_examples = len(math_data)
40 | num_correct = 0
41 | start_time = time.time()
42 |
43 | with open(out_path, "w", encoding="utf-8") as f:
44 | for i, row in enumerate(math_data, start=1):
45 | prompt = render_prompt(row["problem"])
46 | prompt += prompt_suffix # NEW
47 | gen_text = generate_text_stream_concat(
48 | model, tokenizer, prompt, device,
49 | max_new_tokens=max_new_tokens,
50 | verbose=verbose,
51 | )
52 |
53 | extracted = extract_final_candidate(
54 | gen_text
55 | )
56 | is_correct = grade_answer(
57 | extracted, row["answer"]
58 | )
59 | num_correct += int(is_correct)
60 |
61 | record = {
62 | "index": i,
63 | "problem": row["problem"],
64 | "gtruth_answer": row["answer"],
65 | "generated_text": gen_text,
66 | "extracted": extracted,
67 | "correct": bool(is_correct),
68 | }
69 | f.write(json.dumps(record, ensure_ascii=False) + "\n")
70 |
71 | progress_msg = eta_progress_message(
72 | processed=i,
73 | total=num_examples,
74 | start_time=start_time,
75 | show_eta=True,
76 | label="MATH-500",
77 | )
78 | print(progress_msg, end="\r", flush=True)
79 | if verbose:
80 | print(
81 | f"\n\n{'='*50}\n{progress_msg}\n"
82 | f"{'='*50}\nExtracted: {extracted}\n"
83 | f"Expected: {row['answer']}\n"
84 | f"Correct so far: {num_correct}\n{'-'*50}"
85 | )
86 |
87 | seconds_elapsed = time.time() - start_time
88 | acc = num_correct / num_examples if num_examples else 0.0
89 | print(f"\nAccuracy: {acc*100:.1f}% ({num_correct}/{num_examples})")
90 | print(f"Total time: {seconds_elapsed/60:.1f} min")
91 | print(f"Logs written to: {out_path}")
92 | return num_correct, num_examples, acc
93 |
94 |
95 | def parse_args():
96 | parser = argparse.ArgumentParser()
97 | parser.add_argument(
98 | "--device",
99 | type=str,
100 | default="auto",
101 | help="Device to use: 'auto' (default), or any torch device string like 'cpu', 'cuda', 'cuda:0', 'mps'.",
102 | )
103 | parser.add_argument(
104 | "--which_model",
105 | type=str,
106 | default="base",
107 | choices=["base", "reasoning", "instruct"],
108 | help="Model variant to load. Defaults to 'base'.",
109 | )
110 | parser.add_argument(
111 | "--dataset_size",
112 | type=int,
113 | default=10,
114 | help="Number of MATH-500 examples to evaluate. Default: 10",
115 | )
116 | parser.add_argument(
117 | "--max_new_tokens",
118 | type=int,
119 | default=2048,
120 | help="Max new tokens for generation. Default: 2048",
121 | )
122 | parser.add_argument(
123 | "--compile",
124 | action="store_true",
125 | help="Enable torch.compile for the model.",
126 | )
127 | parser.add_argument(
128 | "--prompt_suffix",
129 | type=str,
130 | default="",
131 | help="Can be used to adds a chain-of-thought prompt (default: '')",
132 | )
133 | parser.add_argument(
134 | "--verbose",
135 | action="store_true",
136 | help="Print per-sample correctness while evaluating.",
137 | )
138 | return parser.parse_args()
139 |
140 |
141 | if __name__ == "__main__":
142 | args = parse_args()
143 |
144 | if args.device == "auto":
145 | device = get_device()
146 | else:
147 | device = torch.device(args.device)
148 |
149 | which_model = args.which_model
150 | dataset_size = args.dataset_size
151 | max_new_tokens = args.max_new_tokens
152 | use_compile = args.compile
153 |
154 | print("Model:", which_model)
155 | print("Device:", device)
156 | dev_name = str(device).replace(":", "-")
157 |
158 | math_data = load_math500_test()
159 |
160 | if args.which_model == "instruct":
161 | which_model = "reasoning"
162 | else:
163 | which_model = args.which_model
164 |
165 | model, tokenizer = load_model_and_tokenizer(
166 | which_model=which_model,
167 | device=device,
168 | use_compile=args.compile
169 | )
170 | if args.which_model == "instruct":
171 | tokenizer.add_thinking = False
172 |
173 | model.eval()
174 | torch.set_float32_matmul_precision("high")
175 |
176 | num_correct, num_examples, acc = evaluate_math500_stream(
177 | model=model,
178 | out_path=f"math500_{which_model}-{dev_name}-evaluate-script.jsonl",
179 | tokenizer=tokenizer,
180 | device=device,
181 | math_data=math_data[:dataset_size],
182 | max_new_tokens=max_new_tokens,
183 | verbose=args.verbose,
184 | prompt_suffix=args.prompt_suffix # NEW
185 | )
186 |
--------------------------------------------------------------------------------
/chF/02_mmlu/3_teacher_forcing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import argparse
6 | import time
7 |
8 | import torch
9 | from datasets import load_dataset, get_dataset_config_names
10 | from reasoning_from_scratch.ch02 import get_device
11 | from reasoning_from_scratch.ch03 import load_model_and_tokenizer
12 |
13 |
14 | # Same as before
15 | def format_prompt(example):
16 | return (
17 | f"{example['question']}\n"
18 | f"A. {example['choices'][0]}\n"
19 | f"B. {example['choices'][1]}\n"
20 | f"C. {example['choices'][2]}\n"
21 | f"D. {example['choices'][3]}\n"
22 | "Answer: " # trailing space encourages a single-letter next token
23 | )
24 |
25 |
26 | def common_prefix_len(a, b):
27 | i = 0
28 | n = min(len(a), len(b))
29 | while i < n and a[i] == b[i]:
30 | i += 1
31 | return i
32 |
33 |
34 | def avg_logprob_teacher_forced(model, tokenizer, prompt_fmt, prompt, prompt_ids, letter, choice_text):
35 | # Build full answer text and then just extract the continuation token IDs
36 | answer_text = f"{letter}. {choice_text}"
37 | ids_full = tokenizer.encode(prompt + answer_text)
38 | j = common_prefix_len(ids_full, prompt_ids)
39 | if j >= len(ids_full):
40 | raise ValueError("Continuation produced no new tokens.")
41 | answer_ids = ids_full[j:] # tokens for the answer continuation
42 |
43 | # Input to model is "prompt + all" except for last answer token (to predict each next)
44 | device = prompt_fmt.device
45 | if len(answer_ids) == 0:
46 | return float("-inf")
47 |
48 | answer_prefix = torch.tensor(answer_ids[:-1], dtype=torch.long, device=device).unsqueeze(0)
49 | combined = torch.cat([prompt_fmt, answer_prefix], dim=1)
50 |
51 | with torch.no_grad():
52 | # Logits for every position in `combined``
53 | scores = model(combined).squeeze(0) # shape [num_tokens, vocab_size]
54 | logp = torch.log_softmax(scores, dim=-1)
55 |
56 | prompt_len = prompt_fmt.shape[1]
57 | answer_len = len(answer_ids)
58 |
59 | # Slice the exact rows where the model predicts the answer tokens
60 | steps = logp[prompt_len-1:prompt_len-1+answer_len, :] # [answer_len, vocab_size]
61 |
62 | # Gather log-probs of the ground-truth answer tokens
63 | targets = torch.tensor(answer_ids, dtype=torch.long, device=device).unsqueeze(1) # [answer_len, 1]
64 | avg_logp = steps.gather(dim=1, index=targets).mean().item()
65 | return avg_logp
66 |
67 |
68 | def predict_choice_teacher_forced(model, tokenizer, prompt_fmt, prompt, prompt_ids, example):
69 | scores = {}
70 | for letter in "ABCD":
71 | idx = ord(letter) - ord("A")
72 | choice_text = example["choices"][idx]
73 | scores[letter] = avg_logprob_teacher_forced(
74 | model, tokenizer, prompt_fmt, prompt, prompt_ids, letter, choice_text
75 | )
76 | pred = max(scores, key=scores.get)
77 | return pred, scores
78 |
79 |
80 | def evaluate_mmlu_teacher_forced(
81 | model,
82 | tokenizer,
83 | device,
84 | subsets="high_school_mathematics", # str, list of str, or "all"
85 | split="test",
86 | verbose_every=50,
87 | ):
88 | if subsets == "all":
89 | subset_list = get_dataset_config_names("cais/mmlu")
90 | elif isinstance(subsets, str):
91 | subset_list = [s.strip() for s in subsets.split(",")] if "," in subsets else [subsets]
92 | else:
93 | subset_list = list(subsets)
94 |
95 | total = 0
96 | correct = 0
97 | start = time.time()
98 |
99 | for subset in subset_list:
100 | ds = load_dataset("cais/mmlu", subset, split=split)
101 | for ex in ds:
102 | prompt = format_prompt(ex)
103 | prompt_ids = tokenizer.encode(prompt)
104 | prompt_fmt = torch.tensor(prompt_ids, device=device).unsqueeze(0)
105 |
106 | pred, _scores = predict_choice_teacher_forced(
107 | model, tokenizer, prompt_fmt, prompt, prompt_ids, ex
108 | )
109 |
110 | ans = ex["answer"]
111 | gold = "ABCD"[ans] if isinstance(ans, int) else str(ans).strip().upper()
112 |
113 | total += 1
114 | correct += int(pred == gold)
115 |
116 | if verbose_every and total % verbose_every == 0:
117 | print(f"MMLU {total} acc={correct/total:.3f} [{subset}]")
118 |
119 | acc = correct / max(1, total)
120 | print(
121 | f"\nMMLU letter accuracy (teacher-forced): {correct}/{total} = {acc:.2%} "
122 | f"in {time.time()-start:.1f}s"
123 | )
124 | return {"accuracy": acc, "num_examples": total, "subsets": subset_list, "split": split}
125 |
126 |
127 | def main():
128 | parser = argparse.ArgumentParser(
129 | description="Zero-shot MMLU via teacher-forced log-prob over 'A. '."
130 | )
131 | parser.add_argument(
132 | "--device",
133 | type=str,
134 | default="auto",
135 | help="Device to use: 'auto' (default), or any torch device string like "
136 | "'cpu', 'cuda', 'cuda:0', 'mps'.",
137 | )
138 | parser.add_argument(
139 | "--which_model",
140 | type=str,
141 | default="base",
142 | choices=["base", "reasoning"],
143 | help="Model variant to load. Defaults to 'base'.",
144 | )
145 | parser.add_argument(
146 | "--subsets",
147 | type=str,
148 | default="high_school_mathematics",
149 | help="Comma-separated subset names or 'all'. "
150 | "Default: 'high_school_mathematics'.",
151 | )
152 | args = parser.parse_args()
153 |
154 | device = get_device() if args.device == "auto" else torch.device(args.device)
155 | print(f"Using device: {device}")
156 |
157 | model, tokenizer = load_model_and_tokenizer(args.which_model, device, use_compile=False)
158 | model.eval()
159 | torch.set_float32_matmul_precision("high")
160 |
161 | metrics = evaluate_mmlu_teacher_forced(
162 | model=model,
163 | tokenizer=tokenizer,
164 | device=device,
165 | subsets=args.subsets,
166 | )
167 | print(metrics)
168 |
169 |
170 | if __name__ == "__main__":
171 | main()
172 |
--------------------------------------------------------------------------------
/ch04/02_math500-inference-scaling-scripts/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 4: Improving Reasoning with Inference-Time Scaling
2 |
3 |
4 |
5 | ## Bonus materials
6 |
7 | - [cot_prompting_math500.py](cot_prompting_math500.py): standalone script to evaluate models with chain-of-thought prompting on the MATH-500 dataset
8 | - [self_consistency_math500.py](self_consistency_math500.py): standalone script to evaluate models with self-consistency sampling on the MATH-500 dataset
9 |
10 | Both evaluation scripts import functionality from the [`reasoning_from_scratch`](../../reasoning_from_scratch) package to avoid code duplication. (See [chapter 2 setup instructions](../../ch02/02_setup-tips/python-instructions.md) for installation details.)
11 |
12 |
13 |
14 |
15 |
16 | ---
17 |
18 | **Note**: If you are not a `uv` user, replace `uv run ...py` with `python ...py` in the examples below.
19 |
20 | ---
21 |
22 |
23 |
24 |
25 |
26 | ## Chain-of-thought prompting
27 |
28 | The [`cot_prompting_math500.py`](self_consistency_math500.py) script implements the chain-of-thought prompting method from chapter 4.
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 | The table below compares this approach (row 3) with the baselines from chapter 3:
37 |
38 | | | Method | Model | Accuracy | Time |
39 | |----|----------------------------------------------|-----------|----------|------------|
40 | | 1 | Baseline (chapter 3), greedy decoding | Base | 15.2% | 10.1 min |
41 | | 2 | Baseline (chapter 3), greedy decoding | Reasoning | 48.2% | 182.1 min |
42 | | 3 | Chain-of-thought prompting ("CoT") | Base | 40.6% | 84.5 min |
43 |
44 | The accuracy values and runtimes shown in the table were computed on all 500 samples in the MATH-500 test set using a "cuda" GPU (DGX Spark).
45 |
46 | To run the experiment in row one, use:
47 |
48 | ```bash
49 | python cot_prompting_math500.py \
50 | --which_model "base" \
51 | --dataset_size 500
52 | ```
53 |
54 | Or, with `uv:`
55 |
56 |
57 | ```bash
58 | uv run cot_prompting_math500.py \
59 | --which_model "base" \
60 | --dataset_size 500
61 | ```
62 |
63 | For additional options, use the `--help` flag.
64 |
65 |
66 |
67 |
68 | ## Self-consistency sampling
69 |
70 | The [`self_consistency_math500.py`](self_consistency_math500.py) script implements the sampling method from chapter 4.
71 |
72 | (Optionally, there is a [`self_consistency_math500_batched.py`](self_consistency_math500_batched.py) variant, which executes all `--num_samples` as a batch for faster processing. Note that this requires more compute memory though.)
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 | The table below compares this approach (row 4-12) with the baselines from chapter 3 (rows 1-2):
81 |
82 | | | Method | Model | Accuracy | Time |
83 | | ---- | ----------------------------------------- | --------- | -------- | --------- |
84 | | 1 | Baseline (chapter 3), greedy decoding | Base | 15.2% | 10.1 min |
85 | | 2 | Baseline (chapter 3), greedy decoding | Reasoning | 48.2% | 182.1 min |
86 | | 3 | Chain-of-thought prompting ("CoT") | Base | 40.6% | 84.5 min |
87 | | 4 | Temperature and top-p ("Top-p") | Base | 17.8% | 30.7 min |
88 | | 5 | "Top-p" + Self-consistency (n=3) | Base | 29.6% | 97.6 min |
89 | | 6 | "Top-p" + Self-consistency (n=5) | Base | 27.8% | 116.8 min |
90 | | 7 | "Top-p" + Self-consistency (n=10) | Base | 31.6% | 300.4 min |
91 | | 8 | "Top-p" + "CoT" | Base | 33.4% | 129.2 min |
92 | | 9 | Self-consistency (n=3) + "Top-p" + "CoT" | Base | 42.2% | 211.6 min |
93 | | 10 | Self-consistency (n=5) + "Top-p" + "CoT" | Base | 48.0% | 452.9 min |
94 | | 11 | Self-consistency (n=10) + "Top-p" + "CoT" | Base | 52.0% | 862.6 min |
95 | | 12 | Self-consistency (n=3) + "Top-p" + "CoT" | Reasoning | 55.2% | 544.4 min |
96 |
97 | The accuracy values and runtimes shown in the table were computed on all 500 samples in the MATH-500 test set using a "cuda" GPU (DGX Spark).
98 |
99 | The following codes give instructions on how to run the self-consistency experiments in rows 4-12 (replace `uv run` with `python` if you are not a `uv` user).
100 |
101 | **Row 4:**
102 |
103 | ```bash
104 | uv run self_consistency_math500.py \
105 | --which_model "base" \
106 | --temperature 0.9 \
107 | --top_p 0.9 \
108 | --num_samples 1 \
109 | --dataset_size 500
110 | ```
111 |
112 | **Row 5:**
113 |
114 | ```bash
115 | uv run self_consistency_math500.py \
116 | --which_model "base" \
117 | --temperature 0.9 \
118 | --top_p 0.9 \
119 | --num_samples 3 \
120 | --dataset_size 500
121 | ```
122 |
123 | **Row 6:**
124 |
125 | ```bash
126 | uv run self_consistency_math500.py \
127 | --which_model "base" \
128 | --temperature 0.9 \
129 | --top_p 0.9 \
130 | --num_samples 5 \
131 | --dataset_size 500
132 | ```
133 |
134 | **Row 7:**
135 |
136 | ```bash
137 | uv run self_consistency_math500.py \
138 | --which_model "base" \
139 | --temperature 0.9 \
140 | --top_p 0.9 \
141 | --num_samples 10 \
142 | --dataset_size 500
143 | ```
144 |
145 | **Row 8:**
146 |
147 | ```bash
148 | uv run self_consistency_math500.py \
149 | --which_model "base" \
150 | --temperature 0.9 \
151 | --top_p 0.9 \
152 | --num_samples 1 \
153 | --dataset_size 500 \
154 | --prompt_suffix "\n\nExplain step by step."
155 | ```
156 |
157 | **Row 9:**
158 |
159 | ```bash
160 | uv run self_consistency_math500.py \
161 | --which_model "base" \
162 | --temperature 0.9 \
163 | --top_p 0.9 \
164 | --num_samples 3 \
165 | --dataset_size 500 \
166 | --prompt_suffix "\n\nExplain step by step."
167 | ```
168 |
169 | **Row 10:**
170 |
171 | ```bash
172 | uv run self_consistency_math500.py \
173 | --which_model "base" \
174 | --temperature 0.9 \
175 | --top_p 0.9 \
176 | --num_samples 5 \
177 | --dataset_size 500 \
178 | --prompt_suffix "\n\nExplain step by step."
179 | ```
180 |
181 | **Row 11:**
182 |
183 | ```bash
184 | uv run self_consistency_math500.py \
185 | --which_model "base" \
186 | --temperature 0.9 \
187 | --top_p 0.9 \
188 | --num_samples 10 \
189 | --dataset_size 500 \
190 | --prompt_suffix "\n\nExplain step by step."
191 | ```
192 |
193 | **Row 12:**
194 |
195 | ```bash
196 | uv run self_consistency_math500.py \
197 | --which_model "reasoning" \
198 | --temperature 0.9 \
199 | --top_p 0.9 \
200 | --num_samples 3 \
201 | --dataset_size 500 \
202 | --prompt_suffix "\n\nExplain step by step."
203 | ```
204 |
205 |
206 | For additional options, use the `--help` flag.
207 |
208 |
--------------------------------------------------------------------------------
/tests/test_qwen3_batched_stop.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | import os
6 | import torch
7 | import pytest
8 | from pathlib import Path
9 |
10 | from reasoning_from_scratch.ch02 import get_device
11 | from reasoning_from_scratch.qwen3 import (
12 | download_qwen3_small,
13 | Qwen3Tokenizer,
14 | QWEN_CONFIG_06_B,
15 | )
16 | from reasoning_from_scratch.qwen3_batched import (
17 | generate_text_basic_batched_cache,
18 | generate_text_basic_batched_cache_stop,
19 | generate_text_basic_batched_stream_cache,
20 | generate_text_basic_batched_stream_cache_stop,
21 | Qwen3Model as Qwen3ModelBatched,
22 | )
23 |
24 | skip_expensive = os.environ.get("SKIP_EXPENSIVE", "0") == "1"
25 |
26 | # Make CI more reproducible & robust
27 | os.environ["MKL_NUM_THREADS"] = "1"
28 | os.environ["OMP_NUM_THREADS"] = "1"
29 | torch.backends.mkldnn.enabled = False
30 | torch.set_num_threads(1)
31 | torch.use_deterministic_algorithms(True)
32 |
33 |
34 | @pytest.mark.skipif(skip_expensive, reason="Skipping expensive test on CI")
35 | @pytest.mark.parametrize("reasoning", [False, True])
36 | def test_batched_vs_batched_stop_equivalence(reasoning):
37 |
38 | device = get_device()
39 |
40 | # Download and init tokenizer
41 | kind = "reasoning" if reasoning else "base"
42 | download_qwen3_small(kind=kind, tokenizer_only=False, out_dir="qwen3")
43 | tokenizer_path = Path("qwen3") / (
44 | "tokenizer-reasoning.json" if reasoning else "tokenizer-base.json"
45 | )
46 | model_path = Path("qwen3") / (
47 | "qwen3-0.6B-reasoning.pth" if reasoning else "qwen3-0.6B-base.pth"
48 | )
49 | tokenizer = Qwen3Tokenizer(
50 | tokenizer_file_path=tokenizer_path,
51 | apply_chat_template=reasoning,
52 | add_generation_prompt=reasoning,
53 | add_thinking=reasoning,
54 | )
55 |
56 | # Model
57 | model_batched = Qwen3ModelBatched(QWEN_CONFIG_06_B)
58 | model_batched.load_state_dict(torch.load(model_path, map_location=device))
59 | model_batched.to(device).eval()
60 |
61 | # Prompts
62 | prompts = [
63 | "Explain large language models in two sentences.",
64 | "Explain large language models in one sentence.",
65 | "1+1?",
66 | ]
67 |
68 | # Batched inputs (left-padded)
69 | tokenized = [tokenizer.encode(p) for p in prompts]
70 | max_len = max(len(t) for t in tokenized)
71 | pad_id = tokenizer.pad_token_id
72 | left_padded = [[pad_id] * (max_len - len(t)) + t for t in tokenized]
73 | attn_mask = [
74 | [0] * (max_len - len(t)) + [1] * len(t) for t in tokenized
75 | ]
76 | input_ids_batched = torch.tensor(left_padded, device=device)
77 | attn_mask_batched = torch.tensor(attn_mask, device=device, dtype=torch.bool)
78 |
79 | # Generation
80 | max_new_tokens = 12
81 | outputs_reg = generate_text_basic_batched_cache(
82 | model=model_batched,
83 | token_ids=input_ids_batched,
84 | max_new_tokens=max_new_tokens,
85 | eos_token_id=tokenizer.eos_token_id,
86 | attn_mask=attn_mask_batched,
87 | pad_id=pad_id,
88 | )
89 | outputs_stop = generate_text_basic_batched_cache_stop(
90 | model=model_batched,
91 | token_ids=input_ids_batched,
92 | max_new_tokens=max_new_tokens,
93 | eos_token_id=tokenizer.eos_token_id,
94 | attn_mask=attn_mask_batched,
95 | pad_id=pad_id,
96 | )
97 |
98 | # Check equivalency
99 | for idx in range(len(prompts)):
100 | reg_toks = outputs_reg[idx].tolist()
101 | stop_toks = outputs_stop[idx].tolist()
102 | assert reg_toks == stop_toks, (
103 | f"Token mismatch at prompt {idx}:\n"
104 | f"regular_tokens={reg_toks}\n"
105 | f"stop_tokens ={stop_toks}\n"
106 | f"regular_text={tokenizer.decode(reg_toks)}\n"
107 | f"stop_text ={tokenizer.decode(stop_toks)}"
108 | )
109 |
110 |
111 | @pytest.mark.skipif(skip_expensive, reason="Skipping expensive test on CI")
112 | @pytest.mark.parametrize("reasoning", [False, True])
113 | def test_stream_vs_stream_stop_equivalence(reasoning):
114 |
115 | device = get_device()
116 |
117 | # Download and init tokenizer
118 | kind = "reasoning" if reasoning else "base"
119 | download_qwen3_small(kind=kind, tokenizer_only=False, out_dir="qwen3")
120 | tokenizer_path = Path("qwen3") / (
121 | "tokenizer-reasoning.json" if reasoning else "tokenizer-base.json"
122 | )
123 | model_path = Path("qwen3") / (
124 | "qwen3-0.6B-reasoning.pth" if reasoning else "qwen3-0.6B-base.pth"
125 | )
126 | tokenizer = Qwen3Tokenizer(
127 | tokenizer_file_path=tokenizer_path,
128 | apply_chat_template=reasoning,
129 | add_generation_prompt=reasoning,
130 | add_thinking=reasoning,
131 | )
132 |
133 | # Model
134 | model_batched = Qwen3ModelBatched(QWEN_CONFIG_06_B)
135 | model_batched.load_state_dict(torch.load(model_path, map_location=device))
136 | model_batched.to(device).eval()
137 |
138 | # Prompts
139 | prompts = [
140 | "Explain large language models in two sentences.",
141 | "Explain large language models in one sentence.",
142 | "1+1?",
143 | ]
144 |
145 | # Batched inputs (left-padded)
146 | tokenized = [tokenizer.encode(p) for p in prompts]
147 | max_len = max(len(t) for t in tokenized)
148 | pad_id = tokenizer.pad_token_id
149 | left_padded = [[pad_id] * (max_len - len(t)) + t for t in tokenized]
150 | attn_mask = [
151 | [0] * (max_len - len(t)) + [1] * len(t) for t in tokenized
152 | ]
153 | input_ids_batched = torch.tensor(left_padded, device=device)
154 | attn_mask_batched = torch.tensor(attn_mask, device=device, dtype=torch.bool)
155 |
156 | # Generation
157 | max_new_tokens = 12
158 | B = input_ids_batched.size(0)
159 |
160 | # Regular streaming
161 | reg_stream_tokens = [[] for _ in range(B)]
162 | for step_tokens in generate_text_basic_batched_stream_cache(
163 | model=model_batched,
164 | token_ids=input_ids_batched,
165 | max_new_tokens=max_new_tokens,
166 | eos_token_id=tokenizer.eos_token_id,
167 | attn_mask=attn_mask_batched,
168 | pad_id=pad_id,
169 | ):
170 | step_tokens = step_tokens.squeeze(1)
171 | for b in range(B):
172 | reg_stream_tokens[b].append(int(step_tokens[b].item()))
173 |
174 | # Stop streaming
175 | stop_stream_tokens = [[] for _ in range(B)]
176 | for step_tokens in generate_text_basic_batched_stream_cache_stop(
177 | model=model_batched,
178 | token_ids=input_ids_batched,
179 | max_new_tokens=max_new_tokens,
180 | eos_token_id=tokenizer.eos_token_id,
181 | attn_mask=attn_mask_batched,
182 | pad_id=pad_id,
183 | ):
184 | step_tokens = step_tokens.squeeze(1)
185 | for b in range(B):
186 | stop_stream_tokens[b].append(int(step_tokens[b].item()))
187 |
188 | # Check equivalency
189 | for idx in range(B):
190 | assert reg_stream_tokens[idx] == stop_stream_tokens[idx], (
191 | f"Token mismatch at prompt {idx}:\n"
192 | f"regular_tokens={reg_stream_tokens[idx]}\n"
193 | f"stop_tokens ={stop_stream_tokens[idx]}\n"
194 | f"regular_text={tokenizer.decode(reg_stream_tokens[idx])}\n"
195 | f"stop_text ={tokenizer.decode(stop_stream_tokens[idx])}"
196 | )
197 |
--------------------------------------------------------------------------------
/ch02/05_use_model/chat_multiturn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | # Runs the model similar to chapter 2 and 3 in streaming mode with the least
6 | # amount of bells and whistles. Uses KV caching by default.
7 | # Interactive REPL (Read, Evaluate, Print, Loop) with multiturn memory.
8 |
9 | import argparse
10 | from pathlib import Path
11 | import time
12 | import torch
13 |
14 | from reasoning_from_scratch.ch02 import (
15 | get_device,
16 | generate_stats
17 | )
18 | from reasoning_from_scratch.ch02_ex import (
19 | generate_text_basic_stream_cache
20 | )
21 | from reasoning_from_scratch.qwen3 import (
22 | download_qwen3_small,
23 | Qwen3Model,
24 | Qwen3Tokenizer,
25 | QWEN_CONFIG_06_B
26 | )
27 |
28 | parser = argparse.ArgumentParser(description="Run Qwen3 text generation (interactive REPL)")
29 | parser.add_argument(
30 | "--device",
31 | type=str,
32 | default=None,
33 | help="Device to run on (e.g. 'cpu', 'cuda', 'mps'). "
34 | "If not provided, will auto-detect with get_device()."
35 | )
36 | parser.add_argument(
37 | "--max_new_tokens",
38 | type=int,
39 | default=2048,
40 | help="Maximum number of new tokens to generate in each turn (default: 2048)."
41 | )
42 | parser.add_argument(
43 | "--compile",
44 | action="store_true",
45 | help="Compile PyTorch model (default: False)."
46 | )
47 | parser.add_argument(
48 | "--reasoning",
49 | action="store_true",
50 | help="Use reasoning model variant (default: False)."
51 | )
52 |
53 | args = parser.parse_args()
54 | device = torch.device(args.device) if args.device else get_device()
55 |
56 | if args.reasoning:
57 | download_qwen3_small(kind="reasoning", tokenizer_only=False, out_dir="qwen3")
58 | tokenizer_path = Path("qwen3") / "tokenizer-reasoning.json"
59 | model_path = Path("qwen3") / "qwen3-0.6B-reasoning.pth"
60 | else:
61 | download_qwen3_small(kind="base", tokenizer_only=False, out_dir="qwen3")
62 | tokenizer_path = Path("qwen3") / "tokenizer-base.json"
63 | model_path = Path("qwen3") / "qwen3-0.6B-base.pth"
64 |
65 | # We will apply the chat template manually later
66 | tokenizer = Qwen3Tokenizer(
67 | tokenizer_file_path=tokenizer_path,
68 | apply_chat_template=False
69 | )
70 |
71 | model = Qwen3Model(QWEN_CONFIG_06_B)
72 | state = torch.load(model_path, map_location=device)
73 | model.load_state_dict(state)
74 | model.to(device)
75 | model.eval()
76 |
77 | if args.compile:
78 | model = torch.compile(model)
79 |
80 | # The reasoning model may emit <|im_end|>; base may emit <|endoftext|>.
81 | EOS_TOKEN_IDS = (
82 | tokenizer.encode("<|im_end|>")[0],
83 | tokenizer.encode("<|endoftext|>")[0]
84 | )
85 |
86 | print()
87 | print("=" * 60)
88 | print(f"torch : {torch.__version__}")
89 | print(f"device : {device}")
90 | print("cache : True")
91 | print(f"compile : {args.compile}")
92 | print(f"reasoning : {args.reasoning}")
93 | print("memory : True")
94 | print(f"max_new_tokens (per turn): {args.max_new_tokens}")
95 | print(f"context_length: {model.cfg['context_length']}")
96 | print("=" * 60)
97 | print()
98 | print("Interactive REPL with memory. Type '\\exit' or '\\quit' to quit.")
99 | print("Commands: \\clear (forget memory), \\history (show turn count)\n")
100 |
101 | # Multi-turn memory as a list of role-content dicts
102 | # Example: {"role": "system"|"user"|"assistant", "content": str}
103 | history = [
104 | {"role": "system", "content": "You are a helpful assistant."}
105 | ]
106 |
107 |
108 | def build_prompt_from_history(history, add_assistant_header=True):
109 | """
110 | history: [{"role": "system"|"user"|"assistant", "content": str}, ...]
111 | """
112 | parts = []
113 | for m in history:
114 | role = m["role"]
115 | content = m["content"]
116 | parts.append(f"<|im_start|>{role}\n{content}<|im_end|>\n")
117 |
118 | if add_assistant_header:
119 | parts.append("<|im_start|>assistant\n")
120 | return "".join(parts)
121 |
122 |
123 | def trim_input_tensor(input_ids_tensor, context_len, max_new_tokens):
124 | assert max_new_tokens < context_len
125 | keep_len = max(1, context_len - max_new_tokens)
126 |
127 | # If the prompt is too long, left-truncate to keep_len
128 | if input_ids_tensor.shape[1] > keep_len:
129 | input_ids_tensor = input_ids_tensor[:, -keep_len:]
130 |
131 | return input_ids_tensor
132 |
133 |
134 | def run_generate(user_text):
135 | # Add user prompt to history
136 | history.append({"role": "user", "content": user_text})
137 |
138 | # Encode full history
139 | prompt = build_prompt_from_history(history, add_assistant_header=True)
140 | input_ids = tokenizer.encode(prompt)
141 | input_token_ids_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)
142 |
143 | # Left-tuncate (to make space for generation)
144 | input_token_ids_tensor = trim_input_tensor(
145 | input_ids_tensor=input_token_ids_tensor,
146 | context_len=model.cfg["context_length"],
147 | max_new_tokens=args.max_new_tokens
148 | )
149 |
150 | start_time = time.time()
151 | all_token_ids = []
152 |
153 | print("[Model]\n", end="", flush=True)
154 | for tok in generate_text_basic_stream_cache(
155 | model=model,
156 | token_ids=input_token_ids_tensor,
157 | max_new_tokens=args.max_new_tokens,
158 | # eos_token_id=TOKENIZER.eos_token_id
159 | ):
160 | token_id = tok.squeeze(0)
161 | if token_id in EOS_TOKEN_IDS: # Manually break at stop tokens
162 | break
163 | piece = tokenizer.decode(token_id.tolist())
164 | print(piece, end="", flush=True)
165 | all_token_ids.append(token_id.item())
166 |
167 | end_time = time.time()
168 | print("\n")
169 |
170 | print("[Stats]")
171 | generate_stats(
172 | torch.tensor(all_token_ids),
173 | tokenizer,
174 | start_time,
175 | end_time,
176 | print_tokens=False
177 | )
178 | print("-" * 60)
179 |
180 | # Add model reply to history
181 | assistant_text = tokenizer.decode(all_token_ids)
182 | history.append({"role": "assistant", "content": assistant_text})
183 | return assistant_text
184 |
185 |
186 | # Interactive REPL (Read, Evaluate, Print, Loop)
187 | try:
188 | while True:
189 | try:
190 | user_in = input(">> ").strip()
191 | except EOFError:
192 | print("")
193 | break
194 |
195 | low = user_in.lower()
196 | if low in {r"\exit", r"\quit"}:
197 | break
198 | if low == r"\clear":
199 | # Reset history but keep the system prompt
200 | system_entries = [m for m in history if m["role"] == "system"]
201 | history.clear()
202 | if system_entries:
203 | history.extend(system_entries)
204 | else:
205 | history.append({"role": "system", "content": "You are a helpful assistant."})
206 | print("(memory cleared)\n")
207 | continue
208 | if low == r"\history":
209 | # Count assistant turns as the number of model replies so far
210 | assistant_turns = sum(1 for m in history if m["role"] == "assistant")
211 | print(f"(stored turns: {assistant_turns})\n")
212 | continue
213 | if not user_in:
214 | continue
215 |
216 | print("\n" + "-" * 60)
217 | print("[User]")
218 | print(user_in + "\n")
219 | run_generate(user_in)
220 |
221 | except KeyboardInterrupt:
222 | print("\nInterrupted by user.")
223 |
--------------------------------------------------------------------------------
/ch02/03_optimized-LLM/README.md:
--------------------------------------------------------------------------------
1 | # Optimized Qwen3
2 |
3 | The Qwen3 from-scratch implementation used in this book strikes a balance between being efficient (both on CPU and GPU) and lean while remaining easy to read by a human.
4 |
5 | As an alternative, you can use the optional `Qwen3Model` drop-in replacement, which is slightly more GPU-efficient. The optimized version in [`qwen3_optimized.py`](../../reasoning_from_scratch/qwen3_optimized.py) (discussed further in Appendix C) differs from the baseline implementation in [`qwen3.py`](../../reasoning_from_scratch/qwen3.py) in two key ways:
6 |
7 | - It implements attention using PyTorch’s built-in `torch.nn.functional.scaled_dot_product` instead of a custom implementation.
8 | - It introduces a modified `KVCache` that pre-allocates key/value tensors. This increases memory usage but avoids repeatedly allocating new storage during execution.
9 |
10 |
11 | To explore the differences, I recommend opening [`qwen3.py`](../../reasoning_from_scratch/qwen3.py) and [`qwen3_optimized.py`](../../reasoning_from_scratch/qwen3_optimized.py) side by side and/or looking at a file-diff:
12 |
13 |
14 |
15 | 
16 |
17 |
18 |
19 |
20 | ## How to use
21 |
22 | The optimized code can be used as drop-in replacement for the code used in the main chapters as shown below.
23 |
24 | **Before:**
25 |
26 | ```python
27 | from reasoning_from_scratch.qwen3 import Qwen3Model
28 | from reasoning_from_scratch.ch02 import generate_text_basic_cache
29 | ```
30 |
31 |
32 | **After:**
33 |
34 | ```python
35 | from reasoning_from_scratch.qwen3_optimized import Qwen3Model
36 | from reasoning_from_scratch.qwen3_optimized import generate_text_basic_cache
37 | ```
38 |
39 |
40 | ## How to run comparisons
41 |
42 | To evaluate the performance on your system, you can use the [`compare_inference.py`](compare_inference.py) function contained in this folder:
43 |
44 | ```python
45 | python compare_inference.py
46 | ```
47 |
48 | or
49 |
50 | ```python
51 | uv run compare_inference.py
52 | ```
53 |
54 | Then, add the following flags:
55 |
56 | - `--device`: Select the device, e.g., `cpu`, `mps`, or `cuda`
57 | - `--cache`: Enables the KV cache
58 | - `--compile`: Uses `torch.compile`
59 | - `--reasoning`: Uses the Qwen3 reasoning variant instead of the base model. The base model generates approximately 50 tokens in response to the given prompt. The reasoning variant generates about 2000 tokens.
60 | - `--optimize`: Uses the optimized model from `qwen3_optimized.py` instead of the standard model from `qwen3.py`.
61 |
62 |
63 |
64 |
65 | ### Standard model
66 |
67 |
68 |
69 | | Model | Mode | Command | Hardware | Tokens/sec | GPU Memory (VRAM) |
70 | | -------- | ----------------- | ------------------------------- | --------------- | ------------- | ----------------- |
71 | | qwen3.py | Regular | --device cpu | Mac Mini M4 CPU | 6 | - |
72 | | qwen3.py | Regular compiled | --device cpu --compile | Mac Mini M4 CPU | 6 | - |
73 | | qwen3.py | KV cache | --device cpu --cache | Mac Mini M4 CPU | 28 | - |
74 | | qwen3.py | KV cache compiled | --device cpu --compile --cache | Mac Mini M4 CPU | 68 | - |
75 | | | | | | | |
76 | | qwen3.py | Regular | --device mps | Mac Mini M4 GPU | 17 | - |
77 | | qwen3.py | Regular compiled | --device mps --compile | Mac Mini M4 GPU | InductorError | - |
78 | | qwen3.py | KV cache | --device mps --cache | Mac Mini M4 GPU | 18 | - |
79 | | qwen3.py | KV cache compiled | --device mps --compile --cache | Mac Mini M4 GPU | InductorError | - |
80 | | | | | | | |
81 | | qwen3.py | Regular | --device cuda | NVIDIA H100 GPU | 51 | 1.55 GB |
82 | | qwen3.py | Regular compiled | --device cuda --compile | NVIDIA H100 GPU | 164 | 1.81 GB |
83 | | qwen3.py | KV cache | --device cuda --cache | NVIDIA H100 GPU | 48 | 1.52 GB |
84 | | qwen3.py | KV cache compiled | --device cuda --compile --cache | NVIDIA H100 GPU | 141 | 1.81 GB |
85 |
86 |
87 |
88 |
89 | ### Optimized model
90 |
91 |
92 | | Model | Mode | Command | Hardware | Tokens/sec | GPU Memory (VRAM) |
93 | | ------------------ | ----------------- | ------------------------------------------- | --------------- | ---------- | ----------------- |
94 | | qwen3_optimized.py | Regular | --optimized --device cpu | Mac Mini M4 CPU | 5 | - |
95 | | qwen3_optimized.py | Regular compiled | --optimized --device cpu --compile | Mac Mini M4 CPU | 7 | - |
96 | | qwen3_optimized.py | KV cache | --optimized --device cpu --cache | Mac Mini M4 CPU | 49 | - |
97 | | qwen3_optimized.py | KV cache compiled | --optimized --device cpu --compile --cache | Mac Mini M4 CPU | 51 | - |
98 | | | | | | | |
99 | | qwen3_optimized.py | Regular | --optimized --device mps | Mac Mini M4 GPU | 21 | - |
100 | | qwen3_optimized.py | Regular compiled | --optimized --device mps --compile | Mac Mini M4 GPU | NameError | - |
101 | | qwen3_optimized.py | KV cache | --optimized --device mps --cache | Mac Mini M4 GPU | 29 | - |
102 | | qwen3_optimized.py | KV cache compiled | --optimized --device mps --compile --cache | Mac Mini M4 GPU | 38 | - |
103 | | | | | | | |
104 | | qwen3_optimized.py | Regular | --optimized --device cuda | NVIDIA H100 GPU | 55 | 1.50 GB |
105 | | qwen3_optimized.py | Regular compiled | --optimized --device cuda --compile | NVIDIA H100 GPU | 173 | 1.81 GB |
106 | | qwen3_optimized.py | KV cache | --optimized --device cuda --cache | NVIDIA H100 GPU | 56 | 5.85 GB |
107 | | qwen3_optimized.py | KV cache compiled | --optimized --device cuda --compile --cache | NVIDIA H100 GPU | 177 | 5.85 GB |
108 |
109 |
110 |
111 | Comparing the 2 tables above, we can see that the optimized variant is clearly faster in terms of tokens/second in most cases.
112 |
113 | However, note that the unoptimized version is faster (68 tok/sec) than the optimized version (51 tok/sec) when using the compiled version with KV cache.
114 |
115 | The optimized version also uses more base RAM (5.85 GB with KV Cache) than the unoptimized version (1.5 GB). This is because it pre-allocates the tensors holding the KV values for the maximum supported context length. (So, when running the unoptimized version on a prompt with 41k context length, the RAM usage would be approximately similar.)
116 |
117 | **Perhaps the best recommendation is to use the unoptimized version (with `--cache` and `--compile`) when using a CPU. When using a GPU, use the optimized version (with `--cache` and `--compile`).**
--------------------------------------------------------------------------------
/reasoning_from_scratch/appendix_c.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
3 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
4 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
5 |
6 | import os
7 | import json
8 | import torch
9 |
10 |
11 | QWEN3_CONFIG_1_7B = {
12 | "vocab_size": 151_936,
13 | "context_length": 40_960,
14 | "emb_dim": 2048, # 2x larger than 0.6B model
15 | "n_heads": 16,
16 | "n_layers": 28,
17 | "hidden_dim": 6144, # 2x larger than 0.6B model
18 | "head_dim": 128,
19 | "qk_norm": True,
20 | "n_kv_groups": 8,
21 | "rope_base": 1_000_000.0,
22 | "dtype": torch.bfloat16,
23 | }
24 |
25 | # 4 billion parameters
26 | QWEN3_CONFIG_4B = {
27 | "vocab_size": 151_936,
28 | "context_length": 40_960,
29 | "emb_dim": 2560, # 25% larger than above
30 | "n_heads": 32, # 2x larger than above
31 | "n_layers": 36, # 29% larger than above
32 | "hidden_dim": 9728, # ~3x larger than above
33 | "head_dim": 128,
34 | "qk_norm": True,
35 | "n_kv_groups": 8,
36 | "rope_base": 1_000_000.0,
37 | "dtype": torch.bfloat16,
38 | }
39 |
40 | # 8 billion parameters
41 | QWEN3_CONFIG_8B = {
42 | "vocab_size": 151_936,
43 | "context_length": 40_960,
44 | "emb_dim": 4096, # 60% larger than above
45 | "n_heads": 32,
46 | "n_layers": 36, # 26% larger than above
47 | "hidden_dim": 12288,
48 | "head_dim": 128,
49 | "qk_norm": True,
50 | "n_kv_groups": 8,
51 | "rope_base": 1_000_000.0,
52 | "dtype": torch.bfloat16,
53 | }
54 |
55 | # 14 billion parameters
56 | QWEN3_CONFIG_14B = {
57 | "vocab_size": 151_936,
58 | "context_length": 40_960,
59 | "emb_dim": 5120, # 25% larger than above
60 | "n_heads": 40, # 25% larger than above
61 | "n_layers": 40, # 11% larger than above
62 | "hidden_dim": 17408, # 42% larger than above
63 | "head_dim": 128,
64 | "qk_norm": True,
65 | "n_kv_groups": 8,
66 | "rope_base": 1_000_000.0,
67 | "dtype": torch.bfloat16,
68 | }
69 |
70 | QWEN3_CONFIG_32B = {
71 | "vocab_size": 151_936,
72 | "context_length": 40_960,
73 | "emb_dim": 5120,
74 | "n_heads": 64, # 60% larger than above
75 | "n_layers": 64, # 60% larger than above
76 | "hidden_dim": 25600, # 47% larger than above
77 | "head_dim": 128,
78 | "qk_norm": True,
79 | "n_kv_groups": 8,
80 | "rope_base": 1_000_000.0,
81 | "dtype": torch.bfloat16,
82 | }
83 |
84 |
85 | def load_weights_into_qwen(model, param_config, params):
86 | def assign(left, right, tensor_name="unknown"):
87 | if left.shape != right.shape:
88 | raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
89 | return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
90 |
91 | model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
92 |
93 | for ln in range(param_config["n_layers"]):
94 | block = model.trf_blocks[ln]
95 | att = block.att
96 |
97 | # Q, K, V projections
98 | att.W_query.weight = assign(
99 | att.W_query.weight,
100 | params[f"model.layers.{ln}.self_attn.q_proj.weight"],
101 | f"model.layers.{ln}.self_attn.q_proj.weight"
102 | )
103 | att.W_key.weight = assign(
104 | att.W_key.weight,
105 | params[f"model.layers.{ln}.self_attn.k_proj.weight"],
106 | f"model.layers.{ln}.self_attn.k_proj.weight"
107 | )
108 | att.W_value.weight = assign(
109 | att.W_value.weight,
110 | params[f"model.layers.{ln}.self_attn.v_proj.weight"],
111 | f"model.layers.{ln}.self_attn.v_proj.weight"
112 | )
113 |
114 | # Output projection
115 | att.out_proj.weight = assign(
116 | att.out_proj.weight,
117 | params[f"model.layers.{ln}.self_attn.o_proj.weight"],
118 | f"model.layers.{ln}.self_attn.o_proj.weight"
119 | )
120 |
121 | # QK norms
122 | if hasattr(att, "q_norm") and att.q_norm is not None:
123 | att.q_norm.scale = assign(
124 | att.q_norm.scale,
125 | params[f"model.layers.{ln}.self_attn.q_norm.weight"],
126 | f"model.layers.{ln}.self_attn.q_norm.weight"
127 | )
128 | if hasattr(att, "k_norm") and att.k_norm is not None:
129 | att.k_norm.scale = assign(
130 | att.k_norm.scale,
131 | params[f"model.layers.{ln}.self_attn.k_norm.weight"],
132 | f"model.layers.{ln}.self_attn.k_norm.weight"
133 | )
134 |
135 | # Attention layernorm
136 | block.norm1.scale = assign(
137 | block.norm1.scale,
138 | params[f"model.layers.{ln}.input_layernorm.weight"],
139 | f"model.layers.{ln}.input_layernorm.weight"
140 | )
141 |
142 | # Feedforward weights
143 | block.ff.fc1.weight = assign(
144 | block.ff.fc1.weight,
145 | params[f"model.layers.{ln}.mlp.gate_proj.weight"],
146 | f"model.layers.{ln}.mlp.gate_proj.weight"
147 | )
148 | block.ff.fc2.weight = assign(
149 | block.ff.fc2.weight,
150 | params[f"model.layers.{ln}.mlp.up_proj.weight"],
151 | f"model.layers.{ln}.mlp.up_proj.weight"
152 | )
153 | block.ff.fc3.weight = assign(
154 | block.ff.fc3.weight,
155 | params[f"model.layers.{ln}.mlp.down_proj.weight"],
156 | f"model.layers.{ln}.mlp.down_proj.weight"
157 | )
158 | block.norm2.scale = assign(
159 | block.norm2.scale,
160 | params[f"model.layers.{ln}.post_attention_layernorm.weight"],
161 | f"model.layers.{ln}.post_attention_layernorm.weight"
162 | )
163 |
164 | # Final normalization and output head
165 | model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
166 |
167 | # Model uses weight tying, hence we reuse the embedding layer weights here
168 | model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
169 |
170 |
171 | def download_from_huggingface_from_snapshots(repo_id, local_dir):
172 | from huggingface_hub import hf_hub_download, snapshot_download
173 | from safetensors.torch import load_file # or your preferred loader
174 |
175 | repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)
176 |
177 | index_path = os.path.join(repo_dir, "model.safetensors.index.json")
178 | single_file_path = os.path.join(repo_dir, "model.safetensors")
179 |
180 | if os.path.exists(index_path):
181 | # Multi-shard model
182 | with open(index_path, "r") as f:
183 | index = json.load(f)
184 |
185 | weights_dict = {}
186 | for filename in set(index["weight_map"].values()):
187 | shard_path = os.path.join(repo_dir, filename)
188 | shard = load_file(shard_path)
189 | weights_dict.update(shard)
190 | elif os.path.exists(single_file_path):
191 | # Single-shard model
192 | weights_file = hf_hub_download(
193 | repo_id=repo_id,
194 | filename="model.safetensors",
195 | local_dir=local_dir,
196 | )
197 | weights_dict = load_file(weights_file)
198 | else:
199 | raise FileNotFoundError("No model.safetensors or model.safetensors.index.json found.")
200 |
201 | return weights_dict
202 |
--------------------------------------------------------------------------------
/ch02/05_use_model/README.md:
--------------------------------------------------------------------------------
1 | # Run Inference and Chat With the Model
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | This folder contains standalone example scripts to generate text with the model we loaded in chapter 2 (and exercises):
10 |
11 | - `generate_simple.py`: Generates text similar to the main chapter.
12 | - `chat.py`: Similar to the code above, as an interactive wrapper so that we can prompt the model multiple times without having to reload the model into memory each time.
13 | - `chat_multiturn.py`: Same as above, but with a memory feature to remember the message history.
14 |
15 |
16 |
17 | More usage details are provided in the sections below.
18 |
19 |
20 | ## generate_simple.py
21 |
22 | This simple function loads the model as described in chapter 2 and uses the `generate_text_simple_cache_stream` function from the chapter 2 exercises. You can use the function as follows (replace `uv run` with `python` if you are not using `uv`):
23 |
24 | ```bash
25 | uv run ch02/05_use_model/generate_simple.py
26 | Using Apple Silicon GPU (MPS)
27 | ✓ qwen3/qwen3-0.6B-base.pth already up-to-date
28 |
29 | ============================================================
30 | torch : 2.7.1
31 | device : mps
32 | cache : True
33 | compile : False
34 | reasoning : False
35 | ============================================================
36 |
37 | Large language models are artificial intelligence systems that can understand, generate, and process human language, enabling them to perform a wide range of tasks, from answering questions to writing essays.
38 |
39 | Time: 1.52 sec
40 | 22 tokens/sec
41 | ```
42 |
43 | The function is useful if you want to quickly try out different prompts with the base or reasoning variant. The additional options are listed below:
44 |
45 | ```bash
46 | usage: generate_simple.py [-h] [--device DEVICE]
47 | [--max_new_tokens MAX_NEW_TOKENS] [--compile]
48 | [--reasoning] [--prompt PROMPT]
49 |
50 | Run Qwen3 text generation
51 |
52 | options:
53 | -h, --help show this help message and exit
54 | --device DEVICE Device to run on (e.g. 'cpu', 'cuda', 'mps'). If not
55 | provided, will auto-detect with get_device().
56 | --max_new_tokens MAX_NEW_TOKENS
57 | Maximum number of new tokens to generate (default:
58 | 2048).
59 | --compile Compile PyTorch model (default: False).
60 | --reasoning Use reasoning model variant (default: False).
61 | --prompt PROMPT Use a custom prompt. If not explicitly provided, uses
62 | the following defaults: 'Explain large language models
63 | in a single sentence.' for the base model, and 'Find
64 | all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.'
65 | for the reasoning model.
66 | ```
67 |
68 |
69 | ## chat.py
70 |
71 | Similar to the function above, this function is useful to try different prompts on the base and reasoning models.
72 |
73 | However, in contrast to the previous function, this function keeps the user in an interactive mode so that the model doesn't have to be reloaded each time:
74 |
75 | ```bash
76 | uv run ch02/05_use_model/chat.py
77 | Using Apple Silicon GPU (MPS)
78 | ✓ qwen3/qwen3-0.6B-base.pth already up-to-date
79 |
80 | ============================================================
81 | torch : 2.7.1
82 | device : mps
83 | cache : True
84 | compile : False
85 | reasoning : False
86 | memory : False
87 | ============================================================
88 |
89 | Interactive REPL (no memory). Type '\exit' or '\quit' to quit.
90 |
91 | >> Explain language models in 1 sentence
92 |
93 | ------------------------------------------------------------
94 | [User]
95 | Explain language models in 1 sentence
96 |
97 | [Model]
98 |
99 | Language models are algorithms that analyze and predict the likelihood of future words in a text based on the words already seen, enabling them to generate coherent and contextually relevant text.
100 |
101 | [Stats]
102 | Time: 1.53 sec
103 | 22 tokens/sec
104 | ------------------------------------------------------------
105 | >> Explain machine learning in 1 sentence.
106 |
107 | ------------------------------------------------------------
108 | [User]
109 | Explain machine learning in 1 sentence.
110 |
111 | [Model]
112 | Machine learning is a subset of artificial intelligence that enables computers to learn from data and improve their performance over time without being explicitly programmed.
113 |
114 | [Stats]
115 | Time: 1.04 sec
116 | 24 tokens/sec
117 | ------------------------------------------------------------
118 | ```
119 |
120 | Additional options are listed below:
121 |
122 | ```bash
123 | usage: chat.py [-h] [--device DEVICE] [--max_new_tokens MAX_NEW_TOKENS] [--compile]
124 | [--reasoning]
125 |
126 | Run Qwen3 text generation (interactive REPL)
127 |
128 | options:
129 | -h, --help show this help message and exit
130 | --device DEVICE Device to run on (e.g. 'cpu', 'cuda', 'mps'). If not provided,
131 | will auto-detect with get_device().
132 | --max_new_tokens MAX_NEW_TOKENS
133 | Maximum number of new tokens to generate (default: 2048).
134 | --compile Compile PyTorch model (default: False).
135 | --reasoning Use reasoning model variant (default: False).
136 | ```
137 |
138 |
139 |
140 |
141 |
142 | ## chat_multiturn.py
143 |
144 | This function is similar to the one above, except it adds a multi-turn memory so that the LLM remembers the conversation from the past turns. It is highly recommended to use the reasoning variant here as the base model struggles with conversations:
145 |
146 |
147 |
148 | ```bash
149 | uv run ch02/05_use_model/chat_multiturn.py --reasoning
150 | Using Apple Silicon GPU (MPS)
151 | ✓ qwen3/qwen3-0.6B-reasoning.pth already up-to-date
152 | ✓ qwen3/tokenizer-reasoning.json already up-to-date
153 |
154 | ============================================================
155 | torch : 2.7.1
156 | device : mps
157 | cache : True
158 | compile : False
159 | reasoning : True
160 | memory : True
161 | max_new_tokens (per turn): 2048
162 | context_length: 40960
163 | ============================================================
164 |
165 | Interactive REPL with memory. Type '\exit' or '\quit' to quit.
166 | Commands: \clear (forget memory), \history (show turn count)
167 |
168 | >> What is 1+1 in short?
169 |
170 | ------------------------------------------------------------
171 | [User]
172 | What is 1+1 in short?
173 |
174 | [Model]
175 |
176 | Okay, the user is asking, "What is 1+1 in short?" Let me break this down. First, they want to know the result of adding 1 and 1. In math, 1 plus 1 equals 2. But the question says "in short," which probably means they want a concise answer without the full calculation.
177 |
178 | So, the answer is straightforward. 1+1=2. But maybe they want a more concise way to write it? Like, "2" or "2+2"? But "2" is more direct. Let me check if there's any trick here. Sometimes people might think of 1+1 as something else, but no, it's just two ones.
179 |
180 | I should make sure to present the answer clearly. Since the user is asking in a short form, maybe they just want the number 2. So the final answer is 2.
181 |
182 |
183 | 1+1 equals 2.
184 |
185 | [Stats]
186 | Time: 8.27 sec
187 | 23 tokens/sec
188 | ------------------------------------------------------------
189 | >> What were you just asked?
190 |
191 | ------------------------------------------------------------
192 | [User]
193 | What were you just asked?
194 |
195 | [Model]
196 |
197 | Okay, the user just asked, "What were you just asked?" and I responded with "1+1 equals 2." Now, they're asking again. Let me check if there's any hidden context or if they want more information. Since the previous answer was clear, maybe they want confirmation or a different interpretation. But since the user is asking again, perhaps they want to know if I provided the answer correctly. I should confirm that 1+1 is indeed 2 and that the answer is correct. No further information is needed here. Just a simple confirmation.
198 |
199 | [Stats]
200 | Time: 5.21 sec
201 | 22 tokens/sec
202 | ------------------------------------------------------------
203 | ```
204 |
205 |
206 |
207 | Additional options are listed below:
208 |
209 | ```bash
210 | usage: chat_multiturn.py [-h] [--device DEVICE] [--max_new_tokens MAX_NEW_TOKENS]
211 | [--compile] [--reasoning]
212 |
213 | Run Qwen3 text generation (interactive REPL)
214 |
215 | options:
216 | -h, --help show this help message and exit
217 | --device DEVICE Device to run on (e.g. 'cpu', 'cuda', 'mps'). If not provided,
218 | will auto-detect with get_device().
219 | --max_new_tokens MAX_NEW_TOKENS
220 | Maximum number of new tokens to generate in each turn (default:
221 | 2048).
222 | --compile Compile PyTorch model (default: False).
223 | --reasoning Use reasoning model variant (default: False).
224 | ```
225 |
226 |
--------------------------------------------------------------------------------
/reasoning_from_scratch/ch05.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt)
2 | # Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
3 | # Code repository: https://github.com/rasbt/reasoning-from-scratch
4 |
5 | from .ch03 import extract_final_candidate, render_prompt
6 | from .ch04 import (
7 | generate_text_stream_concat_flex,
8 | generate_text_top_p_stream_cache
9 | )
10 | import math
11 | import torch
12 |
13 |
14 | def heuristic_score(
15 | answer,
16 | prompt=None, # Placeholder that is ignored
17 | brevity_bonus=500.0,
18 | boxed_bonus=2.0,
19 | extract_bonus=1.0,
20 | fulltext_bonus=0.0,
21 | ):
22 | score = 0.0
23 |
24 | # Reward answers that have a final boxed value
25 | cand = extract_final_candidate(answer, fallback="none")
26 | if cand:
27 | score += boxed_bonus
28 |
29 | # Give weaker rewards if answer doesn't have a boxed value
30 | else:
31 | cand = extract_final_candidate(answer, fallback="number_only")
32 | if cand:
33 | score += extract_bonus
34 | else:
35 | cand = extract_final_candidate(
36 | answer, fallback="number_then_full"
37 | )
38 | if cand:
39 | score += fulltext_bonus
40 |
41 | # Add a brevity reward that decays with text length
42 | score += 1.5 * math.exp(-len(answer) / brevity_bonus)
43 | return score
44 |
45 |
46 | @torch.inference_mode()
47 | def calc_next_token_probas(model, tokenizer, prompt, device):
48 |
49 | token_ids = torch.tensor(tokenizer.encode(prompt), device=device)
50 |
51 | # Get logits and probabilities similar to text generation functions
52 | logits = model(token_ids.unsqueeze(0)).squeeze(0)
53 | all_probas = torch.softmax(logits, dim=-1)
54 |
55 | # Positions we score (here: all)
56 | t_idx = torch.arange(0, token_ids.shape[0] - 1, device=device)
57 |
58 | # Since we have the text, we know the true next tokens
59 | next_ids = token_ids[1:]
60 |
61 | # Get probabilities for each next token
62 | next_token_probas = all_probas[t_idx, next_ids]
63 |
64 | print(
65 | "Next-token probabilities:",
66 | [p.item() for p in next_token_probas]
67 | )
68 |
69 | # Likelihood of the sequence is the product of the probability scores
70 | print(
71 | "Joint probability:",
72 | torch.prod(next_token_probas)
73 | )
74 |
75 |
76 | @torch.inference_mode()
77 | def calc_next_token_logprobas(model, tokenizer, prompt, device):
78 |
79 | token_ids = torch.tensor(tokenizer.encode(prompt), device=device)
80 |
81 | logits = model(token_ids.unsqueeze(0)).squeeze(0)
82 | # We now use log_softmax
83 | all_logprobas = torch.log_softmax(logits, dim=-1)
84 |
85 | t_idx = torch.arange(0, token_ids.shape[0] - 1, device=device)
86 | next_ids = token_ids[1:]
87 | next_token_logprobas = all_logprobas[t_idx, next_ids]
88 |
89 | print(
90 | "Next-token log-probabilities:",
91 | [p.item() for p in next_token_logprobas]
92 | )
93 | # We replace the product with a sum
94 | print(
95 | "Joint log-probability:",
96 | torch.sum(next_token_logprobas)
97 | )
98 |
99 |
100 | @torch.inference_mode()
101 | def avg_logprob_answer(model, tokenizer, prompt, answer, device="cpu"):
102 |
103 | # Encode prompt and answer tokens separately to get the prompt length later
104 | prompt_ids = tokenizer.encode(prompt)
105 | answer_ids = tokenizer.encode(answer)
106 | full_ids = torch.tensor(prompt_ids + answer_ids, device=device)
107 |
108 | # Same as in calc_next_token_logprobas before
109 | logits = model(full_ids.unsqueeze(0)).squeeze(0)
110 | logprobs = torch.log_softmax(logits, dim=-1)
111 |
112 | # Index range for positions corresponding to answer tokens
113 | start = len(prompt_ids) - 1
114 | end = full_ids.shape[0] - 1
115 |
116 | # Same as before, except for using start and end
117 | t_idx = torch.arange(start, end, device=device)
118 | next_tokens = full_ids[start + 1 : end + 1]
119 | next_token_logps = logprobs[t_idx, next_tokens]
120 |
121 | # Average over the answer token scores
122 | return torch.mean(next_token_logps).item()
123 |
124 |
125 | def make_critique_prompt(raw_prompt, draft):
126 | return (
127 | "You are a meticulous reviewer. Identify logical errors, missing "
128 | "steps, or arithmetic mistakes. If the answer seems correct, "
129 | "say so briefly. Then propose a concise plan to fix issues.\n\n"
130 | f"Question:\n{raw_prompt}\n\n"
131 | f"Draft answer:\n{draft}\n\n"
132 | "Write a short critique and bullet-point fix plan "
133 | "(under ~120 words).\n"
134 | "Critique:"
135 | )
136 |
137 |
138 | def make_refine_prompt(raw_prompt, draft, critique):
139 | return (
140 | "Revise the answer using the critique. Keep it concise and "
141 | "end with a final boxed result: \\boxed{ANSWER}\n\n"
142 | f"Question:\n{raw_prompt}\n\n"
143 | f"Previous answer:\n{draft}\n\n"
144 | f"Critique:\n{critique}\n\n"
145 | "Revised answer:"
146 | )
147 |
148 |
149 | def self_refinement_loop(
150 | model,
151 | tokenizer,
152 | raw_prompt,
153 | device,
154 | iterations=2,
155 | max_response_tokens=2048,
156 | max_critique_tokens=256,
157 | score_fn=None,
158 | prompt_renderer=render_prompt,
159 | prompt_suffix="",
160 | verbose=False,
161 | temperature=0.7,
162 | top_p=0.9,
163 | ):
164 | steps = []
165 |
166 | # Initial response (draft)
167 | prompt = prompt_renderer(raw_prompt) + prompt_suffix
168 | current_full = generate_text_stream_concat_flex(
169 | model=model,
170 | tokenizer=tokenizer,
171 | prompt=prompt,
172 | device=device,
173 | max_new_tokens=max_response_tokens,
174 | verbose=False,
175 | generate_func=generate_text_top_p_stream_cache,
176 | temperature=temperature,
177 | top_p=top_p,
178 | )
179 |
180 | current_extracted = extract_final_candidate(
181 | current_full, fallback="number_then_full"
182 | )
183 | if score_fn:
184 | current_score = score_fn(answer=current_full, prompt=prompt)
185 | else:
186 | current_score = 0.0
187 |
188 | # Run for one or more iterations
189 | for it in range(iterations):
190 | draft_before_full = current_full
191 | draft_before_extracted = current_extracted
192 | score_before = current_score
193 |
194 | # Critique the response
195 | critique_prompt = make_critique_prompt(
196 | raw_prompt, draft_before_full
197 | )
198 | critique_full = generate_text_stream_concat_flex(
199 | model=model,
200 | tokenizer=tokenizer,
201 | prompt=critique_prompt,
202 | device=device,
203 | max_new_tokens=max_critique_tokens,
204 | verbose=False,
205 | generate_func=generate_text_top_p_stream_cache,
206 | temperature=temperature,
207 | top_p=top_p,
208 | )
209 |
210 | # Refine the response
211 | refine_prompt = make_refine_prompt(
212 | raw_prompt, draft_before_full, critique_full
213 | )
214 | revised_full = generate_text_stream_concat_flex(
215 | model=model,
216 | tokenizer=tokenizer,
217 | prompt=refine_prompt,
218 | device=device,
219 | max_new_tokens=max_response_tokens,
220 | verbose=False,
221 | generate_func=generate_text_top_p_stream_cache,
222 | temperature=temperature,
223 | top_p=top_p,
224 | )
225 |
226 | revised_extracted = extract_final_candidate(
227 | revised_full, fallback="number_then_full"
228 | )
229 | if score_fn:
230 | revised_score = score_fn(
231 | answer=revised_full, prompt=prompt # Still use original prompt here
232 | )
233 | else:
234 | revised_score = 0.0
235 |
236 | # Log the results
237 | step = {
238 | "iteration": it + 1,
239 | "draft_full": draft_before_full,
240 | "draft_extracted": draft_before_extracted,
241 | "critique": critique_full,
242 | "revised_full": revised_full,
243 | "revised_extracted": revised_extracted,
244 | "score_before": score_before,
245 | "score_after": revised_score,
246 | }
247 | steps.append(step)
248 |
249 | if verbose:
250 | print(
251 | f"[Refinement {it+1}/{iterations}]"
252 | f"\nCurrent: {draft_before_extracted}"
253 | f"\nRevised: {revised_extracted}"
254 | f"\nScore before: {score_before:.3f}"
255 | f"\nScore after: {revised_score:.3f}"
256 | f"\n{'=' * 25}"
257 | )
258 |
259 | # Accept revised response if it's not worse
260 | if revised_score >= current_score:
261 | current_full = revised_full
262 | current_extracted = revised_extracted
263 | current_score = revised_score
264 |
265 | return {
266 | "final_full": current_full,
267 | "final_extracted": current_extracted,
268 | "steps": steps,
269 | }
270 |
--------------------------------------------------------------------------------