├── .github └── workflows │ ├── docs.yml │ └── main.yml ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── assets ├── EinsumIJKToIK.mp4 ├── EinsumIJTo.mp4 ├── EinsumIJToJ.mp4 ├── EinsumIJToJI.mp4 └── GPT.svg ├── datasets └── bee_movie.txt ├── docs ├── Makefile ├── conf.py ├── generate_rst.py ├── index.rst ├── make.bat ├── modules.rst ├── tricycle │ ├── activation.rst │ ├── attention.rst │ ├── binary.rst │ ├── blocks.rst │ ├── configs.rst │ ├── context.rst │ ├── dataset.rst │ ├── einsum.rst │ ├── exceptions.rst │ ├── functions.rst │ ├── initialisers.rst │ ├── layers.rst │ ├── loss.rst │ ├── models.rst │ ├── ops.rst │ ├── optimisers.rst │ ├── reduce.rst │ ├── scheduler.rst │ ├── tensor.rst │ ├── tokeniser.rst │ ├── unary.rst │ ├── utils.rst │ └── weakset.rst └── tricycle_datasets │ ├── codeparrot.rst │ ├── fineweb.rst │ └── shakespeare.rst ├── inference.py ├── pyproject.toml ├── requirements ├── environment.cpu.test.yml ├── environment.cpu.yml ├── environment.test.yml └── environment.yml ├── setup.cfg ├── src ├── tricycle │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── binary.py │ ├── blocks.py │ ├── configs.py │ ├── context.py │ ├── dataset.py │ ├── einsum.py │ ├── exceptions.py │ ├── functions.py │ ├── initialisers.py │ ├── layers.py │ ├── loss.py │ ├── models.py │ ├── ops.py │ ├── optimisers.py │ ├── reduce.py │ ├── scheduler.py │ ├── tensor.py │ ├── tokeniser.py │ ├── unary.py │ ├── utils.py │ └── weakset.py └── tricycle_datasets │ ├── __init__.py │ ├── codeparrot.py │ ├── fineweb.py │ └── shakespeare.py ├── tests ├── __init__.py ├── conftest.py ├── test_activations.py ├── test_attention.py ├── test_binary.py ├── test_blocks.py ├── test_composite.py ├── test_dataset.py ├── test_einsum.py ├── test_functions.py ├── test_layers.py ├── test_loss.py ├── test_mixed_precision.py ├── test_model_matches_pytorch.py ├── test_optimisers.py ├── test_python_and_numpy.py ├── test_reduce.py ├── test_simple_neural_network.py ├── test_tensor_api.py ├── test_tensor_pbt.py ├── test_tokeniser.py ├── test_unary_ops.py └── test_vectorise.py └── train_smol_gpt.py /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Build and Deploy Docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | build-and-deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - name: Set up Python 3.10 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: "3.10" 21 | 22 | - name: Install Mamba 23 | uses: mamba-org/setup-micromamba@v1 24 | with: 25 | micromamba-version: '1.5.6-0' # any version from https://github.com/mamba-org/micromamba-releases 26 | environment-file: requirements/environment.cpu.test.yml 27 | init-shell: >- 28 | bash 29 | cache-environment: true 30 | post-cleanup: 'all' 31 | 32 | - name: Activate Mamba environment 33 | run: | 34 | eval "$(micromamba shell hook --shell bash)" 35 | micromamba activate tricycle 36 | 37 | - name: Build docs 38 | run: | 39 | cd docs 40 | sphinx-build -b html . _build/html 41 | shell: bash -el {0} 42 | 43 | - name: Deploy to GitHub Pages 44 | uses: peaceiris/actions-gh-pages@v3 45 | with: 46 | github_token: ${{ secrets.GITHUB_TOKEN }} 47 | publish_dir: ./docs/_build/html 48 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | test: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | 23 | - name: Set up Python 3.10 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: "3.10" 27 | 28 | - name: Install Mamba 29 | uses: mamba-org/setup-micromamba@v1 30 | with: 31 | micromamba-version: '1.5.6-0' # any version from https://github.com/mamba-org/micromamba-releases 32 | environment-file: requirements/environment.cpu.test.yml 33 | init-shell: >- 34 | bash 35 | cache-environment: true 36 | post-cleanup: 'all' 37 | 38 | #---------------------------------------------- 39 | # run pre-commit checks 40 | #---------------------------------------------- 41 | - name: Run pre-commit checks 42 | run: pre-commit run --all-files 43 | shell: bash -el {0} 44 | #---------------------------------------------- 45 | # run test suite 46 | #---------------------------------------------- 47 | - name: Run tests 48 | run: pytest tests 49 | shell: bash -el {0} 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # environment 2 | .venv 3 | poetry.lock 4 | *__pycache__* 5 | 6 | # nvim 7 | .null-ls* 8 | 9 | # macos 10 | *.DS_Store 11 | 12 | # datasets 13 | datasets/shakespeare/tokens_*.pkl 14 | datasets/shakespeare/raw_data.txt 15 | 16 | # jupyter 17 | .ipynb_checkpoints 18 | *.ipynb 19 | output.bin 20 | 21 | # experiment results 22 | *.csv 23 | results/* 24 | .hypothesis 25 | results 26 | models 27 | benchmarks/.profiles 28 | tests/.profiles 29 | 30 | richbench*.json 31 | python*.json 32 | 33 | # binaries 34 | *.pkl 35 | build/lib 36 | datasets/fineweb 37 | 38 | # docs 39 | docs/_build/ 40 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | 9 | - repo: https://github.com/psf/black 10 | rev: 24.2.0 11 | hooks: 12 | - id: black 13 | 14 | - repo: https://github.com/pycqa/isort 15 | rev: 5.12.0 16 | hooks: 17 | - id: isort 18 | name: isort (python) 19 | args: ["--profile", "black", "--filter-files"] 20 | -------------------------------------------------------------------------------- /assets/EinsumIJKToIK.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bclarkson-code/Tricycle/ebccb095d029c0d504eef11bace1a5287226bbc4/assets/EinsumIJKToIK.mp4 -------------------------------------------------------------------------------- /assets/EinsumIJTo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bclarkson-code/Tricycle/ebccb095d029c0d504eef11bace1a5287226bbc4/assets/EinsumIJTo.mp4 -------------------------------------------------------------------------------- /assets/EinsumIJToJ.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bclarkson-code/Tricycle/ebccb095d029c0d504eef11bace1a5287226bbc4/assets/EinsumIJToJ.mp4 -------------------------------------------------------------------------------- /assets/EinsumIJToJI.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bclarkson-code/Tricycle/ebccb095d029c0d504eef11bace1a5287226bbc4/assets/EinsumIJToJI.mp4 -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Autodoc ----------------------------------------------------------------- 7 | import os 8 | import sys 9 | 10 | sys.path.insert(0, os.path.abspath("../../src")) 11 | autodoc_typehints = "description" 12 | autodoc_default_options = { 13 | "members": True, 14 | "undoc-members": True, 15 | "show-inheritance": True, 16 | } 17 | 18 | # -- Project information ----------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 20 | 21 | project = "Tricycle" 22 | copyright = "2024, Ben Clarkson" 23 | author = "Ben Clarkson" 24 | version = "0.1" 25 | release = "0.1.0" 26 | 27 | # -- General configuration --------------------------------------------------- 28 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 29 | 30 | extensions = [ 31 | "sphinx.ext.autodoc", 32 | "sphinx.ext.napoleon", 33 | "sphinx.ext.viewcode", 34 | "recommonmark", 35 | ] 36 | autosummary_generate = True 37 | add_module_names = False 38 | source_suffix = [".rst", ".md"] 39 | 40 | templates_path = ["_templates"] 41 | exclude_patterns = [] 42 | 43 | 44 | # -- Options for HTML output ------------------------------------------------- 45 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 46 | 47 | html_theme = "furo" 48 | html_static_path = ["_static"] 49 | html_build_dir = "../docs" 50 | -------------------------------------------------------------------------------- /docs/generate_rst.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def create_rst_files(src_dir, docs_dir, package_name): 5 | # Ensure the docs directory exists 6 | os.makedirs(os.path.join(docs_dir, package_name), exist_ok=True) 7 | 8 | # Walk through the source directory 9 | for root, _, files in os.walk(src_dir): 10 | for file in files: 11 | if file.endswith(".py") and file != "__init__.py": 12 | # Get the module name 13 | module_name = os.path.splitext(file)[0] 14 | 15 | # Create the full module path 16 | module_path = os.path.relpath( 17 | os.path.join(root, file), src_dir 18 | ) 19 | module_path = os.path.splitext(module_path)[0].replace( 20 | os.path.sep, "." 21 | ) 22 | 23 | # Create the RST content 24 | rst_content = f"""{module_name} 25 | {'=' * len(module_name)} 26 | 27 | .. automodule:: {package_name}.{module_path} 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | """ 32 | 33 | # Write the RST file 34 | rst_path = os.path.join( 35 | docs_dir, package_name, f"{module_name}.rst" 36 | ) 37 | with open(rst_path, "w") as rst_file: 38 | rst_file.write(rst_content) 39 | 40 | print(f"Created {rst_path}") 41 | 42 | 43 | def update_modules_rst(docs_dir, package_names): 44 | modules_content = "API Reference\n=============\n\n" 45 | 46 | for package_name in package_names: 47 | modules_content += ( 48 | f"{package_name.capitalize()}\n{'-' * len(package_name)}\n\n" 49 | ) 50 | modules_content += ".. toctree::\n :maxdepth: 1\n\n" 51 | 52 | package_dir = os.path.join(docs_dir, package_name) 53 | for file in os.listdir(package_dir): 54 | if file.endswith(".rst"): 55 | modules_content += ( 56 | f" {package_name}/{os.path.splitext(file)[0]}\n" 57 | ) 58 | 59 | modules_content += "\n" 60 | 61 | with open(os.path.join(docs_dir, "modules.rst"), "w") as modules_file: 62 | modules_file.write(modules_content) 63 | 64 | print("Updated modules.rst") 65 | 66 | 67 | if __name__ == "__main__": 68 | # Set your directories here 69 | src_dir = "../src" 70 | docs_dir = "." 71 | package_names = ["tricycle", "tricycle_datasets"] 72 | 73 | for package_name in package_names: 74 | create_rst_files( 75 | os.path.join(src_dir, package_name), docs_dir, package_name 76 | ) 77 | 78 | update_modules_rst(docs_dir, package_names) 79 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Tricycle documentation master file, created by 2 | sphinx-quickstart on Thu Jul 25 11:07:47 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Tricycle's documentation! 7 | ==================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | modules 14 | 15 | Indices and tables 16 | ================== 17 | 18 | * :ref:`genindex` 19 | * :ref:`modindex` 20 | * :ref:`search` 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | Tricycle 5 | -------- 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | tricycle/optimisers 11 | tricycle/attention 12 | tricycle/tokeniser 13 | tricycle/einsum 14 | tricycle/dataset 15 | tricycle/functions 16 | tricycle/reduce 17 | tricycle/blocks 18 | tricycle/initialisers 19 | tricycle/tensor 20 | tricycle/models 21 | tricycle/activation 22 | tricycle/layers 23 | tricycle/loss 24 | tricycle/configs 25 | tricycle/scheduler 26 | tricycle/binary 27 | tricycle/context 28 | tricycle/weakset 29 | tricycle/unary 30 | tricycle/utils 31 | tricycle/ops 32 | tricycle/exceptions 33 | 34 | Tricycle_datasets 35 | ----------------- 36 | 37 | .. toctree:: 38 | :maxdepth: 1 39 | 40 | tricycle_datasets/shakespeare 41 | tricycle_datasets/codeparrot 42 | tricycle_datasets/fineweb 43 | -------------------------------------------------------------------------------- /docs/tricycle/activation.rst: -------------------------------------------------------------------------------- 1 | activation 2 | ========== 3 | 4 | .. automodule:: tricycle.activation 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/attention.rst: -------------------------------------------------------------------------------- 1 | attention 2 | ========= 3 | 4 | .. automodule:: tricycle.attention 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/binary.rst: -------------------------------------------------------------------------------- 1 | binary 2 | ====== 3 | 4 | .. automodule:: tricycle.binary 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/blocks.rst: -------------------------------------------------------------------------------- 1 | blocks 2 | ====== 3 | 4 | .. automodule:: tricycle.blocks 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/configs.rst: -------------------------------------------------------------------------------- 1 | configs 2 | ======= 3 | 4 | .. automodule:: tricycle.configs 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/context.rst: -------------------------------------------------------------------------------- 1 | context 2 | ======= 3 | 4 | .. automodule:: tricycle.context 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/dataset.rst: -------------------------------------------------------------------------------- 1 | dataset 2 | ======= 3 | 4 | .. automodule:: tricycle.dataset 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/einsum.rst: -------------------------------------------------------------------------------- 1 | einsum 2 | ====== 3 | 4 | .. automodule:: tricycle.einsum 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/exceptions.rst: -------------------------------------------------------------------------------- 1 | exceptions 2 | ========== 3 | 4 | .. automodule:: tricycle.exceptions 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/functions.rst: -------------------------------------------------------------------------------- 1 | functions 2 | ========= 3 | 4 | .. automodule:: tricycle.functions 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/initialisers.rst: -------------------------------------------------------------------------------- 1 | initialisers 2 | ============ 3 | 4 | .. automodule:: tricycle.initialisers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/layers.rst: -------------------------------------------------------------------------------- 1 | layers 2 | ====== 3 | 4 | .. automodule:: tricycle.layers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/loss.rst: -------------------------------------------------------------------------------- 1 | loss 2 | ==== 3 | 4 | .. automodule:: tricycle.loss 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/models.rst: -------------------------------------------------------------------------------- 1 | models 2 | ====== 3 | 4 | .. automodule:: tricycle.models 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/ops.rst: -------------------------------------------------------------------------------- 1 | ops 2 | === 3 | 4 | .. automodule:: tricycle.ops 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/optimisers.rst: -------------------------------------------------------------------------------- 1 | optimisers 2 | ========== 3 | 4 | .. automodule:: tricycle.optimisers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/reduce.rst: -------------------------------------------------------------------------------- 1 | reduce 2 | ====== 3 | 4 | .. automodule:: tricycle.reduce 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/scheduler.rst: -------------------------------------------------------------------------------- 1 | scheduler 2 | ========= 3 | 4 | .. automodule:: tricycle.scheduler 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/tensor.rst: -------------------------------------------------------------------------------- 1 | tensor 2 | ====== 3 | 4 | .. automodule:: tricycle.tensor 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/tokeniser.rst: -------------------------------------------------------------------------------- 1 | tokeniser 2 | ========= 3 | 4 | .. automodule:: tricycle.tokeniser 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/unary.rst: -------------------------------------------------------------------------------- 1 | unary 2 | ===== 3 | 4 | .. automodule:: tricycle.unary 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/utils.rst: -------------------------------------------------------------------------------- 1 | utils 2 | ===== 3 | 4 | .. automodule:: tricycle.utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle/weakset.rst: -------------------------------------------------------------------------------- 1 | weakset 2 | ======= 3 | 4 | .. automodule:: tricycle.weakset 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle_datasets/codeparrot.rst: -------------------------------------------------------------------------------- 1 | codeparrot 2 | ========== 3 | 4 | .. automodule:: tricycle_datasets.codeparrot 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle_datasets/fineweb.rst: -------------------------------------------------------------------------------- 1 | fineweb 2 | ======= 3 | 4 | .. automodule:: tricycle_datasets.fineweb 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/tricycle_datasets/shakespeare.rst: -------------------------------------------------------------------------------- 1 | shakespeare 2 | =========== 3 | 4 | .. automodule:: tricycle_datasets.shakespeare 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from copy import copy 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import tiktoken 8 | from tqdm import tqdm 9 | 10 | from tricycle.configs import DebugConfig, ShakespeareConfig, SmolGPTConfig 11 | from tricycle.functions import Softmax 12 | from tricycle.layers import Dropout, Layer 13 | from tricycle.models import GPT 14 | from tricycle.tensor import Tensor 15 | from tricycle.tokeniser import BPETokeniser 16 | from tricycle_datasets.fineweb import FineWeb 17 | from tricycle_datasets.shakespeare import Shakespeare 18 | 19 | config = SmolGPTConfig() 20 | 21 | 22 | def load_model(path: str | Path) -> Layer: 23 | print(f"LOADING MODEL: {path}") 24 | with open( 25 | path, 26 | "rb", 27 | ) as f: 28 | return pickle.load(f) 29 | 30 | 31 | def deactivate_dropout(model: Layer) -> Layer: 32 | """ 33 | Traverse through the model and deactivate any dropout layers 34 | """ 35 | stack = [model] 36 | 37 | while stack: 38 | node = stack.pop() 39 | if isinstance(node, Dropout): 40 | node.probability = 0 41 | 42 | if not node.layers: 43 | continue 44 | 45 | stack.extend(iter(node.layers)) 46 | return model 47 | 48 | 49 | def generate( 50 | model: GPT, 51 | tokens: np.ndarray | None = None, 52 | sample=True, 53 | temperature=0.8, 54 | pad_token=-1, 55 | ): 56 | """ 57 | Given a prompt, yield next token predictions for a model 58 | """ 59 | if isinstance(tokens, np.ndarray): 60 | tokens = tokens.tolist() 61 | 62 | while True: 63 | tokens = tokens[-config.context_window :] 64 | n_tokens = len(tokens) 65 | if n_tokens < config.context_window: 66 | pad_tokens = [pad_token] * (config.context_window - n_tokens) 67 | tokens += pad_tokens 68 | 69 | encoded = Tensor( 70 | tokens, dtype=np.uint32, requires_grad=False, is_batched=False 71 | ) 72 | 73 | pred = model(encoded) 74 | pred = Softmax()(pred / temperature) 75 | 76 | next_token_idx = n_tokens - 1 77 | 78 | if pred.on_gpu: 79 | probabilities = pred.xp.asnumpy(pred.array[0][next_token_idx]) 80 | else: 81 | probabilities = pred.array[0][next_token_idx] 82 | 83 | # sample according to probabilities 84 | if sample: 85 | next_token = np.random.choice( 86 | list(range(config.vocab_size)), p=probabilities 87 | ) 88 | else: 89 | next_token = np.argmax(probabilities) 90 | 91 | # remove padding + add new token 92 | tokens = tokens[:n_tokens] 93 | tokens.append(next_token) 94 | 95 | # convert from numpy int to python int 96 | yield int(next_token) 97 | 98 | 99 | def get_sample( 100 | model: GPT, 101 | tokeniser: BPETokeniser | tiktoken.core.Encoding, 102 | sample_tokens: np.ndarray | None = None, 103 | ) -> str: 104 | """ 105 | Given a prompt, generate some new tokens and return them as a string 106 | """ 107 | sampled = [] 108 | for i, next_token in tqdm( 109 | enumerate( 110 | generate( 111 | tokens=sample_tokens, 112 | model=model, 113 | ) 114 | ), 115 | desc="Sampling", 116 | total=config.n_tokens_to_generate, 117 | position=1, 118 | leave=False, 119 | ): 120 | if i > config.n_tokens_to_generate: 121 | break 122 | sampled.append(next_token) 123 | 124 | decoded = tokeniser.decode(sampled) 125 | sample_text = tokeniser.decode(sample_tokens) 126 | decoded = f"PROMPT:\n{sample_text}\nGENERATED:\n{decoded}" 127 | return decoded 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser( 132 | prog="inference.py", description="Generate predictions from a GPT" 133 | ) 134 | 135 | parser.add_argument("model_path") 136 | parser.add_argument("prompt", help="Text that will be passed to the model") 137 | parser.add_argument( 138 | "-c", 139 | "--model_config", 140 | choices=["debug", "smol_gpt", "shakespeare"], 141 | default="shakespeare", 142 | ) 143 | parser.add_argument("-s", "--seed", default=0, type=int) 144 | parser.add_argument( 145 | "-d", 146 | "--dataset", 147 | choices=["shakespeare", "fineweb"], 148 | default="shakespeare", 149 | ) 150 | parser.add_argument("--use-gpu", action="store_true") 151 | 152 | args = parser.parse_args() 153 | print(args) 154 | 155 | match args.model_config: 156 | case "shakespeare": 157 | config = ShakespeareConfig() 158 | case "smol_gpt": 159 | config = SmolGPTConfig() 160 | case "debug": 161 | config = DebugConfig() 162 | case _: 163 | raise ValueError(f"Unknown dataset: {args.config}") 164 | 165 | match args.dataset: 166 | case "shakespeare": 167 | dataset = Shakespeare(config.vocab_size) 168 | case "fineweb": 169 | dataset = FineWeb(config.vocab_size, split="valid") 170 | case _: 171 | raise ValueError(f"Unknown dataset: {args.dataset}") 172 | 173 | np.random.seed(args.seed) 174 | 175 | model_path = Path(args.model_path) 176 | if model_path.exists(): 177 | model = load_model(model_path) 178 | else: 179 | raise FileNotFoundError( 180 | f"Could not find model file: {model_path.absolute()}" 181 | ) 182 | 183 | if args.use_gpu: 184 | model.to_gpu(0) 185 | else: 186 | model.from_gpu() 187 | 188 | model.zero_grad() 189 | deactivate_dropout(model) 190 | 191 | sample_tokens = dataset.tokeniser.encode(args.prompt) 192 | if isinstance(sample_tokens, np.ndarray): 193 | sample_tokens = sample_tokens.tolist() 194 | generated = copy(sample_tokens) 195 | prev = args.prompt 196 | for token in generate( 197 | tokens=sample_tokens, model=model, sample=True, pad_token=0 198 | ): 199 | if args.dataset == "fineweb" and token == dataset.tokeniser.eot_token: 200 | break 201 | generated += [token] 202 | try: 203 | if isinstance(dataset.tokeniser, BPETokeniser): 204 | decoded = dataset.tokeniser.decode(np.array(generated)) 205 | else: 206 | decoded = dataset.tokeniser.decode(generated, errors="strict") 207 | except UnicodeDecodeError: 208 | new = decoded[len(prev) :] 209 | prev = decoded 210 | continue 211 | 212 | new = decoded[len(prev) :] 213 | print(new, end="", flush=True) 214 | prev = decoded 215 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tricycle" 3 | version = "0.1.0" 4 | description = "A Deep Learning library you can actually understand" 5 | authors = [ 6 | { name = "benedictclarkson1", email = "benedictclarkson1@gmail.com" } 7 | ] 8 | license = { text = "MIT" } 9 | readme = "README.md" 10 | 11 | [tool.isort] 12 | profile = "black" 13 | line_length = 79 14 | 15 | [tool.black] 16 | line-length = 79 17 | 18 | [build-system] 19 | requires = ["setuptools", "wheel"] 20 | build-backend = "setuptools.build_meta" 21 | -------------------------------------------------------------------------------- /requirements/environment.cpu.test.yml: -------------------------------------------------------------------------------- 1 | name: tricycle 2 | 3 | channels: 4 | - defaults 5 | - conda-forge 6 | - pytorch 7 | 8 | dependencies: 9 | # base dependencies 10 | - python=3.10 11 | - numpy 12 | - humanize 13 | - tqdm 14 | - mlflow 15 | - psutil 16 | - numba 17 | - tiktoken 18 | - datasets 19 | # test dependencies 20 | - scikit-learn 21 | - pytest 22 | - hypothesis 23 | - pytorch==2 # we use this to test correctness, not to cheat 24 | - pre-commit 25 | # docs dependencies 26 | - sphinx 27 | - furo 28 | - recommonmark 29 | # install tricycle 30 | - pip 31 | - pip: 32 | - -e ../ 33 | -------------------------------------------------------------------------------- /requirements/environment.cpu.yml: -------------------------------------------------------------------------------- 1 | name: tricycle 2 | 3 | channels: 4 | - defaults 5 | - conda-forge 6 | 7 | dependencies: 8 | - python=3.10 9 | - numpy 10 | - humanize 11 | - tqdm 12 | - mlflow 13 | - psutil 14 | - numba 15 | - tiktoken 16 | - datasets 17 | - pip 18 | - pip: 19 | - -e ../ 20 | -------------------------------------------------------------------------------- /requirements/environment.test.yml: -------------------------------------------------------------------------------- 1 | name: tricycle 2 | 3 | channels: 4 | - defaults 5 | - conda-forge 6 | - pytorch 7 | 8 | dependencies: 9 | # base dependencies 10 | - python=3.10 11 | - numpy 12 | - humanize 13 | - tqdm 14 | - mlflow 15 | - psutil 16 | - numba 17 | - tiktoken 18 | - datasets 19 | # gpu dependencies 20 | - cuda-version==12 21 | - cudnn 22 | - cutensor 23 | - nccl 24 | - pynvml 25 | - cupy 26 | # test dependencies 27 | - scikit-learn 28 | - pytest 29 | - hypothesis 30 | - pytorch==2 # we use this to test correctness, not to cheat 31 | - pre-commit 32 | # docs dependencies 33 | - sphinx 34 | - furo 35 | - recommonmark 36 | # install tricycle 37 | - pip 38 | - pip: 39 | - -e ../ 40 | -------------------------------------------------------------------------------- /requirements/environment.yml: -------------------------------------------------------------------------------- 1 | name: tricycle 2 | 3 | channels: 4 | - defaults 5 | - conda-forge 6 | 7 | dependencies: 8 | # base dependencies 9 | - python=3.10 10 | - numpy 11 | - humanize 12 | - tqdm 13 | - mlflow 14 | - psutil 15 | - numba 16 | - tiktoken 17 | - datasets 18 | # gpu dependencies 19 | - cuda-version==12 20 | - cudnn 21 | - cutensor 22 | - nccl 23 | - pynvml 24 | - cupy 25 | # install tricycle 26 | - pip 27 | - pip: 28 | - -e ../ 29 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = tricycle 3 | version = 0.1.0 4 | description = "A Deep Learning library you can actually understand" 5 | author = benedictclarkson1 6 | author_email = benedictclarkson1@gmail.com 7 | license = MIT 8 | long_description = file: README.md 9 | 10 | [options] 11 | package_dir = 12 | = src 13 | packages = find: 14 | 15 | [options.packages.find] 16 | where = src 17 | -------------------------------------------------------------------------------- /src/tricycle/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tricycle: A deep learning framework. 3 | 4 | This module initializes the Tricycle framework and imports its various components. 5 | It also checks for GPU support using CuPY. 6 | 7 | Attributes: 8 | GPU_ENABLED (bool): Indicates whether GPU support is available. 9 | 10 | Imports: 11 | Various submodules of the Tricycle framework. 12 | """ 13 | 14 | from warnings import warn 15 | 16 | try: 17 | import cupy 18 | 19 | # check that we can create a cupy array and operate on it 20 | cupy.array([1, 2, 3]) * 2 21 | GPU_ENABLED = True 22 | except ImportError: 23 | warn("Could not find CuPY, disabling GPU features") 24 | GPU_ENABLED = False 25 | except Exception as e: 26 | GPU_ENABLED = False 27 | warn(f"Failed to build cupy array: {e}. Disabling GPU features") 28 | 29 | from . import ( 30 | activation, 31 | attention, 32 | binary, 33 | blocks, 34 | configs, 35 | context, 36 | dataset, 37 | einsum, 38 | exceptions, 39 | functions, 40 | initialisers, 41 | layers, 42 | loss, 43 | models, 44 | ops, 45 | optimisers, 46 | reduce, 47 | scheduler, 48 | tensor, 49 | tokeniser, 50 | unary, 51 | utils, 52 | weakset, 53 | ) 54 | 55 | __all__ = [ 56 | "activation", 57 | "attention", 58 | "binary", 59 | "blocks", 60 | "configs", 61 | "context", 62 | "dataset", 63 | "einsum", 64 | "exceptions", 65 | "functions", 66 | "initialisers", 67 | "layers", 68 | "loss", 69 | "models", 70 | "ops", 71 | "optimisers", 72 | "reduce", 73 | "scheduler", 74 | "tensor", 75 | "tokeniser", 76 | "unary", 77 | "utils", 78 | "weakset", 79 | ] 80 | -------------------------------------------------------------------------------- /src/tricycle/activation.py: -------------------------------------------------------------------------------- 1 | from tricycle.context import TRICYCLE_CONTEXT 2 | from tricycle.functions import Sigmoid 3 | from tricycle.initialisers import init_xavier 4 | from tricycle.layers import Dense, Layer 5 | from tricycle.optimisers import Optimiser 6 | from tricycle.tensor import Tensor 7 | from tricycle.unary import UnaryMax 8 | 9 | 10 | class ReLU(Layer): 11 | """ 12 | Rectified Linear Unit (ReLU) activation function. 13 | 14 | This layer applies the ReLU function element-wise to the input tensor. 15 | ReLU(x) = max(0, x) 16 | """ 17 | 18 | def forward(self, x: Tensor): 19 | """ 20 | Apply the ReLU function to the input tensor. 21 | 22 | Args: 23 | x (Tensor): Input tensor. 24 | 25 | Returns: 26 | Tensor: Output tensor after applying ReLU. 27 | """ 28 | return UnaryMax()(x, 0) 29 | 30 | 31 | class Swish(Layer): 32 | """ 33 | Swish activation function. 34 | 35 | This layer applies the Swish function element-wise to the input tensor. 36 | Swish(x) = x * sigmoid(x) 37 | 38 | Note: This implementation is equivalent to the SiLU activation function 39 | as it omits the bias term. 40 | """ 41 | 42 | def backward(self, grad: Tensor): 43 | """ 44 | Compute the gradient of the Swish function. 45 | 46 | Args: 47 | grad (Tensor): Upstream gradient. 48 | 49 | Returns: 50 | Tensor: Gradient with respect to the input. 51 | """ 52 | xp = grad.xp 53 | 54 | # Exponents tend to overflow/underflow when using 16 bit precision 55 | # so we need to switch to 32 bit 56 | if TRICYCLE_CONTEXT.use_mixed_precision: 57 | self._input = self._input.astype(xp.float32) 58 | 59 | exp = xp.exp(-self._input) 60 | numerator = 1 + exp + self._input * exp 61 | denominator = (1 + exp) ** 2 62 | coef = numerator / denominator 63 | 64 | if TRICYCLE_CONTEXT.use_mixed_precision: 65 | coef = coef.astype(xp.float16) 66 | 67 | return Tensor(grad * coef) 68 | 69 | def forward(self, tensor: Tensor): 70 | """ 71 | Apply the Swish function to the input tensor. 72 | 73 | Args: 74 | tensor (Tensor): Input tensor. 75 | 76 | Returns: 77 | Tensor: Output tensor after applying Swish. 78 | """ 79 | xp = tensor.xp 80 | 81 | self._input = tensor.array 82 | # Exponents tend to overflow/underflow when using 16 bit precision 83 | # so we need to switch to 32 bit 84 | if TRICYCLE_CONTEXT.use_mixed_precision: 85 | self._input = self._input.astype(xp.float32) 86 | 87 | out = tensor.array / (1 + xp.exp(-tensor.array)) 88 | 89 | if TRICYCLE_CONTEXT.use_mixed_precision: 90 | self._input = self._input.astype(xp.float16) 91 | out = out.astype(xp.float16) 92 | 93 | return Tensor( 94 | out, args=(tensor,), back_fns=(self.backward,), name="swish" 95 | ) 96 | 97 | 98 | class GeLU(Layer): 99 | """ 100 | Gaussian Error Linear Unit (GELU) activation function. 101 | 102 | This layer applies the GELU function element-wise to the input tensor. 103 | GELU(x) ≈ 0.5x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) 104 | 105 | Args: 106 | approximate (bool): Whether to use the approximate version of GELU. 107 | Defaults to False. 108 | """ 109 | 110 | CONST_1 = 0.7978845608028654 111 | CONST_2 = 0.044715 112 | 113 | def __init__(self, *args, approximate: bool = False, **kwargs): 114 | """ 115 | Initialize the GELU layer. 116 | 117 | Args: 118 | approximate (bool): Whether to use the approximate version of GELU. 119 | Defaults to False. 120 | *args: Variable length argument list. 121 | **kwargs: Arbitrary keyword arguments. 122 | """ 123 | super().__init__(*args, **kwargs) 124 | self.approximate = approximate 125 | 126 | def backward(self, grad: Tensor): 127 | """ 128 | Compute the gradient of the GELU function. 129 | 130 | Args: 131 | grad (Tensor): Upstream gradient. 132 | 133 | Returns: 134 | Tensor: Gradient with respect to the input. 135 | """ 136 | xp = grad.xp 137 | 138 | # Hyperbolic trig functions (cosh and tanh) use exponents under the 139 | # hood which can overflow/underflow when using 16 bit precision so 140 | # we need to switch to 32 bit precision 141 | if TRICYCLE_CONTEXT.use_mixed_precision: 142 | self._input = self._input.astype(xp.float32) 143 | 144 | inner = ( 145 | self.CONST_1 * self._input * (1 + self.CONST_2 * self._input**2) 146 | ) 147 | coef = ( 148 | self.CONST_1 149 | * self._input 150 | * (1 + self.CONST_2 * 3 * self._input**2) 151 | ) 152 | 153 | left = xp.tanh(inner) 154 | cosh = xp.cosh(inner) 155 | right = coef / (cosh * cosh) 156 | 157 | if TRICYCLE_CONTEXT.use_mixed_precision: 158 | left = left.astype(xp.float16) 159 | right = right.astype(xp.float16) 160 | 161 | self._grad = 0.5 * (1 + left + right) * grad.array 162 | 163 | result = Tensor( 164 | self._grad, 165 | is_batched=grad.is_batched, 166 | requires_grad=grad.requires_grad, 167 | ) 168 | result.name = "gelu_back" 169 | return result 170 | 171 | def forward(self, tensor: Tensor): 172 | """ 173 | Apply the GELU function to the input tensor. 174 | 175 | Args: 176 | tensor (Tensor): Input tensor. 177 | 178 | Returns: 179 | Tensor: Output tensor after applying GELU. 180 | """ 181 | xp = tensor.xp 182 | self._input = tensor.array 183 | 184 | # Tanh tends to overflow/underflow when using 16 bit precision 185 | # so we need to switch to 32 bit 186 | if TRICYCLE_CONTEXT.use_mixed_precision: 187 | self._input = self._input.astype(xp.float32) 188 | 189 | inner = self.CONST_1 * (self._input + self.CONST_2 * self._input**3) 190 | result = self._input * 0.5 * (1 + xp.tanh(inner)) 191 | 192 | if TRICYCLE_CONTEXT.use_mixed_precision: 193 | self._input = self._input.astype(xp.float16) 194 | result = result.astype(xp.float16) 195 | 196 | result = Tensor( 197 | result, 198 | is_batched=tensor.is_batched, 199 | requires_grad=tensor.requires_grad, 200 | ) 201 | result.name = "gelu" 202 | result.args = (tensor,) 203 | result.back_fns = (self.backward,) 204 | return result 205 | 206 | 207 | class GLU(Layer): 208 | """ 209 | Gated Linear Unit (GLU) activation function. 210 | 211 | This layer applies the GLU function to the input tensor. 212 | GLU(x) = x_left * sigmoid(x_right) 213 | 214 | Args: 215 | size (int): Size of the input tensor. 216 | initialiser (callable): Function to initialize the weights. 217 | Defaults to init_xavier. 218 | """ 219 | 220 | linear: Dense 221 | 222 | def __init__(self, size: int, initialiser=init_xavier, *args, **kwargs): 223 | """ 224 | Initialize the GLU layer. 225 | 226 | Args: 227 | size (int): Size of the input tensor. 228 | initialiser (callable): Function to initialize the weights. 229 | Defaults to init_xavier. 230 | *args: Variable length argument list. 231 | **kwargs: Arbitrary keyword arguments. 232 | """ 233 | super().__init__(*args, **kwargs) 234 | self.linear = Dense(size, 2 * size, initialiser) 235 | self.layers = [self.linear] 236 | self.sigmoid = Sigmoid() 237 | 238 | def forward(self, x: Tensor): 239 | """ 240 | Apply the GLU function to the input tensor. 241 | 242 | Args: 243 | x (Tensor): Input tensor. 244 | 245 | Returns: 246 | Tensor: Output tensor after applying GLU. 247 | """ 248 | x = self.linear(x) 249 | left, right = x.split(2) 250 | return left * self.sigmoid(right) 251 | 252 | def update(self, optimiser: Optimiser): 253 | """ 254 | Update the layer parameters using the given optimizer. 255 | 256 | Args: 257 | optimiser (Optimiser): The optimizer to use for updating parameters. 258 | """ 259 | self.linear.update(optimiser) 260 | 261 | def zero_grad(self): 262 | """ 263 | Reset the gradients of the layer parameters to zero. 264 | """ 265 | self.linear.zero_grad() 266 | 267 | def to_gpu(self): 268 | """ 269 | Move the layer parameters to GPU memory. 270 | """ 271 | self.linear.to_gpu() 272 | 273 | def from_gpu(self): 274 | """ 275 | Move the layer parameters from GPU to CPU memory. 276 | """ 277 | self.linear.from_gpu() 278 | -------------------------------------------------------------------------------- /src/tricycle/attention.py: -------------------------------------------------------------------------------- 1 | """Attention module for multi-head attention operations. 2 | 3 | This module implements the multi-head attention mechanism as described in 4 | "Attention Is All You Need" (Vaswani et al., 2017). It includes functions for 5 | building attention masks and the main Attention class for performing 6 | multi-head attention operations. 7 | """ 8 | 9 | from math import sqrt 10 | 11 | import numpy as np 12 | 13 | from tricycle import GPU_ENABLED 14 | from tricycle.context import TRICYCLE_CONTEXT 15 | from tricycle.ops import Op 16 | from tricycle.tensor import Tensor 17 | 18 | 19 | def build_mask(context_window: int, n_heads: int) -> Tensor: 20 | """Build an attention mask to prevent attending to future tokens. 21 | 22 | This function creates a boolean mask that can be used in multi-head attention 23 | mechanisms to implement causal (unidirectional) attention. 24 | 25 | Args: 26 | context_window: An integer representing the size of the context window. 27 | n_heads: An integer representing the number of attention heads. 28 | 29 | Returns: 30 | A boolean tensor of shape (n_heads, context_window, context_window) 31 | representing the attention mask. 32 | """ 33 | mask = np.ones((context_window, context_window), dtype=bool) 34 | idx = np.tril(mask) 35 | mask = np.stack([~idx] * n_heads) 36 | return mask 37 | 38 | 39 | class Attention(Op): 40 | """Multi-head attention operation. 41 | 42 | This class implements the multi-head attention mechanism as described in 43 | "Attention Is All You Need" (Vaswani et al., 2017). 44 | 45 | Attributes: 46 | embedding_dim: An integer representing the dimension of the input embeddings. 47 | n_heads: An integer representing the number of attention heads. 48 | context_window: An integer representing the size of the context window. 49 | mask: A tensor representing the attention mask. 50 | _grad: A tensor to store gradients during backpropagation. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | embedding_dim: int, 56 | n_heads: int, 57 | context_window: int, 58 | ): 59 | """Initialize the Attention operation. 60 | 61 | Args: 62 | embedding_dim: An integer representing the dimension of the input embeddings. 63 | n_heads: An integer representing the number of attention heads. 64 | context_window: An integer representing the size of the context window. 65 | """ 66 | super().__init__() 67 | self.embedding_dim = embedding_dim 68 | self.n_heads = n_heads 69 | self.context_window = context_window 70 | self.mask = build_mask( 71 | context_window=self.context_window, n_heads=self.n_heads 72 | ) 73 | self._grad = None 74 | 75 | def backward(self, grad: Tensor): 76 | """Compute the gradient of the attention operation. 77 | 78 | Args: 79 | grad: A Tensor representing the upstream gradient. 80 | 81 | Returns: 82 | A Tensor representing the gradient with respect to the input. 83 | """ 84 | xp = grad.xp 85 | in_shape = (self.batch_size, self.context_window, self.embedding_dim) 86 | 87 | attention = grad.array 88 | 89 | # TODO: come up with a better name 90 | # smush 91 | attention = attention.reshape( 92 | ( 93 | self.batch_size, 94 | self.context_window, 95 | self.n_heads, 96 | self.head_size, 97 | ) 98 | ) 99 | value = xp.einsum("BNIj, BINH -> BNjH", self._before_smush, attention) 100 | attention = xp.einsum("BINH, BNjH -> BNIj", attention, self._value) 101 | 102 | # softmax 103 | inner = xp.sum(attention * self._before_smush, axis=-1, keepdims=True) 104 | attention = self._before_smush * (attention - inner) 105 | 106 | # mask 107 | attention = xp.where( 108 | self.mask[:, : self.n_tokens, : self.n_tokens], 0, attention 109 | ) 110 | attention /= self.divisor 111 | 112 | # attend 113 | query = xp.einsum("BNIJ, BNJh -> BNIh", attention, self._key) 114 | key = xp.einsum("BNIh, BNIJ -> BNJh", self._query, attention) 115 | 116 | # reshape + reorder 117 | key = xp.einsum("BNTH->BTNH", key) 118 | query = xp.einsum("BNTH->BTNH", query) 119 | value = xp.einsum("BNTH->BTNH", value) 120 | 121 | key = key.reshape(in_shape) 122 | query = query.reshape(in_shape) 123 | value = value.reshape(in_shape) 124 | 125 | # merge into single tensor 126 | if self._grad is None: 127 | self._grad = xp.zeros( 128 | (self.batch_size, self.context_window, self.embedding_dim * 3) 129 | ) 130 | self._grad[:, :, : self.embedding_dim] = query 131 | self._grad[:, :, self.embedding_dim : self.embedding_dim * 2] = key 132 | self._grad[:, :, self.embedding_dim * 2 :] = value 133 | 134 | return Tensor(self._grad) 135 | 136 | def forward(self, tensor: Tensor): 137 | """Apply the multi-head attention operation to the input tensor. 138 | 139 | Args: 140 | tensor: A Tensor of shape (batch_size, seq_len, embedding_dim * 3). 141 | The input should contain concatenated query, key, and value projections. 142 | 143 | Returns: 144 | A Tensor representing the output after applying multi-head attention. 145 | """ 146 | xp = tensor.xp 147 | 148 | assert tensor.is_batched 149 | 150 | # split the input into 3 peices 151 | self._input = tensor 152 | query = tensor[:, :, : self.embedding_dim] 153 | key = tensor[:, :, self.embedding_dim : self.embedding_dim * 2] 154 | value = tensor[:, :, self.embedding_dim * 2 :] 155 | 156 | # Figure out how big everything is 157 | self.batch_size = key.array.shape[0] 158 | self.head_size = self.embedding_dim // self.n_heads 159 | self.n_tokens = key.shape[-2] 160 | head_shape = ( 161 | self.batch_size, 162 | self.n_tokens, # number of tokens 163 | self.n_heads, # number of heads 164 | self.head_size, # embedding per head 165 | ) 166 | out_shape = (self.batch_size, self.n_tokens, self.embedding_dim) 167 | 168 | # reshape and reorder the heads 169 | key = key.array 170 | query = query.array 171 | value = value.array 172 | 173 | key = key.reshape(head_shape) 174 | query = query.reshape(head_shape) 175 | value = value.reshape(head_shape) 176 | 177 | key = xp.einsum("BTNH->BNTH", key) 178 | query = xp.einsum("BTNH->BNTH", query) 179 | value = xp.einsum("BTNH->BNTH", value) 180 | 181 | self._key = key 182 | self._query = query 183 | self._value = value 184 | 185 | # attend 186 | self.divisor = sqrt(self.head_size) 187 | attention = xp.einsum("BNIh, BNJh -> BNIJ", query, key) 188 | attention = attention / self.divisor 189 | 190 | # mask 191 | attention = xp.where( 192 | self.mask[:, : self.n_tokens, : self.n_tokens], -xp.inf, attention 193 | ) 194 | 195 | # Exponents tend to overflow/underflow when using 16 bit precision 196 | # so we need to switch to 32 bit 197 | if TRICYCLE_CONTEXT.use_mixed_precision: 198 | attention = attention.astype(xp.float32) 199 | 200 | # softmax 201 | exp = xp.exp(attention - xp.max(attention, axis=-1, keepdims=True)) 202 | denominator = xp.sum(exp, axis=-1, keepdims=True) 203 | attention = exp / denominator 204 | 205 | if TRICYCLE_CONTEXT.use_mixed_precision: 206 | attention = attention.astype(xp.float16) 207 | 208 | # smush the heads back together 209 | self._before_smush = attention 210 | attention = xp.einsum("BNTi, BNiH -> BTNH", attention, value) 211 | attention = attention.reshape(out_shape) 212 | 213 | result = Tensor(attention, is_batched=True) 214 | result.back_fns = (self.backward,) 215 | result.args = (self._input,) 216 | return result 217 | 218 | def to_gpu(self, device: int): 219 | """Move this operation to a GPU. 220 | 221 | Args: 222 | device: An integer representing the GPU device number. 223 | """ 224 | if GPU_ENABLED: 225 | import cupy as cp 226 | 227 | cp.cuda.Device(device).use() 228 | self.mask = cp.array(self.mask) 229 | 230 | def from_gpu(self): 231 | """Move the operation back to CPU.""" 232 | if GPU_ENABLED: 233 | import cupy as cp 234 | 235 | self.mask = cp.asnumpy(self.mask) 236 | -------------------------------------------------------------------------------- /src/tricycle/configs.py: -------------------------------------------------------------------------------- 1 | """Configurations for different GPT models. 2 | 3 | This module contains configuration classes for various GPT models, including 4 | a base configuration class and specific configurations for debugging, 5 | Shakespeare-based models, and a small GPT model. 6 | 7 | Classes: 8 | GPTConfig: Base configuration class for GPT models. 9 | DebugConfig: Configuration for debugging purposes. 10 | ShakespeareConfig: Configuration for Shakespeare-based models. 11 | SmolGPTConfig: Configuration for a small GPT model. 12 | """ 13 | 14 | from typing import Literal 15 | 16 | 17 | class GPTConfig: 18 | """Base configuration class for GPT models. 19 | 20 | This class defines the common parameters and hyperparameters used in 21 | GPT model training and evaluation. 22 | 23 | Attributes: 24 | embedding_dim (int): Dimension of the embedding layer. 25 | context_window (int): Size of the context window. 26 | vocab_size (int): Size of the vocabulary. 27 | n_heads (int): Number of attention heads. 28 | n_layers (int): Number of transformer layers. 29 | expansion_ratio (float): Expansion ratio for feed-forward layers. 30 | activation_fn (str): Activation function used in the model. 31 | norm_fn (str): Normalization function used in the model. 32 | input_dropout_prob (float): Dropout probability for input embeddings. 33 | residual_dropout_prob (float): Dropout probability for residual connections. 34 | linear_dropout_prob (float): Dropout probability for linear layers. 35 | max_learning_rate (float): Maximum learning rate for training. 36 | min_learning_rate (float): Minimum learning rate for training. 37 | warmup_steps (int): Number of warmup steps for learning rate scheduling. 38 | weight_decay (float): Weight decay factor for regularization. 39 | momentum (float): Momentum factor for optimization. 40 | beta1 (float): Beta1 parameter for Adam optimizer. 41 | beta2 (float): Beta2 parameter for Adam optimizer. 42 | steps (int | Literal["chinchilla_optimal"]): Number of training steps or "chinchilla_optimal". 43 | eval_interval (int): Interval between evaluations. 44 | batch_size (int): Batch size for training. 45 | gradient_accumulation_steps (int): Number of steps for gradient accumulation. 46 | device_idx (int): Index of the device to use for training. 47 | mlflow_tracking_uri (str): URI for MLflow tracking server. 48 | mlflow_experiment_name (str): Name of the MLflow experiment. 49 | """ 50 | 51 | embedding_dim: int 52 | context_window: int 53 | vocab_size: int 54 | n_heads: int 55 | n_layers: int 56 | expansion_ratio: float 57 | activation_fn: str 58 | norm_fn: str 59 | 60 | input_dropout_prob: float 61 | residual_dropout_prob: float 62 | linear_dropout_prob: float 63 | 64 | max_learning_rate: float 65 | min_learning_rate: float 66 | warmup_steps: int 67 | weight_decay: float 68 | momentum = float 69 | beta1: float 70 | beta2: float 71 | 72 | steps: int | Literal["chinchilla_optimal"] 73 | eval_interval: int 74 | batch_size: int 75 | gradient_accumulation_steps: int 76 | 77 | device_idx: int 78 | 79 | mlflow_tracking_uri: str 80 | mlflow_experiment_name: str 81 | 82 | def dict(self) -> dict[str, int | float | str | bool]: 83 | """Convert the configuration to a dictionary. 84 | 85 | Returns: 86 | dict[str, int | float | str | bool]: A dictionary representation of the configuration. 87 | """ 88 | out = {} 89 | for k, v in self.__class__.__dict__.items(): 90 | if k.startswith("__"): 91 | continue 92 | 93 | if callable(v): 94 | continue 95 | out[k] = v 96 | return out 97 | 98 | 99 | class DebugConfig(GPTConfig): 100 | """Configuration for debugging purposes. 101 | 102 | This class inherits from GPTConfig and sets specific values for debugging. 103 | """ 104 | 105 | embedding_dim = 14 106 | context_window = 13 107 | vocab_size = 11 108 | n_heads = 2 109 | n_layers = 1 110 | expansion_ratio = 4 111 | activation_fn = "gelu" 112 | norm_fn = "layer_norm" 113 | 114 | input_dropout_prob = 0.2 115 | residual_dropout_prob = 0.2 116 | linear_dropout_prob = 0.2 117 | 118 | max_learning_rate = 1e-3 119 | min_learning_rate = 1e-4 120 | warmup_steps = 100 121 | weight_decay = 1e-1 122 | momentum = 0 123 | beta1 = 0.9 124 | beta2 = 0.99 125 | 126 | steps = 250 127 | eval_interval = 1 128 | eval_steps = 1 129 | batch_size = 5 130 | gradient_accumulation_steps = 1 131 | sample_size = 4 132 | 133 | device_idx = 0 134 | 135 | mlflow_enabled = False 136 | mlflow_tracking_uri = "" 137 | 138 | 139 | class ShakespeareConfig(GPTConfig): 140 | """Configuration for Shakespeare-based models. 141 | 142 | This class inherits from GPTConfig and sets specific values for 143 | Shakespeare-based language models. 144 | """ 145 | 146 | embedding_dim = 384 147 | context_window = 256 148 | vocab_size = 1024 149 | n_heads = 6 150 | n_layers = 6 151 | expansion_ratio = 4 152 | activation_fn = "gelu" 153 | norm_fn = "layer_norm" 154 | 155 | input_dropout_prob = 0.2 156 | residual_dropout_prob = 0.2 157 | linear_dropout_prob = 0.2 158 | 159 | max_learning_rate = 1e-2 160 | min_learning_rate = 1e-4 161 | warmup_steps = 100 162 | weight_decay = 1e-1 163 | momentum = 0 164 | beta1 = 0.9 165 | beta2 = 0.99 166 | 167 | steps = 5000 168 | eval_interval = 250 169 | eval_steps = 128 170 | batch_size = 128 171 | gradient_accumulation_steps = 1 172 | sample_size = 512 173 | 174 | device_idx = 1 175 | 176 | mlflow_enabled = True 177 | mlflow_tracking_uri = "http://localhost:5000" 178 | 179 | 180 | class SmolGPTConfig(GPTConfig): 181 | """Configuration for a small GPT model. 182 | 183 | This class inherits from GPTConfig and sets specific values for 184 | a small-scale GPT model. 185 | """ 186 | 187 | embedding_dim = 768 188 | context_window = 1024 189 | vocab_size = 50256 190 | n_heads = 12 191 | n_layers = 12 192 | expansion_ratio = 4 193 | activation_fn = "gelu" 194 | norm_fn = "layer_norm" 195 | 196 | input_dropout_prob = 0 197 | residual_dropout_prob = 0 198 | linear_dropout_prob = 0 199 | 200 | max_learning_rate = 6e-4 201 | min_learning_rate = 0 202 | warmup_steps = 150 # roughly matches andrej's warmup steps in llm.c 203 | weight_decay = 1e-1 204 | momentum = 0 205 | beta1 = 0.9 206 | beta2 = 0.95 207 | 208 | steps = "chinchilla_optimal" 209 | eval_interval = 100 210 | eval_steps = 128 211 | batch_size = 4 212 | gradient_accumulation_steps = 128 # effective batch size of 524288 tokens 213 | n_tokens_to_generate = 512 214 | 215 | device_idx = 0 216 | 217 | mlflow_enabled = True 218 | mlflow_tracking_uri = "http://localhost:5000" 219 | -------------------------------------------------------------------------------- /src/tricycle/context.py: -------------------------------------------------------------------------------- 1 | """Defines the context for Tricycle operations. 2 | 3 | This module provides a dataclass for storing Tricycle context information, 4 | including mixed precision and loss scaling settings. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | 9 | 10 | @dataclass 11 | class TricycleContext: 12 | """A dataclass to store Tricycle context information. 13 | 14 | Attributes: 15 | use_mixed_precision (bool): Flag to enable mixed precision. Default is False. 16 | Note: It's recommended to use the tricycle/utils.py:UseMixedPrecision 17 | context manager for mixed precision training instead of modifying this directly. 18 | 19 | loss_scale_factor (int): Factor to scale the loss when using mixed precision. 20 | This helps prevent under and overflowing. Default is 128. 21 | """ 22 | 23 | use_mixed_precision: bool = False 24 | loss_scale_factor: int = 128 25 | 26 | 27 | # Global instance of TricycleContext 28 | TRICYCLE_CONTEXT = TricycleContext() 29 | -------------------------------------------------------------------------------- /src/tricycle/exceptions.py: -------------------------------------------------------------------------------- 1 | class GPUDisabledException(Exception): 2 | """Raised when a GPU operation is attempted while GPU computation is disabled. 3 | 4 | This exception is thrown when a function or method tries to perform 5 | a GPU-based operation, but GPU computation has been explicitly disabled 6 | or is unavailable in the current environment. 7 | 8 | Attributes: 9 | None 10 | 11 | Example: 12 | >>> try: 13 | ... perform_gpu_operation() 14 | ... except GPUDisabledException: 15 | ... print("GPU operations are currently disabled.") 16 | """ 17 | -------------------------------------------------------------------------------- /src/tricycle/functions.py: -------------------------------------------------------------------------------- 1 | """Functions for neural network activation operations. 2 | 3 | This module provides implementations of common activation functions 4 | used in neural networks, including Softmax and Sigmoid. 5 | """ 6 | 7 | from tricycle.context import TRICYCLE_CONTEXT 8 | from tricycle.ops import Op 9 | from tricycle.tensor import Tensor 10 | 11 | 12 | class Softmax(Op): 13 | """Applies the softmax function to the input tensor. 14 | 15 | The softmax function is applied only to the final dimension of the tensor. 16 | The input is normalized for numeric stability. 17 | 18 | Attributes: 19 | _out: The output of the forward pass. 20 | _grad: The gradient computed during the backward pass. 21 | """ 22 | 23 | def back_fn(self, grad: Tensor) -> Tensor: 24 | """Computes the gradient of the softmax function. 25 | 26 | Args: 27 | grad: The gradient tensor from the subsequent layer. 28 | 29 | Returns: 30 | A Tensor containing the computed gradient. 31 | """ 32 | xp = grad.xp 33 | 34 | inner = xp.sum(grad.array * self._out, axis=-1, keepdims=True) 35 | self._grad = self._out * (grad.array - inner) 36 | 37 | return Tensor( 38 | self._grad, 39 | is_batched=grad.is_batched, 40 | requires_grad=grad.requires_grad, 41 | ) 42 | 43 | def forward(self, tensor: Tensor): 44 | """Applies the softmax function to the input tensor. 45 | 46 | Args: 47 | tensor: The input tensor. 48 | 49 | Returns: 50 | A Tensor with the softmax function applied. 51 | """ 52 | xp = tensor.xp 53 | 54 | # Exponents tend to overflow/underflow when using 16 bit precision 55 | # so we need to switch to 32 bit 56 | if TRICYCLE_CONTEXT.use_mixed_precision: 57 | tensor.array = tensor.array.astype(xp.float32) 58 | 59 | exp = xp.exp( 60 | # subtract the largest value for numeric stability 61 | tensor.array 62 | - xp.max(tensor.array, axis=-1, keepdims=True) 63 | ) 64 | denominator = xp.sum(exp, axis=-1, keepdims=True) 65 | self._out = exp / denominator 66 | if TRICYCLE_CONTEXT.use_mixed_precision: 67 | self._out = self._out.astype(xp.float16) 68 | 69 | return Tensor( 70 | self._out, 71 | args=(tensor,), 72 | name="softmax", 73 | is_batched=tensor.is_batched, 74 | back_fns=(self.back_fn,), 75 | ) 76 | 77 | 78 | class Sigmoid(Op): 79 | """Applies the sigmoid function to the input tensor. 80 | 81 | Attributes: 82 | _out: The output of the forward pass. 83 | _grad: The gradient computed during the backward pass. 84 | """ 85 | 86 | def backward(self, grad: Tensor) -> Tensor: 87 | """Computes the gradient of the sigmoid function. 88 | 89 | Args: 90 | grad: The gradient tensor from the subsequent layer. 91 | 92 | Returns: 93 | A Tensor containing the computed gradient. 94 | """ 95 | self._grad = self._out * (1 - self._out) * grad.array 96 | return Tensor(self._grad, requires_grad=grad) 97 | 98 | def forward(self, tensor: Tensor) -> Tensor: 99 | """Applies the sigmoid function to the input tensor. 100 | 101 | Args: 102 | tensor: The input tensor. 103 | 104 | Returns: 105 | A Tensor with the sigmoid function applied. 106 | """ 107 | xp = tensor.xp 108 | 109 | # Exponents tend to overflow/underflow when using 16 bit precision 110 | # so we need to switch to 32 bit 111 | if TRICYCLE_CONTEXT.use_mixed_precision: 112 | tensor.array = tensor.array.astype(xp.float32) 113 | 114 | self._out = 1 / (1 + xp.exp(-tensor.array)) 115 | return Tensor( 116 | self._out, 117 | back_fns=(self.backward,), 118 | args=(tensor,), 119 | requires_grad=tensor.requires_grad, 120 | ) 121 | -------------------------------------------------------------------------------- /src/tricycle/initialisers.py: -------------------------------------------------------------------------------- 1 | """Initializers module for tensor initialization. 2 | 3 | This module provides functions for initializing tensors with specific 4 | distributions or patterns. 5 | """ 6 | 7 | import numpy as np 8 | 9 | from tricycle.ops import Tensor 10 | from tricycle.tensor import DEFAULT_DTYPE 11 | 12 | 13 | def init_xavier(shape: tuple[int, int], name: str = "") -> Tensor: 14 | """Initialize a tensor with Xavier/Glorot initialization. 15 | 16 | This function implements Xavier/Glorot initialization, which helps in 17 | setting initial random weights for neural networks. It's particularly 18 | useful for maintaining the scale of gradients across layers. 19 | 20 | Args: 21 | shape: A tuple of two integers (f_in, f_out), where f_in is the number 22 | of input units and f_out is the number of output units. 23 | name: An optional string to name the created tensor. Defaults to an 24 | empty string. 25 | 26 | Returns: 27 | A Tensor object initialized with Xavier/Glorot initialization. 28 | 29 | Raises: 30 | ValueError: If the shape tuple does not contain exactly two integers. 31 | 32 | Example: 33 | >>> weight = init_xavier((100, 50), name="layer1_weights") 34 | """ 35 | f_in, f_out = shape 36 | bound = np.sqrt(6) / np.sqrt(f_in + f_out) 37 | out = Tensor( 38 | np.random.uniform(low=-bound, high=bound, size=shape), 39 | dtype=DEFAULT_DTYPE, 40 | name=name, 41 | ) 42 | return out 43 | -------------------------------------------------------------------------------- /src/tricycle/loss.py: -------------------------------------------------------------------------------- 1 | """Loss functions for neural network training. 2 | 3 | This module contains implementations of common loss functions used in neural 4 | network training, such as Mean Squared Error and Cross Entropy. 5 | 6 | Classes: 7 | MeanSquaredError: Calculates the Mean Squared Error loss. 8 | CrossEntropy: Calculates the Cross Entropy loss. 9 | """ 10 | 11 | import logging 12 | 13 | from tricycle.context import TRICYCLE_CONTEXT 14 | from tricycle.ops import Op 15 | from tricycle.tensor import Tensor 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class MeanSquaredError(Op): 21 | """Calculates Mean Squared Error loss. 22 | 23 | This class implements the Mean Squared Error (MSE) loss function, which 24 | measures the average squared difference between the predicted and true values. 25 | 26 | Attributes: 27 | diff: The difference between predicted and true values. 28 | divisor: A scaling factor for the loss calculation. 29 | 30 | """ 31 | 32 | def backward(self, grad: Tensor) -> Tensor: 33 | """Computes the backward pass for Mean Squared Error loss. 34 | 35 | Args: 36 | grad: A Tensor containing the gradient from the previous layer. 37 | 38 | Returns: 39 | A Tensor containing the computed gradients. 40 | """ 41 | xp = grad.xp 42 | 43 | if TRICYCLE_CONTEXT.use_mixed_precision: 44 | grad.array = grad.array.astype(xp.float32) 45 | 46 | out = self.diff * 2 * grad.array * self.divisor 47 | 48 | if TRICYCLE_CONTEXT.use_mixed_precision: 49 | out = out.astype(xp.float16) 50 | 51 | return Tensor(out) 52 | 53 | def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor: 54 | """Computes the forward pass for Mean Squared Error loss. 55 | 56 | Args: 57 | y_true: A Tensor containing the true values. 58 | y_pred: A Tensor containing the predicted values. 59 | 60 | Returns: 61 | A Tensor containing the computed MSE loss. 62 | 63 | Raises: 64 | ValueError: If the computed loss is infinite. 65 | """ 66 | xp = y_pred.xp 67 | 68 | if TRICYCLE_CONTEXT.use_mixed_precision: 69 | y_pred.array = y_pred.array.astype(xp.float32) 70 | y_true.array = y_true.array.astype(xp.float32) 71 | 72 | self.diff = y_pred.array - y_true.array 73 | self.divisor = 1 / xp.prod(y_pred.shape[-1]) 74 | 75 | out = (self.diff**2).sum() * self.divisor 76 | 77 | if TRICYCLE_CONTEXT.use_mixed_precision: 78 | out *= TRICYCLE_CONTEXT.loss_scale_factor 79 | out = out.astype(xp.float16) 80 | 81 | if not xp.isfinite(out): 82 | raise ValueError("Loss is infinite") 83 | 84 | # only y_pred is differentiable: y_true is a constant 85 | return Tensor( 86 | out, 87 | args=(y_pred,), 88 | back_fns=(self.backward,), 89 | name="mean_squared_error", 90 | ) 91 | 92 | 93 | class CrossEntropy(Op): 94 | """Calculates Cross Entropy loss. 95 | 96 | This class implements the Cross Entropy loss function, which is commonly 97 | used for classification tasks. It computes the loss given logits and target 98 | indices (as opposed to one-hot encoded tensors). 99 | 100 | Attributes: 101 | _y_true: The true labels (cached for backward pass). 102 | _log_softmax_pred: The log softmax of predictions (cached for backward pass). 103 | _out: The computed loss (cached for backward pass). 104 | _grad: The computed gradients (cached for backward pass). 105 | """ 106 | 107 | def log_softmax(self, tensor: Tensor): 108 | """Computes the log softmax of the input tensor. 109 | 110 | Args: 111 | tensor: A Tensor containing the input values. 112 | 113 | Returns: 114 | The log softmax of the input tensor. 115 | """ 116 | xp = tensor.xp 117 | x_max = xp.max(tensor.array, axis=-1, keepdims=True) 118 | log_sum_exp = x_max + xp.log( 119 | xp.sum(xp.exp(tensor.array - x_max), axis=-1, keepdims=True) 120 | ) 121 | return tensor.array - log_sum_exp 122 | 123 | def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor: 124 | """Computes the forward pass for Cross Entropy loss. 125 | 126 | Args: 127 | y_true: A Tensor containing the true labels. 128 | y_pred: A Tensor containing the predicted logits. 129 | 130 | Returns: 131 | A Tensor containing the computed Cross Entropy loss. 132 | 133 | Raises: 134 | NotImplementedError: If the input tensor has an unsupported number of dimensions. 135 | """ 136 | xp = y_pred.xp 137 | # cross entropy reduces a huge matrix to a single number which makes 138 | # it really sensitive to errors. To rememdy this, we need to use 139 | # full precision 140 | if TRICYCLE_CONTEXT.use_mixed_precision: 141 | y_pred.array = y_pred.array.astype(xp.float32) 142 | 143 | log_softmax_pred = self.log_softmax(y_pred) 144 | 145 | # Cache for backward pass 146 | self._y_true = y_true.array 147 | self._log_softmax_pred = log_softmax_pred 148 | 149 | ndim = log_softmax_pred.ndim 150 | 151 | if ndim == 3: 152 | batch_indices = xp.arange(y_true.shape[0], dtype=int) 153 | token_indices = xp.arange(y_true.shape[1], dtype=int) 154 | loss = -log_softmax_pred[ 155 | batch_indices[:, None], token_indices, y_true.array 156 | ] 157 | elif ndim == 2: 158 | indices = xp.arange(y_true.shape[0], dtype=int) 159 | loss = -log_softmax_pred[indices, y_true.array] 160 | elif ndim == 1: 161 | loss = -log_softmax_pred[y_true.array] 162 | else: 163 | raise NotImplementedError( 164 | f"BinaryCrossEntropy with predictions with ndim: {ndim} are not yet supported" 165 | ) 166 | 167 | # Mean loss over all elements 168 | loss = loss.mean() 169 | 170 | self._out = loss 171 | if TRICYCLE_CONTEXT.use_mixed_precision: 172 | self._out = ( 173 | self._out.astype(xp.float16) 174 | * TRICYCLE_CONTEXT.loss_scale_factor 175 | ) 176 | 177 | return Tensor( 178 | self._out, 179 | is_batched=False, 180 | back_fns=(self.backward,), 181 | args=(y_pred,), 182 | name="cross_entropy", 183 | ) 184 | 185 | def backward(self, grad: Tensor) -> Tensor: 186 | """Computes the backward pass for Cross Entropy loss. 187 | 188 | Args: 189 | grad: A Tensor containing the gradient from the previous layer. 190 | 191 | Returns: 192 | A Tensor containing the computed gradients. 193 | 194 | Raises: 195 | NotImplementedError: If the input tensor has an unsupported number of dimensions. 196 | """ 197 | xp = grad.xp 198 | ndim = self._log_softmax_pred.ndim 199 | 200 | if TRICYCLE_CONTEXT.use_mixed_precision: 201 | grad.array = grad.array.astype(xp.float32) 202 | 203 | if ndim == 3: 204 | batch_indices = xp.arange(self._y_true.shape[0], dtype=int) 205 | token_indices = xp.arange(self._y_true.shape[1], dtype=int) 206 | grad_output = xp.exp(self._log_softmax_pred) 207 | grad_output[ 208 | batch_indices[:, None], token_indices, self._y_true 209 | ] -= 1 210 | grad_output *= grad.array / ( 211 | self._y_true.shape[0] * self._y_true.shape[1] 212 | ) 213 | 214 | elif ndim == 2: 215 | indices = xp.arange(self._y_true.shape[0], dtype=int) 216 | grad_output = xp.exp(self._log_softmax_pred) 217 | grad_output[indices, self._y_true] -= 1 218 | grad_output *= grad.array / self._y_true.shape[0] 219 | elif ndim == 1: 220 | grad_output = xp.exp(self._log_softmax_pred) 221 | grad_output[self._y_true] -= 1 222 | grad_output *= grad.array 223 | else: 224 | raise NotImplementedError( 225 | f"BinaryCrossEntropy with predictions with ndim: {ndim} are not yet supported" 226 | ) 227 | 228 | self._grad = grad_output 229 | 230 | # remember to convert the gradient back to the right precision 231 | if TRICYCLE_CONTEXT.use_mixed_precision: 232 | self._grad = self._grad.astype(xp.float16) 233 | 234 | return Tensor(self._grad, is_batched=grad.is_batched) 235 | -------------------------------------------------------------------------------- /src/tricycle/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPT model implementation using the Tricycle framework. 3 | 4 | This module defines the GPT class, which implements a GPT-style transformer model 5 | using components from the Tricycle framework. 6 | """ 7 | 8 | import humanize 9 | import numpy as np 10 | 11 | from tricycle.blocks import GPT2TransformerBlock 12 | from tricycle.configs import GPTConfig 13 | from tricycle.layers import ( 14 | Dense, 15 | Dropout, 16 | Embedding, 17 | Layer, 18 | LayerNorm, 19 | RMSNorm, 20 | ) 21 | from tricycle.optimisers import Optimiser 22 | from tricycle.tensor import Tensor 23 | 24 | 25 | class GPT(Layer): 26 | """ 27 | Generative Pre-trained Transformer (GPT) model implementation. 28 | 29 | This class implements a GPT-style transformer model using components from 30 | the Tricycle framework. It includes token and position embeddings, multiple 31 | transformer blocks, and a final output layer. 32 | 33 | Attributes: 34 | embedding_dim (int): Dimension of the embedding space. 35 | context_window (int): Size of the context window for position embeddings. 36 | token_embedding (Embedding): Embedding layer for input tokens. 37 | position_embedding (Embedding): Embedding layer for positional information. 38 | input_dropout (Dropout): Dropout layer applied to the input embeddings. 39 | blocks (list): List of GPT2TransformerBlock instances. 40 | head (Dense): Final dense layer for output. 41 | norm (LayerNorm or RMSNorm): Normalization layer. 42 | layers (list): List of all layers in the model. 43 | """ 44 | 45 | def __init__(self, config: GPTConfig): 46 | """ 47 | Initializes the GPT model with the given configuration. 48 | 49 | Args: 50 | config (GPTConfig): Configuration object containing model parameters. 51 | """ 52 | self.embedding_dim = config.embedding_dim 53 | self.context_window = config.context_window 54 | self.token_embedding = Embedding( 55 | to_size=self.embedding_dim, 56 | from_size=config.vocab_size, 57 | name="token_embedding", 58 | ) 59 | self.position_embedding = Embedding( 60 | to_size=self.embedding_dim, 61 | from_size=self.context_window, 62 | name="position_embedding", 63 | ) 64 | self.input_dropout = Dropout(config.input_dropout_prob) 65 | 66 | self.blocks = [ 67 | GPT2TransformerBlock( 68 | embedding_dim=self.embedding_dim, 69 | n_heads=config.n_heads, 70 | context_window=self.context_window, 71 | expansion_ratio=config.expansion_ratio, 72 | activation_fn=config.activation_fn, 73 | norm_fn=config.norm_fn, 74 | ) 75 | for _ in range(config.n_layers) 76 | ] 77 | 78 | self.head = Dense( 79 | to_size=config.vocab_size, 80 | from_size=self.embedding_dim, 81 | name="head", 82 | ) 83 | match config.norm_fn: 84 | case "layer_norm": 85 | self.norm = LayerNorm(self.embedding_dim) 86 | case "rms_norm": 87 | self.norm = RMSNorm(self.embedding_dim) 88 | case _: 89 | raise ValueError(f"Unknown norm: {config.norm_fn}") 90 | 91 | self.layers = [ 92 | self.token_embedding, 93 | self.position_embedding, 94 | self.input_dropout, 95 | *self.blocks, 96 | self.norm, 97 | self.head, 98 | ] 99 | 100 | def forward(self, tensor: Tensor) -> Tensor: 101 | """ 102 | Performs a forward pass through the GPT model. 103 | 104 | Args: 105 | tensor (Tensor): Input tensor, expected to be one-hot encoded. 106 | 107 | Returns: 108 | Tensor: Output tensor after passing through the model. 109 | 110 | Raises: 111 | AssertionError: If the input tensor doesn't match the expected context window size. 112 | """ 113 | xp = tensor.xp 114 | if tensor.ndim == 1: 115 | n_tokens = tensor.shape[-1] 116 | tensor.array = xp.expand_dims(tensor.array, 0) 117 | tensor = tensor.to_batched() 118 | else: 119 | n_tokens = tensor.shape[-1] 120 | assert n_tokens == self.context_window, ( 121 | "Expected a full context window. ", 122 | f"Found {n_tokens=} and {self.context_window=}", 123 | ) 124 | 125 | position = Tensor( 126 | xp.arange(self.context_window), 127 | requires_grad=False, 128 | dtype=int, 129 | ) 130 | 131 | pos_embedding = self.position_embedding(position) 132 | token_embedding = self.token_embedding(tensor) 133 | 134 | embedding = token_embedding + pos_embedding 135 | 136 | embedding = self.input_dropout(embedding) 137 | 138 | for i, block in enumerate(self.blocks): 139 | embedding = block(embedding) 140 | 141 | embedding = self.norm(embedding) 142 | 143 | embedding = self.head(embedding) 144 | return embedding 145 | 146 | def zero_grad(self): 147 | """ 148 | Zeroes out the gradients of all layers in the model. 149 | 150 | Returns: 151 | GPT: The current GPT instance. 152 | """ 153 | self.token_embedding.zero_grad() 154 | self.position_embedding.zero_grad() 155 | self.norm.zero_grad() 156 | self.head.zero_grad() 157 | for block in self.blocks: 158 | block.zero_grad() 159 | return self 160 | 161 | def update(self, optimiser: Optimiser): 162 | """ 163 | Updates all layers in the model using the provided optimiser. 164 | 165 | Args: 166 | optimiser (Optimiser): The optimiser to use for updating model parameters. 167 | 168 | Returns: 169 | GPT: The current GPT instance. 170 | """ 171 | self.token_embedding.update(optimiser) 172 | self.position_embedding.update(optimiser) 173 | self.norm.update(optimiser) 174 | self.head.update(optimiser) 175 | for block in self.blocks: 176 | block.update(optimiser) 177 | return self 178 | 179 | def to_gpu(self, device: int = 0): 180 | """ 181 | Moves all layers of the model to the specified GPU device. 182 | 183 | Args: 184 | device (int, optional): The GPU device number. Defaults to 0. 185 | 186 | Returns: 187 | GPT: The current GPT instance. 188 | """ 189 | self.token_embedding.to_gpu(device) 190 | self.position_embedding.to_gpu(device) 191 | for block in self.blocks: 192 | block.to_gpu(device) 193 | self.norm.to_gpu(device) 194 | self.head.to_gpu(device) 195 | return self 196 | 197 | def from_gpu(self): 198 | """ 199 | Moves all layers of the model from GPU back to CPU. 200 | 201 | Returns: 202 | GPT: The current GPT instance. 203 | """ 204 | self.token_embedding.from_gpu() 205 | self.position_embedding.from_gpu() 206 | for block in self.blocks: 207 | block.from_gpu() 208 | self.norm.from_gpu() 209 | self.head.from_gpu() 210 | return self 211 | 212 | def display(self): 213 | """Prints a string representation of the model.""" 214 | print(self) 215 | 216 | def _contents(self): 217 | """ 218 | Returns a flattened list of the layers in this model, along with 219 | their depth in the tree of layers. 220 | 221 | Returns: 222 | list: A list of tuples containing layer name, size, and depth. 223 | """ 224 | stack = [(self, 0)] 225 | 226 | contents = [] 227 | while stack: 228 | node, indent = stack.pop() 229 | 230 | tensors = list(node.tensors.values()) 231 | shapes = [t.shape for t in tensors] 232 | size = sum(np.prod(shape) for shape in shapes) 233 | contents.append((node.__class__.__name__, size, indent)) 234 | 235 | stack.extend((layer, indent + 1) for layer in node.layers[::-1]) 236 | return contents 237 | 238 | def __str__(self): 239 | """ 240 | Returns a string representation of the model, including layer sizes 241 | and total parameter count. 242 | 243 | Returns: 244 | str: A formatted string representing the model structure and size. 245 | """ 246 | string = "" 247 | total = 0 248 | for layer, size, n_indent in self._contents(): 249 | total += size 250 | size = humanize.scientific(size) if size else "" 251 | indent = " " * n_indent 252 | 253 | string += f"{indent}{layer}({size})\n" 254 | 255 | PARAM_SIZE = self.head.weights[0][0].dtype.itemsize 256 | total *= PARAM_SIZE 257 | 258 | string += "Total size:\n" 259 | string += f" - {humanize.naturalsize(total)}\n" 260 | string += "Total parameters:\n" 261 | string += f" - {humanize.intword(total/PARAM_SIZE)}\n" 262 | return string 263 | -------------------------------------------------------------------------------- /src/tricycle/ops.py: -------------------------------------------------------------------------------- 1 | """Operations module for tensor manipulations. 2 | 3 | This module contains various operations that can be applied to tensors, 4 | including repeat, split, reshape, and mean operations. 5 | """ 6 | 7 | from abc import abstractmethod 8 | from typing import Sequence 9 | 10 | from numpy.typing import ArrayLike 11 | 12 | from tricycle.context import TRICYCLE_CONTEXT 13 | from tricycle.einsum import Einsum, Subscript 14 | from tricycle.tensor import Tensor 15 | 16 | 17 | class Op: 18 | """Base class for operations.""" 19 | 20 | _out: ArrayLike | None = None 21 | 22 | def __call__(self, *args, **kwargs) -> Tensor: 23 | """Call the forward method of the operation. 24 | 25 | Args: 26 | *args: Variable length argument list. 27 | **kwargs: Arbitrary keyword arguments. 28 | 29 | Returns: 30 | Tensor: The result of the forward operation. 31 | """ 32 | return self.forward(*args, **kwargs) 33 | 34 | @abstractmethod 35 | def forward(self, *args, **kwargs) -> Tensor: 36 | """Abstract method for the forward pass of the operation. 37 | 38 | Args: 39 | *args: Variable length argument list. 40 | **kwargs: Arbitrary keyword arguments. 41 | 42 | Raises: 43 | NotImplementedError: This method should be implemented by subclasses. 44 | 45 | Returns: 46 | Tensor: The result of the forward operation. 47 | """ 48 | raise NotImplementedError() 49 | 50 | 51 | class Repeat(Op): 52 | """Operation to repeat a tensor along its final axis.""" 53 | 54 | def forward(self, tensor: Tensor, repeats: int): 55 | """Repeat a tensor along its final axis. 56 | 57 | This is done by multiplying with a ones tensor the same shape as the 58 | desired output. 59 | 60 | Args: 61 | tensor (Tensor): The input tensor to repeat. 62 | repeats (int): The number of times to repeat the tensor. 63 | 64 | Returns: 65 | Tensor: The repeated tensor. 66 | """ 67 | xp = tensor.xp 68 | subscript = Subscript("...,...a->...a") 69 | new_shape = tensor.shape + (repeats,) 70 | ones = Tensor( 71 | xp.ones(new_shape), 72 | is_batched=tensor.is_batched, 73 | requires_grad=False, 74 | ) 75 | 76 | return Einsum(subscript)(tensor, ones) 77 | 78 | 79 | class Split(Op): 80 | """Operation to split a tensor along an axis.""" 81 | 82 | _indices: tuple[int] 83 | _axis: int 84 | _n_splits: int 85 | _grad: list[ArrayLike] 86 | 87 | def back_fn(self, grad: Tensor, idx: int) -> Tensor: 88 | """The backwards operation for a split operation. 89 | 90 | Produces a tensor of zeros the same shape as the input 91 | except in the section that was split. 92 | 93 | Args: 94 | grad (Tensor): The gradient tensor. 95 | idx (int): The index of the split. 96 | 97 | Returns: 98 | Tensor: The gradient for the input tensor. 99 | 100 | Example: 101 | >>> result = split([1,2,3,4], 2) 102 | >>> result 103 | [tensor([1, 2]), tensor([3, 4])] 104 | # set an arbitrary derivative for first split 105 | >>> result[0].grad = Tensor([1,1]) 106 | >>> undo_split(result[0].grad) 107 | [1, 1, 0, 0] 108 | """ 109 | xp = grad.xp 110 | self._grad[idx] = xp.zeros(self._in_shape) 111 | 112 | # TODO: this loop is really slow and should be replaced 113 | indices = [] 114 | for i in range(self._grad[idx].ndim): 115 | if i == self._axis % self._grad[idx].ndim: 116 | step = self._in_shape[i] // self._n_splits 117 | start = step * idx 118 | end = step * (idx + 1) 119 | indices.append(slice(start, end)) 120 | else: 121 | indices.append(slice(None)) 122 | self._grad[idx][tuple(indices)] = grad.array 123 | 124 | result = Tensor(self._grad[idx]) 125 | result.is_batched = grad.is_batched 126 | return result 127 | 128 | def forward( 129 | self, tensor: Tensor, n_splits: int, axis: int = -1 130 | ) -> Sequence[Tensor]: 131 | """Split a tensor along an axis into n_splits partitions. 132 | 133 | Args: 134 | tensor (Tensor): The input tensor to split. 135 | n_splits (int): The number of splits to make. 136 | axis (int, optional): The axis along which to split. Defaults to -1. 137 | 138 | Returns: 139 | Sequence[Tensor]: A sequence of split tensors. 140 | """ 141 | xp = tensor.xp 142 | 143 | assert isinstance(n_splits, int) 144 | 145 | self._out = xp.split(tensor.array, n_splits, axis=axis) 146 | self._in_shape = tensor.shape 147 | self._axis = axis 148 | self._n_splits = n_splits 149 | self._grad = [None] * n_splits 150 | 151 | # TODO: this loop is really slow and should be replaced 152 | results = [] 153 | for idx, result in enumerate(self._out): 154 | # the back_fn depends on index so we need to 155 | # dynamically create this function 156 | def back_fn(grad, idx=idx): 157 | return self.back_fn(grad, idx=idx) 158 | 159 | result = Tensor(result) 160 | result.back_fns = (back_fn,) 161 | result.args = (tensor,) 162 | result.is_batched = tensor.is_batched 163 | results.append(result) 164 | return results 165 | 166 | 167 | class Reshape(Op): 168 | """Operation to reshape a tensor.""" 169 | 170 | _original_shape: Sequence[int] 171 | 172 | def back_fn(self, grad: Tensor) -> Tensor: 173 | """Backward function for the reshape operation. 174 | 175 | Args: 176 | grad (Tensor): The gradient tensor. 177 | 178 | Returns: 179 | Tensor: The gradient reshaped to the original shape. 180 | """ 181 | xp = grad.xp 182 | 183 | self._grad = xp.reshape(grad.array, self._original_shape) 184 | 185 | return Tensor(array=self._grad, is_batched=grad.is_batched) 186 | 187 | def forward(self, tensor: Tensor, shape: Sequence[int]) -> Tensor: 188 | """Reshape a tensor. 189 | 190 | The new shape needs to have the same number of elements 191 | as the original, but can have any number of dimensions. 192 | 193 | Args: 194 | tensor (Tensor): The input tensor to reshape. 195 | shape (Sequence[int]): The new shape for the tensor. 196 | 197 | Returns: 198 | Tensor: The reshaped tensor. 199 | """ 200 | xp = tensor.xp 201 | 202 | # if the tensor is batched, don't include the first dimension in 203 | # the reshape 204 | if tensor.is_batched: 205 | shape = [tensor.shape[0]] + list(shape) 206 | 207 | self._out = xp.reshape(tensor.array, shape) 208 | self._original_shape = tensor.shape 209 | 210 | return Tensor( 211 | array=self._out, 212 | args=(tensor,), 213 | back_fns=(self.back_fn,), 214 | name="reshape", 215 | is_batched=tensor.is_batched, 216 | ) 217 | 218 | 219 | class Mean(Op): 220 | """Operation to find the mean of a tensor.""" 221 | 222 | def backward(self, grad: Tensor) -> Tensor: 223 | """Backward function for the mean operation. 224 | 225 | Args: 226 | grad (Tensor): The gradient tensor. 227 | 228 | Returns: 229 | Tensor: The gradient for the input tensor. 230 | """ 231 | xp = grad.xp 232 | 233 | result = xp.full(self._in_shape, self.divisor) 234 | out = grad.array * result 235 | 236 | return Tensor(out, is_batched=self._is_batched) 237 | 238 | def forward(self, tensor: Tensor) -> Tensor: 239 | """Find the mean of a tensor. 240 | 241 | Args: 242 | tensor (Tensor): The input tensor. 243 | 244 | Returns: 245 | Tensor: A tensor containing the mean value. 246 | """ 247 | xp = tensor.xp 248 | self._is_batched = tensor.is_batched 249 | self._in_shape = tensor.shape 250 | 251 | # we can overflow here with large arrays so we'll use full precision 252 | if TRICYCLE_CONTEXT.use_mixed_precision: 253 | tensor.array = tensor.array.astype(xp.float32) 254 | 255 | self.divisor = 1 / xp.prod(tensor.shape) if tensor.shape else 1 256 | out = tensor.array.sum() * self.divisor 257 | 258 | if TRICYCLE_CONTEXT.use_mixed_precision: 259 | out = out.astype(xp.float16) 260 | 261 | return Tensor( 262 | out, name="mean", back_fns=(self.backward,), args=(tensor,) 263 | ) 264 | -------------------------------------------------------------------------------- /src/tricycle/reduce.py: -------------------------------------------------------------------------------- 1 | """Provides reduction operations for tensors. 2 | 3 | This module contains classes for performing max and min reduction operations 4 | on tensors using einsum notation. 5 | """ 6 | 7 | from tricycle.einsum import Einsum, Subscript 8 | from tricycle.ops import Op 9 | from tricycle.tensor import Tensor 10 | 11 | 12 | class ReduceMax(Op): 13 | """Performs max reduction on a tensor along specified dimensions.""" 14 | 15 | def __call__(self, tensor: Tensor, subscript: Subscript | str): 16 | """Generates an indicator tensor for max reduction using einsum. 17 | 18 | This method creates an indicator tensor that, when einsummed with the 19 | input tensor, results in a tensor equal to the max applied along the 20 | indices that don't appear in the output of the subscript. 21 | 22 | Args: 23 | tensor: The input tensor to perform max reduction on. 24 | subscript: The einsum subscript specifying the reduction. 25 | 26 | Returns: 27 | A Tensor representing the result of the max reduction. 28 | 29 | Raises: 30 | AssertionError: If the subscript suggests more than one input tensor. 31 | """ 32 | if isinstance(subscript, str): 33 | subscript = Subscript(subscript) 34 | 35 | assert ( 36 | len(subscript.inputs) == 1 37 | ), f"Can only reduce a single tensor at a time. Indices suggeststed: {len(subscript.inputs)} tensors: {subscript.inputs}" 38 | 39 | [idx] = subscript.inputs 40 | 41 | reduce_along_axes = [ 42 | i for i, char in enumerate(idx) if char not in subscript.output 43 | ] 44 | 45 | if not reduce_along_axes: 46 | return tensor 47 | 48 | indicator = tensor.array == tensor.xp.max( 49 | tensor.array, axis=tuple(reduce_along_axes), keepdims=True 50 | ) 51 | indicator = Tensor( 52 | indicator, requires_grad=False, is_batched=tensor.is_batched 53 | ) 54 | indicator.array = indicator.array.astype(tensor.xp.int8) 55 | 56 | new_subscript = Subscript.from_split([idx, idx], subscript.output) 57 | 58 | result = Einsum(new_subscript)(tensor, indicator) 59 | result.name = f"min({new_subscript})" 60 | 61 | return result 62 | 63 | 64 | class ReduceMin(Op): 65 | """Performs min reduction on a tensor along specified dimensions.""" 66 | 67 | def __call__(self, tensor: Tensor, subscript: Subscript | str): 68 | """Generates an indicator tensor for min reduction using einsum. 69 | 70 | This method creates an indicator tensor that, when einsummed with the 71 | input tensor, results in a tensor equal to the min applied along the 72 | indices that don't appear in the output of the subscript. 73 | 74 | Args: 75 | tensor: The input tensor to perform min reduction on. 76 | subscript: The einsum subscript specifying the reduction. 77 | 78 | Returns: 79 | A Tensor representing the result of the min reduction. 80 | 81 | Raises: 82 | AssertionError: If the subscript suggests more than one input tensor. 83 | """ 84 | if isinstance(subscript, str): 85 | subscript = Subscript(subscript) 86 | 87 | assert ( 88 | len(subscript.inputs) == 1 89 | ), f"Can only reduce a single tensor at a time. Indices suggeststed: {len(subscript.inputs)} tensors: {subscript.inputs}" 90 | 91 | [idx] = subscript.inputs 92 | 93 | reduce_along_axes = [ 94 | i for i, char in enumerate(idx) if char not in subscript.output 95 | ] 96 | 97 | if not reduce_along_axes: 98 | return tensor 99 | 100 | indicator = tensor.array == tensor.xp.min( 101 | tensor.array, axis=tuple(reduce_along_axes), keepdims=True 102 | ) 103 | indicator = Tensor( 104 | indicator, requires_grad=False, is_batched=tensor.is_batched 105 | ) 106 | indicator.array = indicator.array.astype(tensor.xp.int8) 107 | 108 | new_subscript = Subscript.from_split([idx, idx], subscript.output) 109 | 110 | result = Einsum(new_subscript)(tensor, indicator) 111 | result.name = f"min({new_subscript})" 112 | 113 | return result 114 | -------------------------------------------------------------------------------- /src/tricycle/scheduler.py: -------------------------------------------------------------------------------- 1 | """Provides learning rate scheduling functions and classes. 2 | 3 | This module contains implementations of linear and cosine learning rate 4 | schedules with optional warmup periods. These can be used to dynamically 5 | adjust learning rates during training of machine learning models. 6 | 7 | Typical usage example: 8 | 9 | schedule = CosineSchedule(max_learning_rate=6e-4, min_learning_rate=0, 10 | total_steps=5000, warmup_steps=100) 11 | learning_rate = schedule(current_step) 12 | """ 13 | 14 | import math 15 | 16 | 17 | def linear_schedule( 18 | step: int, 19 | max_learning_rate: float, 20 | min_learning_rate: float, 21 | warmup_steps: int, 22 | total_steps: int, 23 | ) -> float: 24 | """Calculates the learning rate using a linear decay schedule with warmup. 25 | 26 | Args: 27 | step: Current step in the training process. 28 | max_learning_rate: Maximum learning rate. 29 | min_learning_rate: Minimum learning rate. 30 | warmup_steps: Number of warmup steps. 31 | total_steps: Total number of steps in the training process. 32 | 33 | Returns: 34 | The calculated learning rate for the current step. 35 | 36 | Raises: 37 | ValueError: If warmup_steps is greater than total_steps. 38 | """ 39 | # avoid an off by one error 40 | step += 1 41 | 42 | if warmup_steps: 43 | if total_steps < warmup_steps: 44 | raise ValueError( 45 | "Cannot have a warmup longer than the total number of steps" 46 | ) 47 | if step < warmup_steps: 48 | return (step / warmup_steps) * max_learning_rate 49 | 50 | coef = 1 - ((step - warmup_steps) / total_steps) 51 | coef *= max_learning_rate - min_learning_rate 52 | return min_learning_rate + coef 53 | 54 | 55 | class CosineSchedule: 56 | """A class to implement a cosine decay learning rate schedule with warmup. 57 | 58 | Attributes: 59 | max_learning_rate: Maximum learning rate. 60 | min_learning_rate: Minimum learning rate. 61 | total_steps: Total number of steps in the training process. 62 | warmup_steps: Number of warmup steps. 63 | n_steps: Number of steps after warmup. 64 | coef: Coefficient used in the cosine decay calculation. 65 | """ 66 | 67 | def __init__( 68 | self, 69 | max_learning_rate: float, 70 | min_learning_rate: float, 71 | total_steps: int, 72 | warmup_steps: int = 0, 73 | ): 74 | """Initialises the CosineSchedule with the given parameters. 75 | 76 | Args: 77 | max_learning_rate: Maximum learning rate. 78 | min_learning_rate: Minimum learning rate. 79 | total_steps: Total number of steps in the training process. 80 | warmup_steps: Number of warmup steps. Defaults to 0. 81 | 82 | Raises: 83 | ValueError: If warmup_steps is greater than total_steps. 84 | """ 85 | self.max_learning_rate = max_learning_rate 86 | self.min_learning_rate = min_learning_rate 87 | self.total_steps = total_steps 88 | self.warmup_steps = warmup_steps 89 | 90 | if self.total_steps < self.warmup_steps: 91 | raise ValueError( 92 | "Cannot have a warmup longer than the total number of steps" 93 | ) 94 | 95 | self.n_steps = total_steps - warmup_steps 96 | self.coef = (self.max_learning_rate - self.min_learning_rate) * 0.5 97 | 98 | def step( 99 | self, 100 | step: int, 101 | ) -> float: 102 | """Calculates the learning rate for a given step. 103 | 104 | Args: 105 | step: Current step in the training process. 106 | 107 | Returns: 108 | The calculated learning rate for the current step. 109 | """ 110 | # use 1 indexing so our inital LR is nonzero 111 | step += 1 112 | 113 | if step < self.warmup_steps: 114 | return (step / self.warmup_steps) * self.max_learning_rate 115 | 116 | if self.warmup_steps < step < self.total_steps: 117 | idx = math.pi * (step - self.warmup_steps) / self.n_steps 118 | 119 | return self.min_learning_rate + self.coef * (math.cos(idx) + 1) 120 | 121 | return self.min_learning_rate 122 | 123 | def __call__(self, step: int) -> float: 124 | """Allows the class to be called as a function. 125 | 126 | Args: 127 | step: Current step in the training process. 128 | 129 | Returns: 130 | The calculated learning rate for the current step. 131 | """ 132 | return self.step(step) 133 | 134 | 135 | if __name__ == "__main__": 136 | from matplotlib import pyplot as plt 137 | 138 | x = [i for i in range(5000)] 139 | schedule = CosineSchedule( 140 | max_learning_rate=6e-4, 141 | min_learning_rate=0, 142 | total_steps=5000, 143 | warmup_steps=100, 144 | ) 145 | y = [schedule(i) for i in x] 146 | plt.plot(x, y) 147 | plt.savefig("out.png") 148 | -------------------------------------------------------------------------------- /src/tricycle/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions and classes for the Tricycle project. 2 | 3 | This module contains various utility functions and classes used throughout 4 | the Tricycle project, including dataset handling, mixed precision training, 5 | tensor shape matching, and performance logging. 6 | 7 | """ 8 | 9 | import time 10 | from abc import abstractmethod 11 | from pathlib import Path 12 | from typing import TYPE_CHECKING, Iterable 13 | from warnings import warn 14 | 15 | import humanize 16 | import numpy as np 17 | 18 | from tricycle import GPU_ENABLED 19 | from tricycle.configs import GPTConfig 20 | from tricycle.context import TRICYCLE_CONTEXT 21 | from tricycle.exceptions import GPUDisabledException 22 | 23 | if TYPE_CHECKING: 24 | from tricycle.models import GPT 25 | from tricycle.tensor import Tensor 26 | 27 | 28 | class Dataset: 29 | """Abstract base class for datasets. 30 | 31 | This class defines the interface for dataset objects used in the project. 32 | Subclasses should implement the __len__ and __getitem__ methods. 33 | """ 34 | 35 | @abstractmethod 36 | def __len__(self): 37 | """Returns the number of items in the dataset.""" 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def __getitem__(self, index): 42 | """Returns the item at the specified index.""" 43 | raise NotImplementedError 44 | 45 | def __iter__(self): 46 | """Returns an iterator over the dataset.""" 47 | for idx in range(self.__len__()): 48 | yield self.__getitem__(idx) 49 | 50 | def __next__(self): 51 | """Returns the next item in the dataset.""" 52 | return self.__getitem__(next(iter(self))) 53 | 54 | 55 | class UseMixedPrecision: 56 | """Context manager for enabling mixed precision training. 57 | 58 | This class provides a context manager that enables mixed precision training 59 | when entered and disables it when exited. 60 | 61 | Args: 62 | initial_loss_scale_factor (int): The initial loss scale factor for mixed 63 | precision training. Defaults to 128. 64 | """ 65 | 66 | def __init__(self, initial_loss_scale_factor: int = 128): 67 | self.active = False 68 | TRICYCLE_CONTEXT.loss_scale_factor = initial_loss_scale_factor 69 | warn( 70 | "Mixed precision training is unstable. Expect your loss to " 71 | "explode/vanish." 72 | ) 73 | 74 | def __enter__(self): 75 | """Enables mixed precision training.""" 76 | self.active = True 77 | TRICYCLE_CONTEXT.use_mixed_precision = True 78 | 79 | def __exit__(self, *args, **kwargs): 80 | """Disables mixed precision training.""" 81 | self.active = False 82 | TRICYCLE_CONTEXT.use_mixed_precision = False 83 | 84 | 85 | def shapes_match(tensor_1: "Tensor", tensor_2: "Tensor") -> bool: 86 | """Checks if the shapes of two tensors match for binary operations. 87 | 88 | Args: 89 | tensor_1: The first tensor to compare. 90 | tensor_2: The second tensor to compare. 91 | 92 | Returns: 93 | bool: True if the shapes match, False otherwise. 94 | 95 | Raises: 96 | ValueError: If the shapes do not match. 97 | """ 98 | # sourcery skip: assign-if-exp, merge-duplicate-blocks, remove-redundant-if 99 | if tensor_1.is_batched and tensor_2.is_batched: 100 | shape_1 = tensor_1.shape 101 | shape_2 = tensor_2.shape 102 | elif tensor_1.is_batched: 103 | shape_1 = tensor_1.shape[1:] 104 | shape_2 = tensor_2.shape 105 | elif tensor_2.is_batched: 106 | shape_1 = tensor_1.shape 107 | shape_2 = tensor_2.shape[1:] 108 | else: 109 | shape_1 = tensor_1.shape 110 | shape_2 = tensor_2.shape 111 | 112 | if shape_1 != shape_2: 113 | raise ValueError( 114 | f"Shapes {shape_1} and {shape_2} do not match: " 115 | f"{tensor_1.array.shape}, {tensor_2.array.shape}" 116 | ) 117 | return shape_1 == shape_2 118 | 119 | 120 | def smooth(iterable: Iterable, factor: float): 121 | """Applies exponential smoothing to an iterable. 122 | 123 | Args: 124 | iterable: The input iterable to smooth. 125 | factor: The smoothing factor. 126 | 127 | Yields: 128 | float: The smoothed values. 129 | """ 130 | prev = 0 131 | for val in iterable: 132 | yield prev * factor + (val - prev) * factor 133 | prev = val 134 | 135 | 136 | def r_squared(actual_values, predicted_values): 137 | """Calculates the R-squared metric. 138 | 139 | Args: 140 | actual_values: The actual values. 141 | predicted_values: The predicted values. 142 | 143 | Returns: 144 | float: The R-squared value. 145 | """ 146 | actual_values = np.array(actual_values) 147 | predicted_values = np.array(predicted_values) 148 | 149 | mean_actual = np.mean(actual_values) 150 | tss = np.sum((actual_values - mean_actual) ** 2) 151 | rss = np.sum((actual_values - predicted_values) ** 2) 152 | 153 | return 1 - (rss / tss) 154 | 155 | 156 | def log_memory_and_time(stage: str, path: Path = Path("memory.log")): 157 | """Logs the current GPU memory usage and timestamp to a file. 158 | 159 | Args: 160 | stage: A string describing the current stage of execution. 161 | path: The path to the log file. Defaults to "memory.log". 162 | 163 | Raises: 164 | GPUDisabledException: If GPU is not enabled. 165 | """ 166 | if not GPU_ENABLED: 167 | raise GPUDisabledException( 168 | "Cannot log GPU memory if GPU is not enabled" 169 | ) 170 | 171 | import cupy 172 | 173 | if not path.exists(): 174 | path.write_text( 175 | "stage,used_bytes_human,total_bytes_human,used_bytes,total_bytes,timestamp\n" # noqa: E501 176 | ) 177 | 178 | pool = cupy.get_default_memory_pool() 179 | now = time.perf_counter() 180 | 181 | used_bytes = humanize.naturalsize(pool.used_bytes()) 182 | total_bytes = humanize.naturalsize(pool.total_bytes()) 183 | with open(path, "a") as f: 184 | f.write( 185 | f"{stage},{used_bytes},{total_bytes},{pool.used_bytes()},{pool.total_bytes()},{now}\n" # noqa: E501 186 | ) 187 | 188 | 189 | def optimal_n_tokens(model: "GPT", config: GPTConfig) -> tuple[int, int]: 190 | """Estimates the compute-optimal number of tokens to train on using Chinchilla scaling. 191 | 192 | Args: 193 | model: The GPT model. 194 | config: The GPT configuration. 195 | 196 | Returns: 197 | tuple: A tuple containing the optimal number of tokens and steps. 198 | 199 | Reference: 200 | https://arxiv.org/abs/2404.10102 201 | """ 202 | # values from the appendix of the paper 203 | flops = [ 204 | 1.84e19, 205 | 1.20e20, 206 | 1.32e22, 207 | 6.88e23, 208 | 4.54e24, 209 | 1.18e25, 210 | 4.19e25, 211 | 1.59e26, 212 | 1.75e28, 213 | ] 214 | tokens = [ 215 | 7.7e9, 216 | 20e9, 217 | 219.5e9, 218 | 1.7e12, 219 | 4.3e12, 220 | 7.1e12, 221 | 13.4e12, 222 | 26.5e12, 223 | 292e12, 224 | ] 225 | 226 | # fit a linear regression 227 | slope, intercept = np.polyfit(np.log(flops), np.log(tokens), 1) 228 | 229 | n_parameters = sum(size for _, size, _ in model._contents()) 230 | 231 | # rearrange regression to get number of tokens: 232 | # 233 | # assuming flops ~= 6 * n_tokens * n_parameters, we get 234 | # log(tokens) = slope * log(6 * n_tokens * n_parameters) + intercept 235 | # which rearranges to the following: 236 | power = 1 / (1 - slope) 237 | constant = (6**slope) * (n_parameters**slope) * np.exp(intercept) 238 | n_tokens = int(constant**power) 239 | 240 | tokens_per_step = ( 241 | config.batch_size 242 | * config.gradient_accumulation_steps 243 | * config.context_window 244 | ) 245 | tokens_per_parameter = n_tokens / n_parameters 246 | 247 | n_steps = n_tokens // tokens_per_step 248 | 249 | print("Chinchilla Optimal Parameters:") 250 | print(f" - Number of tokens: {humanize.intword(n_tokens)}") 251 | print(f" - Number of steps: {humanize.intword(n_steps)}") 252 | print(f" - Tokens per parameters: {tokens_per_parameter:.1f}") 253 | return n_tokens, n_steps 254 | -------------------------------------------------------------------------------- /src/tricycle/weakset.py: -------------------------------------------------------------------------------- 1 | from collections.abc import MutableSet 2 | from typing import Any 3 | from weakref import WeakValueDictionary 4 | 5 | 6 | class WeakSet(MutableSet): 7 | """A Set that uses weak references and does not check for equality. 8 | 9 | Normal sets check that two elements are equal by comparing __hash__ 10 | and __eq__. For tensors, __eq__ is slow so this class only checks 11 | __hash__. This is a bad idea normally but because we control the hash 12 | for Tensors, we can (hopefully) use it for gradient calculation 13 | 14 | To avoid circular dependencies, we implement this as a weak set so that 15 | the garbage collector can clean up objects that are referred to by this 16 | class 17 | 18 | Attributes: 19 | _dict: A WeakValueDictionary to store the weak references. 20 | 21 | """ 22 | 23 | def __init__(self, *args, **kwargs): 24 | """Initializes the WeakSet. 25 | 26 | Args: 27 | *args: Variable length argument list. 28 | **kwargs: Arbitrary keyword arguments. 29 | """ 30 | super().__init__(*args, **kwargs) 31 | self._dict = WeakValueDictionary() 32 | 33 | def __contains__(self, x: Any) -> bool: 34 | """Checks if an element is in the set. 35 | 36 | Args: 37 | x: The element to check. 38 | 39 | Returns: 40 | bool: True if the element is in the set, False otherwise. 41 | """ 42 | return hash(x) in self._dict 43 | 44 | def __iter__(self): 45 | """Returns an iterator over the elements in the set. 46 | 47 | Returns: 48 | iterator: An iterator over the values in the WeakValueDictionary. 49 | """ 50 | return self._dict.values() 51 | 52 | def __len__(self): 53 | """Returns the number of elements in the set. 54 | 55 | Returns: 56 | int: The number of elements in the set. 57 | """ 58 | return len(self._dict) 59 | 60 | def add(self, x: Any): 61 | """Adds an element to the set. 62 | 63 | Args: 64 | x: The element to add to the set. 65 | """ 66 | self._dict[hash(x)] = x 67 | 68 | def discard(self, x: Any): 69 | """Removes an element from the set if it is present. 70 | 71 | Args: 72 | x: The element to remove from the set. 73 | """ 74 | if hash(x) in self._dict: 75 | del self._dict[hash(x)] 76 | -------------------------------------------------------------------------------- /src/tricycle_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import codeparrot, fineweb, shakespeare 2 | 3 | __all__ = ["fineweb", "codeparrot", "shakespeare"] 4 | -------------------------------------------------------------------------------- /src/tricycle_datasets/codeparrot.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module prepares and handles the CodeParrot dataset: a dataset of python 3 | files scraped from github 4 | 5 | It downloads, tokenizes, and processes the CodeParrot dataset, creating memory-mapped 6 | files for efficient data handling during training. The module also provides a 7 | CodeParrot class for easy access to the processed data. 8 | 9 | Typical usage example: 10 | 11 | dataset = CodeParrot(vocab_size=100000, split="train") 12 | tokens = dataset[0:1000] # Get the first 1000 tokens 13 | """ 14 | 15 | import os 16 | from collections import abc 17 | from pathlib import Path 18 | from typing import Literal 19 | 20 | import numpy as np 21 | import tiktoken 22 | from tqdm.auto import tqdm 23 | 24 | from datasets import load_dataset 25 | 26 | N_CORES = os.cpu_count() 27 | SAVE_DIR = Path("datasets/codeparrot") 28 | SAVE_DIR.mkdir(exist_ok=True, parents=True) 29 | DTYPE = np.uint32 30 | 31 | tokeniser = tiktoken.get_encoding("cl100k_base") 32 | 33 | 34 | def tokenise_document(example): 35 | """ 36 | Tokenizes a single document from the dataset. 37 | 38 | Args: 39 | example: A dictionary containing the document content. 40 | 41 | Returns: 42 | A dictionary with tokenized 'ids' and 'len' fields. 43 | """ 44 | ids = tokeniser.encode_ordinary( 45 | example["content"] 46 | ) # encode_ordinary ignores any special tokens 47 | ids.append(tokeniser.eot_token) # add the end of text token 48 | out = {"ids": ids, "len": len(ids)} 49 | return out 50 | 51 | 52 | def prepare_data(): 53 | """ 54 | Downloads and tokenizes the CodeParrot dataset. 55 | 56 | This function splits the dataset into train and validation sets, 57 | tokenizes the content, and saves the tokenized data as memory-mapped files. 58 | 59 | Note: 60 | This script is adapted from Andrej Karpathy's NanoGPT: 61 | https://github.com/karpathy/nanoGPT/blob/master/data/openwebtext/prepare.py 62 | """ 63 | split_dataset = dataset.train_test_split( 64 | test_size=0.0005, seed=2357, shuffle=True 65 | ) 66 | split_dataset["valid"] = split_dataset.pop( 67 | "test" 68 | ) # rename the test split to val 69 | 70 | # tokenise the dataset 71 | tokenised = split_dataset.map( 72 | tokenise_document, 73 | remove_columns=["content"], 74 | desc="Tokenising", 75 | num_proc=N_CORES, 76 | ) 77 | 78 | # concatenate all the ids in each dataset into one large file we can use 79 | # for training 80 | for split, dset in tokenised.items(): 81 | filename = SAVE_DIR / f"{split}.bin" 82 | 83 | n_tokens = np.sum(dset["len"], dtype=np.uint64) 84 | print(f"Found: {n_tokens} {split} tokens") 85 | 86 | arr = np.memmap(filename, dtype=DTYPE, mode="w+", shape=(n_tokens,)) 87 | total_batches = 1024 88 | 89 | idx = 0 90 | for batch_idx in tqdm( 91 | range(total_batches), desc=f"writing {filename}" 92 | ): 93 | # Batch together samples for faster write 94 | batch = dset.shard( 95 | num_shards=total_batches, index=batch_idx, contiguous=True 96 | ).with_format("numpy") 97 | arr_batch = np.concatenate(batch["ids"]) 98 | # Write into mmap 99 | arr[idx : idx + len(arr_batch)] = arr_batch 100 | idx += len(arr_batch) 101 | arr.flush() 102 | 103 | 104 | class CodeParrot(abc.Sequence): 105 | """ 106 | A class to handle the CodeParrot dataset. 107 | 108 | This class provides an interface to access the tokenized CodeParrot dataset, 109 | including methods for encoding and decoding text. 110 | 111 | Attributes: 112 | url: The source URL of the dataset. 113 | vocab_size: The size of the vocabulary. 114 | token_path: The path to the tokenized data file. 115 | tokeniser_string: The name of the tokenizer to use. 116 | tokens: The memory-mapped array of tokens. 117 | 118 | Args: 119 | vocab_size: The size of the vocabulary to use. 120 | split: The dataset split to use ("train" or "valid"). 121 | token_path: Optional custom path to the tokenized data file. 122 | """ 123 | 124 | url: str = ( 125 | "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" # noqa: E501 126 | ) 127 | vocab_size: int 128 | token_path: Path 129 | tokeniser_string: str = "cl100k_base" 130 | tokens: np.ndarray 131 | 132 | def __init__( 133 | self, 134 | vocab_size: int, 135 | split: Literal["train"] | Literal["valid"], 136 | token_path: Path | None = None, 137 | ): 138 | self.vocab_size = vocab_size 139 | 140 | self.tokeniser = tiktoken.get_encoding(self.tokeniser_string) 141 | if self.tokeniser.max_token_value != vocab_size: 142 | raise ValueError( 143 | "Expected tokeniser.max_token_value == vocab_size. Found " 144 | f"{self.tokeniser.max_token_value=}, {vocab_size=}" 145 | ) 146 | 147 | if token_path is None: 148 | self.token_path = SAVE_DIR / f"{split}.bin" 149 | else: 150 | self.token_path = token_path 151 | 152 | if not self.token_path.exists(): 153 | prepare_data() 154 | 155 | assert self.token_path.exists() 156 | self.tokens = np.memmap(self.token_path, dtype=DTYPE, mode="r") 157 | 158 | def __getitem__(self, key): 159 | """ 160 | Retrieves tokens at the specified index or slice. 161 | 162 | Args: 163 | key: An integer index or slice object. 164 | 165 | Returns: 166 | The token(s) at the specified index or slice. 167 | """ 168 | return self.tokens[key] 169 | 170 | def __len__(self): 171 | """ 172 | Returns the total number of tokens in the dataset. 173 | 174 | Returns: 175 | The length of the tokens array. 176 | """ 177 | return len(self.tokens) 178 | 179 | def encode(self, *args): 180 | """ 181 | Encodes the input text into tokens. 182 | 183 | Args: 184 | *args: The text to encode. 185 | 186 | Returns: 187 | A list of token ids. 188 | """ 189 | return self.tokeniser.encode_ordinary(*args) 190 | 191 | def decode(self, *args): 192 | """ 193 | Decodes the input tokens into text. 194 | 195 | Args: 196 | *args: The tokens to decode. 197 | 198 | Returns: 199 | The decoded text as a string. 200 | """ 201 | return self.tokeniser.decode(*args) 202 | -------------------------------------------------------------------------------- /src/tricycle_datasets/fineweb.py: -------------------------------------------------------------------------------- 1 | """Prepares and manages web text data from Fineweb. 2 | 3 | This module provides functionality to download, tokenize, and manage the 4 | fineweb dataset. It includes utilities for data preparation and a custom 5 | dataset class for efficient data loading. 6 | 7 | Typical usage example: 8 | 9 | dataset = FineWeb(vocab_size=50257, split='train') 10 | tokens = dataset[0:1000] # Get the first 1000 tokens 11 | """ 12 | 13 | import os 14 | from collections import abc 15 | from pathlib import Path 16 | from typing import Literal 17 | 18 | import numpy as np 19 | import tiktoken 20 | from tqdm.auto import tqdm 21 | 22 | import datasets 23 | from datasets import load_dataset 24 | 25 | N_CORES = os.cpu_count() 26 | SAVE_DIR = Path("datasets/fineweb") 27 | SAVE_DIR.mkdir(exist_ok=True, parents=True) 28 | DTYPE = np.uint16 29 | 30 | 31 | tokeniser = tiktoken.get_encoding("gpt2") 32 | 33 | 34 | def tokenise_document(example): 35 | """Tokenizes a single document from the dataset. 36 | 37 | Args: 38 | example: A dictionary containing the 'text' field to be tokenized. 39 | 40 | Returns: 41 | A dictionary with 'ids' (tokenized text) and 'len' (number of tokens). 42 | """ 43 | ids = tokeniser.encode_ordinary( 44 | example["text"] 45 | ) # encode_ordinary ignores any special tokens 46 | ids.append(tokeniser.eot_token) # add the end of text token 47 | out = {"ids": ids, "len": len(ids)} 48 | return out 49 | 50 | 51 | def prepare_data(): 52 | """Downloads and tokenizes the coreparrot dataset. 53 | 54 | This function is adapted from Andrej Karpathy's NanoGPT: 55 | https://github.com/karpathy/nanoGPT/blob/master/data/openwebtext/prepare.py 56 | 57 | The function performs the following steps: 58 | 1. Loads the dataset 59 | 2. Splits it into train and validation sets 60 | 3. Tokenizes the dataset 61 | 4. Saves the tokenized data to binary files 62 | 63 | Note: 64 | This function uses OpenAI's tiktoken for tokenization due to 65 | performance considerations. 66 | """ 67 | datasets.disable_caching() 68 | dataset = load_dataset( 69 | "HuggingFaceFW/fineweb", name="sample-10BT", split="train" 70 | ) 71 | split_dataset = dataset.train_test_split( 72 | test_size=0.0005, seed=2357, shuffle=True 73 | ) 74 | split_dataset["valid"] = split_dataset.pop( 75 | "test" 76 | ) # rename the test split to val 77 | 78 | # tokenise the dataset 79 | tokenised = split_dataset.map( 80 | tokenise_document, 81 | remove_columns=["text"], 82 | desc="Tokenising", 83 | num_proc=N_CORES, 84 | ) 85 | 86 | # concatenate all the ids in each dataset into one large file we can use 87 | # for training 88 | for split, dset in tokenised.items(): 89 | filename = SAVE_DIR / f"{split}.bin" 90 | 91 | n_tokens = np.sum(dset["len"]) 92 | print(f"Found: {n_tokens} {split} tokens") 93 | 94 | arr = np.memmap(filename, dtype=DTYPE, mode="w+", shape=(n_tokens,)) 95 | total_batches = 1024 96 | 97 | idx = 0 98 | for batch_idx in tqdm( 99 | range(total_batches), desc=f"writing {filename}" 100 | ): 101 | # Batch together samples for faster write 102 | batch = dset.shard( 103 | num_shards=total_batches, index=batch_idx, contiguous=True 104 | ).with_format("numpy") 105 | arr_batch = np.concatenate(batch["ids"]) 106 | # Write into mmap 107 | arr[idx : idx + len(arr_batch)] = arr_batch 108 | idx += len(arr_batch) 109 | arr.flush() 110 | 111 | 112 | class FineWeb(abc.Sequence): 113 | """A custom dataset class for efficient loading of tokenized fineweb data. 114 | 115 | This class provides an interface to access tokenized fineweb data, 116 | supporting indexing and length operations. It also includes methods 117 | for encoding and decoding tokens. 118 | 119 | Attributes: 120 | vocab_size: An integer representing the vocabulary size. 121 | token_path: A Path object pointing to the tokenized data file. 122 | tokeniser_string: A string specifying the tokenizer to use (default: "gpt2"). 123 | tokens: A numpy memmap of the tokenized data. 124 | 125 | Args: 126 | vocab_size: An integer specifying the vocabulary size. 127 | split: A string literal, either "train" or "valid", specifying the dataset split. 128 | token_path: An optional Path object for the tokenized data file. 129 | 130 | Raises: 131 | ValueError: If the tokenizer's max token value doesn't match the specified vocab size. 132 | """ 133 | 134 | vocab_size: int 135 | token_path: Path 136 | tokeniser_string: str = "gpt2" 137 | tokens: np.ndarray 138 | 139 | def __init__( 140 | self, 141 | vocab_size: int, 142 | split: Literal["train"] | Literal["valid"], 143 | token_path: Path | None = None, 144 | ): 145 | self.vocab_size = vocab_size 146 | 147 | self.tokeniser = tiktoken.get_encoding(self.tokeniser_string) 148 | if self.tokeniser.max_token_value != vocab_size: 149 | raise ValueError( 150 | "Expected tokeniser.max_token_value == vocab_size. Found " 151 | f"{self.tokeniser.max_token_value=}, {vocab_size=}" 152 | ) 153 | 154 | if token_path is None: 155 | self.token_path = SAVE_DIR / f"{split}.bin" 156 | else: 157 | self.token_path = token_path 158 | 159 | if not self.token_path.exists(): 160 | prepare_data() 161 | 162 | assert self.token_path.exists() 163 | self.tokens = np.memmap(self.token_path, dtype=DTYPE, mode="r") 164 | 165 | def __getitem__(self, key): 166 | """Retrieves token(s) at the specified index or slice. 167 | 168 | Args: 169 | key: An integer index or slice object. 170 | 171 | Returns: 172 | The token(s) at the specified index or slice. 173 | """ 174 | return self.tokens[key] 175 | 176 | def __len__(self): 177 | """Returns the total number of tokens in the dataset. 178 | 179 | Returns: 180 | An integer representing the number of tokens. 181 | """ 182 | return len(self.tokens) 183 | 184 | def encode(self, *args): 185 | """Encodes the input text into tokens. 186 | 187 | Args: 188 | *args: Variable length argument list to be passed to the tokenizer. 189 | 190 | Returns: 191 | A list of integer token IDs. 192 | """ 193 | return self.tokeniser.encode_ordinary(*args) 194 | 195 | def decode(self, *args): 196 | """Decodes the input tokens into text. 197 | 198 | Args: 199 | *args: Variable length argument list to be passed to the tokenizer. 200 | 201 | Returns: 202 | A string of decoded text. 203 | """ 204 | return self.tokeniser.decode(*args) 205 | -------------------------------------------------------------------------------- /src/tricycle_datasets/shakespeare.py: -------------------------------------------------------------------------------- 1 | """Provides classes for handling Shakespeare datasets. 2 | 3 | This module contains two main classes: 4 | 1. Shakespeare: For handling tokenized Shakespeare text using BPE tokenization. 5 | 2. ShakespeareChar: For handling character-level Shakespeare text. 6 | 7 | Both classes provide methods for downloading, tokenizing, encoding, and decoding 8 | Shakespeare's text. 9 | 10 | Typical usage example: 11 | 12 | shakespeare = Shakespeare(1024) 13 | char_shakespeare = ShakespeareChar() 14 | """ 15 | 16 | import pickle 17 | from collections import abc 18 | from pathlib import Path 19 | 20 | import numpy as np 21 | import requests 22 | 23 | from tricycle.tokeniser import BPETokeniser 24 | 25 | 26 | class Shakespeare(abc.Sequence): 27 | """A class for handling tokenized Shakespeare text using BPE tokenization. 28 | 29 | This class downloads the Shakespeare dataset, tokenizes it using BPE, 30 | and provides methods for encoding and decoding text. 31 | 32 | Attributes: 33 | url: A string containing the URL for the Shakespeare dataset. 34 | vocab_size: An integer representing the size of the vocabulary. 35 | token_path: A Path object for the tokenized data file. 36 | raw_data_path: A Path object for the raw data file. 37 | tokens: A numpy array containing the tokenized data. 38 | tokeniser: A BPETokeniser object for tokenization. 39 | """ 40 | 41 | url: str = ( 42 | "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" # noqa: E501 43 | ) 44 | vocab_size: int 45 | token_path: Path 46 | raw_data_path: Path 47 | tokens: np.ndarray 48 | 49 | def __init__( 50 | self, 51 | vocab_size: int, 52 | token_path: Path | None = None, 53 | raw_data_path: Path = Path("datasets/shakespeare/raw_data.txt"), 54 | tokeniser_path: Path = Path("datasets/shakespeare/tokeniser.pkl"), 55 | ): 56 | """Initializes the Shakespeare object. 57 | 58 | Args: 59 | vocab_size: An integer representing the size of the vocabulary. 60 | token_path: A Path object for the tokenized data file. If None, a default path is used. 61 | raw_data_path: A Path object for the raw data file. 62 | tokeniser_path: A Path object for the tokeniser pickle file. 63 | """ 64 | if token_path is None: 65 | token_path = Path(f"datasets/shakespeare/tokens_{vocab_size}.pkl") 66 | 67 | self.vocab_size = vocab_size 68 | self.raw_data_path = raw_data_path 69 | self.token_path = token_path 70 | self.tokeniser_path = tokeniser_path 71 | 72 | if self.tokeniser_path.exists(): 73 | with open(self.tokeniser_path, "rb") as f: 74 | self.tokeniser = pickle.load(f) 75 | else: 76 | self.tokeniser = None 77 | 78 | if not self.token_path.exists(): 79 | self.tokeniser = self.generate() 80 | self.tokeniser_path.parent.mkdir(parents=True, exist_ok=True) 81 | with open(self.tokeniser_path, "wb") as f: 82 | pickle.dump(self.tokeniser, f) 83 | 84 | self.tokens = self.tokeniser.tokens 85 | self.token_path.parent.mkdir(parents=True, exist_ok=True) 86 | with open(self.token_path, "wb") as f: 87 | pickle.dump(self.tokens, f) 88 | else: 89 | with open(self.token_path, "rb") as f: 90 | self.tokens = pickle.load(f) 91 | 92 | def download(self): 93 | """Downloads the Shakespeare dataset. 94 | 95 | The downloaded data is saved to the path specified by raw_data_path. 96 | """ 97 | raw_data = requests.get(self.url).text 98 | self.raw_data_path.parent.mkdir(parents=True, exist_ok=True) 99 | with open(self.raw_data_path, "w") as f: 100 | f.write(raw_data) 101 | 102 | def generate(self) -> BPETokeniser: 103 | """Downloads and tokenizes the Shakespeare dataset. 104 | 105 | Returns: 106 | A BPETokeniser object trained on the Shakespeare dataset. 107 | """ 108 | self.download() 109 | raw_data = np.array( 110 | list(self.raw_data_path.read_bytes()), dtype=np.int32 111 | ) 112 | if self.tokeniser is None: 113 | self.tokeniser = BPETokeniser(self.vocab_size) 114 | return self.tokeniser.train_ints(raw_data, loading_bar=True) 115 | 116 | def __getitem__(self, idx: int) -> int | list[int]: 117 | """Returns the token(s) at the specified index. 118 | 119 | Args: 120 | idx: An integer index or slice. 121 | 122 | Returns: 123 | The token(s) at the specified index. 124 | """ 125 | return self.tokens[idx] 126 | 127 | def __len__(self) -> int: 128 | """Returns the number of tokens in the dataset. 129 | 130 | Returns: 131 | An integer representing the number of tokens. 132 | """ 133 | return len(self.tokens) 134 | 135 | def encode(self, *args): 136 | """Encodes the input using the BPE tokenizer. 137 | 138 | Args: 139 | *args: Arguments to pass to the tokenizer's encode method. 140 | 141 | Returns: 142 | The encoded input. 143 | """ 144 | return self.tokeniser.encode(*args) 145 | 146 | def decode(self, *args): 147 | """Decodes the input using the BPE tokenizer. 148 | 149 | Args: 150 | *args: Arguments to pass to the tokenizer's decode method. 151 | 152 | Returns: 153 | The decoded input. 154 | """ 155 | return self.tokeniser.decode(*args) 156 | 157 | 158 | class ShakespeareChar(abc.Sequence): 159 | """A class for handling character-level Shakespeare text. 160 | 161 | This class downloads the Shakespeare dataset and provides methods for 162 | encoding and decoding text at the character level. 163 | 164 | Attributes: 165 | url: A string containing the URL for the Shakespeare dataset. 166 | vocab_size: An integer representing the size of the vocabulary. 167 | raw_data_path: A Path object for the raw data file. 168 | chars: A list of integers representing the characters in the dataset. 169 | """ 170 | 171 | url: str = ( 172 | "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" # noqa: E501 173 | ) 174 | vocab_size: int 175 | raw_data_path: Path 176 | chars: list[int] 177 | 178 | def __init__( 179 | self, 180 | raw_data_path: Path = Path("datasets/shakespeare/raw_data.txt"), 181 | ): 182 | """Initializes the ShakespeareChar object. 183 | 184 | Args: 185 | raw_data_path: A Path object for the raw data file. 186 | """ 187 | self.raw_data_path = raw_data_path 188 | self.chars = self.generate() 189 | self.vocab_size = len(set(self.chars)) 190 | 191 | def encode(self, chars: list[int] | str): 192 | """Encodes the input characters into character IDs. 193 | 194 | Args: 195 | chars: A list of integers or a string to encode. 196 | 197 | Returns: 198 | A list of integer character IDs. 199 | """ 200 | if isinstance(chars, str): 201 | chars = [ord(i) for i in chars] 202 | return [self.char_ids[c] for c in chars] 203 | 204 | def decode(self, char_ids: list[int]): 205 | """Decodes the input character IDs into characters. 206 | 207 | Args: 208 | char_ids: A list of integer character IDs to decode. 209 | 210 | Returns: 211 | A list of decoded characters. 212 | """ 213 | inv_char_ids = {i: c for c, i in self.char_ids.items()} 214 | return [inv_char_ids[i] for i in char_ids] 215 | 216 | def download(self): 217 | """Downloads the Shakespeare dataset. 218 | 219 | The downloaded data is saved to the path specified by raw_data_path. 220 | """ 221 | raw_data = requests.get(self.url).text 222 | self.raw_data_path.parent.mkdir(parents=True, exist_ok=True) 223 | with open(self.raw_data_path, "w") as f: 224 | f.write(raw_data) 225 | 226 | def generate(self) -> list[int]: 227 | """Downloads and processes the Shakespeare dataset. 228 | 229 | Returns: 230 | A list of integers representing the characters in the dataset. 231 | """ 232 | if not self.raw_data_path.exists(): 233 | self.download() 234 | 235 | raw_data = list(self.raw_data_path.read_bytes()) 236 | self.char_ids = {c: i for i, c in enumerate(set(raw_data))} 237 | return self.encode(raw_data) 238 | 239 | def __getitem__(self, idx: int) -> int | list[int]: 240 | """Returns the character(s) at the specified index. 241 | 242 | Args: 243 | idx: An integer index or slice. 244 | 245 | Returns: 246 | The character(s) at the specified index. 247 | """ 248 | return self.chars[idx] 249 | 250 | def __len__(self) -> int: 251 | """Returns the number of characters in the dataset. 252 | 253 | Returns: 254 | An integer representing the number of characters. 255 | """ 256 | return len(self.chars) 257 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bclarkson-code/Tricycle/ebccb095d029c0d504eef11bace1a5287226bbc4/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | def pytest_addoption(parser): 2 | parser.addoption( 3 | "--run-slow", 4 | action="store_true", 5 | default=False, 6 | help="Run slow tests", 7 | ) 8 | -------------------------------------------------------------------------------- /tests/test_activations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tricycle.activation import GLU, GeLU, ReLU, Swish 4 | from tricycle.tensor import Tensor 5 | 6 | 7 | def test_relu(): 8 | x = Tensor([-1, 0, 1]) 9 | relu = ReLU() 10 | y = relu(x) 11 | assert y.close_to([0, 0, 1]) 12 | 13 | 14 | def test_swish(): 15 | x = Tensor([-1, 0, 1]) 16 | swish = Swish() 17 | y = swish(x) 18 | assert y.close_to([-0.26894142, 0.0, 0.73105858], rtol=1e-3) 19 | 20 | 21 | def test_gelu_full(): 22 | x = Tensor([-1, 0, 1]) 23 | gelu = GeLU(approximate=False) 24 | y = gelu(x) 25 | assert y.close_to([-0.158808, 0.0, 0.841192], rtol=1e-3) 26 | 27 | 28 | def test_gelu_batched(): 29 | x = Tensor([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]) 30 | x = x.to_batched() 31 | gelu = GeLU(approximate=False) 32 | y = gelu(x) 33 | assert y.close_to( 34 | [ 35 | [-0.158808, 0.0, 0.841192], 36 | [-0.158808, 0.0, 0.841192], 37 | [-0.158808, 0.0, 0.841192], 38 | ], 39 | rtol=1e-3, 40 | ) 41 | 42 | 43 | def test_gelu_approx(): 44 | x = Tensor([-1, 0, 1]) 45 | gelu = GeLU(approximate=True) 46 | y = gelu(x) 47 | 48 | assert y.close_to(GeLU(approximate=False)(x), rtol=1e-3) 49 | 50 | 51 | def test_glu(): 52 | x = Tensor([-1, 0, 2]) 53 | glu = GLU(size=3) 54 | glu.linear.weights = Tensor(np.ones(glu.linear.weights.shape)) 55 | 56 | y = glu(x) 57 | assert y.close_to([0.73105858, 0.73105858, 0.73105858], rtol=1e-3) 58 | -------------------------------------------------------------------------------- /tests/test_attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from tricycle.attention import Attention, build_mask 5 | from tricycle.blocks import MultiHeadSelfAttention 6 | from tricycle.context import TRICYCLE_CONTEXT 7 | from tricycle.einsum import Einsum 8 | from tricycle.functions import Softmax 9 | from tricycle.tensor import DEFAULT_DTYPE, Tensor 10 | 11 | TORCH_DTYPE = ( 12 | torch.float16 if TRICYCLE_CONTEXT.use_mixed_precision else torch.float32 13 | ) 14 | 15 | 16 | def pytorch_attention(q, k, v, B, T, C, n_head): 17 | k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) 18 | q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) 19 | v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) 20 | # return k 21 | y = torch.nn.functional.scaled_dot_product_attention( 22 | q, 23 | k, 24 | v, 25 | attn_mask=None, 26 | dropout_p=0, 27 | is_causal=True, 28 | ) 29 | y = y.transpose(1, 2).contiguous().view(B, T, C) 30 | return y 31 | 32 | 33 | def andrej_attention(q, k, v, B, T, C, n_head, block_size=32, bias=None): 34 | """ 35 | Andrej Karpathy's implementation of attention from nanogpt 36 | """ 37 | import math 38 | 39 | from torch.nn import functional as F 40 | 41 | if bias is None: 42 | bias = torch.tril(torch.ones(block_size, block_size)).view( 43 | 1, 1, block_size, block_size 44 | ) 45 | k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) 46 | q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) 47 | v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) 48 | 49 | att = (q.to(torch.float32) @ k.to(torch.float32).transpose(-2, -1)).to( 50 | TORCH_DTYPE 51 | ) 52 | att *= 1.0 / math.sqrt(k.size(-1)) 53 | att = att.masked_fill(bias[:, :, :T, :T] == 0, float("-inf")) 54 | att = F.softmax(att.to(torch.float32), dim=-1) 55 | y = att @ v.to( 56 | torch.float32 57 | ) # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 58 | return y.to(TORCH_DTYPE).transpose(1, 2).contiguous().view(B, T, C) 59 | 60 | 61 | def andrej_attention_block( 62 | x, B, T, C, n_head, c_attn, c_proj, n_embd, block_size=32 63 | ): 64 | """ 65 | Andrej Karpathy's implementation of an attention block from nanogpt 66 | """ 67 | q, k, v = c_attn(x).split(n_embd, dim=-1) 68 | y = andrej_attention(q, k, v, B, T, C, n_head, block_size) 69 | return c_proj(y.to(torch.float32)).to(TORCH_DTYPE) 70 | 71 | 72 | def test_attention_individually(): 73 | """ 74 | This operation is pretty complex so we'll perform each stage 75 | with pytorch and then compare the results. Here, I'm comparing 76 | with Andrej Karpathy's implementation from NanoGPT 77 | For this test, we're doing everything non-batch 78 | """ 79 | # setup 80 | embedding_dim = 15 81 | n_heads = 3 82 | n_tokens = 7 83 | projected_size = embedding_dim * 3 84 | context_window = 32 85 | head_size = embedding_dim // n_heads 86 | head_shape = (n_tokens, n_heads, head_size) 87 | out_shape = (n_tokens, embedding_dim) 88 | T = n_tokens 89 | C = embedding_dim 90 | 91 | # random input tensor 92 | in_tensor = np.random.uniform(-5, 5, (n_tokens, projected_size)) 93 | in_tensor = Tensor(in_tensor) 94 | xp = in_tensor.xp 95 | 96 | x = torch.from_numpy(in_tensor.array) 97 | 98 | qu, k, v = x.split(embedding_dim, dim=-1) # pytorch 99 | query, key, value = in_tensor.split(3, axis=-1) # tricycle 100 | 101 | assert query.close_to(qu, rtol=1e-3) 102 | assert key.close_to(k, rtol=1e-3) 103 | assert value.close_to(v, rtol=1e-3) 104 | 105 | # pytorch 106 | k = k.view(T, n_heads, C // n_heads) 107 | qu = qu.view(T, n_heads, C // n_heads) 108 | v = v.view(T, n_heads, C // n_heads) 109 | k = k.transpose(-3, -2) 110 | qu = qu.transpose(-3, -2) 111 | v = v.transpose(-3, -2) 112 | 113 | # tricycle 114 | key = key.reshape(head_shape).einsum("TNH -> NTH") 115 | query = query.reshape(head_shape).einsum("TNH -> NTH") 116 | value = value.reshape(head_shape).einsum("TNH -> NTH") 117 | 118 | assert query.close_to(qu, rtol=1e-3) 119 | assert key.close_to(k, rtol=1e-3) 120 | assert value.close_to(v, rtol=1e-3) 121 | 122 | # pytorch 123 | att = qu.to(torch.float32) @ k.transpose(-2, -1).to(torch.float32) 124 | att = att.to(TORCH_DTYPE) 125 | att *= 1 / np.sqrt(k.size(-1)) 126 | 127 | # tricycle 128 | attention = Einsum("NIh, NJh -> NIJ")(query, key) / np.sqrt(head_size) 129 | 130 | assert attention.close_to(att, rtol=1e-1, atol=1e-4) 131 | 132 | # pytorch 133 | bias = torch.tril(torch.ones(context_window, context_window)).view( 134 | 1, context_window, context_window 135 | ) 136 | att = att.masked_fill(bias[:, :T, :T] == 0, float("-inf")) 137 | 138 | # tricycle 139 | mask = build_mask(context_window, n_heads=n_heads) 140 | attention = xp.where( 141 | mask[:, :n_tokens, :n_tokens], -xp.inf, attention.array 142 | ) 143 | attention = Tensor(attention) 144 | 145 | assert attention.close_to(att.numpy(), rtol=1e-2) 146 | 147 | # pytorch 148 | att = torch.softmax(att.to(torch.float32), dim=-1) 149 | 150 | # tricycle 151 | attention = Softmax()(attention) 152 | 153 | assert attention.close_to(att.numpy(), rtol=1e-2, atol=1e-4) 154 | 155 | # pytorch 156 | att = att.to(torch.float32) @ v.to(torch.float32) 157 | att = att.to(TORCH_DTYPE).transpose(0, 1).contiguous() 158 | 159 | # tricycle 160 | attention = Einsum("NIj, NjH -> INH")(attention, value) 161 | 162 | assert attention.close_to(att.numpy(), rtol=1e-2) 163 | 164 | # pytorch 165 | att = att.view(T, C) 166 | 167 | # tricycle 168 | attention = attention.reshape(out_shape) 169 | 170 | assert attention.close_to(att.numpy(), rtol=1e-2) 171 | 172 | 173 | def test_attention_combined(): 174 | """ 175 | Compare Tricycle's attention with Andrej's 176 | """ 177 | n_heads = 3 178 | embedding_dim = 15 179 | n_tokens = 7 180 | batch_size = 11 181 | projected_size = embedding_dim * 3 182 | context_window = n_tokens 183 | B = batch_size 184 | T = n_tokens 185 | C = embedding_dim 186 | 187 | np.random.seed(0) 188 | # random input tensor 189 | in_tensor = np.random.uniform( 190 | -5, 5, (batch_size, n_tokens, projected_size) 191 | ) 192 | in_tensor = Tensor(in_tensor).to_batched() 193 | 194 | x = torch.from_numpy(in_tensor.array) 195 | x.requires_grad = True 196 | 197 | qu, k, v = x.split(embedding_dim, dim=-1) 198 | 199 | pytorch_result = andrej_attention( 200 | q=qu, 201 | k=k, 202 | v=v, 203 | B=B, 204 | T=T, 205 | C=C, 206 | n_head=n_heads, 207 | ) 208 | 209 | tricycle_attention = Attention( 210 | embedding_dim=embedding_dim, 211 | n_heads=n_heads, 212 | context_window=context_window, 213 | ) 214 | tricycle_result = tricycle_attention(in_tensor).from_batched() 215 | 216 | assert tricycle_result.close_to( 217 | pytorch_result.detach(), rtol=1e-1, atol=1e-3 218 | ) 219 | 220 | tricycle_result.from_batched().sum().backward() 221 | pytorch_result.sum().backward() 222 | 223 | # I don't know why the tolerance has to be so large here. 224 | # smells like numerical instability 225 | # TODO: investigate discrepency 226 | assert in_tensor.grad.close_to( 227 | x.grad.detach().numpy(), atol=1e-1, rtol=1e-1 228 | ) 229 | 230 | 231 | def test_attention_block(): 232 | """ 233 | Compare Tricycle attention with pytorch's MultiheadAttention 234 | """ 235 | n_heads = 3 236 | embedding_dim = 15 237 | n_tokens = 32 238 | batch_size = 11 239 | context_window = 32 240 | 241 | np.random.seed(0) 242 | 243 | x = np.random.normal(size=(batch_size, n_tokens, embedding_dim)).astype( 244 | DEFAULT_DTYPE 245 | ) 246 | 247 | in_projection_weights = np.random.normal( 248 | 0, 1, (embedding_dim, embedding_dim * 3) 249 | ).astype(DEFAULT_DTYPE) 250 | out_projection_weights = np.random.normal( 251 | 0, 1, (embedding_dim, embedding_dim) 252 | ).astype(DEFAULT_DTYPE) 253 | 254 | tricycle_attention = MultiHeadSelfAttention( 255 | embedding_dim=embedding_dim, 256 | n_heads=n_heads, 257 | context_window=context_window, 258 | residual_dropout_prob=0, 259 | ) 260 | tricycle_attention.in_projection.weights = Tensor( 261 | in_projection_weights, name="in_proj" 262 | ) 263 | tricycle_attention.out_projection.weights = Tensor( 264 | out_projection_weights, name="out_proj" 265 | ) 266 | 267 | in_tensor = Tensor(x, requires_grad=False).to_batched() 268 | tricycle_result = tricycle_attention(in_tensor) 269 | 270 | c_attn = torch.nn.Linear(embedding_dim, 3 * embedding_dim, bias=False) 271 | c_attn.weight = torch.nn.Parameter(torch.tensor(in_projection_weights.T)) 272 | c_proj = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 273 | c_proj.weight = torch.nn.Parameter(torch.tensor(out_projection_weights.T)) 274 | 275 | andrej_result = andrej_attention_block( 276 | torch.tensor(x), 277 | batch_size, 278 | n_tokens, 279 | embedding_dim, 280 | n_heads, 281 | c_attn, 282 | c_proj, 283 | embedding_dim, 284 | block_size=32, 285 | ) 286 | 287 | assert tricycle_result.close_to( 288 | andrej_result.detach().numpy(), rtol=1e-1, atol=1e-1 289 | ) 290 | 291 | tricycle_loss = tricycle_result.from_batched().einsum("abc->") 292 | andrej_loss = andrej_result.sum() 293 | 294 | assert tricycle_loss.close_to(andrej_loss.detach().numpy()) 295 | 296 | tricycle_loss.backward() 297 | andrej_loss.backward() 298 | 299 | assert not tricycle_attention.out_projection.weights.is_batched 300 | tricycle_out_weights = tricycle_attention.out_projection.weights.grad 301 | 302 | assert tricycle_out_weights.close_to(c_proj.weight.grad.T.numpy()) 303 | 304 | tricycle_in_weights = tricycle_attention.in_projection.weights.grad 305 | 306 | assert tricycle_in_weights.close_to( 307 | c_attn.weight.grad.T.numpy(), rtol=1e-2, atol=1e-4 308 | ) 309 | -------------------------------------------------------------------------------- /tests/test_binary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tricycle.binary import ( 4 | BinaryAdd, 5 | BinaryDivide, 6 | BinaryMax, 7 | BinaryMin, 8 | BinaryMultiply, 9 | BinarySubtract, 10 | ) 11 | from tricycle.tensor import DEFAULT_DTYPE, Tensor 12 | 13 | 14 | def test_can_badd(): # sourcery skip: extract-duplicate-method 15 | in_tensor_1 = Tensor(np.arange(12).reshape(3, 4)) 16 | in_tensor_2 = Tensor(np.arange(1, 13).reshape(3, 4)) 17 | 18 | out_tensor = BinaryAdd()(in_tensor_1, in_tensor_2) 19 | 20 | assert out_tensor.shape == (3, 4) 21 | 22 | correct = Tensor([[1, 3, 5, 7], [9, 11, 13, 15], [17, 19, 21, 23]]) 23 | assert out_tensor.close_to(correct) 24 | 25 | out_tensor.backward() 26 | 27 | correct = Tensor( 28 | [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] 29 | ) 30 | assert in_tensor_1.grad is not None 31 | assert in_tensor_1.grad.close_to(correct) 32 | 33 | correct = Tensor( 34 | [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] 35 | ) 36 | assert in_tensor_2.grad is not None 37 | assert in_tensor_2.grad.close_to(correct) 38 | 39 | 40 | def test_can_bsub(): # sourcery skip: extract-duplicate-method 41 | in_tensor_1 = Tensor(np.arange(12).reshape(3, 4), is_batched=True) 42 | in_tensor_2 = Tensor(np.arange(1, 13).reshape(3, 4), is_batched=True) 43 | 44 | out_tensor = BinarySubtract()(in_tensor_1, in_tensor_2) 45 | 46 | assert out_tensor.shape == (3, 4) 47 | correct = Tensor([[-1, -1, -1, -1], [-1, -1, -1, -1], [-1, -1, -1, -1]]) 48 | assert out_tensor.close_to(correct) 49 | 50 | out_tensor.backward() 51 | 52 | correct = Tensor( 53 | [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]] 54 | ) 55 | assert in_tensor_1.grad is not None 56 | assert in_tensor_1.grad.close_to(correct) 57 | 58 | correct = Tensor( 59 | [ 60 | [-1.0, -1.0, -1.0, -1.0], 61 | [-1.0, -1.0, -1.0, -1.0], 62 | [-1.0, -1.0, -1.0, -1.0], 63 | ] 64 | ) 65 | assert in_tensor_2.grad is not None 66 | assert in_tensor_2.grad.close_to(correct) 67 | 68 | 69 | def test_can_bmul(): 70 | in_tensor_1 = Tensor(np.arange(12).reshape(3, 4), is_batched=True) 71 | in_tensor_2 = Tensor(np.arange(1, 13).reshape(3, 4), is_batched=True) 72 | 73 | out_tensor = BinaryMultiply()(in_tensor_1, in_tensor_2) 74 | 75 | assert out_tensor.shape == (3, 4) 76 | correct = Tensor([[0, 2, 6, 12], [20, 30, 42, 56], [72, 90, 110, 132]]) 77 | assert out_tensor.close_to(correct) 78 | 79 | out_tensor.backward() 80 | 81 | assert in_tensor_1.grad is not None 82 | assert in_tensor_2.grad is not None 83 | assert in_tensor_1.grad.close_to(in_tensor_2) 84 | assert in_tensor_2.grad.close_to(in_tensor_1) 85 | 86 | 87 | def test_can_bdiv(): 88 | in_tensor_1 = Tensor( 89 | np.arange(12).reshape(3, 4), is_batched=True, dtype=DEFAULT_DTYPE 90 | ) 91 | in_tensor_2 = Tensor( 92 | np.arange(1, 13).reshape(3, 4), is_batched=True, dtype=DEFAULT_DTYPE 93 | ) 94 | 95 | out_tensor = BinaryDivide()(in_tensor_1, in_tensor_2) 96 | 97 | assert out_tensor.shape == (3, 4) 98 | correct = Tensor( 99 | [ 100 | [0, 1 / 2, 2 / 3, 3 / 4], 101 | [4 / 5, 5 / 6, 6 / 7, 7 / 8], 102 | [8 / 9, 9 / 10, 10 / 11, 11 / 12], 103 | ], 104 | dtype=DEFAULT_DTYPE, 105 | ) 106 | 107 | assert out_tensor.close_to(correct, rtol=1e-3) 108 | 109 | out_tensor.backward() 110 | 111 | assert in_tensor_1.grad is not None 112 | assert in_tensor_2.grad is not None 113 | assert in_tensor_1.grad.close_to(1 / in_tensor_2.array, rtol=1e-3) 114 | 115 | assert in_tensor_2.grad.close_to( 116 | -in_tensor_1.array / (in_tensor_2.array**2), rtol=1e-3 117 | ) 118 | 119 | 120 | def test_can_bmax(): 121 | in_tensor_1 = Tensor(np.arange(12).reshape(3, 4), is_batched=True) 122 | in_tensor_2 = Tensor( 123 | [[0, 0, 0, 0], [100, 100, 100, 100], [8, 9, 10, 11]], is_batched=True 124 | ) 125 | 126 | out_tensor = BinaryMax()(in_tensor_1, in_tensor_2) 127 | 128 | assert out_tensor.shape == (3, 4) 129 | correct = Tensor([[0, 1, 2, 3], [100, 100, 100, 100], [8, 9, 10, 11]]) 130 | assert out_tensor.close_to(correct) 131 | 132 | out_tensor.backward() 133 | 134 | one_is_bigger = in_tensor_1 > in_tensor_2 135 | two_is_bigger = in_tensor_1 <= in_tensor_2 136 | 137 | assert in_tensor_1.grad is not None 138 | assert in_tensor_2.grad is not None 139 | assert in_tensor_1.grad.close_to(one_is_bigger) 140 | assert in_tensor_2.grad.close_to(two_is_bigger) 141 | 142 | 143 | def test_can_bmin(): 144 | in_tensor_1 = Tensor(np.arange(12).reshape(3, 4), is_batched=True) 145 | in_tensor_2 = Tensor( 146 | [[0, 0, 0, 0], [100, 100, 100, 100], [8, 9, 10, 11]], is_batched=True 147 | ) 148 | 149 | out_tensor = BinaryMin()(in_tensor_1, in_tensor_2) 150 | 151 | assert out_tensor.shape == (3, 4) 152 | correct = Tensor([[0, 0, 0, 0], [4, 5, 6, 7], [8, 9, 10, 11]]) 153 | assert out_tensor.close_to(correct) 154 | 155 | out_tensor.backward() 156 | 157 | one_is_smaller = in_tensor_1 < in_tensor_2 158 | two_is_smaller = in_tensor_1 >= in_tensor_2 159 | 160 | assert in_tensor_1.grad is not None 161 | assert in_tensor_2.grad is not None 162 | assert in_tensor_1.grad.close_to(one_is_smaller) 163 | assert in_tensor_2.grad.close_to(two_is_smaller) 164 | -------------------------------------------------------------------------------- /tests/test_blocks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tricycle.blocks import GPT2TransformerBlock, MLPBlock 4 | from tricycle.tensor import Tensor 5 | 6 | 7 | def test_MLPBlock(): 8 | np.random.seed(0) 9 | in_tensor = Tensor(np.arange(12, dtype=float).reshape(3, 4)) 10 | block = MLPBlock(embedding_dim=4, expansion_ratio=4, dropout_prob=0.5) 11 | 12 | assert block.linear_1.weights.shape == (4, 16) 13 | assert block.linear_2.weights.shape == (16, 4) 14 | 15 | block.linear_1.weights = Tensor(np.ones(block.linear_1.weights.shape)) 16 | block.linear_2.weights = Tensor(np.ones(block.linear_2.weights.shape)) 17 | 18 | out_tensor = block(in_tensor.to_batched()) 19 | 20 | assert out_tensor.shape == (3, 4) 21 | 22 | correct_output = np.array( 23 | [ 24 | [192.0, 0.0, 192.0, 0.0], 25 | [ 26 | 0.0, 27 | 0.0, 28 | 704.0, 29 | 704.0, 30 | ], 31 | [1216.0, 1216.0, 1216.0, 0.0], 32 | ] 33 | ) 34 | correct_output = Tensor(correct_output) 35 | 36 | assert out_tensor.is_batched 37 | assert out_tensor.close_to(correct_output) 38 | 39 | out_tensor.backward() 40 | 41 | assert in_tensor.grad is not None 42 | correct_grad = Tensor( 43 | [ 44 | [64.0, 64.0, 64.0, 64.0], 45 | [64.0, 64.0, 64.0, 64.0], 46 | [96.0, 96.0, 96.0, 96.0], 47 | ] 48 | ) 49 | 50 | assert in_tensor.grad.close_to(correct_grad) 51 | 52 | 53 | def test_GPT2TransformerBlock(): 54 | np.random.seed(0) 55 | batch_size = 11 56 | n_tokens = 32 57 | n_heads = 3 58 | embedding_dim = 7 * n_heads 59 | 60 | in_tensor = Tensor( 61 | np.random.random((batch_size, n_tokens, embedding_dim)), 62 | is_batched=True, 63 | ) 64 | block = GPT2TransformerBlock( 65 | embedding_dim=embedding_dim, 66 | n_heads=3, 67 | expansion_ratio=4, 68 | context_window=32, 69 | ) 70 | 71 | out_tensor = block(in_tensor.to_batched()) 72 | 73 | assert out_tensor.shape == (batch_size, n_tokens, embedding_dim) 74 | 75 | out_tensor.backward() 76 | 77 | assert in_tensor.grad is not None 78 | assert in_tensor.grad.shape == in_tensor.shape 79 | -------------------------------------------------------------------------------- /tests/test_composite.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tricycle.ops import Split 4 | from tricycle.tensor import Tensor 5 | 6 | 7 | def test_split_first_axis(): 8 | in_tensor = Tensor([1, 2, 3, 4, 5, 6]) 9 | 10 | out_tensors = Split()(in_tensor, 3) 11 | 12 | assert len(out_tensors) == 3 13 | 14 | assert out_tensors[0].shape == (2,) 15 | assert out_tensors[1].shape == (2,) 16 | assert out_tensors[2].shape == (2,) 17 | 18 | assert out_tensors[0].close_to([1, 2]) 19 | assert out_tensors[1].close_to([3, 4]) 20 | assert out_tensors[2].close_to([5, 6]) 21 | 22 | out_tensors[0].backward() 23 | assert in_tensor.grad is not None 24 | assert in_tensor.grad.close_to([1, 1, 0, 0, 0, 0]) 25 | 26 | out_tensors[1].backward() 27 | assert in_tensor.grad is not None 28 | assert in_tensor.grad.close_to([1, 1, 1, 1, 0, 0]) 29 | 30 | out_tensors[2].backward() 31 | assert in_tensor.grad is not None 32 | assert in_tensor.grad.close_to([1, 1, 1, 1, 1, 1]) 33 | 34 | 35 | def test_split_middle_axis(): 36 | in_tensor = Tensor(np.ones((2, 3, 4))) 37 | 38 | out_tensors = Split()(in_tensor, n_splits=2, axis=-1) 39 | 40 | assert len(out_tensors) == 2 41 | 42 | assert out_tensors[0].shape == (2, 3, 2) 43 | assert out_tensors[1].shape == (2, 3, 2) 44 | 45 | assert out_tensors[0].close_to(np.ones((2, 3, 2))) 46 | assert out_tensors[1].close_to(np.ones((2, 3, 2))) 47 | 48 | out_tensors[0].backward() 49 | assert in_tensor.grad is not None 50 | assert in_tensor.grad.close_to([[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0]]) 51 | 52 | 53 | def test_reshape(): 54 | in_tensor = Tensor([1, 2, 3, 4, 5, 6]) 55 | 56 | out_tensor = in_tensor.reshape((2, 3)) 57 | 58 | assert out_tensor.shape == (2, 3) 59 | assert out_tensor.close_to([[1, 2, 3], [4, 5, 6]]) 60 | 61 | out_tensor.backward() 62 | 63 | assert in_tensor.grad is not None 64 | assert in_tensor.grad.close_to([1, 1, 1, 1, 1, 1]) 65 | 66 | 67 | def test_mean(): 68 | in_tensor = Tensor([1, 2, 3, 4, 5, 6]) 69 | 70 | out_tensor = in_tensor.mean() 71 | 72 | assert out_tensor.close_to(3.5) 73 | 74 | out_tensor.backward() 75 | 76 | assert in_tensor.grad is not None 77 | assert in_tensor.grad.close_to( 78 | [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6], rtol=1e-3 79 | ) 80 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tricycle.dataset import CausalLMDataset 4 | 5 | 6 | def test_can_build_causal_lm_dataset(): 7 | tokens = np.arange(100) 8 | dataset = CausalLMDataset( 9 | tokens=tokens, vocab_size=100, batch_size=10, context_window=10 10 | ) 11 | 12 | inputs, outputs = dataset[0] 13 | assert isinstance(inputs, np.ndarray) 14 | assert isinstance(outputs, np.ndarray) 15 | 16 | assert len(inputs) == 10 17 | expected_tokens = tokens[:11] 18 | 19 | assert np.allclose(inputs, expected_tokens[:-1]) 20 | assert np.allclose(outputs, expected_tokens[1:]) 21 | 22 | dataset.batch() 23 | 24 | inputs, outputs = dataset[0] 25 | assert inputs.shape == (10, 10) 26 | assert outputs.shape == (10, 10) 27 | -------------------------------------------------------------------------------- /tests/test_einsum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tricycle.einsum import Einsum, Subscript 4 | from tricycle.tensor import Tensor 5 | 6 | 7 | def test_batched_reduce(): 8 | x = Tensor(np.arange(5)) 9 | op = Einsum("a->") 10 | result = op(x) 11 | assert result.close_to(10) 12 | 13 | result.backward() 14 | assert x.grad is not None 15 | assert x.grad.close_to(np.ones(x.shape)) 16 | 17 | 18 | def test_matrix_reduce(): 19 | x = Tensor(np.arange(20).reshape(4, 5)) 20 | op = Einsum("ab->") 21 | assert op(x) == 190 22 | 23 | op(x).backward() 24 | assert x.grad is not None 25 | assert x.grad.close_to(np.ones(x.shape)) 26 | 27 | 28 | def test_matrix_partial_reduce(): 29 | x = Tensor(np.arange(20).reshape(4, 5)) 30 | op = Einsum("ab->b") 31 | assert op(x).close_to([30, 34, 38, 42, 46]) 32 | 33 | op(x).backward() 34 | assert x.grad is not None 35 | assert x.grad.close_to(np.ones(x.shape)) 36 | 37 | 38 | def test_transpose(): 39 | x = Tensor(np.arange(20).reshape(4, 5)) 40 | op = Einsum("ij->ji") 41 | assert op(x).close_to(x.array.T) 42 | 43 | op(x).backward() 44 | assert x.grad is not None 45 | assert x.grad.close_to(np.ones(x.shape)) 46 | 47 | 48 | def test_parse_subscripts(): 49 | subscript = Subscript("a,b->ab") 50 | assert subscript.inputs == [["a"], ["b"]] 51 | assert subscript.output == ["a", "b"] 52 | 53 | subscript = Subscript("a,b->") 54 | assert subscript.inputs == [["a"], ["b"]] 55 | assert subscript.output == [] 56 | 57 | subscript = Subscript("...a,b...->ab...") 58 | assert subscript.inputs == [["...", "a"], ["b", "..."]] 59 | assert subscript.output == ["a", "b", "..."] 60 | 61 | subscript = Subscript("...,...->...") 62 | assert subscript.inputs == [["..."], ["..."]] 63 | assert subscript.output == ["..."] 64 | 65 | subscript = Subscript("z...,z...->z...") 66 | assert subscript.inputs == [["z", "..."], ["z", "..."]] 67 | assert subscript.output == ["z", "..."] 68 | 69 | 70 | def test_can_parse_split(): 71 | inputs = [["a"], ["b"]] 72 | output = ["a", "b"] 73 | assert Subscript.join(inputs, output) == "a,b->ab" 74 | 75 | inputs = [["a"], ["b"]] 76 | output = [] 77 | assert Subscript.join(inputs, output) == "a,b->" 78 | 79 | inputs = [["...", "a"], ["b", "..."]] 80 | output = ["a", "b", "..."] 81 | assert Subscript.join(inputs, output) == "...a,b...->ab..." 82 | 83 | inputs = [["..."], ["..."]] 84 | output = ["..."] 85 | assert Subscript.join(inputs, output) == "...,...->..." 86 | 87 | inputs = [["z", "..."], ["z", "..."]] 88 | output = ["z", "..."] 89 | assert Subscript.join(inputs, output) == "z...,z...->z..." 90 | -------------------------------------------------------------------------------- /tests/test_functions.py: -------------------------------------------------------------------------------- 1 | from tricycle.functions import Sigmoid 2 | from tricycle.tensor import Tensor 3 | 4 | 5 | def test_sigmoid(): 6 | in_tensor = Tensor([0, 1, 2, 3]) 7 | out_tensor = Sigmoid()(in_tensor) 8 | 9 | assert out_tensor.shape == (4,) 10 | assert out_tensor.close_to( 11 | [0.5, 0.73105858, 0.88079708, 0.95257413], rtol=1e-3 12 | ) 13 | 14 | out_tensor.backward() 15 | correct_grad = out_tensor * (1 - out_tensor) 16 | 17 | assert in_tensor.grad is not None 18 | assert in_tensor.grad.close_to(correct_grad, rtol=1e-3, atol=1e-3) 19 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | 3 | import numpy as np 4 | 5 | from tricycle.layers import ( # noqa: E501 6 | Dense, 7 | Dropout, 8 | Embedding, 9 | LayerNorm, 10 | Sequential, 11 | ) 12 | from tricycle.tensor import Tensor 13 | 14 | 15 | def test_dense_layer(): 16 | layer = Dense(10, 8) 17 | 18 | assert layer.weights.shape == (10, 8) 19 | 20 | x_in = Tensor(np.ones(10)) 21 | 22 | x_out = layer(x_in) 23 | assert x_out.shape == (8,) 24 | 25 | 26 | def test_sequential_layer(): 27 | layer1 = Dense(10, 8) 28 | layer2 = Dense(8, 4) 29 | 30 | model = Sequential(layer1, layer2) 31 | 32 | assert model.layers[0].weights.shape == (10, 8) 33 | assert model.layers[1].weights.shape == (8, 4) 34 | 35 | x_in = Tensor(np.ones(10)) 36 | 37 | x_out = model(x_in) 38 | assert x_out.shape == (4,) 39 | 40 | 41 | def test_dropout(): # sourcery skip: square-identity 42 | np.random.seed(0) 43 | size = 100 44 | dropout_prob = 0.3 45 | 46 | # non-batched 47 | in_tensor = Tensor(np.random.normal(size=(size, size)), name="in_tensor") 48 | dropout = Dropout(dropout_prob) 49 | 50 | out_tensor = dropout(in_tensor.to_batched()) 51 | 52 | assert out_tensor.shape == in_tensor.shape 53 | zero_x_idx, zero_y_idx = np.where(out_tensor.array == 0) 54 | n_zeros = len(zero_x_idx) 55 | expected_n_zeros = int(size * size * dropout_prob) 56 | 57 | assert n_zeros / size**2 - expected_n_zeros / size**2 < 0.05 58 | 59 | out_tensor.backward() 60 | 61 | assert in_tensor.grad is not None 62 | assert in_tensor.grad.shape == in_tensor.shape 63 | 64 | coef = 1 / (1 - dropout_prob) 65 | correct_grad = np.full(in_tensor.shape, coef) 66 | correct_grad[zero_x_idx, zero_y_idx] = 0 67 | 68 | assert in_tensor.grad.close_to(correct_grad) 69 | 70 | 71 | def test_layer_norm(): 72 | np.random.seed(0) 73 | in_tensor = Tensor(np.random.normal(size=(100, 100)), name="in_tensor") 74 | layer_norm = LayerNorm(100) 75 | out_tensor = layer_norm(in_tensor.to_batched()) 76 | 77 | assert out_tensor.shape == in_tensor.shape 78 | out_tensor.backward() 79 | 80 | assert copy(out_tensor).mean().close_to(0, atol=1e-3) 81 | assert np.allclose(np.std(out_tensor.array), [1] * 100, atol=1e-7) 82 | 83 | assert in_tensor.grad is not None 84 | assert in_tensor.grad.shape == in_tensor.shape 85 | 86 | # not sure if this is correct. TODO: check 87 | assert in_tensor.grad.close_to(np.zeros(in_tensor.shape), atol=1e-3) 88 | 89 | 90 | def test_embedding(): 91 | np.random.seed(0) 92 | vocab_size = 3 93 | out_shape = 5 94 | in_tensor = Tensor( 95 | [0, 1, 2, 0], 96 | requires_grad=False, 97 | dtype=int, 98 | ) 99 | 100 | embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape) 101 | weights = np.indices((vocab_size * out_shape,)).reshape( 102 | vocab_size, out_shape 103 | ) 104 | embedding_layer.weights = Tensor(weights) 105 | 106 | result = embedding_layer(in_tensor) 107 | 108 | assert result.shape == (4, 5) 109 | assert result[0].close_to(embedding_layer.weights[0]) 110 | assert result[1].close_to(embedding_layer.weights[1]) 111 | assert result[2].close_to(embedding_layer.weights[2]) 112 | assert result[3].close_to(embedding_layer.weights[0]) 113 | 114 | result.backward() 115 | 116 | assert embedding_layer.weights.grad is not None 117 | assert embedding_layer.weights.grad.shape == embedding_layer.weights.shape 118 | assert embedding_layer.weights.grad.close_to( 119 | [[2, 2, 2, 2, 2], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]] 120 | ) 121 | 122 | 123 | def test_embedding_batched(): 124 | np.random.seed(0) 125 | vocab_size = 3 126 | out_shape = 5 127 | in_tensor = Tensor( 128 | [[0, 1, 2, 0], [1, 2, 2, 1]], 129 | requires_grad=False, 130 | dtype=np.int8, 131 | ).to_batched() 132 | 133 | embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape) 134 | weights = np.indices((vocab_size * out_shape,)).reshape( 135 | vocab_size, out_shape 136 | ) 137 | embedding_layer.weights = Tensor(weights) 138 | 139 | result = embedding_layer(in_tensor) 140 | 141 | assert result.shape == (2, 4, 5) 142 | assert result[0][0].close_to(embedding_layer.weights[0]) 143 | assert result[0][1].close_to(embedding_layer.weights[1]) 144 | assert result[0][2].close_to(embedding_layer.weights[2]) 145 | assert result[0][3].close_to(embedding_layer.weights[0]) 146 | 147 | assert result[1][0].close_to(embedding_layer.weights[1]) 148 | assert result[1][1].close_to(embedding_layer.weights[2]) 149 | assert result[1][2].close_to(embedding_layer.weights[2]) 150 | assert result[1][3].close_to(embedding_layer.weights[1]) 151 | 152 | result.backward() 153 | 154 | assert embedding_layer.weights.grad is not None 155 | assert embedding_layer.weights.grad.shape == (vocab_size, out_shape) 156 | assert embedding_layer.weights.grad.close_to( 157 | [ 158 | [ 159 | [2.0, 2.0, 2.0, 2.0, 2.0], 160 | [3.0, 3.0, 3.0, 3.0, 3.0], 161 | [3.0, 3.0, 3.0, 3.0, 3.0], 162 | ], 163 | ] 164 | ) 165 | -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.datasets import load_diabetes, load_iris, load_linnerud 4 | from sklearn.preprocessing import RobustScaler 5 | 6 | from tricycle.einsum import Einsum 7 | from tricycle.initialisers import init_xavier 8 | from tricycle.loss import CrossEntropy, MeanSquaredError 9 | from tricycle.tensor import Tensor 10 | from tricycle.utils import r_squared, smooth 11 | 12 | slow_test = pytest.mark.skipif( 13 | "not config.getoption('--run-slow')", 14 | reason="Only run when --run-slow is given", 15 | ) 16 | 17 | # TODO: a lot of these tests are outdated and should be modernised 18 | 19 | 20 | def test_can_mean_square_error(): 21 | y_true = Tensor([0, 0, 1]) 22 | y_pred = Tensor([0, 0.5, 0.5]) 23 | 24 | mse = MeanSquaredError()(y_true, y_pred) 25 | 26 | assert mse.close_to(1 / 6, rtol=1e-3) 27 | 28 | 29 | def test_can_CrossEntropy(): 30 | y_true = Tensor([1], dtype=int) 31 | y_pred = Tensor([[0, 0, 0]]) 32 | 33 | loss = CrossEntropy()(y_true, y_pred) 34 | 35 | assert loss.close_to(1.0986122886681098) 36 | 37 | 38 | def test_CrossEntropy_batched(): 39 | batch_size = 3 40 | n_tokens = 5 41 | vocab_size = 7 42 | 43 | y_true = np.random.randint(0, vocab_size, size=(batch_size, n_tokens)) 44 | y_pred = np.random.random((batch_size, n_tokens, vocab_size)) 45 | 46 | y_true = Tensor(y_true, dtype=int).to_batched() 47 | y_pred = Tensor(y_pred).to_batched() 48 | 49 | loss = CrossEntropy()(y_true, y_pred) 50 | 51 | assert loss.shape == () 52 | 53 | 54 | # TODO: write a proper backprop test for these loss functions 55 | 56 | 57 | def test_can_single_linear_regression_step(): 58 | """ 59 | A single step of linear regression 60 | """ 61 | x_input = [1] 62 | y_input = [3] 63 | slope = Tensor([0.02]) 64 | intercept = Tensor([0.01]) 65 | 66 | x_input = Tensor(x_input, requires_grad=False, name="x") 67 | y_input = Tensor(y_input, requires_grad=False, name="y") 68 | 69 | y_pred = x_input * slope + intercept 70 | 71 | loss = MeanSquaredError()(y_input, y_pred) 72 | 73 | assert loss.close_to(8.8209, rtol=1e-3) 74 | 75 | loss.backward() 76 | assert slope.grad is not None 77 | assert intercept.grad is not None 78 | assert slope.grad.close_to([-5.94], rtol=1e-3) 79 | assert intercept.grad.close_to([-5.94], rtol=1e-3) 80 | 81 | 82 | @slow_test 83 | def test_can_linear_regression(): 84 | np.random.seed(42) 85 | 86 | n = 4 87 | learning_rate = 3e-3 88 | learning_rate /= n 89 | x = np.linspace(-10, 10, n) 90 | y = x * 2 + 1 + np.random.normal(loc=0, scale=0.01, size=n) 91 | 92 | x = Tensor(x.reshape(-1, 1), requires_grad=False, name="x") 93 | y = Tensor(y.reshape(-1, 1), requires_grad=False, name="y") 94 | 95 | slope = Tensor([0.02], name="slope") 96 | intercept = Tensor([0.0], name="intercept") 97 | 98 | def slope_derivative(x, y, slope, intercept): 99 | return -2 * (y - x * slope - intercept) * x 100 | 101 | def intercept_derivative(x, y, slope, intercept): 102 | return -2 * (y - x * slope - intercept) 103 | 104 | losses = [0] * 100 105 | intercepts = [] 106 | slopes = [] 107 | for idx in range(100): 108 | last_slope_grad = Tensor([0]) 109 | last_intercept_grad = Tensor([0]) 110 | for x_input, y_input in zip(x, y): 111 | x_input = Tensor(x_input, requires_grad=False, name="x") 112 | y_input = Tensor(y_input, requires_grad=False, name="y") 113 | y_pred = x_input * slope + intercept 114 | loss = MeanSquaredError()(y_input, y_pred) 115 | losses[idx] += loss 116 | 117 | loss.backward() 118 | 119 | slope_grad = slope_derivative(x_input, y_input, slope, intercept) 120 | intercept_grad = intercept_derivative( 121 | x_input, y_input, slope, intercept 122 | ) 123 | 124 | assert slope.grad is not None 125 | assert intercept.grad is not None 126 | 127 | assert slope.grad.close_to( 128 | last_slope_grad + slope_grad, atol=1e-4 129 | ), f"{slope.grad=}, {last_slope_grad=}, {slope_grad=}" 130 | 131 | assert intercept.grad.close_to( 132 | last_intercept_grad + intercept_grad 133 | ), f"{intercept.grad=}, {last_intercept_grad=}, {intercept_grad=}" 134 | 135 | last_slope_grad = slope.grad 136 | last_intercept_grad = intercept.grad 137 | 138 | slope -= slope.grad * learning_rate 139 | intercept -= intercept.grad * learning_rate 140 | 141 | slopes.append(slope) 142 | intercepts.append(intercept) 143 | slope = slope.zero_grad() 144 | intercept = intercept.zero_grad() 145 | 146 | assert losses[-1] < 1.5 147 | assert slopes[-1].close_to(2, atol=0.01) 148 | # The intercept takes much longer to tune 149 | assert intercepts[-1].close_to(0.455, atol=0.01) 150 | 151 | 152 | @slow_test 153 | def test_linear_regression_real_data(): 154 | X, y = load_diabetes(return_X_y=True) 155 | x_scaler = RobustScaler() 156 | y_scaler = RobustScaler() 157 | X = x_scaler.fit_transform(X) 158 | y = y_scaler.fit_transform(y.reshape(-1, 1)) 159 | 160 | X = Tensor(X) 161 | y = Tensor(y) 162 | 163 | loops = 100 164 | learning_rate = 1e-1 165 | n = len(X) 166 | learning_rate /= n 167 | 168 | slope = init_xavier((X.shape[1], 1)) 169 | intercept = Tensor([0], name="intercept") 170 | 171 | for _ in range(loops): 172 | for x_in, y_in in zip(X, y): 173 | y_pred = Einsum("i,ij->j")(x_in, slope) + intercept 174 | loss = mean_square_error(y_in, y_pred) 175 | loss.backward() 176 | 177 | slope = Tensor(slope - slope.grad * learning_rate, name="slope") 178 | intercept = Tensor( 179 | intercept - intercept.grad * learning_rate, name="intercept" 180 | ) 181 | 182 | predicted = X @ np.array(slope) + intercept[0] 183 | r_square = r_squared(np.array(y), predicted) 184 | assert r_square > 0.45 185 | 186 | 187 | @slow_test 188 | def test_linear_regression_multi_input_output(): 189 | X_data, y_data = load_linnerud(return_X_y=True) 190 | x_scaler = RobustScaler() 191 | y_scaler = RobustScaler() 192 | X_data = x_scaler.fit_transform(X_data) 193 | y_data = y_scaler.fit_transform(y_data) 194 | 195 | learning_rate = 1e-0 196 | n = len(X_data) 197 | learning_rate /= n 198 | loops = 100 199 | 200 | slope = init_xavier((X_data.shape[1], y_data.shape[1]), name="slope") 201 | intercept = Tensor([-0.01, 0.01, 0.02], name="intercept") 202 | 203 | losses: list[np.ndarray | int] = [0] * loops 204 | 205 | def model(X, slope, intercept): 206 | return Einsum("i,ij->j")(X, slope) + intercept 207 | 208 | for idx in range(loops): 209 | X = Tensor(X_data).to_batched() 210 | y = Tensor(y_data).to_batched() 211 | 212 | # predict an output 213 | y_pred = model(X, slope, intercept) 214 | 215 | # calculate the loss 216 | loss = mean_square_error(y, y_pred) 217 | # we need to unbatch the loss before finding its average 218 | loss = loss.from_batched().mean() 219 | 220 | losses[idx] = loss.numpy() 221 | loss.backward() 222 | 223 | assert slope.grad is not None 224 | assert intercept.grad is not None 225 | 226 | slope.grad = slope.grad.from_batched().einsum("abc->bc") 227 | intercept.grad = intercept.grad.from_batched().einsum("ab->b") 228 | 229 | slope = (slope - slope.grad * learning_rate).zero_grad() 230 | intercept = (intercept - intercept.grad * learning_rate).zero_grad() 231 | 232 | # the loss should plateau at around 0.5 233 | assert losses[-1] < 0.6 234 | 235 | 236 | @slow_test 237 | def test_CrossEntropy(): 238 | """ 239 | This is a really slow test, preserved for reference 240 | """ 241 | X, y = load_iris(return_X_y=True) 242 | x_scaler = RobustScaler() 243 | X = x_scaler.fit_transform(X) 244 | 245 | # one hot encode y 246 | y = np.eye(3)[y.astype(int)] 247 | 248 | X = Tensor(X) 249 | y = Tensor(y) 250 | 251 | learning_rate = 1e0 252 | n = len(X) 253 | learning_rate /= n 254 | loops = 100 255 | 256 | slope = init_xavier((X.shape[1], y.shape[1]), name="slope") 257 | intercept = Tensor([-0.01, 0.01, 0.02], name="intercept") 258 | 259 | losses = [0] * loops 260 | for idx in range(loops): 261 | for x_in, y_in in zip(X, y): 262 | y_pred = Einsum("i,ij->j")(x_in, slope) + intercept 263 | loss = CrossEntropy()(y_in, y_pred) 264 | losses[idx] += loss 265 | loss.backward() 266 | 267 | slope = Tensor(slope - slope.grad * learning_rate, name="slope") 268 | intercept = Tensor( 269 | intercept - intercept.grad * learning_rate, name="intercept" 270 | ) 271 | 272 | assert losses[-1] < 35 273 | 274 | 275 | @slow_test 276 | def test_CrossEntropy_minibatch(): 277 | """ 278 | This is a really slow test, preserved for reference 279 | """ 280 | np.random.seed(42) 281 | 282 | def dataset(X, y, batch_size): 283 | while True: 284 | indices = np.arange(len(X)) 285 | batch_indices = np.random.choice( 286 | indices, size=batch_size, replace=False 287 | ) 288 | yield zip(X[batch_indices], y[batch_indices]) 289 | 290 | X, y = load_iris(return_X_y=True) 291 | x_scaler = RobustScaler() 292 | X = x_scaler.fit_transform(X) 293 | y = np.eye(3)[y.astype(int)] 294 | 295 | learning_rate = 1e0 296 | learning_rate /= 16 297 | loops = 500 298 | 299 | slope = init_xavier((X.shape[1], y.shape[1]), name="slope") 300 | intercept = Tensor([-0.01, 0.01, 0.02], name="intercept") 301 | 302 | losses = [] 303 | for idx, batch in enumerate(dataset(X, y, batch_size=16)): 304 | if idx > loops: 305 | break 306 | batch_loss = 0 307 | for x_in, y_in in batch: 308 | x_in = Tensor(x_in) 309 | y_in = Tensor(y_in) 310 | 311 | y_pred = Einsum("i,ij->j")(x_in, slope) + intercept 312 | loss = CrossEntropy()(y_in, y_pred) 313 | batch_loss += loss 314 | loss.backward() 315 | 316 | losses.append(batch_loss) 317 | 318 | slope = Tensor(slope - slope.grad * learning_rate, name="slope") 319 | intercept = Tensor( 320 | intercept - intercept.grad * learning_rate, name="intercept" 321 | ) 322 | 323 | losses = list(smooth(losses, 0.99)) 324 | assert losses[-1] < 6 325 | -------------------------------------------------------------------------------- /tests/test_mixed_precision.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | 5 | from tricycle.context import TRICYCLE_CONTEXT 6 | from tricycle.layers import Dense, Layer 7 | from tricycle.loss import MeanSquaredError 8 | from tricycle.optimisers import StochasticGradientDescent 9 | from tricycle.tensor import Tensor 10 | from tricycle.utils import UseMixedPrecision 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class LongBoi(Layer): 16 | """ 17 | A very deep MLP with no nonlinearities, designed to underflow in mixed 18 | precision training 19 | """ 20 | 21 | def __init__(self, n_layers: int = 16): 22 | self.layers = [ 23 | Dense(to_size=16, from_size=16, name=f"layer_{i}") 24 | for i in range(n_layers) 25 | ] 26 | 27 | def forward(self, tensor: Tensor) -> Tensor: 28 | for layer in self.layers: 29 | tensor = layer(tensor) 30 | return tensor 31 | 32 | def zero_grad(self): 33 | for layer in self.layers: 34 | layer.zero_grad() 35 | 36 | def update(self, optimiser): 37 | for layer in self.layers: 38 | layer.update(optimiser) 39 | 40 | 41 | def test_can_train_in_mixed_precision(): 42 | """ 43 | Check that a model can be trained in mixed precision without overflowing 44 | 45 | We're using a very deep model with no nonlinearities that should cause 46 | gradient issues if mixed precision is broken 47 | """ 48 | np.random.seed(0) 49 | learning_rate = 1e-3 50 | weight_decay = 1e-1 51 | model = LongBoi(64) 52 | 53 | loss_fn = MeanSquaredError() 54 | optimiser = StochasticGradientDescent( 55 | learning_rate=learning_rate, weight_decay=weight_decay, logger=logger 56 | ) 57 | 58 | inputs = Tensor( 59 | np.random.random( 60 | (32, 16), 61 | ), 62 | is_batched=True, 63 | requires_grad=False, 64 | ) 65 | outputs = Tensor( 66 | np.random.random( 67 | (32, 16), 68 | ), 69 | is_batched=True, 70 | requires_grad=False, 71 | ) 72 | 73 | with UseMixedPrecision(): 74 | first_loop = True 75 | for step in range(100): 76 | logits = model(inputs) 77 | loss = loss_fn(outputs, logits) 78 | loss.backward() 79 | loss = loss.numpy().item() / TRICYCLE_CONTEXT.loss_scale_factor 80 | if first_loop: 81 | # make sure we start with a big loss 82 | assert loss > 50 83 | first_loop = False 84 | logger.info(f"{loss=}, {TRICYCLE_CONTEXT.loss_scale_factor=}") 85 | model.update(optimiser) 86 | 87 | # make sure the loss has decreased as expected 88 | assert 7.5 < loss < 8 89 | -------------------------------------------------------------------------------- /tests/test_optimisers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.datasets import load_iris 3 | 4 | from tricycle.activation import ReLU 5 | from tricycle.dataset import InfiniteBatchDataset 6 | from tricycle.layers import Dense, Sequential 7 | from tricycle.loss import CrossEntropy 8 | from tricycle.optimisers import StochasticGradientDescent 9 | 10 | 11 | def test_can_train_simple_neural_network_no_wd(): 12 | """ 13 | Train a simple neural network on the iris dataset 14 | """ 15 | BATCH_SIZE = 16 16 | N_STEPS = 10 17 | 18 | np.random.seed(42) 19 | X, y = load_iris(return_X_y=True) 20 | # one hot encode y 21 | y = y.astype(int) 22 | 23 | # create a dataset 24 | ds = InfiniteBatchDataset(X, y, batch_size=BATCH_SIZE) 25 | 26 | # create a model 27 | layer_1 = Dense(4, 16) 28 | layer_2 = Dense(16, 3) 29 | relu = ReLU() 30 | model = Sequential(layer_1, relu, layer_2) 31 | loss_fn = CrossEntropy() 32 | optimiser = StochasticGradientDescent(learning_rate=1e-2) 33 | 34 | losses = [] 35 | batches = ds.to_tensor() 36 | # sourcery skip: no-loop-in-tests 37 | # sourcery skip: no-conditionals-in-tests 38 | for step, (x, y) in enumerate(batches): 39 | if step > N_STEPS: 40 | break 41 | 42 | y_pred = model(x) 43 | loss = loss_fn(y, y_pred) 44 | loss.backward() 45 | losses.append(loss) 46 | 47 | model.update(optimiser) 48 | model.zero_grad() 49 | 50 | assert losses[-1] < 1.5 51 | 52 | 53 | def test_can_train_simple_neural_network_wd(): 54 | """ 55 | Train a simple neural network on the iris dataset with weight decay 56 | """ 57 | BATCH_SIZE = 16 58 | N_STEPS = 10 59 | 60 | np.random.seed(42) 61 | X, y = load_iris(return_X_y=True) 62 | # one hot encode y 63 | y = y.astype(int) 64 | 65 | # create a dataset 66 | ds = InfiniteBatchDataset(X, y, batch_size=BATCH_SIZE) 67 | 68 | # create a model 69 | layer_1 = Dense(4, 16) 70 | layer_2 = Dense(16, 3) 71 | relu = ReLU() 72 | model = Sequential(layer_1, relu, layer_2) 73 | loss_fn = CrossEntropy() 74 | optimiser = StochasticGradientDescent(learning_rate=1e-2, weight_decay=1e1) 75 | 76 | losses = [] 77 | batches = ds.to_tensor() 78 | # sourcery skip: no-loop-in-tests 79 | # sourcery skip: no-conditionals-in-tests 80 | for step, (x, y) in enumerate(batches): 81 | if step > N_STEPS: 82 | break 83 | 84 | y_pred = model(x) 85 | loss = loss_fn(y, y_pred) 86 | loss.backward() 87 | losses.append(loss) 88 | 89 | model.update(optimiser) 90 | model.zero_grad() 91 | 92 | assert losses[-1] < 1.5 93 | 94 | 95 | def test_can_train_simple_neural_network_momentum(): 96 | """ 97 | Train a simple neural network on the iris dataset with momentum 98 | """ 99 | BATCH_SIZE = 16 100 | N_STEPS = 10 101 | 102 | np.random.seed(42) 103 | X, y = load_iris(return_X_y=True) 104 | # one hot encode y 105 | y = y.astype(int) 106 | 107 | # create a dataset 108 | ds = InfiniteBatchDataset(X, y, batch_size=BATCH_SIZE) 109 | 110 | # create a model 111 | layer_1 = Dense(4, 16) 112 | layer_2 = Dense(16, 3) 113 | relu = ReLU() 114 | model = Sequential(layer_1, relu, layer_2) 115 | loss_fn = CrossEntropy() 116 | optimiser = StochasticGradientDescent(learning_rate=1e-2, momentum=0.9) 117 | 118 | losses = [] 119 | batches = ds.to_tensor() 120 | # sourcery skip: no-loop-in-tests 121 | # sourcery skip: no-conditionals-in-tests 122 | for step, (x, y) in enumerate(batches): 123 | if step > N_STEPS: 124 | break 125 | 126 | y_pred = model(x) 127 | loss = loss_fn(y, y_pred) 128 | loss.backward() 129 | losses.append(loss) 130 | 131 | model.update(optimiser) 132 | model.zero_grad() 133 | 134 | assert losses[-1] < 1.5 135 | -------------------------------------------------------------------------------- /tests/test_python_and_numpy.py: -------------------------------------------------------------------------------- 1 | # tests/test_python_and_numpy.py 2 | import numpy as np 3 | 4 | 5 | def test_can_sum(): 6 | a = np.array([1, 2, 3]) 7 | assert a.sum() == 6 8 | 9 | 10 | def test_can_do_matrix_algebra(): 11 | a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 12 | b = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 13 | assert np.allclose(b @ a, a @ b) 14 | assert np.allclose(a, a @ b) 15 | -------------------------------------------------------------------------------- /tests/test_reduce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tricycle.reduce import ReduceMax, ReduceMin 4 | from tricycle.tensor import Tensor 5 | 6 | 7 | def test_can_rmax(): 8 | in_tensor = Tensor(np.arange(3 * 4 * 5).reshape(3, 4, 5)) 9 | 10 | out_tensor = ReduceMax()(in_tensor, "ijk->ik") 11 | 12 | assert out_tensor.shape == (3, 5) 13 | assert out_tensor.close_to( 14 | [[15, 16, 17, 18, 19], [35, 36, 37, 38, 39], [55, 56, 57, 58, 59]] 15 | ) 16 | 17 | out_tensor.backward() 18 | correct = [ 19 | [ 20 | [0, 0, 0, 0, 0], 21 | [0, 0, 0, 0, 0], 22 | [0, 0, 0, 0, 0], 23 | [1, 1, 1, 1, 1], 24 | ], 25 | [ 26 | [0, 0, 0, 0, 0], 27 | [0, 0, 0, 0, 0], 28 | [0, 0, 0, 0, 0], 29 | [1, 1, 1, 1, 1], 30 | ], 31 | [ 32 | [0, 0, 0, 0, 0], 33 | [0, 0, 0, 0, 0], 34 | [0, 0, 0, 0, 0], 35 | [1, 1, 1, 1, 1], 36 | ], 37 | ] 38 | 39 | assert in_tensor.grad is not None 40 | assert in_tensor.grad.close_to(correct) 41 | 42 | 43 | def test_can_rmin(): 44 | in_tensor = Tensor(np.arange(3 * 4 * 5).reshape(3, 4, 5)) 45 | 46 | out_tensor = ReduceMin()(in_tensor, "ijk->ik") 47 | 48 | assert out_tensor.shape == (3, 5) 49 | assert out_tensor.close_to( 50 | [[0, 1, 2, 3, 4], [20, 21, 22, 23, 24], [40, 41, 42, 43, 44]] 51 | ) 52 | 53 | out_tensor.backward() 54 | 55 | correct = [ 56 | [ 57 | [1, 1, 1, 1, 1], 58 | [0, 0, 0, 0, 0], 59 | [0, 0, 0, 0, 0], 60 | [0, 0, 0, 0, 0], 61 | ], 62 | [ 63 | [1, 1, 1, 1, 1], 64 | [0, 0, 0, 0, 0], 65 | [0, 0, 0, 0, 0], 66 | [0, 0, 0, 0, 0], 67 | ], 68 | [ 69 | [1, 1, 1, 1, 1], 70 | [0, 0, 0, 0, 0], 71 | [0, 0, 0, 0, 0], 72 | [0, 0, 0, 0, 0], 73 | ], 74 | ] 75 | 76 | assert in_tensor.grad.close_to(correct) 77 | -------------------------------------------------------------------------------- /tests/test_simple_neural_network.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import pytest 5 | from sklearn.datasets import load_iris 6 | 7 | from tricycle import GPU_ENABLED 8 | from tricycle.activation import ReLU 9 | from tricycle.dataset import InfiniteBatchDataset 10 | from tricycle.layers import Dense, Sequential 11 | from tricycle.loss import CrossEntropy 12 | from tricycle.optimisers import StochasticGradientDescent 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | slow_test = pytest.mark.skipif( 17 | "not config.getoption('--run-slow')", 18 | reason="Only run when --run-slow is given", 19 | ) 20 | 21 | 22 | def test_can_train_simple_neural_network(): 23 | """ 24 | Train a simple neural network on the iris dataset 25 | """ 26 | BATCH_SIZE = 64 27 | LEARNING_RATE = 3e-2 28 | N_STEPS = 100 29 | 30 | np.random.seed(42) 31 | X, y = load_iris(return_X_y=True) 32 | 33 | # one hot encode y 34 | y = y.astype(int) 35 | 36 | # create a dataset 37 | ds = InfiniteBatchDataset(X, y, batch_size=BATCH_SIZE) 38 | 39 | # create a model 40 | layer_1 = Dense(4, 16) 41 | layer_2 = Dense(16, 3) 42 | relu = ReLU() 43 | model = Sequential(layer_1, relu, layer_2) 44 | loss_fn = CrossEntropy() 45 | optimiser = StochasticGradientDescent(learning_rate=LEARNING_RATE) 46 | 47 | losses = [] 48 | 49 | # sourcery skip: no-loop-in-tests 50 | # sourcery skip: no-conditionals-in-tests 51 | i = 0 52 | batches = ds.to_tensor() 53 | for step, (x_in, y_out) in enumerate(batches): 54 | if step > N_STEPS: 55 | break 56 | 57 | y_pred = model(x_in) 58 | loss = loss_fn(y_out, y_pred) 59 | loss.backward() 60 | losses.append(loss) 61 | 62 | model.update(optimiser) 63 | model.zero_grad() 64 | i += 1 65 | 66 | # Final loss should be 0.45 but we need to account for randomness 67 | assert losses[-1] < 0.6 68 | 69 | 70 | @slow_test 71 | def test_can_train_simple_neural_network_gpu(): 72 | """ 73 | Train a simple neural network on the iris dataset 74 | """ 75 | if not GPU_ENABLED: 76 | pytest.skip() 77 | 78 | BATCH_SIZE = 64 79 | LEARNING_RATE = 3e-2 80 | N_STEPS = 100 81 | 82 | np.random.seed(42) 83 | X, y = load_iris(return_X_y=True) 84 | 85 | # one hot encode y 86 | y = np.eye(3)[y.astype(int)] 87 | 88 | # create a dataset 89 | ds = InfiniteBatchDataset(X, y, batch_size=BATCH_SIZE) 90 | 91 | # create a model 92 | layer_1 = Dense(4, 16) 93 | layer_2 = Dense(16, 3) 94 | relu = ReLU() 95 | model = Sequential(layer_1, relu, layer_2) 96 | model.to_gpu() 97 | loss_fn = CrossEntropy() 98 | optimiser = StochasticGradientDescent(learning_rate=LEARNING_RATE) 99 | 100 | losses = [] 101 | 102 | # sourcery skip: no-loop-in-tests 103 | # sourcery skip: no-conditionals-in-tests 104 | i = 0 105 | batches = ds.to_tensor() 106 | for step, (x_in, y_out) in enumerate(batches): 107 | if step > N_STEPS: 108 | break 109 | x_in = x_in.to_gpu() 110 | y_out = y_out.to_gpu() 111 | 112 | y_pred = model(x_in) 113 | loss = loss_fn(y_out, y_pred).from_batched().einsum("a->") / BATCH_SIZE 114 | loss.backward() 115 | 116 | model.update(optimiser) 117 | model.zero_grad() 118 | losses.append(loss.from_gpu()) 119 | i += 1 120 | 121 | # Final loss should be 0.45 but we need to account for randomness 122 | assert losses[-1] < 0.6 123 | -------------------------------------------------------------------------------- /tests/test_tensor_api.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | 5 | from tricycle.binary import ( 6 | BinaryAdd, 7 | BinaryDivide, 8 | BinaryMultiply, 9 | BinarySubtract, 10 | ) 11 | from tricycle.ops import Tensor 12 | from tricycle.unary import UnaryAdd, UnaryDivide, UnaryMultiply, UnarySubtract 13 | 14 | 15 | def test_can_add_tensors(): 16 | tensor_1 = Tensor(np.arange(12).reshape(3, 4)) 17 | tensor_2 = Tensor(np.arange(12).reshape(3, 4)) 18 | 19 | assert (tensor_1 + 1).close_to(UnaryAdd()(tensor_1, 1)) 20 | 21 | assert (tensor_1 + tensor_2).close_to(BinaryAdd()(tensor_1, tensor_2)) 22 | 23 | before = deepcopy(tensor_1) 24 | tensor_1 += 1 25 | 26 | assert tensor_1.close_to(UnaryAdd()(before, 1)) 27 | 28 | 29 | def test_can_subtract_tensors(): 30 | tensor_1 = Tensor(np.arange(12).reshape(3, 4)) 31 | tensor_2 = Tensor(np.arange(12).reshape(3, 4)) 32 | 33 | assert (tensor_1 - 1).close_to(UnarySubtract()(tensor_1, 1)) 34 | 35 | assert (tensor_1 - tensor_2).close_to(BinarySubtract()(tensor_1, tensor_2)) 36 | 37 | before = deepcopy(tensor_1) 38 | tensor_1 -= 1 39 | 40 | assert (tensor_1).close_to(UnarySubtract()(before, 1)) 41 | 42 | 43 | def test_can_multiply_tensors(): 44 | tensor_1 = Tensor(np.arange(12).reshape(3, 4)) 45 | tensor_2 = Tensor(np.arange(12).reshape(3, 4)) 46 | 47 | assert (tensor_1 * 2).close_to(UnaryMultiply()(tensor_1, 2)) 48 | 49 | assert (tensor_1 * tensor_2).close_to(BinaryMultiply()(tensor_1, tensor_2)) 50 | 51 | before = deepcopy(tensor_1) 52 | tensor_1 *= 2 53 | 54 | assert (tensor_1).close_to(UnaryMultiply()(before, 2)) 55 | 56 | 57 | def test_can_divide_tensors(): 58 | tensor_1 = Tensor(np.arange(1, 13).reshape(3, 4).astype(float)) 59 | tensor_2 = Tensor(np.arange(1, 13).reshape(3, 4).astype(float)) 60 | 61 | assert (2 / tensor_1).close_to(UnaryDivide()(2, tensor_1)) 62 | 63 | assert (tensor_1 / tensor_2).close_to(BinaryDivide()(tensor_1, tensor_2)) 64 | 65 | 66 | def test_can_pow_tensors(): 67 | tensor_1 = Tensor(np.arange(12).reshape(3, 4)) 68 | 69 | assert (tensor_1**2).close_to(pow(tensor_1, 2)) 70 | -------------------------------------------------------------------------------- /tests/test_tensor_pbt.py: -------------------------------------------------------------------------------- 1 | """ 2 | A lot of features of Tricycle are quite hard to test. For example, there are 3 | a lot of different ways to combine `Op`s. 4 | 5 | To fix this, this file contains several property-based tests where, instead 6 | of defining a specific situation and checking the output, inputs are generated 7 | randomly and then the outputs are checked against some predefined properties. 8 | 9 | For example, in our tokeniser, we want decode(encode()) to return 10 | the input so we can build a property based test that tries a whole bunch 11 | of random inputs and checks that they are unmodified by the decode(encode()) 12 | operation. 13 | """ 14 | 15 | import numbers 16 | from warnings import warn 17 | 18 | import hypothesis.strategies as st 19 | import numpy as np 20 | import pytest 21 | from hypothesis import assume, given, settings 22 | from hypothesis.extra import numpy as xp 23 | 24 | from tricycle import GPU_ENABLED 25 | from tricycle.binary import ( 26 | BinaryAdd, 27 | BinaryDivide, 28 | BinaryMax, 29 | BinaryMin, 30 | BinaryMultiply, 31 | BinarySubtract, 32 | ) 33 | from tricycle.einsum import EinsumBackOp 34 | from tricycle.tensor import Tensor 35 | from tricycle.tokeniser import BPETokeniser 36 | from tricycle.unary import ( 37 | UnaryAdd, 38 | UnaryCos, 39 | UnaryDivide, 40 | UnaryExp, 41 | UnaryLog, 42 | UnaryMax, 43 | UnaryMin, 44 | UnaryMultiply, 45 | UnaryPower, 46 | UnarySin, 47 | UnarySquareRoot, 48 | UnarySubtract, 49 | nothing, 50 | ) 51 | from tricycle.utils import shapes_match 52 | 53 | 54 | @st.composite 55 | def scalar(draw): 56 | """ 57 | Generate a single, initial scalar 58 | """ 59 | group = draw(st.sampled_from(["int", "float", "complex"])) 60 | if group == "int": 61 | return draw(st.integers()) 62 | if group == "float": 63 | return draw(st.floats()) 64 | if group == "complex": 65 | return draw(st.complex_numbers()) 66 | 67 | 68 | @st.composite 69 | def string(draw): 70 | return draw(st.text()) 71 | 72 | 73 | @st.composite 74 | def integer(draw): 75 | return draw(st.integers(min_value=1, max_value=1024)) 76 | 77 | 78 | @st.composite 79 | def unary_op(draw): 80 | """ 81 | Generate a single, initial unary operation 82 | """ 83 | ops = [ 84 | UnarySin(), 85 | UnaryCos(), 86 | UnaryExp(), 87 | UnaryLog(), 88 | UnarySquareRoot(), 89 | ] 90 | needs_constant = [ 91 | UnaryAdd(), 92 | UnarySubtract(), 93 | UnaryMultiply(), 94 | UnaryPower(), 95 | UnaryDivide(), 96 | UnaryMax(), 97 | UnaryMin(), 98 | ] 99 | op = draw(st.sampled_from(ops)) 100 | if op in needs_constant: 101 | constant = draw(scalar()) 102 | return op, constant 103 | return op, None 104 | 105 | 106 | @st.composite 107 | def binary_op(draw): 108 | """ 109 | Generate a single, initial binary operation 110 | """ 111 | ops = [ 112 | BinaryAdd(), 113 | BinaryDivide(), 114 | BinaryMax(), 115 | BinaryMin(), 116 | BinaryMultiply(), 117 | BinarySubtract(), 118 | ] 119 | return draw(st.sampled_from(ops)) 120 | 121 | 122 | @st.composite 123 | def tensor(draw): 124 | """ 125 | Generate a single, initial tensor (not as the result of an operation) 126 | """ 127 | shape = draw( 128 | xp.array_shapes(min_dims=1, max_dims=4, min_side=1, max_side=32) 129 | ) 130 | data = draw(xp.arrays(dtype=np.float32, shape=shape)) 131 | match len(shape): 132 | case 1: 133 | is_batched = False 134 | case 2: 135 | is_batched = draw(st.booleans()) 136 | case 3: 137 | is_batched = draw(st.booleans()) 138 | case 4: 139 | is_batched = True 140 | requires_grad = draw(st.booleans()) 141 | return Tensor( 142 | data, 143 | is_batched=is_batched, 144 | requires_grad=requires_grad, 145 | ) 146 | 147 | 148 | @st.composite 149 | def small_tensor(draw): 150 | """ 151 | Generate a single, initial tensor (not as the result of an operation). 152 | The tensor can be 1, 2 or 3d 153 | """ 154 | shape = draw(st.integers(min_value=1, max_value=4)) 155 | data = draw(xp.arrays(dtype=np.float64, shape=shape)) 156 | is_batched = len(shape) in {3, 4} 157 | requires_grad = draw(st.booleans()) 158 | if GPU_ENABLED: 159 | on_gpu = draw(st.booleans()) 160 | else: 161 | warn("GPU_ENABLED = False so GPU tests have been disabled") 162 | on_gpu = False 163 | 164 | tensor = Tensor( 165 | data, 166 | is_batched=is_batched, 167 | requires_grad=requires_grad, 168 | ) 169 | if on_gpu: 170 | tensor = tensor.to_gpu() 171 | return tensor 172 | 173 | 174 | @st.composite 175 | def tensor_pair_same_shape(draw): 176 | """ 177 | Generate two tensors with the same shape 178 | """ 179 | shape = draw(st.integers(min_value=1, max_value=10)) 180 | if isinstance(shape, int): 181 | shape = (shape,) 182 | 183 | tensors = [] 184 | for _ in range(2): 185 | data = draw(xp.arrays(dtype=np.float64, shape=shape)) 186 | is_batched = draw(st.booleans()) 187 | 188 | if draw(st.booleans()): 189 | data = data[1:] 190 | 191 | tensor = Tensor(data, is_batched=is_batched) 192 | tensors.append(tensor) 193 | 194 | return tensors 195 | 196 | 197 | @given(tensor_pair_same_shape()) 198 | def test_tensor_addition_same_shape(tensors): 199 | # sourcery skip: no-conditionals-in-tests 200 | tensor_1, tensor_2 = tensors 201 | 202 | try: 203 | _shapes_match = shapes_match(tensor_1, tensor_2) 204 | except ValueError: 205 | _shapes_match = False 206 | 207 | # sourcery skip: no-conditionals-in-tests 208 | if not _shapes_match: 209 | assume(False) 210 | 211 | result = tensor_1 + tensor_2 212 | largest_input_shape = max(tensor_1.shape, tensor_2.shape) 213 | assert result.shape == largest_input_shape 214 | 215 | assert result.args == (tensor_1, tensor_2) 216 | assert result.back_fns == (nothing, nothing) 217 | 218 | assert result.is_batched == tensor_1.is_batched or tensor_2.is_batched 219 | 220 | 221 | @given(tensor(), scalar()) 222 | def test_tensor_addition_scalar(tensor, scalar): 223 | assume(isinstance(scalar, numbers.Number)) 224 | try: 225 | assume(abs(scalar) < 2**64) 226 | except OverflowError: 227 | assume(False) 228 | assume(not isinstance(scalar, np.datetime64)) 229 | assume(not isinstance(scalar, np.timedelta64)) 230 | 231 | result = tensor + scalar 232 | assert result.shape == tensor.shape 233 | 234 | assert result.args == (tensor,) 235 | assert result.back_fns == (nothing,) 236 | 237 | assert result.is_batched == tensor.is_batched 238 | 239 | 240 | @given(tensor_pair_same_shape()) 241 | def test_tensor_multiplication(tensors): 242 | # sourcery skip: no-conditionals-in-tests 243 | tensor_1, tensor_2 = tensors 244 | 245 | try: 246 | _shapes_match = shapes_match(tensor_1, tensor_2) 247 | except ValueError: 248 | _shapes_match = False 249 | 250 | # sourcery skip: no-conditionals-in-tests 251 | if not _shapes_match: 252 | assume(False) 253 | 254 | result = tensor_1 * tensor_2 255 | largest_input_shape = max(tensor_1.shape, tensor_2.shape) 256 | assert result.shape == largest_input_shape 257 | 258 | assert result.args == (tensor_1, tensor_2) 259 | assert len(result.back_fns) == 2 260 | 261 | assert isinstance(result.back_fns[0], EinsumBackOp) 262 | assert isinstance(result.back_fns[1], EinsumBackOp) 263 | 264 | assert result.is_batched == tensor_1.is_batched or tensor_2.is_batched 265 | 266 | 267 | @given(tensor()) 268 | def test_close_to(tensor): 269 | equal_nan = np.isnan(tensor.array).any() 270 | 271 | assert tensor.close_to(tensor, equal_nan=equal_nan, rtol=1e-6, atol=1e-8) 272 | 273 | 274 | @given(tensor()) 275 | def test_can_batch_and_unbatch(tensor): 276 | assume(not tensor.is_batched) 277 | 278 | batched = tensor.to_batched() 279 | assert batched.is_batched 280 | 281 | unbatched = batched.from_batched() 282 | assert not unbatched.is_batched 283 | 284 | assert tensor.close_to(unbatched, equal_nan=True) 285 | 286 | # sourcery skip: no-conditionals-in-tests 287 | if tensor.requires_grad: 288 | assert len(unbatched.args) == 1 289 | assert unbatched.args[0].close_to(tensor, equal_nan=True) 290 | 291 | assert len(unbatched.args[0].args) == 1 292 | assert unbatched.args[0].args[0].close_to(tensor, equal_nan=True) 293 | 294 | assert unbatched.requires_grad 295 | 296 | 297 | @given(tensor()) 298 | def test_can_move_to_and_from_gpu(tensor): 299 | # only run this test if we have a gpu enabled 300 | if not GPU_ENABLED: 301 | pytest.skip("GPU not enabled") 302 | assume(not tensor.on_gpu) 303 | 304 | gpu_tensor = tensor.to_gpu() 305 | assert gpu_tensor.on_gpu 306 | 307 | cpu_tensor = gpu_tensor.from_gpu() 308 | assert not cpu_tensor.on_gpu 309 | 310 | 311 | @given(tensor(), unary_op()) 312 | def test_unary_ops(tensor, op): 313 | # sourcery skip: no-conditionals-in-tests 314 | op, constant = op 315 | if constant is not None: 316 | try: 317 | assume(abs(constant) < 2**64) 318 | except OverflowError: 319 | assume(False) 320 | result = op(tensor=tensor, constant=constant) 321 | else: 322 | result = op(tensor) 323 | assert result.shape == tensor.shape 324 | assert result.is_batched == tensor.is_batched 325 | assert result.on_gpu == tensor.on_gpu 326 | 327 | 328 | @given(tensor_pair_same_shape(), binary_op()) 329 | def test_binary_ops(tensors, op): 330 | # sourcery skip: no-conditionals-in-tests 331 | tensor_1, tensor_2 = tensors 332 | 333 | try: 334 | _shapes_match = shapes_match(tensor_1, tensor_2) 335 | except ValueError: 336 | _shapes_match = False 337 | assume(_shapes_match) 338 | 339 | result = op(tensor_1, tensor_2) 340 | 341 | assert result.shape in [tensor_1.shape, tensor_2.shape] 342 | assert result.is_batched == any([tensor_1.is_batched, tensor_2.is_batched]) 343 | assert result.on_gpu == any([tensor_1.on_gpu, tensor_2.on_gpu]) 344 | 345 | 346 | @given(string()) 347 | @settings(deadline=1000) 348 | def test_tokeniser_train_encode_decode(text): 349 | tokeniser = BPETokeniser(vocab_size=1024) 350 | 351 | tokeniser.train(text) 352 | 353 | encoded = tokeniser.encode(text) 354 | assert np.allclose(encoded, tokeniser.tokens) 355 | 356 | decoded = tokeniser.decode(encoded) 357 | assert text == decoded 358 | -------------------------------------------------------------------------------- /tests/test_tokeniser.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from tricycle.tokeniser import BPETokeniser, count_pairs, replace_pair 5 | 6 | slow_test = pytest.mark.skipif( 7 | "not config.getoption('--run-slow')", 8 | reason="Only run when --run-slow is given", 9 | ) 10 | 11 | 12 | def test_count_pairs(): 13 | data = np.array([0, 0, 1, 0, 1], dtype=np.int32) 14 | 15 | got = count_pairs(data=data, token_id=1) 16 | want = np.array([1, 2, 1, 0]) 17 | 18 | assert np.allclose(got, want) 19 | 20 | 21 | def test_replace_pair(): 22 | to_replace = (1, 2) 23 | data = np.array([1, 1, 2, 1, 2, 1]) 24 | 25 | got = replace_pair(data, to_replace, 3) 26 | want = np.array([1, 3, 3, 1]) 27 | 28 | assert np.allclose(got, want) 29 | 30 | 31 | def test_replace_pair_when_final_tokens_are_pair(): 32 | tokeniser = BPETokeniser(256) 33 | 34 | to_replace = (1, 2) 35 | data = np.array([1, 1, 2, 1, 2]) 36 | 37 | got = tokeniser.replace_pair(data, to_replace, 3) 38 | want = np.array([1, 3, 3]) 39 | 40 | assert np.allclose(got, want) 41 | 42 | 43 | def test_can_train_simple_text(): 44 | tokeniser = BPETokeniser(256 + 3) 45 | sample_text = "aababa" 46 | 47 | with pytest.warns(UserWarning): 48 | tokeniser.train(sample_text) 49 | 50 | assert tokeniser.vocab_size == 256 + 3 51 | 52 | assert (ord("a"), ord("b")) in tokeniser.merges 53 | assert (ord("a"), ord("b")) in tokeniser.pairs 54 | 55 | assert len(tokeniser.merges) == len(tokeniser.pairs) == 257 56 | 57 | 58 | def test_can_tokenise_simple_text(): 59 | tokeniser = BPETokeniser(257) 60 | tokeniser.merges[(ord("a"), ord("b"))] = 256 61 | 62 | sample_text = "aababa" 63 | got = tokeniser.encode(sample_text) 64 | want = np.array([ord("a"), 256, 256, ord("a")]) 65 | 66 | assert np.allclose(got, want) 67 | 68 | 69 | def test_can_tokenise_paragraph(): 70 | tokeniser = BPETokeniser(300) 71 | 72 | sample_text = """(Barry is picking out a shirt) 73 | Yellow, black. Yellow, black. 74 | Yellow, black. Yellow, black. 75 | : 76 | Ooh, black and yellow! 77 | Let's shake it up a little. 78 | """ 79 | tokeniser.train(sample_text) 80 | got = tokeniser.encode(sample_text) 81 | want = np.array( 82 | [ 83 | 40, 84 | 66, 85 | 97, 86 | 114, 87 | 114, 88 | 121, 89 | 271, 90 | 115, 91 | 32, 92 | 112, 93 | 105, 94 | 256, 95 | 105, 96 | 110, 97 | 103, 98 | 32, 99 | 111, 100 | 117, 101 | 116, 102 | 269, 103 | 275, 104 | 105, 105 | 114, 106 | 116, 107 | 41, 108 | 274, 109 | 274, 110 | 10, 111 | 32, 112 | 58, 113 | 10, 114 | 79, 115 | 111, 116 | 104, 117 | 263, 118 | 269, 119 | 110, 120 | 100, 121 | 32, 122 | 121, 123 | 265, 124 | 33, 125 | 10, 126 | 76, 127 | 101, 128 | 116, 129 | 39, 130 | 115, 131 | 275, 132 | 97, 133 | 107, 134 | 101, 135 | 271, 136 | 116, 137 | 32, 138 | 117, 139 | 112, 140 | 269, 141 | 32, 142 | 108, 143 | 105, 144 | 116, 145 | 116, 146 | 108, 147 | 101, 148 | 46, 149 | 10, 150 | ] 151 | ) 152 | assert np.allclose(got, want) 153 | 154 | 155 | def test_can_decode_tokens(): 156 | tokeniser = BPETokeniser(257) 157 | tokeniser.vocab.append(b"ab") 158 | 159 | sample_tokens = np.array([ord("a"), 256, 256, ord("a")]) 160 | got = tokeniser.decode(sample_tokens) 161 | want = "aababa" 162 | 163 | assert got == want 164 | 165 | 166 | def test_can_tokenise_longer_text(): 167 | tokeniser = BPETokeniser(1000) 168 | 169 | with open("datasets/bee_movie.txt", "r") as f: 170 | sample_text = f.read() 171 | 172 | tokeniser.train(sample_text) 173 | 174 | assert len(tokeniser.merges) == len(tokeniser.pairs) == 1000 175 | 176 | got = tokeniser.encode(sample_text) 177 | 178 | assert np.allclose(got[:10], [78, 279, 82, 65, 84, 829, 684, 66, 337, 386]) 179 | assert np.allclose( 180 | got[-10:], [644, 617, 339, 454, 266, 115, 600, 437, 468, 262] 181 | ) 182 | -------------------------------------------------------------------------------- /tests/test_unary_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tricycle.tensor import Tensor 4 | from tricycle.unary import ( 5 | UnaryAdd, 6 | UnaryCos, 7 | UnaryDivide, 8 | UnaryExp, 9 | UnaryLog, 10 | UnaryMask, 11 | UnaryMax, 12 | UnaryMin, 13 | UnaryMultiply, 14 | UnaryPower, 15 | UnarySin, 16 | UnarySquareRoot, 17 | UnarySubtract, 18 | ) 19 | 20 | 21 | def test_can_add(): 22 | in_tensor = Tensor([0, 1, 2, 3]) 23 | out_tensor = UnaryAdd()(in_tensor, 1) 24 | 25 | correct = np.array([1, 2, 3, 4]) 26 | 27 | assert out_tensor.close_to(correct) 28 | 29 | out_tensor.backward() 30 | 31 | correct = np.ones_like(correct) 32 | assert in_tensor.grad is not None 33 | assert in_tensor.grad.close_to(correct) 34 | 35 | 36 | def test_can_umul(): 37 | in_tensor = Tensor([0, 1, 2, 3]) 38 | out_tensor = UnaryMultiply()(in_tensor, 2) 39 | 40 | assert out_tensor.shape == (4,) 41 | assert out_tensor.close_to(np.array([0, 2, 4, 6])) 42 | 43 | out_tensor.backward() 44 | 45 | assert in_tensor.grad is not None 46 | assert in_tensor.grad.close_to([2, 2, 2, 2]) 47 | 48 | 49 | def test_can_usub(): 50 | in_tensor = Tensor([1, 2, 3, 4]) 51 | 52 | # subtract 1 from each element 53 | out_tensor = UnarySubtract()(in_tensor, 1) 54 | 55 | assert out_tensor.shape == (4,) 56 | assert out_tensor.close_to(np.array([0, 1, 2, 3])) 57 | 58 | out_tensor.backward() 59 | correct = np.ones(in_tensor.shape) 60 | 61 | assert in_tensor.grad is not None 62 | assert in_tensor.grad.close_to(correct) 63 | 64 | 65 | def test_can_upow(): 66 | in_tensor = Tensor([1, 2, 3, 4]) 67 | out_tensor = UnaryPower()(in_tensor, 3) 68 | 69 | assert out_tensor.shape == (4,) 70 | assert out_tensor.close_to(np.array([1, 8, 27, 64])) 71 | 72 | out_tensor.backward() 73 | correct = np.array([3, 12, 27, 48]) 74 | 75 | assert in_tensor.grad is not None 76 | assert in_tensor.grad.close_to(correct) 77 | 78 | 79 | def test_can_udiv(): 80 | # 2 divided by each element 81 | in_tensor = Tensor(np.arange(12, dtype=float).reshape(3, 4)) 82 | with np.errstate(divide="ignore"): 83 | out_tensor = UnaryDivide()(2, in_tensor) 84 | 85 | assert out_tensor.shape == (3, 4) 86 | assert out_tensor.close_to( 87 | [ 88 | [np.inf, 2, 1, 2 / 3], 89 | [2 / 4, 2 / 5, 2 / 6, 2 / 7], 90 | [2 / 8, 2 / 9, 2 / 10, 2 / 11], 91 | ], 92 | rtol=1e-3, 93 | ) 94 | with np.errstate(divide="ignore"): 95 | out_tensor.backward() 96 | correct = -np.power(in_tensor.array, -2) * 2 97 | 98 | assert in_tensor.grad is not None 99 | assert in_tensor.grad.close_to(correct, rtol=1e-3) 100 | 101 | 102 | def test_can_umax(): 103 | in_tensor = Tensor([1, 2, 3, 4]) 104 | out_tensor = UnaryMax()(in_tensor, 2) 105 | 106 | assert out_tensor.shape == (4,) 107 | assert out_tensor.close_to([2, 2, 3, 4]) 108 | 109 | out_tensor.backward() 110 | 111 | correct = [0, 0, 1, 1] 112 | 113 | assert in_tensor.grad is not None 114 | assert in_tensor.grad.close_to(correct) 115 | 116 | 117 | def test_can_umin(): 118 | in_tensor = Tensor([1, 2, 3, 4]) 119 | out_tensor = UnaryMin()(in_tensor, 3) 120 | 121 | assert out_tensor.shape == (4,) 122 | assert out_tensor.close_to([1, 2, 3, 3]) 123 | 124 | out_tensor.backward() 125 | 126 | correct = [1, 1, 0, 0] 127 | 128 | assert in_tensor.grad is not None 129 | assert in_tensor.grad.close_to(correct) 130 | 131 | 132 | def test_can_uexp(): 133 | in_tensor = Tensor([1, 2, 3, 4]) 134 | out_tensor = UnaryExp()(in_tensor) 135 | 136 | assert out_tensor.shape == (4,) 137 | 138 | correct = np.exp([1, 2, 3, 4]) 139 | assert out_tensor.close_to(correct, rtol=1e-3) 140 | 141 | out_tensor.backward() 142 | 143 | correct = np.exp([1, 2, 3, 4]) 144 | 145 | assert in_tensor.grad is not None 146 | assert in_tensor.grad.close_to(correct, rtol=1e-3) 147 | 148 | 149 | def test_can_ulog(): 150 | in_tensor = Tensor([1, 2, 3, 4]) 151 | out_tensor = UnaryLog()(in_tensor) 152 | 153 | assert out_tensor.shape == (4,) 154 | assert out_tensor.close_to([0, np.log(2), np.log(3), np.log(4)], rtol=1e-3) 155 | 156 | out_tensor.backward() 157 | 158 | correct = [1, 1 / 2, 1 / 3, 1 / 4] 159 | 160 | assert in_tensor.grad is not None 161 | assert in_tensor.grad.close_to(correct, rtol=1e-3) 162 | 163 | 164 | def test_can_usin(): 165 | in_tensor = Tensor([1, 2, 3, 4]) 166 | out_tensor = UnarySin()(in_tensor) 167 | 168 | assert out_tensor.shape == (4,) 169 | assert out_tensor.close_to( 170 | [np.sin(1), np.sin(2), np.sin(3), np.sin(4)], rtol=1e-3 171 | ) 172 | 173 | out_tensor.backward() 174 | 175 | assert in_tensor.grad is not None 176 | assert in_tensor.grad.close_to( 177 | [np.cos(1), np.cos(2), np.cos(3), np.cos(4)], rtol=1e-3 178 | ) 179 | 180 | 181 | def test_can_ucos(): 182 | in_tensor = Tensor([1, 2, 3, 4]) 183 | out_tensor = UnaryCos()(in_tensor) 184 | 185 | assert out_tensor.shape == (4,) 186 | assert out_tensor.close_to( 187 | [np.cos(1), np.cos(2), np.cos(3), np.cos(4)], rtol=1e-3 188 | ) 189 | 190 | out_tensor.backward() 191 | 192 | correct = [-np.sin(1), -np.sin(2), -np.sin(3), -np.sin(4)] 193 | 194 | assert in_tensor.grad is not None 195 | assert in_tensor.grad.close_to(correct, rtol=1e-3) 196 | 197 | 198 | def test_can_usqrt(): 199 | in_tensor = Tensor([1, 2, 3, 4]) 200 | out_tensor = UnarySquareRoot()(in_tensor) 201 | 202 | assert out_tensor.shape == (4,) 203 | assert out_tensor.close_to( 204 | [1, np.sqrt(2), np.sqrt(3), np.sqrt(4)], rtol=1e-3 205 | ) 206 | 207 | out_tensor.backward() 208 | 209 | correct = [0.5, 0.35355339, 0.28867513, 0.25] 210 | 211 | assert in_tensor.grad is not None 212 | assert in_tensor.grad.close_to(correct, rtol=1e-3) 213 | 214 | 215 | def test_can_bmask(): 216 | in_tensor = Tensor(np.arange(12).reshape(3, 4), is_batched=True) 217 | mask = Tensor( 218 | [[0, 0, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], 219 | is_batched=True, 220 | requires_grad=False, 221 | ) 222 | out_tensor = UnaryMask()(in_tensor, mask) 223 | 224 | assert out_tensor.shape == (3, 4) 225 | assert out_tensor.close_to([[0, 0, 0, 0], [4, 0, 6, 0], [8, 9, 10, 11]]) 226 | 227 | out_tensor.backward() 228 | 229 | assert mask.grad is None 230 | assert in_tensor.grad is not None 231 | assert in_tensor.grad.close_to([[0, 0, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) 232 | -------------------------------------------------------------------------------- /tests/test_vectorise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tricycle.activation import ReLU 4 | from tricycle.einsum import Einsum 5 | from tricycle.functions import Softmax 6 | from tricycle.layers import Dense, Sequential 7 | from tricycle.loss import MeanSquaredError 8 | from tricycle.tensor import Tensor 9 | from tricycle.unary import Batch, Unbatch 10 | 11 | 12 | def test_can_batch_single_einsum(): 13 | input_1 = np.arange(1, 4) 14 | input_2 = np.arange(2, 5) 15 | input_3 = np.arange(3, 6) 16 | 17 | op = Einsum("a->") 18 | 19 | output_1 = op(Tensor(input_1)) 20 | output_2 = op(Tensor(input_2)) 21 | output_3 = op(Tensor(input_3)) 22 | 23 | assert output_1 == 6 24 | assert output_2 == 9 25 | assert output_3 == 12 26 | 27 | input_batch = Tensor([input_1, input_2, input_3]) 28 | input_batch = Batch()(input_batch) 29 | op = Einsum("a->") 30 | output_batch = op(input_batch) 31 | output_batch = Unbatch()(output_batch) 32 | 33 | assert output_batch.close_to([6, 9, 12]) 34 | 35 | 36 | def test_can_batch_entire_model(): 37 | np.random.seed(42) 38 | layer_1 = Dense(4, 16) 39 | layer_2 = Dense(16, 3) 40 | relu = ReLU() 41 | model = Sequential(layer_1, relu, layer_2) 42 | 43 | input_1 = np.arange(1, 5) 44 | input_2 = np.arange(2, 6) 45 | input_3 = np.arange(3, 7) 46 | 47 | output_1 = model(Tensor(input_1)) 48 | output_2 = model(Tensor(input_2)) 49 | output_3 = model(Tensor(input_3)) 50 | 51 | input_batch = Tensor([input_1, input_2, input_3]) 52 | correct_output = Tensor([output_1.array, output_2.array, output_3.array]) 53 | 54 | input_batch = Batch()(input_batch) 55 | correct_output = Batch()(correct_output) 56 | output_batch = model(input_batch) 57 | output_batch = Unbatch()(output_batch) 58 | 59 | assert output_batch.close_to(correct_output) 60 | 61 | 62 | def test_can_batch_mse(): 63 | y_true = Tensor([0, 0, 1, 0]) 64 | 65 | input_1 = Tensor(np.arange(1, 5)) 66 | input_2 = Tensor(np.arange(2, 6)) 67 | input_3 = Tensor(np.arange(3, 7)) 68 | 69 | output_1 = MeanSquaredError()(y_true, input_1) 70 | output_2 = MeanSquaredError()(y_true, input_2) 71 | output_3 = MeanSquaredError()(y_true, input_3) 72 | 73 | input_y_true = Tensor(np.array([y_true.array] * 3)) 74 | input_batch = Tensor( 75 | np.array([input_1.array, input_2.array, input_3.array]) 76 | ) 77 | correct_output = Tensor( 78 | np.array([output_1.array, output_2.array, output_3.array]).sum() 79 | ) 80 | 81 | input_y_true = Batch()(input_y_true) 82 | input_batch = Batch()(input_batch) 83 | output_batch = MeanSquaredError()(input_y_true, input_batch) 84 | output_batch = Unbatch()(output_batch) 85 | 86 | assert output_batch.close_to(correct_output) 87 | 88 | 89 | def test_can_batch_softmax(): 90 | input_1 = Tensor(np.arange(1, 5)) 91 | input_2 = Tensor(np.arange(2, 6)) 92 | input_3 = Tensor(np.arange(3, 7)) 93 | 94 | output_1 = Softmax()(input_1) 95 | output_2 = Softmax()(input_2) 96 | output_3 = Softmax()(input_3) 97 | 98 | input_batch = Tensor( 99 | np.array([input_1.array, input_2.array, input_3.array]) 100 | ) 101 | correct_output = Tensor( 102 | np.array([output_1.array, output_2.array, output_3.array]) 103 | ) 104 | 105 | input_batch = Batch()(input_batch) 106 | output_batch = Softmax()(input_batch) 107 | output_batch = Unbatch()(output_batch) 108 | 109 | assert output_batch.close_to(correct_output) 110 | 111 | 112 | def test_can_batch_split(): 113 | in_tensor = Tensor( 114 | [[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]], name="in_tensor" 115 | ) 116 | 117 | out_tensors = in_tensor.to_batched().split(3) 118 | 119 | assert len(out_tensors) == 3 120 | assert out_tensors[0].shape == (2, 2) 121 | assert out_tensors[1].shape == (2, 2) 122 | assert out_tensors[2].shape == (2, 2) 123 | 124 | assert out_tensors[0].close_to([[1, 2], [1, 2]]) 125 | assert out_tensors[1].close_to([[3, 4], [3, 4]]) 126 | assert out_tensors[2].close_to([[5, 6], [5, 6]]) 127 | 128 | assert out_tensors[0].is_batched 129 | assert out_tensors[1].is_batched 130 | assert out_tensors[2].is_batched 131 | 132 | out_tensors[0].backward() 133 | 134 | assert in_tensor.grad is not None 135 | assert in_tensor.grad.close_to([[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0]]) 136 | -------------------------------------------------------------------------------- /train_smol_gpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script for SmolGPT a replication of GPT-2 3 | 4 | The training script is pretty generic. You can tune the parameters in by 5 | modifying the config. 6 | 7 | Currently, we train the model on fineweb, a cleaned dump of ~10B tokens of web 8 | data 9 | """ 10 | 11 | import os 12 | import pickle 13 | import uuid 14 | from pathlib import Path 15 | 16 | from tricycle import GPU_ENABLED 17 | from tricycle.context import TRICYCLE_CONTEXT 18 | from tricycle.ops import Op 19 | from tricycle.tensor import Tensor 20 | from tricycle.utils import UseMixedPrecision, optimal_n_tokens 21 | 22 | if GPU_ENABLED: 23 | import cupy as xp 24 | else: 25 | import numpy as xp 26 | 27 | import mlflow 28 | from tqdm import tqdm 29 | 30 | from inference import get_sample 31 | from tricycle.configs import SmolGPTConfig 32 | from tricycle.dataset import CausalLMDataset 33 | from tricycle.loss import CrossEntropy 34 | from tricycle.models import GPT 35 | from tricycle.optimisers import AdamW 36 | from tricycle.scheduler import CosineSchedule 37 | from tricycle_datasets.fineweb import FineWeb 38 | 39 | # fix the seed for reproducibility 40 | xp.random.seed(0) 41 | config = SmolGPTConfig() 42 | 43 | 44 | def load_datasets(config: SmolGPTConfig): 45 | """ 46 | Load tokens, batch and shuffle them. 47 | """ 48 | 49 | # if you are loading this for the first time, this can take a while. 50 | # it will create some big cache files in ~/.cache/huggingface that you 51 | # might want to clean up once you are done with the dataset 52 | print("Loading dataset") 53 | train_dataset = FineWeb(config.vocab_size, split="train") 54 | 55 | # we cant fit more than 3B indices in memory on my computer which is more 56 | # than we need anyway. 57 | # TODO: figure out how to shuffle without using much memory 58 | train_dataset.tokens = train_dataset.tokens[: int(3e9)] 59 | valid_dataset = FineWeb(config.vocab_size, split="valid") 60 | 61 | print("Loading dataloaders") 62 | train_dataloader = ( 63 | CausalLMDataset( 64 | tokens=train_dataset.tokens, 65 | vocab_size=train_dataset.vocab_size, 66 | batch_size=config.batch_size, 67 | context_window=config.context_window, 68 | ) 69 | .batch() 70 | .shuffle() # only shuffle train dataset. 71 | .to_tensor() 72 | ) 73 | valid_dataloader = ( 74 | CausalLMDataset( 75 | tokens=valid_dataset.tokens, 76 | vocab_size=valid_dataset.vocab_size, 77 | batch_size=config.batch_size, 78 | context_window=config.context_window, 79 | ) 80 | .batch() 81 | .to_tensor() 82 | ) 83 | return ( 84 | train_dataset, 85 | valid_dataset, 86 | train_dataloader, 87 | valid_dataloader, 88 | ) 89 | 90 | 91 | def estimate_loss( 92 | model: GPT, 93 | valid_dataloader: CausalLMDataset, 94 | config: SmolGPTConfig, 95 | loss_fn: Op, 96 | ) -> float: 97 | """ 98 | Run the model on the validation dataset to estimate its loss 99 | """ 100 | batch_loss = 0 101 | for valid_step, (inputs, outputs) in tqdm( 102 | enumerate(valid_dataloader), total=config.eval_steps, desc="Validation" 103 | ): 104 | if valid_step == config.eval_steps: 105 | break 106 | 107 | assert isinstance(inputs, Tensor) 108 | assert isinstance(outputs, Tensor) 109 | if GPU_ENABLED: 110 | inputs = inputs.to_gpu(config.device_idx) 111 | outputs = outputs.to_gpu(config.device_idx) 112 | 113 | # forward pass 114 | logits = model(inputs) 115 | loss = loss_fn(outputs, logits) 116 | batch_loss += loss.array / config.eval_steps 117 | 118 | return batch_loss 119 | 120 | 121 | def validate( 122 | model: GPT, 123 | valid_dataset: CausalLMDataset, 124 | config: SmolGPTConfig, 125 | loss_fn: Op, 126 | best_loss: float, 127 | ): 128 | """ 129 | Check the performance of the model on validation data. Both in terms of 130 | loss and by generating some sample text and storing it in MLFlow. 131 | 132 | If the new validation loss is better than the previous validation loss, 133 | save the model to disk 134 | """ 135 | # generate some text 136 | predicted = get_sample( 137 | model=model, 138 | tokeniser=valid_dataset.tokeniser, 139 | sample_tokens=valid_dataset.tokens[: config.context_window], 140 | ) 141 | mlflow.log_text(predicted, f"generated/{step}.txt") 142 | 143 | # esimate validation loss 144 | valid_loss = estimate_loss( 145 | model=model, 146 | valid_dataloader=valid_dataloader, 147 | config=config, 148 | loss_fn=loss_fn, 149 | ) 150 | mlflow.log_metric("valid_loss", valid_loss, step=step) 151 | 152 | # checkpoint if new model better than old 153 | if valid_loss < best_loss: 154 | Path("models").mkdir(exist_ok=True) 155 | with open(f"models/model_{unique_id}.pkl", "wb") as f: 156 | pickle.dump(model, f) 157 | best_loss = valid_loss 158 | return best_loss 159 | 160 | 161 | # Create our model from the config 162 | model = GPT(config) 163 | model.display() 164 | 165 | # Use corrected Chinchilla scaling to estimate the compute-optimal number of 166 | # tokens and steps we should train for 167 | n_tokens, n_steps = optimal_n_tokens(model, config) 168 | 169 | loss_fn = CrossEntropy() 170 | scheduler = CosineSchedule( 171 | max_learning_rate=config.max_learning_rate, 172 | min_learning_rate=config.min_learning_rate, 173 | warmup_steps=config.warmup_steps, 174 | total_steps=n_steps, 175 | ) 176 | optimiser = AdamW( 177 | learning_rate=scheduler(0), 178 | weight_decay=config.weight_decay, 179 | betas=(config.beta1, config.beta2), 180 | ) 181 | 182 | train_dataset, valid_dataset, train_dataloader, valid_dataloader = ( 183 | load_datasets(config) 184 | ) 185 | 186 | if GPU_ENABLED: 187 | model.to_gpu(config.device_idx) 188 | 189 | 190 | # start tracking the experiment in mlflow 191 | mlflow.set_tracking_uri(config.mlflow_tracking_uri) 192 | mlflow.set_experiment("SmolGPT:fineweb:base") 193 | os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true" 194 | with mlflow.start_run() as run, UseMixedPrecision(): 195 | unique_id = uuid.uuid4() 196 | 197 | best_loss = xp.inf 198 | 199 | losses = xp.zeros(n_steps) 200 | for step in tqdm(range(n_steps), position=0): 201 | mlflow.log_params(config.dict()) 202 | 203 | optimiser.step() 204 | batch_loss = 0 205 | 206 | # perform several forward and backward passes before doing a gradient 207 | # update to increase the effective batch size 208 | for _ in range(config.gradient_accumulation_steps): 209 | inputs, outputs = next(train_dataloader) 210 | assert isinstance(inputs, Tensor) 211 | assert isinstance(outputs, Tensor) 212 | if GPU_ENABLED: 213 | inputs = inputs.to_gpu(config.device_idx) 214 | outputs = outputs.to_gpu(config.device_idx) 215 | 216 | # forward and backward pass 217 | logits = model(inputs) 218 | loss = loss_fn(outputs, logits) 219 | batch_loss += loss.array / config.gradient_accumulation_steps 220 | loss.backward() 221 | 222 | # Use the optimiser to update weights 223 | model.update(optimiser) 224 | 225 | if TRICYCLE_CONTEXT.use_mixed_precision: 226 | mlflow.log_metric( 227 | "loss", 228 | batch_loss / TRICYCLE_CONTEXT.loss_scale_factor, 229 | step=step, 230 | ) 231 | else: 232 | mlflow.log_metric("loss", batch_loss, step=step) 233 | 234 | mlflow.log_metric("lr", float(optimiser.learning_rate), step=step) 235 | 236 | # step the learning rate 237 | optimiser.learning_rate = scheduler(step) 238 | 239 | # run validation every eval_intervals 240 | if step % config.eval_interval == 0: 241 | best_loss = validate( 242 | model=model, 243 | valid_dataset=valid_dataset, 244 | config=config, 245 | loss_fn=loss_fn, 246 | best_loss=best_loss, 247 | ) 248 | 249 | # run a final validation at the end of training 250 | validate( 251 | model=model, 252 | valid_dataset=valid_dataset, 253 | config=config, 254 | loss_fn=loss_fn, 255 | best_loss=best_loss, 256 | ) 257 | --------------------------------------------------------------------------------