├── .github ├── actions │ └── mike-docs │ │ └── action.yaml └── workflows │ ├── lint-and-test.yaml │ └── release.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── CONTRIBUTING.md ├── _images │ ├── aai-favicon.png │ └── aai-logo-cropped.png ├── _scripts │ └── gen_api_ref_pages.py ├── _styles │ ├── extra.css │ ├── neoteroi-mkdocs.css │ └── theme.css ├── _theme_overrides │ └── main.html ├── cli │ ├── cli.md │ ├── comparisons.md │ ├── fixtures.md │ ├── index.md │ └── pyproject.md ├── guides │ ├── benchmarks.md │ ├── customization.md │ ├── index.md │ ├── organization.md │ └── runners.md ├── index.md ├── quickstart.md └── tutorials │ ├── bq.md │ ├── duckdb.md │ ├── huggingface.md │ ├── index.md │ ├── mnist.md │ ├── prefect.md │ ├── prefect_resources │ ├── deployments.png │ ├── detail_flow_runs.png │ └── flow_runs.png │ ├── streamlit.md │ └── streamlit_resources │ └── initial_ui.png ├── examples ├── bq │ ├── benchmarks.py │ └── bq.py ├── huggingface │ ├── benchmark.py │ ├── runner.py │ └── training.py ├── mnist │ ├── benchmarks.py │ └── mnist.py ├── prefect │ ├── pyproject.toml │ └── src │ │ ├── __init__.py │ │ ├── benchmark.py │ │ ├── runner.py │ │ └── training.py ├── streamlit │ ├── pyproject.toml │ └── streamlit_example.py └── zenml │ ├── README.md │ ├── benchmarks.py │ └── pipeline.py ├── mkdocs.yml ├── pyproject.toml ├── src └── nnbench │ ├── __init__.py │ ├── __main__.py │ ├── cli.py │ ├── compare.py │ ├── config.py │ ├── context.py │ ├── core.py │ ├── fixtures.py │ ├── py.typed │ ├── reporter │ ├── __init__.py │ ├── console.py │ ├── file.py │ ├── mlflow.py │ ├── sqlite.py │ └── util.py │ ├── runner.py │ ├── types.py │ └── util.py ├── tests ├── __init__.py ├── benchmarks │ ├── argchecks.py │ ├── standard.py │ └── tags.py ├── cli │ ├── __init__.py │ ├── benchmarks │ │ ├── a.py │ │ └── b.py │ ├── conf.py │ └── test_parallel_exec.py ├── conftest.py ├── integration │ ├── __init__.py │ └── test_benchmark_family_memory_consumption.py ├── test_config.py ├── test_context.py ├── test_core.py ├── test_file_reporter.py ├── test_runner.py ├── test_types.py └── test_utils.py ├── uv.lock └── zizmor.yml /.github/actions/mike-docs/action.yaml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | description: Build and publish documentation using mike 3 | inputs: 4 | version: 5 | description: Version number 6 | required: true 7 | alias: 8 | description: Alias name 9 | push: 10 | description: Whether to push the built documentation to the repository 11 | required: true 12 | default: "false" 13 | pre_release: 14 | description: Whether this version is a pre-release version (to render a notification banner) 15 | default: "false" 16 | runs: 17 | using: "composite" 18 | steps: 19 | - run: | 20 | # https://github.com/jimporter/mike#deploying-via-ci 21 | git fetch origin gh-pages --depth=1 22 | 23 | # For proper UI integration: https://github.com/actions/checkout/pull/1184 24 | git config user.name "github-actions[bot]" 25 | git config user.email "41898282+github-actions[bot]@users.noreply.github.com" 26 | shell: bash 27 | - env: 28 | DOCS_PRERELEASE: ${{ inputs.pre_release }} 29 | INPUTS_PUSH: ${{ inputs.push }} 30 | INPUTS_VERSION: ${{ inputs.version }} 31 | INPUTS_ALIAS: ${{ inputs.alias }} 32 | run: | 33 | MIKE_OPTIONS=( "--update-aliases" ) 34 | if [ "true" = "${INPUTS_PUSH}" ]; then 35 | MIKE_OPTIONS+=( "--push" ) 36 | fi 37 | uv run mike deploy "${INPUTS_VERSION}" ${INPUTS_ALIAS} "${MIKE_OPTIONS[@]}" 38 | shell: bash 39 | -------------------------------------------------------------------------------- /.github/workflows/lint-and-test.yaml: -------------------------------------------------------------------------------- 1 | name: Lint and test nnbench 2 | 3 | permissions: {} 4 | 5 | on: 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | branches: 11 | - main 12 | 13 | jobs: 14 | lint: 15 | name: Run code checks and formatting hooks 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v4 19 | with: 20 | persist-credentials: false 21 | - name: Install uv 22 | uses: astral-sh/setup-uv@v5 23 | with: 24 | python-version: "3.10" 25 | - name: Run pre-commit checks 26 | run: uv run pre-commit run --all-files --verbose --show-diff-on-failure 27 | test: 28 | name: Test nnbench on ${{ matrix.os }} on Python ${{ matrix.python-version }} 29 | runs-on: ${{ matrix.os }} 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | os: [ ubuntu-latest, macos-latest, windows-latest ] 34 | python-version: [ "3.10", 3.11, 3.12, 3.13 ] 35 | steps: 36 | - uses: actions/checkout@v4 37 | with: 38 | persist-credentials: false 39 | - name: Install uv 40 | uses: astral-sh/setup-uv@v5 41 | with: 42 | python-version: ${{ matrix.python-version }} 43 | - name: Test with Python ${{ matrix.python-version }} 44 | run: uv run --frozen pytest -s 45 | docs: 46 | name: Publish latest documentation for nnbench 47 | runs-on: ubuntu-latest 48 | permissions: 49 | contents: write 50 | steps: 51 | - uses: actions/checkout@v4 52 | with: 53 | fetch-depth: 0 54 | persist-credentials: true # needed for mike to publish. 55 | - name: Install uv 56 | uses: astral-sh/setup-uv@v5 57 | with: 58 | python-version: "3.11" 59 | - name: Install the project 60 | run: uv sync --group docs 61 | - name: Build documentation using mike 62 | uses: ./.github/actions/mike-docs 63 | with: 64 | version: latest 65 | pre_release: true # include pre-release notification banner 66 | push: ${{ github.ref == 'refs/heads/main' }} # build always, publish on 'main' only to prevent version clutter 67 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Build and publish Python wheel and sdist 2 | 3 | permissions: {} 4 | 5 | on: 6 | workflow_dispatch: 7 | release: 8 | types: 9 | - published 10 | 11 | jobs: 12 | build: 13 | name: Build source distribution and wheel 14 | runs-on: ubuntu-latest 15 | permissions: 16 | # docs build pushes to the gh-pages branch 17 | contents: write 18 | steps: 19 | - name: Check out repository 20 | uses: actions/checkout@v4 21 | with: 22 | fetch-depth: 0 23 | persist-credentials: true 24 | - name: Install uv 25 | uses: astral-sh/setup-uv@v5 26 | with: 27 | python-version: "3.11" 28 | - name: Install the project 29 | run: uv sync --all-groups 30 | - name: Build and check 31 | run: uv run -m build 32 | - name: Upload build artifacts 33 | uses: actions/upload-artifact@v4 34 | with: 35 | name: dist 36 | path: dist 37 | if-no-files-found: error 38 | - name: Publish stable documentation for nnbench 39 | uses: ./.github/actions/mike-docs 40 | with: 41 | version: stable 42 | push: true 43 | publish_pypi: 44 | name: Publish wheels to PyPI 45 | needs: [build] 46 | runs-on: ubuntu-latest 47 | permissions: 48 | id-token: write 49 | steps: 50 | - name: Download build artifacts 51 | uses: actions/download-artifact@v4 52 | with: 53 | name: dist 54 | path: dist 55 | - name: Publish distribution 📦 to PyPI 56 | uses: pypa/gh-action-pypi-publish@release/v1 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # Installer logs 30 | pip-log.txt 31 | pip-delete-this-directory.txt 32 | 33 | # Unit test / coverage reports 34 | htmlcov/ 35 | .tox/ 36 | .nox/ 37 | .coverage 38 | .coverage.* 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | *.cover 43 | *.py,cover 44 | .hypothesis/ 45 | .pytest_cache/ 46 | cover/ 47 | pytest-junit.xml 48 | 49 | # Sphinx documentation 50 | docs/_build/ 51 | 52 | # Jupyter Notebook 53 | .ipynb_checkpoints 54 | 55 | # IPython 56 | profile_default/ 57 | ipython_config.py 58 | 59 | # pyenv 60 | # For a library or package, you might want to ignore these files since the code is 61 | # intended to run in multiple environments; otherwise, check them in: 62 | .python-version 63 | 64 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 65 | __pypackages__/ 66 | 67 | # Environments 68 | .env 69 | .venv 70 | env/ 71 | venv/ 72 | ENV/ 73 | env.bak/ 74 | venv.bak/ 75 | 76 | # mkdocs documentation 77 | /site 78 | 79 | # mypy 80 | .mypy_cache/ 81 | .dmypy.json 82 | dmypy.json 83 | mypy-report.xml 84 | 85 | # Pyre type checker 86 | .pyre/ 87 | 88 | # Cython debug symbols 89 | cython_debug/ 90 | 91 | # PyCharm 92 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 93 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 94 | # and can be added to the global gitignore or merged into this file. For a more nuclear 95 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 96 | .idea/ 97 | 98 | # VS Code 99 | .vscode/ 100 | 101 | # NPM-based tooling (Danger, ...) 102 | node_modules/ 103 | package.json 104 | yarn.lock 105 | package.lock 106 | 107 | # GitLab Pages 108 | public/ 109 | 110 | # Syft / SBOM Outputs 111 | cyclonedx.json 112 | spdx.json 113 | syft-output.json 114 | 115 | .pre-commit-cache 116 | .ruff_cache 117 | 118 | # direnv 119 | .envrc 120 | .direnv 121 | 122 | # Generated documentation 123 | public/ 124 | 125 | # NumPy zip archives created by the examples. 126 | *.npz 127 | 128 | # Memray flamegraphs and their HTML views 129 | *.bin 130 | *.html 131 | !docs/**/*.html 132 | 133 | # Written records 134 | *.json 135 | *.parquet 136 | *.csv 137 | *.ndjson 138 | *.db 139 | 140 | # Zed 141 | .ropeproject 142 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-added-large-files 6 | - id: check-json 7 | - id: check-toml 8 | - id: check-yaml 9 | exclude: "mkdocs.yml" 10 | - id: end-of-file-fixer 11 | - id: mixed-line-ending 12 | - repo: https://github.com/pre-commit/mirrors-mypy 13 | rev: v1.16.0 14 | hooks: 15 | - id: mypy 16 | types_or: [ python, pyi ] 17 | args: [--ignore-missing-imports, --explicit-package-bases] 18 | - repo: https://github.com/astral-sh/ruff-pre-commit 19 | rev: v0.11.12 20 | hooks: 21 | - id: ruff-check 22 | args: [--fix, --exit-non-zero-on-fix] 23 | - id: ruff-format 24 | - repo: https://github.com/astral-sh/uv-pre-commit 25 | rev: 0.7.10 26 | hooks: 27 | - id: uv-lock 28 | name: Lock project dependencies 29 | - repo: https://github.com/woodruffw/zizmor-pre-commit 30 | rev: v1.9.0 31 | hooks: 32 | - id: zizmor 33 | args: [--min-severity=medium] 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to nnbench 2 | 3 | Thank you for your interest in contributing to this project! 4 | 5 | We appreciate issue reports, pull requests for code and documentation, 6 | as well as any project-related communication through [GitHub Discussions](https://github.com/aai-institute/nnbench/discussions). 7 | 8 | ## Getting Started 9 | 10 | To get started with development, you can follow these steps: 11 | 12 | 1. Clone this repository: 13 | 14 | ```shell 15 | git clone https://github.com/aai-institute/nnbench.git 16 | ``` 17 | 18 | 2. Navigate to the directory and install the development dependencies into a virtual environment, e.g. using `uv`: 19 | 20 | ```shell 21 | cd nnbench 22 | uv venv --seed -p 3.11 23 | source .venv/bin/activate 24 | ``` 25 | 26 | 3. After making your changes, verify they adhere to our Python code style by running `pre-commit`: 27 | 28 | ```shell 29 | uvx pre-commit run --all-files --verbose --show-diff-on-failure 30 | ``` 31 | 32 | You can also set up Git hooks through `pre-commit` to perform these checks automatically: 33 | 34 | ```shell 35 | pre-commit install 36 | ``` 37 | 38 | 4. To run the tests, just invoke `pytest` from the package root directory: 39 | ```shell 40 | uv run pytest -s 41 | ``` 42 | 43 | ## Updating dependencies 44 | 45 | Dependencies should stay locked for as long as possible, ideally for a whole release. 46 | If you have to update a dependency during development, you should do the following: 47 | 48 | 1. If it is a core dependency needed for the package, add it to the `dependencies` section in the `pyproject.toml` via `uv add `. 49 | 2. In case of a development dependency, add it to the `dev` section of the `project.dependency-groups` table instead (`uv add --group dev `). 50 | 3. Dependencies needed for documentation generation are found in the `docs` sections of `project.dependency-groups` (`uv add --group docs `). 51 | 52 | After adding the dependency in either of these sections, use `uv lock` to pin all dependencies again: 53 | 54 | ```shell 55 | uv lock 56 | ``` 57 | 58 | > [!IMPORTANT] 59 | > Since the official development version is Python 3.11, please run the above commands in a virtual environment with Python 3.11. 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nnbench: A small framework for benchmarking machine learning models 2 | 3 | Welcome to nnbench, a framework for benchmarking machine learning models. 4 | The main goals of this project are 5 | 6 | 1. To provide a portable, easy-to-use solution for model evaluation that leads to better ML experiment organization, and 7 | 2. To integrate with experiment and metadata tracking solutions for easy adoption. 8 | 9 | On a high level, you can think of nnbench as "pytest for ML models" - you define benchmarks similarly to test cases, collect them, and selectively run them based on model type, markers, and environment info. 10 | 11 | What's new is that upon completion, you can stream the resulting data to any sink of your choice (including multiple at the same), which allows easy integration with experiment trackers and metadata stores. 12 | 13 | See the [quickstart](https://aai-institute.github.io/nnbench/latest/quickstart/) for a lightning-quick demo, or the [examples](https://aai-institute.github.io/nnbench/latest/tutorials/) for more advanced usages. 14 | 15 | ## Installation 16 | 17 | ⚠️ nnbench is an experimental project - expect bugs and sharp edges. 18 | 19 | Install it directly from source, for example either using `pip` or `uv`: 20 | 21 | ```shell 22 | pip install nnbench 23 | # or 24 | uv add nnbench 25 | ``` 26 | 27 | ## A ⚡️- quick demo 28 | 29 | To understand how nnbench works, you can run the following in your Python interpreter: 30 | 31 | ```python 32 | # example.py 33 | import nnbench 34 | from nnbench.reporter import ConsoleReporter 35 | 36 | 37 | @nnbench.benchmark 38 | def product(a: int, b: int) -> int: 39 | return a * b 40 | 41 | 42 | @nnbench.benchmark 43 | def power(a: int, b: int) -> int: 44 | return a ** b 45 | 46 | 47 | reporter = ConsoleReporter() 48 | # first, collect the above benchmarks directly from the current module... 49 | benchmarks = nnbench.collect("__main__") 50 | # ... then run the benchmarks with the parameters `a=2, b=10`... 51 | result = nnbench.run(benchmarks, params={"a": 2, "b": 10}) 52 | reporter.write(result) # ...and print the results to the terminal. 53 | 54 | # results in a table look like the following: 55 | # ┏━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ 56 | # ┃ Benchmark ┃ Value ┃ Wall time (ns) ┃ Parameters ┃ 57 | # ┡━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ 58 | # │ product │ 20 │ 1917 │ {'a': 2, 'b': 10} │ 59 | # │ power │ 1024 │ 583 │ {'a': 2, 'b': 10} │ 60 | # └───────────┴───────┴────────────────┴───────────────────┘ 61 | ``` 62 | Watch the following video for a high level overview of the capabilities and inner workings of nnbench. 63 | 64 | [![nnbench overview video thumbnail](https://img.youtube.com/vi/CT9bKq-U8ZQ/0.jpg)](https://www.youtube.com/watch?v=CT9bKq-U8ZQ) 65 | 66 | For a more realistic example of how to evaluate a trained model with a benchmark suite, check the [Quickstart](https://aai-institute.github.io/nnbench/latest/quickstart/). 67 | For even more advanced usages of the library, you can check out the [Examples](https://aai-institute.github.io/nnbench/latest/tutorials/) in the documentation. 68 | 69 | ## Contributing 70 | 71 | We encourage and welcome contributions from the community to enhance the project. 72 | Please check [discussions](https://github.com/aai-institute/nnbench/discussions) or raise an [issue](https://github.com/aai-institute/nnbench/issues) on GitHub for any problems you encounter with the library. 73 | 74 | For information on the general development workflow, see the [contribution guide](CONTRIBUTING.md). 75 | 76 | ## License 77 | 78 | The nnbench library is distributed under the [Apache-2 license](LICENSE). 79 | -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ../CONTRIBUTING.md -------------------------------------------------------------------------------- /docs/_images/aai-favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/nnbench/9a4e37cd7fa99178a2dab747bbc5b153ac1cd42a/docs/_images/aai-favicon.png -------------------------------------------------------------------------------- /docs/_images/aai-logo-cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/nnbench/9a4e37cd7fa99178a2dab747bbc5b153ac1cd42a/docs/_images/aai-logo-cropped.png -------------------------------------------------------------------------------- /docs/_scripts/gen_api_ref_pages.py: -------------------------------------------------------------------------------- 1 | """Automatically generate API reference pages from source files. 2 | 3 | Source: https://mkdocstrings.github.io/recipes/#automatic-code-reference-pages 4 | 5 | Note: this script assumes a source layout with a `src/` folder. 6 | """ 7 | 8 | import ast 9 | import logging 10 | from pathlib import Path 11 | 12 | import docstring_parser 13 | import mkdocs_gen_files 14 | 15 | nav = mkdocs_gen_files.Nav() 16 | 17 | for path in sorted(Path("src").rglob("*.py")): 18 | module_path = path.relative_to("src").with_suffix("") 19 | doc_path = path.relative_to("src").with_suffix(".md") 20 | full_doc_path = Path("reference", doc_path) 21 | 22 | parts = list(module_path.parts) 23 | 24 | if parts[-1] == "__init__": 25 | parts = parts[:-1] 26 | doc_path = doc_path.with_name("index.md") 27 | full_doc_path = full_doc_path.with_name("index.md") 28 | elif parts[-1] == "__main__": 29 | continue 30 | 31 | nav[parts] = doc_path.as_posix() 32 | 33 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 34 | identifier = ".".join(parts) 35 | print("::: " + identifier, file=fd) 36 | 37 | mkdocs_gen_files.set_edit_path(full_doc_path, path) 38 | 39 | # Add links for top-level modules to root page 40 | root_page = next(it for it in nav.items() if it.level == 0) 41 | children = [it for it in nav.items() if it.level == 1] 42 | 43 | with mkdocs_gen_files.open(f"reference/{root_page.filename}", "a") as f: 44 | f.write("## Modules\n") 45 | for ch in children: 46 | f.write(f"### [{ch.title}](../{ch.filename})\n") 47 | 48 | try: 49 | source_file = Path("src", ch.filename).with_suffix(".py") 50 | 51 | # Index page for submodules maps to __init__.py of the module 52 | if source_file.stem == "index": 53 | source_file = source_file.with_stem("__init__") 54 | 55 | tree = ast.parse(source_file.read_text()) 56 | docstring = ast.get_docstring(tree, clean=False) 57 | doc = docstring_parser.parse(docstring) 58 | 59 | if doc.short_description: 60 | f.write(f"{doc.short_description}\n\n") 61 | except Exception as e: 62 | logging.warning(f"Could not parse module docstring: {ch.filename}", exc_info=True) 63 | 64 | with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: 65 | nav_file.writelines(nav.build_literate_nav()) 66 | -------------------------------------------------------------------------------- /docs/_styles/extra.css: -------------------------------------------------------------------------------- 1 | /*Remove leading `!` from code cells, which use the cell magic for shell commands*/ 2 | code>.err:first-of-type { 3 | display: none; 4 | } 5 | 6 | /* Hide empty root heading (-> `#`) */ 7 | h1#_1 { 8 | display: none; 9 | } 10 | 11 | .logo { 12 | /* No wider than the screen on smaller screens */ 13 | max-width: min(100%, 100vw) !important; 14 | max-height: 196px; 15 | 16 | /* Horizontally centered */ 17 | display: block; 18 | margin: 0 auto; 19 | } 20 | 21 | table { 22 | display: block; 23 | max-width: -moz-fit-content; 24 | max-width: fit-content; 25 | margin: 0 auto; 26 | overflow-x: auto; 27 | white-space: nowrap; 28 | } 29 | 30 | .nt-card .landing-page-icon { 31 | color: #b7c6cf; 32 | } 33 | 34 | .nt-card:hover .landing-page-icon { 35 | color: var(--md-accent-fg-color); 36 | } 37 | 38 | .celltag_Remove_all_output .output_wrapper, 39 | .celltag_Remove_single_output .output_wrapper { 40 | display: none !important 41 | } 42 | 43 | 44 | .celltag_Remove_input .input_wrapper { 45 | display: none !important 46 | } 47 | -------------------------------------------------------------------------------- /docs/_styles/theme.css: -------------------------------------------------------------------------------- 1 | [data-md-color-scheme="aai-light"] { 2 | /* Primary color shades */ 3 | --md-primary-fg-color: #084059; 4 | --md-primary-fg-color--light: #46add5; 5 | --md-primary-fg-color--dark: #04212f; 6 | 7 | --md-typeset-a-color: var(--md-primary-fg-color); 8 | 9 | /* Accent color shades */ 10 | --md-accent-fg-color: var(--md-primary-fg-color--light); 11 | 12 | --md-code-font-family: "Source Code Pro", monospace 13 | } 14 | 15 | [data-md-color-scheme="slate"] { 16 | --md-primary-fg-color: #46add5; 17 | --md-primary-fg-color--dark: #04212f; 18 | --md-accent-fg-color: hsl(197, 63%, 75%); 19 | 20 | --md-typeset-a-color: var(--md-primary-fg-color) !important; 21 | 22 | --md-accent-fg-color: var(--md-primary-fg-color); 23 | 24 | --md-code-font-family: "Source Code Pro", monospace 25 | } 26 | 27 | [data-md-color-scheme="aai-light"] img[src$="#only-dark"], 28 | [data-md-color-scheme="aai-light"] img[src$="#gh-dark-mode-only"] { 29 | display: none; 30 | /* Hide dark images in light mode */ 31 | } 32 | 33 | .md-header { 34 | background: rgb(24, 165, 167); 35 | background: linear-gradient(30deg, rgb(24, 165, 167), 90%, rgb(191, 255, 199)); 36 | } 37 | -------------------------------------------------------------------------------- /docs/_theme_overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block announce %} 3 | {%- if config.extra.pre_release -%} 4 | You are currently viewing pre-release documentation. This may contain features that have not yet been 5 | released. For the most accurate information, please access the latest 6 | release. 7 | {%- endif -%} 8 | {% endblock %} 9 | -------------------------------------------------------------------------------- /docs/cli/cli.md: -------------------------------------------------------------------------------- 1 | # `nnbench` command reference 2 | 3 | While you can always use nnbench to directly run your benchmarks in your Python code, for example as part of a workflow, there is also the option of running benchmarks from the command line. 4 | This way of using nnbench is especially useful for integrating into a machine learning pipeline as part of a continuous training/delivery scenario. 5 | 6 | ## General options 7 | 8 | The `nnbench` CLI has the following top-level options: 9 | 10 | ```commandline 11 | $ nnbench 12 | usage: nnbench [-h] [--version] [--log-level ] ... 13 | 14 | options: 15 | -h, --help show this help message and exit 16 | --version show program's version number and exit 17 | --log-level Log level to use for the nnbench package, defaults to NOTSET (no logging). 18 | 19 | Available commands: 20 | 21 | run Run a benchmark workload. 22 | compare Compare results from multiple benchmark runs. 23 | ``` 24 | 25 | Supported log levels are `"DEBUG", "INFO", "WARNING", "ERROR"`, and `"CRITICAL"`. 26 | 27 | ## Running benchmark workloads on the command line 28 | 29 | This is the responsibility of the `nnbench run` subcommand. 30 | 31 | ```commandline 32 | $ nnbench run -h 33 | usage: nnbench run [-h] [-n ] [-j ] [--context ] [-t ] [-o ] [--jsonifier ] [] 34 | 35 | positional arguments: 36 | A Python file or directory of files containing benchmarks to run. 37 | 38 | options: 39 | -h, --help show this help message and exit 40 | -n, --name A name to assign to the benchmark run, for example for record keeping in a database. 41 | -j Number of processes to use for running benchmarks in parallel, default: -1 (no parallelism) 42 | --context 43 | Additional context values giving information about the benchmark run. 44 | -t, --tag Only run benchmarks marked with one or more given tag(s). 45 | -o, --output-file 46 | File or stream to write results to, defaults to stdout. 47 | --jsonifier 48 | Function to create a JSON representation of input parameters with, helping make runs reproducible. 49 | ``` 50 | 51 | To run a benchmark workload contained in a single `benchmarks.py` file, you would run `nnbench run benchmarks.py`. 52 | For tips on how to structure and annotate your benchmarks, refer to the [organization](../guides/organization.md) guide. 53 | 54 | For injecting context values on the command line, you need to give the key-value pair explicitly by passing the `--context` switch. 55 | For example, to look up and persist the `pyyaml` version in the current environment, you could run the following: 56 | 57 | ```commandline 58 | nnbench run .sandbox/example.py --context=pyyaml=`python3 -c "from importlib.metadata import version; print(version('pyyaml'))"` 59 | ``` 60 | 61 | !!! tip 62 | Both `--context` and `--tag` are appending options, so you can pass multiple context values and multiple tags per run. 63 | 64 | !!! tip 65 | For more complex calculations of context values, it is recommended to register a *custom context provider* in your pyproject.toml file. 66 | An introductory example can be found in the [nnbench CLI configuration guide](pyproject.md). 67 | 68 | ### Streaming results to different locations with URIs 69 | 70 | Like in the nnbench Python SDK, you can use builtin and custom reporters to write benchmark results to various locations. 71 | To select a reporter implementation, specify the output file as a URI, with a leading protocol, and suffixed with `://`. 72 | 73 | !!! Example 74 | The builtin file reporter supports multiple file system protocols of the `fsspec` project. 75 | If you have `fsspec` installed, you can stream benchmark results to different cloud storage providers like so: 76 | 77 | ```commandline 78 | # Write a result to disk... 79 | nnbench run benchmarks.py -o result.json 80 | # ...or to an S3 storage bucket... 81 | nnbench run benchmarks.py -o s3://my-bucket/result.json 82 | # ...or to GCS... 83 | nnbench run benchmarks.py -o gs://my-bucket/result.json 84 | # ...or lakeFS: 85 | nnbench run benchmarks.py -o lakefs://my-repo/my-branch/result.json 86 | ``` 87 | 88 | For a comprehensive list of supported protocols, see the [fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#other-known-implementations). 89 | 90 | If `fsspec` is not installed, only local files can be written (i.e. to the executing machine's filesystem). 91 | 92 | ## Comparing results across multiple benchmark runs 93 | 94 | To create a comparison table between multiple benchmark runs, use the `nnbench compare` command. 95 | 96 | ```commandline 97 | $ nnbench compare -h 98 | usage: nnbench compare [-h] [--comparison-file ] [-C ] [-E ] results [results ...] 99 | 100 | positional arguments: 101 | results Results to compare. Can be given as local files or remote URIs. 102 | 103 | options: 104 | -h, --help show this help message and exit 105 | --comparison-file 106 | A file containing comparison functions to run on benchmarking metrics. 107 | -C, --include-context 108 | Context values to display in the comparison table. Use dotted syntax for nested context values. 109 | -E, --extra-column 110 | Additional result data to display in the comparison table. 111 | ``` 112 | 113 | Supposing we have the following results from previous runs, for a benchmark `add(a,b)` that adds two integers: 114 | 115 | ```json 116 | // Pretty-printed JSON, obtained as jq . result1.json 117 | { 118 | "run": "nnbench-3ff188b4", 119 | "context": { 120 | "foo": "bar" 121 | }, 122 | "benchmarks": [ 123 | { 124 | "name": "add", 125 | "function": "add", 126 | "description": "", 127 | "timestamp": 1733157676, 128 | "error_occurred": false, 129 | "error_message": "", 130 | "parameters": { 131 | "a": 200, 132 | "b": 100 133 | }, 134 | "value": 300, 135 | "time_ns": 1291 136 | } 137 | ] 138 | } 139 | ``` 140 | 141 | and 142 | 143 | ```json 144 | // jq . result2.json 145 | { 146 | "run": "nnbench-5cbb85f8", 147 | "context": { 148 | "foo": "baz" 149 | }, 150 | "benchmarks": [ 151 | { 152 | "name": "add", 153 | "function": "add", 154 | "description": "", 155 | "timestamp": 1733157724, 156 | "error_occurred": false, 157 | "error_message": "", 158 | "parameters": { 159 | "a": 200, 160 | "b": 100 161 | }, 162 | "value": 300, 163 | "time_ns": 1792 164 | } 165 | ] 166 | }, 167 | ``` 168 | 169 | we can compare them in a table view by running `nnnbench compare result1.json result2.json`: 170 | 171 | ```commandline 172 | $ nnbench compare result1.json result2.json 173 | ┏━━━━━━━━━━━━━━━━━━┳━━━━━┓ 174 | ┃ Benchmark run ┃ add ┃ 175 | ┡━━━━━━━━━━━━━━━━━━╇━━━━━┩ 176 | │ nnbench-3ff188b4 │ 300 │ 177 | │ nnbench-5cbb85f8 │ 300 │ 178 | └──────────────────┴─────┘ 179 | ``` 180 | 181 | To include context values in the table - in our case, we might want to display the `foo` value - use the `-C` switch (you can use it multiple times to include multiple values): 182 | 183 | ```commandline 184 | $ nnbench compare result1.json result2.json -C foo 185 | ┏━━━━━━━━━━━━━━━━━━┳━━━━━┳━━━━━┓ 186 | ┃ Benchmark run ┃ add ┃ foo ┃ 187 | ┡━━━━━━━━━━━━━━━━━━╇━━━━━╇━━━━━┩ 188 | │ nnbench-3ff188b4 │ 300 │ bar │ 189 | │ nnbench-5cbb85f8 │ 300 │ baz │ 190 | └──────────────────┴─────┴─────┘ 191 | ``` 192 | 193 | To learn how to define per-metric comparisons and use comparisons in a continuous training pipeline, refer to the [comparison documentation](comparisons.md). 194 | -------------------------------------------------------------------------------- /docs/cli/comparisons.md: -------------------------------------------------------------------------------- 1 | # Comparing different benchmark runs with `nnbench compare` 2 | 3 | To compare benchmark results across different runs, you can use the `nnbench compare` subcommand of the nnbench CLI. 4 | 5 | ```commandline 6 | $ nnbench compare -h 7 | usage: nnbench compare [-h] 8 | [--comparison-file ] 9 | [-C ] [-E ] 10 | results [results ...] 11 | 12 | positional arguments: 13 | results Results to compare. Can be given as local files or remote URIs. 14 | 15 | options: 16 | -h, --help show this help message and exit 17 | --comparison-file 18 | A file containing comparison functions to run on benchmarking metrics. 19 | -C, --include-context 20 | Context values to display in the comparison table. 21 | Use dotted syntax for nested context values. 22 | -E, --extra-column 23 | Additional result data to display in the comparison table. 24 | ``` 25 | 26 | ## A quick example 27 | 28 | Suppose you run the following set of benchmarks: 29 | 30 | ```python 31 | import nnbench 32 | 33 | 34 | @nnbench.benchmark 35 | def add(a: int, b: int) -> int: 36 | return a + b 37 | 38 | 39 | @nnbench.benchmark 40 | def sub(a: int, b: int) -> int: 41 | return a + b 42 | ``` 43 | 44 | two times, with the parameters `a = 1, b = 2` and `a = 1, b = 3`, respectively. 45 | We will now set up a comparison between the two runs by crafting a `comparisons.json` file that defines the comparisons we will apply on the respective metrics: 46 | 47 | ```json 48 | # file: comparisons.json 49 | { 50 | "add": { 51 | "class": "nnbench.compare.GreaterEqual" 52 | }, 53 | "sub": { 54 | "class": "nnbench.compare.AbsDiffLessEqual", 55 | "kwargs": { 56 | "thresh": 0.5 57 | } 58 | } 59 | } 60 | ``` 61 | 62 | The structure itself is simple: The contents map the metric names (`"add"` and `"sub"` in this case) to their respective comparison classes. 63 | 64 | Each comparison is encoded by the class name, given as a fully qualified Python module path that can be imported via `importlib`, and a `kwargs` dictionary, which will be passed to the chosen class on construction (i.e. to its `__init__()`) method. 65 | 66 | !!! Note 67 | You are responsible for only passing keyword arguments that are expected by the chosen comparison class. 68 | 69 | 70 | Now, running `nnbench compare --comparison-file=comparisons.json` will produce a table like the following (with an sqlite database holding the two results as an example): 71 | 72 | ```commandline 73 | $ nnbench compare sqlite://hello.db --comparison-file=comp.json 74 | Comparison strategy: All vs. first 75 | Comparisons: 76 | add: x ≥ y 77 | sub: |x - y| <= 0.50 78 | ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━┓ 79 | ┃ Run Name ┃ add ┃ sub ┃ Status ┃ 80 | ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━┩ 81 | │ nnbench-1749042461424137000 │ 3.00 │ -1.0 │ │ 82 | │ nnbench-1749042477901108000 │ 4.00 (vs. 3.00) │ -2.0 (vs. -1.0) │ ✅❌ │ 83 | └─────────────────────────────┴─────────────────┴─────────────────┴────────┘ 84 | ``` 85 | 86 | The comparison strategy is printed first, listed as "All vs. first" (the only strategy currently supported), meaning that all results are compared against the result of the first given run. 87 | After that, the comparisons are listed as concise human-readable mathematical expressions. 88 | 89 | In this case, we want to ensure that the candidate results have a greater value for "add" (indicated by the "x ≥ y" formula, where x is the candidate value and y is the value in the first record), and a value for sub that is within `0.5` compared to the first result. 90 | 91 | As the value for `add` is greater in the second run, this comparison succeeded, as indicated by the green checkmark, but the `sub` comparison failed, since both results differ by `1.0`, which is more than the allowed `0.5`. 92 | 93 | ## Extending comparisons 94 | 95 | To craft your own comparison, subclass the `nnbench.compare.AbstractComparison` class. 96 | You can also define your own comparison functions, which need to implement the following protocol: 97 | 98 | ```python 99 | class Comparator(Protocol): 100 | def __call__(self, val1: Any, val2: Any) -> bool: ... 101 | ``` 102 | 103 | !!! Note 104 | One-versus-all metric comparisons and running user-defined comparisons are planned for a future release of `nnbench`. 105 | -------------------------------------------------------------------------------- /docs/cli/fixtures.md: -------------------------------------------------------------------------------- 1 | # Using fixtures to supply parameters to benchmarks 2 | 3 | One of the main problems when running `nnbench` on the command line is how to supply parameters. 4 | Default values for benchmarks are one solution, but that does not scale well, and requires frequent code changes when values change. 5 | 6 | Instead, nnbench borrows a bit of pytest's [fixture concept](https://docs.pytest.org/en/stable/how-to/fixtures.html) to source parameters from special marker files, named `conf.py` in reference to pytest's `conftest.py`. 7 | 8 | ## How to define fixture values for benchmarks 9 | 10 | Suppose you have a benchmark defined in a single file, `metrics.py`: 11 | 12 | ```python 13 | # metrics.py 14 | import nnbench 15 | 16 | 17 | @nnbench.benchmark 18 | def accuracy(model, data): 19 | ... 20 | ``` 21 | 22 | To supply `model` and `data` to the benchmark, define both values as return values of similarly named functions in a `conf.py` file in the same directory. 23 | The layout of your benchmark directory should look like this: 24 | 25 | ```commandline 26 | 📂 benchmarks 27 | ┣━━ conf.py 28 | ┣━━ metrics.py 29 | ┣━━ ... 30 | ``` 31 | 32 | Inside your `conf.py` file, you might define your values as shown below. Note that currently, all fixtures must be raw Python callables, and their names must match input values of benchmarks exactly. 33 | 34 | ```python 35 | # benchmarks/conf.py 36 | def model(): 37 | return MyModel() 38 | 39 | 40 | def data(): 41 | return TestDataset.load("path/to/my/dataset") 42 | ``` 43 | 44 | Then, nnbench will discover and auto-use these values when running this benchmark from the command line: 45 | 46 | ```commandline 47 | $ nnbench run benchmarks.py 48 | ``` 49 | 50 | !!! Warning 51 | Benchmarks with default values for their arguments will unconditionally use those defaults over potential fixtures. 52 | That is, for a benchmark `def add(a: int, b: int = 1)`, only the named parameter `a` will be resolved. 53 | 54 | ## Fixtures with inputs 55 | 56 | Like in pytest, fixtures can consume inputs. However, in nnbench, fixtures can consume other inputs by name only within the same module scope, i.e. members within the same `conf.py`. 57 | 58 | ```python 59 | # conf.py 60 | 61 | # Revisiting the above example, we could also write the following: 62 | def path() -> str: 63 | return "path/to/my/dataset" 64 | 65 | 66 | def data(path): 67 | return TestDataset.load(path) 68 | 69 | # ... but not this, since `config` is not a member of the conf.py module: 70 | def model(config): 71 | return MyModel.load(config) 72 | ``` 73 | 74 | !!! Warning 75 | nnbench fixtures cannot have cycles in them - two fixtures may never depend on each other. 76 | 77 | ## Hierarchical `conf.py` files 78 | 79 | nnbench also supports sourcing fixtures from different levels in a directory hierarchy. 80 | Suppose we have a benchmark directory layout like this: 81 | 82 | ```commandline 83 | 📂 benchmarks 84 | ┣━━ 📂 nested 85 | ┃ ┣━━ conf.py 86 | ┃ ┗━━ model.py 87 | ┣━━ base.py 88 | ┗━━ conf.py 89 | ``` 90 | 91 | Let's assume that the benchmarks in `nested/model.py` consume some fixture values specific to them, and reuse some top-level fixtures as well. 92 | 93 | ```python 94 | # benchmarks/conf.py 95 | 96 | def path() -> str: 97 | return "path/to/my/dataset" 98 | 99 | 100 | def data(path: str): 101 | """Test dataset, to be reused by all benchmarks.""" 102 | return TestDataset.load(path) 103 | 104 | # ------------------------- 105 | # benchmarks/nested/conf.py 106 | # ------------------------- 107 | 108 | def model(): 109 | """Model, needed only by the nested benchmarks.""" 110 | return MyModel.load() 111 | ``` 112 | 113 | If we have a benchmark in `benchmarks/nested/model.py` defined like this: 114 | 115 | ```python 116 | # benchmarks/nested/model.py 117 | 118 | def accuracy(model, data): 119 | ... 120 | ``` 121 | 122 | Now nnbench will source the `model` fixture from `benchmarks/nested/conf.py` and fall back to the top-level `benchmarks/conf.py` to obtain `data`. 123 | 124 | !!! Info 125 | Just like pytest, nnbench collects fixture values bottom-up, starting with the benchmark file's parent directory. 126 | 127 | For example, if the `benchmarks/nested/conf.py` above also defined a `data` fixture, the `accuracy` benchmark would use that instead. 128 | -------------------------------------------------------------------------------- /docs/cli/index.md: -------------------------------------------------------------------------------- 1 | # Command-line interface (CLI) 2 | 3 | This subdirectory contains resources on how to run ML benchmarks in command line mode with the `nnbench` CLI. 4 | -------------------------------------------------------------------------------- /docs/cli/pyproject.md: -------------------------------------------------------------------------------- 1 | # Configuring the CLI experience in `pyproject.toml` 2 | 3 | To create your custom CLI profile for nnbench, you can set certain options directly in your `pyproject.toml` file. 4 | Like other tools, nnbench will look for a `[tool.nnbench]` table inside the pyproject.toml file, and if found, use it to set certain values. 5 | 6 | Currently, you can set the log level, and register custom context provider classes. 7 | 8 | ### General options 9 | 10 | ```toml 11 | [tool.nnbench] 12 | # This sets the `nnbench` logger's level to "DEBUG", enabling debug log collections. 13 | log-level = "DEBUG" 14 | ``` 15 | 16 | ### Registering custom context providers 17 | 18 | As a quick refresher, in nnbench, a *context provider* is a function taking no arguments, and returning a Python dictionary with string keys: 19 | 20 | ```python 21 | import os 22 | 23 | def foo() -> dict[str, str]: 24 | """Returns a context value named 'foo', containing the value of the FOO environment variable.""" 25 | return {"foo": os.getenv("FOO", "")} 26 | ``` 27 | 28 | If you would like to use a custom context provider to collect metadata before a CLI benchmark run, you can give its details in a `[tool.nnbench.context]` table. 29 | 30 | ```toml 31 | [tool.nnbench.context.myctx] 32 | name = "myctx" 33 | classpath = "nnbench.context.PythonInfo" 34 | arguments = { packages = ["rich", "pyyaml"] } 35 | ``` 36 | 37 | In this case, we are augmenting `nnbench.context.PythonInfo`, a builtin provider class, to also collect the versions of the `rich` and `pyyaml` packages from the current environment, and registering it under the name "myctx". 38 | 39 | The `name` field is used to register the context provider. 40 | The `classpath` field needs to be a fully qualified Python module path to the context provider class or function. 41 | Any arguments needed to instantiate a context provider class can be given under the `arguments` key in an inline table, which will be passed to the class found under `classpath` as keyword arguments. 42 | 43 | !!! Warning 44 | If you register a context provider *function*, you **must** leave the `arguments` key out of the above TOML table, since by definition, context providers do not take any arguments in their `__call__()` signature. 45 | 46 | Now we can use said provider in a benchmark run by passing just the provider name: 47 | 48 | ```commandline 49 | $ nnbench run benchmarks.py --context=myctx 50 | Context values: 51 | { 52 | "python": { 53 | "version": "3.11.10", 54 | "implementation": "CPython", 55 | "buildno": "main", 56 | "buildtime": "Sep 7 2024 01:03:31", 57 | "packages": { 58 | "rich": "13.9.3", 59 | "pyyaml": "6.0.2" 60 | } 61 | } 62 | } 63 | 64 | 65 | ``` 66 | 67 | !!! Tip 68 | This feature is a work in progress. The ability to register custom IO and comparison classes will be implemented in future releases of nnbench. 69 | -------------------------------------------------------------------------------- /docs/guides/customization.md: -------------------------------------------------------------------------------- 1 | # Defining setup/teardown tasks, context, and `nnbench.Parameters` 2 | 3 | This page introduces some customization options for benchmark runs. 4 | These options can be helpful for tasks surrounding benchmark state management, such as automatic setup and cleanup, contextualizing results with context values, and defining typed parameters with the `nnbench.Parameters` class. 5 | 6 | ## Defining setup and teardown tasks 7 | 8 | For some benchmarks, it is important to set certain configuration values and prepare the execution environment before running. 9 | To do this, you can pass a setup task to all of the nnbench decorators via the `setUp` keyword: 10 | 11 | ```python 12 | import os 13 | 14 | import nnbench 15 | 16 | 17 | def set_envvar(**params): 18 | os.environ["MY_ENV"] = "MY_VALUE" 19 | 20 | 21 | @nnbench.benchmark(setUp=set_envvar) 22 | def prod(a: int, b: int) -> int: 23 | return a * b 24 | ``` 25 | 26 | Similarly, to revert the environment state back to its previous form (or clean up any created resources), you can supply a finalization task with the `tearDown` keyword: 27 | 28 | ```python 29 | import os 30 | 31 | import nnbench 32 | 33 | 34 | def set_envvar(**params): 35 | os.environ["MY_ENV"] = "MY_VALUE" 36 | 37 | 38 | def pop_envvar(**params): 39 | os.environ.pop("MY_ENV") 40 | 41 | 42 | @nnbench.benchmark(setUp=set_envvar, tearDown=pop_envvar) 43 | def prod(a: int, b: int) -> int: 44 | return a * b 45 | ``` 46 | 47 | Both the setup and teardown task must take the exact same set of parameters as the benchmark function. To simplify function declaration, it is easiest to use a variadic keyword-only interface, i.e. `setup(**kwargs)`, as shown. 48 | 49 | !!! tip 50 | This facility works exactly the same for the `@nnbench.parametrize` and `@nnbench.product` decorators. 51 | There, the specified setup and teardown tasks are run once before or after each of the resulting benchmarks respectively. 52 | 53 | ## Enriching benchmark metadata with context values 54 | 55 | It is often useful to log specific environment metadata in addition to the benchmark's target metrics. 56 | Such metadata can give a clearer picture of how certain models perform on a given hardware, how model architectures compare in performance, and much more. 57 | In `nnbench`, you can give additional metadata to your benchmarks as **context values**. 58 | 59 | A _context value_ is defined here as a key-value pair where `key` is a string, and `value` is any valid JSON value holding the desired information. 60 | As an example, the context value `{"cpuarch": "arm64"}` gives information about the CPU architecture of the host machine running the benchmark. 61 | 62 | A _context provider_ is a function taking no arguments and returning a Python dictionary of context values. The following is a basic example of a context provider: 63 | 64 | ```python 65 | import platform 66 | 67 | def platinfo() -> dict[str, str]: 68 | """Returns CPU arch, system name (Windows/Linux/Darwin), and Python version.""" 69 | return { 70 | "system": platform.system(), 71 | "cpuarch": platform.machine(), 72 | "python_version": platform.python_version(), 73 | } 74 | ``` 75 | 76 | To supply context to your benchmarks, you can give a sequence of context providers to `nnbench.run()`: 77 | 78 | ```python 79 | import nnbench 80 | 81 | # uses the `platinfo` context provider from above to log platform metadata. 82 | benchmarks = nnbench.collect(__name__) 83 | result = nnbench.run(benchmarks, params={}, context=[platinfo]) 84 | ``` 85 | 86 | ## Being type safe by using `nnbench.Parameters` 87 | 88 | Instead of specifying your benchmark's parameters by using a raw Python dictionary, you can define a custom subclass of `nnbench.Parameters`: 89 | 90 | ```python 91 | import nnbench 92 | from dataclasses import dataclass 93 | 94 | 95 | @dataclass(frozen=True) 96 | class MyParams(nnbench.Parameters): 97 | a: int 98 | b: int 99 | 100 | 101 | @nnbench.benchmark 102 | def prod(a: int, b: int) -> int: 103 | return a * b 104 | 105 | 106 | params = MyParams(a=1, b=2) 107 | benchmarks = nnbench.collect(__name__) 108 | result = nnbench.run(benchmarks, params=params) 109 | ``` 110 | 111 | While this does not have a concrete advantage in terms of type safety over a raw dictionary, it guards against accidental modification of parameters breaking reproducibility. 112 | -------------------------------------------------------------------------------- /docs/guides/index.md: -------------------------------------------------------------------------------- 1 | # User Guide 2 | 3 | The nnbench user guide provides documentation for users of the library looking to solve specific tasks. 4 | See the [Quickstart guide](../quickstart.md) for an introductory tutorial. 5 | 6 | The following guides are available, covering the core concepts of nnbench: 7 | 8 | * [Defining benchmarks and benchmark families with decorators](benchmarks.md) 9 | * [Adding context, customizing benchmarks, and supplying parameters](customization.md) 10 | * [Organizing benchmark code efficiently](organization.md) 11 | * [Collecting and running benchmarks by hand](runners.md) 12 | -------------------------------------------------------------------------------- /docs/guides/organization.md: -------------------------------------------------------------------------------- 1 | # How to efficiently organize benchmark code 2 | 3 | To efficiently organize benchmarks and keeping your setup modular, you can follow a few guidelines. 4 | 5 | ## Tip 1: Separate benchmarks from project code 6 | 7 | This tip is well known from other software development practices such as unit testing. 8 | To improve project organization, consider splitting off your benchmarks into their own modules or even directories, if you have multiple benchmark workloads. 9 | 10 | An example project layout can look like this, with benchmarks as a separate directory at the top-level: 11 | 12 | ``` 13 | my-project/ 14 | ├── benchmarks/ # <- contains all benchmarking Python files. 15 | ├── docs/ 16 | ├── src/ 17 | ├── .pre-commit-config.yaml 18 | ├── pyproject.toml 19 | ├── README.md 20 | └── ... 21 | ``` 22 | 23 | This keeps the benchmarks neatly grouped together while siloing them away from the actual project code. 24 | Since you will most likely not run your benchmarks in a production setting, this is also advantageous for packaging, as the `benchmarks/` directory does not ship by default in this configuration. 25 | 26 | ## Tip 2: Group benchmarks by common attributes 27 | 28 | To maintain good organization within your benchmark directory, you can group similar benchmarks into their own Python files. 29 | As an example, if you have a set of benchmarks to establish data quality, and benchmarks for scoring trained models on curated data, you could structure them as follows: 30 | 31 | ``` 32 | benchmarks/ 33 | ├── data_quality.py 34 | ├── model_perf.py 35 | └── ... 36 | ``` 37 | 38 | This is helpful when running multiple benchmark workloads separately, as you can just point your benchmark runner to each of these separate files: 39 | 40 | ```python 41 | import nnbench 42 | 43 | data_quality_benchmarks = nnbench.collect("benchmarks/data_quality.py") 44 | data_metrics = nnbench.run(data_quality_benchmarks, params=...) 45 | # same for model metrics, where instead you pass benchmarks/model_perf.py. 46 | model_perf_benchmarks = nnbench.collect("benchmarks/model_perf.py") 47 | model_metrics = nnbench.run(model_perf_benchmarks, params=...) 48 | ``` 49 | 50 | ## Tip 3: Attach tags to benchmarks for selective filtering 51 | 52 | For structuring benchmarks within files, you can also use **tags**, which are tuples of strings attached to a benchmark: 53 | 54 | ```python 55 | # benchmarks/data_quality.py 56 | import nnbench 57 | 58 | 59 | @nnbench.benchmark(tags=("foo",)) 60 | def foo1(data) -> float: 61 | ... 62 | 63 | 64 | @nnbench.benchmark(tags=("foo",)) 65 | def foo2(data) -> int: 66 | ... 67 | 68 | 69 | @nnbench.benchmark(tags=("bar",)) 70 | def bar(data) -> int: 71 | ... 72 | ``` 73 | 74 | Now, to only run data quality benchmarks marked "foo", pass the corresponding tag to `nnbench.run()`: 75 | 76 | ```python 77 | import nnbench 78 | 79 | benchmarks = nnbench.collect("benchmarks/data_quality.py", tags=("foo",)) 80 | foo_data_metrics = nnbench.run(benchmarks, params=..., ) 81 | ``` 82 | 83 | !!!tip 84 | This concept works exactly the same when creating benchmarks with the `@nnbench.parametrize` and `@nnbench.product` decorators. 85 | -------------------------------------------------------------------------------- /docs/guides/runners.md: -------------------------------------------------------------------------------- 1 | # Collecting and running benchmarks 2 | 3 | nnbench provides the `nnbench.collect` and `nnbench.run` APIs as a compact interface to collect and run benchmarks selectively. 4 | 5 | Use the `nnbench.collect()` method to collect benchmarks from files or directories. 6 | Assume we have the following benchmark setup: 7 | ```python 8 | # dir_a/bm1.py 9 | import nnbench 10 | 11 | @nnbench.benchmark 12 | def dummy_benchmark(a: int) -> int: 13 | return a 14 | ``` 15 | 16 | ```python 17 | # dir_b/bm2.py 18 | import nnbench 19 | 20 | @nnbench.benchmark(tags=("tag",)) 21 | def another_benchmark(b: int) -> int: 22 | return b 23 | 24 | @nnbench.benchmark 25 | def yet_another_benchmark(c: int) -> int: 26 | return c 27 | ``` 28 | 29 | ```python 30 | # dir_b/bm3.py 31 | import nnbench 32 | @nnbench.benchmark(tags=("tag",)) 33 | def the_last_benchmark(d: int) -> int: 34 | return d 35 | ``` 36 | 37 | Now we can collect benchmarks from files: 38 | 39 | ```python 40 | import nnbench 41 | 42 | 43 | benchmarks = nnbench.collect('dir_a/bm1.py') 44 | ``` 45 | Or directories: 46 | 47 | ```python 48 | benchmarks = nnbench.collect('dir_b') 49 | ``` 50 | 51 | You can also supply tags to the runner to selectively collect only benchmarks with the appropriate tag. 52 | For example, after clearing the runner again, you can collect all benchmarks with the `"tag"` tag as such: 53 | 54 | ```python 55 | import nnbench 56 | 57 | 58 | tagged_benchmarks = nnbench.collect('dir_b', tags=("tag",)) 59 | ``` 60 | 61 | To run the benchmarks, call the `nnbench.run()` method and supply the necessary parameters required by the collected benchmarks. 62 | 63 | ```python 64 | result = nnbench.run(benchmarks, params={"b": 1, "c": 2, "d": 3}) 65 | ``` 66 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # 2 | 3 | Welcome to nnbench, a framework for reproducibly benchmarking machine learning models. 4 | The main goals of this project are portable and customizable benchmarking for ML models, and easy integration into existing ML pipelines. 5 | 6 | Highlights: 7 | 8 | - Easy definition, bookkeeping and organization of machine learning benchmarks, 9 | - Enriching benchmark results with context to properly track and annotate results, 10 | - Streaming results to a variety of data sinks. 11 | 12 | ::cards:: cols=3 13 | 14 | - title: Quickstart 15 | content: Step-by-step installation and first operations 16 | icon: ":octicons-flame-24:{ .landing-page-icon }" 17 | url: quickstart.md 18 | 19 | - title: Examples 20 | content: Examples on how to use nnbench 21 | icon: ":octicons-repo-clone-24:{ .landing-page-icon }" 22 | url: tutorials/index.md 23 | 24 | - title: API Reference 25 | content: Full documentation of the Python API 26 | icon: ":octicons-file-code-24:{ .landing-page-icon }" 27 | url: reference/nnbench/index.md 28 | 29 | - title: User Guide 30 | content: Solving specific tasks with nnbench 31 | icon: ":octicons-tasklist-24:{ .landing-page-icon }" 32 | url: guides/index.md 33 | 34 | - title: Contributing 35 | content: How to contribute to the project 36 | icon: ":octicons-code-of-conduct-24:{ .landing-page-icon }" 37 | url: CONTRIBUTING.md 38 | 39 | ::/cards:: 40 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | Welcome! This quickstart guide will convey the basics needed to use nnbench. 4 | You will define a benchmark, initialize a runner and reporter, and execute the benchmark, obtaining the results in the console in tabular format. 5 | 6 | ## A short scikit-learn model benchmark 7 | 8 | In the following simple example, we put the training and benchmarking logic in the same file. For more complex workloads, we recommend structuring your code into multiple files to improve project organization, similarly to unit tests. 9 | Check the [user guide on structuring your benchmarks](guides/organization.md) for inspiration. 10 | 11 | ```python 12 | from sklearn.datasets import load_iris 13 | from sklearn.ensemble import RandomForestClassifier 14 | from sklearn.model_selection import train_test_split 15 | 16 | data = load_iris() 17 | X, y = data.data, data.target 18 | 19 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) 20 | 21 | model = RandomForestClassifier() 22 | model.fit(X_train, y_train) 23 | ``` 24 | 25 | To benchmark your model, you encapsulate the benchmark code into a function and apply the `@nnbench.benchmark` decorator. 26 | This marks the function for collection to our benchmark runner later. 27 | 28 | ```python 29 | import nnbench 30 | import numpy as np 31 | from sklearn import base, metrics 32 | 33 | 34 | @nnbench.benchmark 35 | def accuracy(model: base.BaseEstimator, X_test: np.ndarray, y_test: np.ndarray) -> float: 36 | y_pred = model.predict(X_test) 37 | accuracy = metrics.accuracy_score(y_test, y_pred) 38 | return accuracy 39 | ``` 40 | 41 | Now we can instantiate a benchmark runner to collect and run the accuracy benchmark. 42 | Then, using the `ConsoleReporter` we report the resulting accuracy metric by printing it to the terminal in a table. 43 | 44 | ```python 45 | import nnbench 46 | from nnbench.reporter import ConsoleReporter 47 | 48 | benchmarks = nnbench.collect("__main__") 49 | reporter = ConsoleReporter() 50 | # To collect in the current file, pass "__main__" as module name. 51 | result = nnbench.run(benchmarks, params={"model": model, "X_test": X_test, "y_test": y_test}) 52 | reporter.write(result) 53 | ``` 54 | 55 | The resulting output might look like this: 56 | 57 | ```bash 58 | python benchmarks.py 59 | 60 | name value 61 | -------- -------- 62 | accuracy 0.933333 63 | ``` 64 | -------------------------------------------------------------------------------- /docs/tutorials/bq.md: -------------------------------------------------------------------------------- 1 | # Streaming benchmarks to a cloud database 2 | 3 | Once you obtain the results of your benchmarks, you will most likely want to store them somewhere. 4 | Whether that is in storage as flat files, on a server, or in a database, `nnbench` allows you to write records anywhere, provided the destination supports JSON. 5 | 6 | This is a small guide containing a snippet on how to stream benchmark results to a Google Cloud BigQuery table. 7 | 8 | ## The benchmarks 9 | 10 | Configure your benchmarks as normal, for example by separating them into a Python file. 11 | The following is a very simple example benchmark setup. 12 | 13 | ```python 14 | --8<-- "examples/bq/benchmarks.py" 15 | ``` 16 | 17 | ## Setting up a BigQuery client 18 | 19 | In order to authenticate with BigQuery, follow the official [Google Cloud documentation](https://cloud.google.com/bigquery/docs/authentication#client-libs). 20 | In this case, we rely on Application Default Credentials (ADC), which can be configured with the `gcloud` CLI. 21 | 22 | To interact with BigQuery from Python, the `google-cloud-bigquery` package has to be installed. 23 | You can do this e.g. using pip via `pip install --upgrade google-cloud-bigquery`. 24 | 25 | ## Creating a table 26 | 27 | Within your configured project, proceed by creating a destination table to write the benchmarks to. 28 | Consider the [BigQuery Python documentation on tables](https://cloud.google.com/bigquery/docs/tables#python) for how to create a table programmatically. 29 | 30 | !!! Note 31 | If the configured dataset does not exist, you will have to create it as well, either programmatically via the `bigquery.Client.create_dataset` API or in the Google Cloud console. 32 | 33 | ## Using BigQuery's schema auto-detection 34 | 35 | In order to skip tedious schema inference by hand, we can use BigQuery's [schema auto-detection from JSON records](https://cloud.google.com/bigquery/docs/schema-detect). 36 | All we have to do is configure a BigQuery load job to auto-detect the schema from the Python dictionaries in memory: 37 | 38 | ```python 39 | --8<-- "examples/bq/bq.py:13:16" 40 | ``` 41 | 42 | After that, write and stream the compacted benchmark record directly to your destination table. 43 | In this example, we decide to flatten the benchmark context to be able to extract scalar context values directly from the result table using raw SQL queries. 44 | Note that you have to use a custom separator (an underscore `"_"` in this case) for the context data, since BigQuery does not allow dots in column names. 45 | 46 | ```python 47 | --8<-- "examples/bq/bq.py:21:25" 48 | ``` 49 | 50 | !!! Tip 51 | If you would like to save the context dictionary as a struct instead, use `mode = "inline"` in the call to `BenchmarkRecord.compact()`. 52 | 53 | And that's all! To check that the records appear as expected, you can now query the data e.g. like so: 54 | 55 | ```python 56 | # check that the insert worked. 57 | query = f'SELECT name, value, time_ns, git_commit AS commit FROM {table_id}' 58 | r = client.query(query) 59 | for row in r.result(): 60 | print(r) 61 | ``` 62 | 63 | ## Recap and the full source code 64 | 65 | In this tutorial, we 66 | 67 | 1) defined and ran a benchmark workload using `nnbench`. 68 | 2) configured a Google Cloud BigQuery client and a load job to insert benchmark records into a table, and 69 | 3) inserted the records into the destination table. 70 | 71 | The full source code for this tutorial is included below, and also in the [nnbench repository](https://github.com/aai-institute/nnbench/tree/main/examples/bq). 72 | 73 | ```python 74 | --8<-- "examples/bq/bq.py" 75 | ``` 76 | -------------------------------------------------------------------------------- /docs/tutorials/duckdb.md: -------------------------------------------------------------------------------- 1 | # Querying benchmark results at scale with duckDB 2 | 3 | For a powerful way to query, filter, and visualize benchmark results, [duckdb](https://duckdb.org/) is a great choice. 4 | This page contains a quick tutorial for analyzing benchmark results with duckDB. 5 | 6 | ## Prerequisites and installation 7 | 8 | To use duckdb, install the duckdb Python package by running `pip install --upgrade duckdb`. 9 | In this tutorial, we are going to be using the in-memory database only, but you can easily persist SQL views of results on disk as well. 10 | 11 | ## Writing and ingesting benchmark results 12 | 13 | We consider the following easy benchmark example: 14 | 15 | ```python 16 | --8<-- "examples/bq/benchmarks.py" 17 | ``` 18 | 19 | Running both of these benchmarks produces a benchmark result, which we can save to disk using the `FileIO` class. 20 | 21 | ```python 22 | import nnbench 23 | from nnbench.context import GitEnvironmentInfo 24 | from nnbench.reporter.file import FileReporter 25 | 26 | benchmarks = nnbench.collect("benchmarks.py") 27 | result = nnbench.run(benchmarks, params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),)) 28 | 29 | file_reporter = FileReporter() 30 | file_reporter.write(result, "result.json") 31 | ``` 32 | 33 | This writes a newline-delimited JSON file as `result.json` into the current directory. We choose this format because it is ideal for duckdb to work with. 34 | 35 | Now, we can easily ingest the result into a duckDB database: 36 | 37 | ```python 38 | import duckdb 39 | 40 | duckdb.sql( 41 | """ 42 | SELECT name, value FROM read_ndjson_auto('result.json') 43 | """ 44 | ).show() 45 | 46 | # ----- prints: ----- 47 | # ┌─────────┬───────┐ 48 | # │ name │ value │ 49 | # │ varchar │ int64 │ 50 | # ├─────────┼───────┤ 51 | # │ prod │ 1 │ 52 | # │ sum │ 2 │ 53 | # └─────────┴───────┘ 54 | ``` 55 | 56 | ## Querying metadata directly in SQL by flattening the context struct 57 | 58 | By default, the benchmark context struct, which holds additional information about the benchmark runs, is inlined into the raw dictionary before saving it to a file. 59 | This is not ideal for some SQL implementations, where you might not be able to filter results easily by interpreting the serialized `context` struct. 60 | 61 | To improve, you can pass `ctxmode="flatten"` to the `FileIO.write()` method to flatten the context and inline all nested values instead. 62 | This comes at the expense of an inflated schema, i.e. more columns in the database. 63 | 64 | ```python 65 | fio = FileIO() 66 | fio.write(result, "result.json", driver="ndjson", ctxmode="flatten") 67 | ``` 68 | 69 | In the example above, we used the `GitEnvironmentInfo` context provider to log some information on the git environment we ran our benchmarks in. 70 | In flat mode, this includes the `git.commit` and `git.repository` values, telling us at which commit and in which repository the benchmarks were run, respectively. 71 | 72 | To log this information in a duckDB view, we run the following on a flat-context NDJSON result: 73 | 74 | ```python 75 | duckdb.sql( 76 | """ 77 | SELECT name, value, \"git.commit\", \"git.repository\" FROM read_ndjson_auto('result.json') 78 | """ 79 | ).show() 80 | 81 | # ---------------------------------- prints ------------------------------------------ 82 | # ┌─────────┬───────┬──────────────────────────────────────────┬───────────────────────┐ 83 | # │ name │ value │ git.commit │ git.repository │ 84 | # │ varchar │ int64 │ varchar │ varchar │ 85 | # ├─────────┼───────┼──────────────────────────────────────────┼───────────────────────┤ 86 | # │ prod │ 1 │ 0d47d7bcd2d2c13b69796355fe9d4ef5f50b1edb │ aai-institute/nnbench │ 87 | # │ sum │ 2 │ 0d47d7bcd2d2c13b69796355fe9d4ef5f50b1edb │ aai-institute/nnbench │ 88 | # └─────────┴───────┴──────────────────────────────────────────┴───────────────────────┘ 89 | ``` 90 | 91 | This method is great to select a subset of context values when the full context metadata structure is not required. 92 | 93 | !!! Tip 94 | In duckDB specifically, this is equivalent to dotted access of the "context" column if `ctxmode="inline"`. 95 | This means that the following also works to obtain the git commit mentioned above: 96 | ```python 97 | duckdb.sql( 98 | """ 99 | SELECT name, value, context.git.commit AS \"git.commit\" FROM read_ndjson_auto('result.json') 100 | """ 101 | ).show() 102 | ``` 103 | -------------------------------------------------------------------------------- /docs/tutorials/huggingface.md: -------------------------------------------------------------------------------- 1 | # Benchmarking HuggingFace models on a dataset 2 | There is a high likelihood that you, at some point, find yourself wanting to benchmark previously trained models. 3 | This guide shows you how to do it for a HuggingFace model with nnbench. 4 | 5 | ## Example: Named Entity Recognition 6 | We start with a small tangent about the example setup that we will use in this guide. 7 | If you are only interested in the application of nnbench, you can skip this section. 8 | 9 | There are lots of reasons why you could want to retrieve saved models for benchmarking. 10 | Among them these are reviewing the work of colleagues, comparing model performance to an existing benchmark, or dealing with models that require significant compute such that in-place retraining is impractical. 11 | 12 | For this example, we look at a named entity recognition (NER) model that is based on the pre-trained encoder-decoder transformer [BERT](https://arxiv.org/abs/1810.04805) from HuggingFace. 13 | The model is trained on the [CoNLLpp dataset](https://huggingface.co/datasets/conllpp) which consists of sentences from news stories where words were tagged with Person, Organization, Location, or Miscellaneous if they referred to entities. 14 | Words are assigned an out-of-entity label if they do not represent an entity. 15 | 16 | ## Model Training 17 | You find the code to train the model in the nnbench [repository](https://github.com/aai-institute/nnbench/tree/main/examples/huggingface). 18 | If you want to skip running the training script but still want to reproduce this example, you can take any BERT model fine tuned for NER with the CoNLL dataset family. 19 | You find many on the Huggingface model hub, for example [this one](https://huggingface.co/dslim/bert-base-NER). You need to download the `model.safetensors`, `config.json`, `tokenizer_config.json`, and `tokenizer.json` files. 20 | If you want to train your own model, continue below. 21 | 22 | There is some necessary preprocessing and data wrangling to train the model. 23 | We will not go into the details here, but if you are interested in a more thorough walkthrough, look into this [resource](https://huggingface.co/learn/nlp-course/chapter7/2?fw=pt) by Huggingface which served as the basis for this example. 24 | 25 | It is not feasible to train the model on a CPU. If you do not have access to a GPU, you can use free GPU instances on [Google Colab](https://colab.research.google.com/). 26 | When opening a new Colab notebook, make sure to select a GPU instance in the upper right corner. 27 | Then, you can upload the `training.py`. You can ignore any warnings about the data not being persisted. 28 | 29 | Next, install the necessary dependencies: `!pip install datasets transformers[torch]`. Google Colab comes with some dependencies already installed in the environment. 30 | Hence, if you are working with a different GPU instance, make sure to install everything from the `pyproject.toml` in the `examples/artifact_benchmarking` folder. 31 | 32 | Finally, you can execute the `training.py` with `!python training.py`. 33 | This will train two BERT models ("bert-base-uncased" and "distilbert-base-uncased") which we can compare using nnbench. 34 | If you want, you can adapt the training script to train other models by editing the tuples in the `tokenizers_and_models` list at the bottom of the training script. 35 | The training of the models takes around 10 minutes. 36 | 37 | Once it is done, download the respective files and save them to your disk. 38 | They should be the same mentioned above. We will need the paths to the files for benchmarking later. 39 | 40 | ## The benchmarks 41 | 42 | The benchmarking code is found in the `examples/huggingface/benchmark.py`. 43 | We calculate precision, recall, accuracy, and f1 scores for the whole test set and specific labels. 44 | Additionally, we obtain information about the model such as its memory footprint and inference time. 45 | 46 | We are not walking through the whole file but instead point out certain design choices as an inspiration to you. 47 | If you are interested in a more detailed walkthrough on how to set up benchmarks, you can find it [here](../guides/benchmarks.md). 48 | 49 | Notable design choices in this benchmark are that we factored out the evaluation loop as it is necessary for all evaluation metrics. 50 | We cache it using the `functools.cache` decorator so the evaluation loop runs only once per benchmark run instead of once per metric which greatly reduces runtime. 51 | 52 | We also use `nnbench.parametrize` to get the per-class metrics. 53 | As the parametrization method needs the same arguments for each benchmark, we use Python's builtin `functools.partial` to fill the arguments. 54 | 55 | ```python 56 | --8<-- "examples/huggingface/benchmark.py:63:67" 57 | ``` 58 | 59 | After this, the benchmarking code is actually very simple, as in most of the other examples. 60 | You find it in the nnbench repository in `examples/huggingface/runner.py`. 61 | -------------------------------------------------------------------------------- /docs/tutorials/index.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This page showcases some examples of applications for nnbench. 4 | Click any of the links below for inspiration on how to use nnbench in your projects. 5 | 6 | * [Integrating nnbench into an existing ML pipeline](mnist.md) 7 | * [Integrating nnbench with workflow orchestrators](prefect.md) 8 | * [Using a streamlit web app to dispatch benchmarks](streamlit.md) 9 | * [Analyzing benchmark results at scale with duckDB](duckdb.md) 10 | * [Streaming benchmark results to a cloud database (Google BigQuery)](bq.md) 11 | * [How to benchmark pre-trained HuggingFace models](huggingface.md) 12 | -------------------------------------------------------------------------------- /docs/tutorials/mnist.md: -------------------------------------------------------------------------------- 1 | # Integrating nnbench into an existing ML pipeline 2 | 3 | Thanks to nnbench's modularity, we can easily integrate it into existing ML experiment code. 4 | 5 | As an example, we use an MNIST pipeline written for the popular ML framework [JAX](https://jax.readthedocs.io/en/latest/). 6 | While the actual data sourcing and training code is interesting on its own, we focus solely on the nnbench application part. 7 | You can find the full example code in the nnbench [repository.](https://github.com/aai-institute/nnbench/tree/main/examples/mnist) 8 | 9 | ## Defining and organizing benchmarks 10 | 11 | To properly structure our project, we avoid mixing training pipeline code and benchmark code by placing all benchmarks in a standalone file, similarly to how you might structure unit tests for your code. 12 | 13 | ```python 14 | --8<-- "examples/mnist/benchmarks.py" 15 | ``` 16 | 17 | This definition is short and sweet, and contains a few important details: 18 | 19 | * Both functions are given the `@nnbench.benchmark` decorator - this allows us to find and collect them before starting the benchmark run. 20 | * The `modelsize` benchmark is given a custom name (`"Model size (MB)"`), indicating that the resulting number is the combined size of the model weights in megabytes. 21 | This is done for display purposes, to improve interpretability when reporting results. 22 | * The `params` argument is the same in both benchmarks, both in name and type. This is important, since it ensures that both benchmarks will be run with the same model weights. 23 | 24 | That's all - now we can shift over to our main pipeline code and see what is necessary to execute the benchmarks and visualize the results. 25 | 26 | ## Setting up a benchmark run and parameters 27 | 28 | After finishing the benchmark setup, we only need a few more lines to augment our pipeline. 29 | 30 | We assume that the benchmark file is located in the same folder as the training pipeline - thus, we can specify our parent directory as the place in which to search for benchmarks: 31 | 32 | ```python 33 | --8<-- "examples/mnist/mnist.py:26:26" 34 | ``` 35 | 36 | Next, we can define a custom subclass of `nnbench.Parameters` to hold our benchmark parameters. 37 | Benchmark parameters are a set of variables used as inputs to the benchmark functions collected during the benchmark run. 38 | 39 | Since our benchmarks above are parametrized by the model weights (named `params` in the function signatures) and the MNIST data split (called `data`), we define our parameters to take exactly these two values. 40 | 41 | ```python 42 | --8<-- "examples/mnist/mnist.py:38:41" 43 | ``` 44 | 45 | And that's it! After we implement all training code, we just run nnbench directly after training in our top-level pipeline function: 46 | 47 | ```python 48 | --8<-- "examples/mnist/mnist.py:213:223" 49 | ``` 50 | 51 | We use the `BenchmarkReporter` to print the results directly to the terminal in a table. 52 | Notice how by we can reuse the training artifacts in nnbench as parameters to obtain results right after training! 53 | 54 | The output might look like this: 55 | 56 | ``` 57 | name value 58 | --------------- ------- 59 | accuracy 0.9712 60 | Model size (MB) 3.29783 61 | ``` 62 | 63 | This can be improved in a number of ways - for example by enriching it with metadata about the model architecture, the used GPU, etc. 64 | For more information on how to supply context to benchmarks, check the [user guide](../guides/index.md) section. 65 | -------------------------------------------------------------------------------- /docs/tutorials/prefect.md: -------------------------------------------------------------------------------- 1 | # Integrating nnbench with Prefect 2 | 3 | If you have more complex workflows it is sensible to use a workflow orchestration tool to manage them. 4 | Benchmarking with nnbench can be integrated with orchestrators. We will present an example integration with Prefect. 5 | We will explain the orchestration concepts in a high level and link to the corresponding parts of the Prefect [docs](https://docs.prefect.io/). 6 | The full example code can be found in the nnbench [repository](https://github.com/aai-institute/nnbench/tree/main/examples/prefect). 7 | 8 | In this example we want to orchestrate the training and benchmarking of a linear regression model. 9 | 10 | ## Project Structure 11 | ### Defining the training tasks and workflows 12 | We recommend to separate the training and benchmarking logic. 13 | 14 | ```python 15 | --8<-- "examples/prefect/src/training.py" 16 | ``` 17 | 18 | The `training.py` file contains functions to generate synthetic data for our regression model, facilitate a train-test-split, and finally train the regression model. 19 | We have applied Prefect's [`@task` decorator.](https://docs.prefect.io/latest/concepts/tasks/). which marks the contained logic as a discrete unit of work for Prefect. 20 | Two other functions prepare the regression data and train the estimator. 21 | They are labeled with the [`@flow` decorator.](https://docs.prefect.io/latest/concepts/flows) that labels the function as a workflow that can depend on other flows or tasks. 22 | The `prepare_regressor_and_test_data` function returns the model and test data so that we can use it in our benchmarks. 23 | 24 | ### Defining Benchmarks 25 | The benchmarks are in the `benchmark.py` file. 26 | We have two functions to calculate the mean absolute error and the mean squared error. 27 | These benchmarks are tagged to indicate they are metrics. 28 | Another two benchmarks calculate calculate information about the model, namely the inference time and size of the model. 29 | The last two functions serve to investigate the test dataset. 30 | 31 | ```python 32 | --8<-- "examples/prefect/src/benchmark.py" 33 | ``` 34 | 35 | We did not apply any Prefect decorators here, as we will assign `@task`s - Prefects smallest unit of work - to run a benchmark family. 36 | 37 | ### Defining Benchmark runners. 38 | In the `runners.py` file, we define the logic to run our benchmarks. 39 | The runner collects the benchmarks from the specified file. 40 | We can filter by tags and use this to define two separate tasks, one to run the metrics and the other to run the metadata benchmarks. We have applied the `@task` decorator to these functions. 41 | 42 | ```python 43 | --8<-- "examples/prefect/src/runner.py:32:55" 44 | ``` 45 | 46 | We have also defined a basic reporter that we will use to save the benchmark results with Prefect's artifact storage machinery. 47 | 48 | ```python 49 | --8<-- "examples/prefect/src/runner.py:18:29" 50 | ``` 51 | In a real-world scenario, we would report to a database and use a dedicated frontend to look at the benchmark results. But logging will suffice as we are only discussing integration with orchestrators here. 52 | 53 | A final compound flow executes the model training, obtains the test set and supplies it to the benchmarks we defined earlier. 54 | 55 | ```python 56 | --8<-- "examples/prefect/src/runner.py:58:91" 57 | ``` 58 | 59 | The final lines in the `runner.py` serve the `train_and_benchmark` function to make it available to Prefect for execution. 60 | 61 | ```python 62 | --8<-- "examples/prefect/src/runner.py:94:95" 63 | ``` 64 | 65 | ## Running Prefect 66 | To run Prefect we have to do several things. 67 | First, we have to make sure it is installed. You can use `pip install -U prefect`. 68 | Then we have to run a Prefect server using `prefect server start`. 69 | We make our benchmark flows available by executing it, `python runner.py`. 70 | This enables us to now order an execution with the following command: `prefect deployment run 'train-and-benchmark/benchmark-runner'`. 71 | The command should also be displayed in the output of the `runner.py` execution. 72 | 73 | Now we can visit the local Prefect dashboard. By default it is on `localhost:4200`. 74 | Here we see the executed tasks and workflows. 75 | 76 | ![Workflow runs](./prefect_resources/flow_runs.png) 77 | 78 | If we navigate to the "Flow Runs" tab we see more details of the flow runs. 79 | 80 | ![Flows](./prefect_resources/detail_flow_runs.png) 81 | 82 | In the "Deployments" tab you see all deployed flows. 83 | Currently, there is only our `train_and_benchmark` flow under the `benchmark-runner` name. 84 | We can trigger a custom execution of workflows in the menu behind the three dots. 85 | 86 | ![Deployments](./prefect_resources/deployments.png) 87 | 88 | You find the results of the benchmarks when visiting the "Artifacts" tab or by navigating to the "Artifacts" section of a specific flow execution. 89 | 90 | As you can see, the nnbench is easily integrated with workflow orchestrators by simply registering the execution of a benchmark runner as a task in the orchestrator. 91 | 92 | For more functionality of Prefect, you can check out their [documentation](https://docs.prefect.io). 93 | -------------------------------------------------------------------------------- /docs/tutorials/prefect_resources/deployments.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/nnbench/9a4e37cd7fa99178a2dab747bbc5b153ac1cd42a/docs/tutorials/prefect_resources/deployments.png -------------------------------------------------------------------------------- /docs/tutorials/prefect_resources/detail_flow_runs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/nnbench/9a4e37cd7fa99178a2dab747bbc5b153ac1cd42a/docs/tutorials/prefect_resources/detail_flow_runs.png -------------------------------------------------------------------------------- /docs/tutorials/prefect_resources/flow_runs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/nnbench/9a4e37cd7fa99178a2dab747bbc5b153ac1cd42a/docs/tutorials/prefect_resources/flow_runs.png -------------------------------------------------------------------------------- /docs/tutorials/streamlit.md: -------------------------------------------------------------------------------- 1 | ## Integrating nnbench with Streamlit and Prefect 2 | 3 | In a project you may want to execute benchmarks or investigate their results with a dedicated frontend. There exist several frameworks that can help you setting up a frontend quickly. For example [Streamlit](https://streamlit.io/), [Gradio](https://www.gradio.app/), [Dash](https://dash.plotly.com/tutorial), or you could roll your own implementation using a backend framework such as [Flask](https://flask.palletsprojects.com/en/3.0.x/). 4 | In this guide we will use Streamlit and integrate it with the [orchestration setup](prefect.md) we've developed with Prefect. 5 | That guide is a prerequisite for this one. 6 | The full example code can be found in the nnbench [repository](https://github.com/aai-institute/nnbench/tree/main/examples/streamlit). 7 | 8 | ## The Streamlit UI 9 | The Streamlit UI is launched by executing 10 | ```bash 11 | streamlit run streamlit_example.py 12 | ``` 13 | and initially looks like this: 14 | ![Workflow runs](./streamlit_resources/initial_ui.png) 15 | 16 | The user interface is assembled in the final part of `streamlit_example.py`. 17 | 18 | ```python 19 | --8<-- "examples/streamlit/streamlit.py:55:69" 20 | ``` 21 | 22 | The user inputs are generated via the custom `setup_ui()` function which then processes the input values once the "Run Benchmark" button is clicked. 23 | 24 | ```python 25 | --8<-- "examples/streamlit/streamlit.py:23:31" 26 | ``` 27 | 28 | We use a session state to keep track of all the benchmarks we ran in the current session which then are displayed within expander elements at the bottom. 29 | 30 | ```python 31 | --8<-- "examples/streamlit/streamlit.py:19:20" 32 | ``` 33 | 34 | ## Integrating Prefect with Streamlit 35 | To integrate Streamlit with Prefect, we have to do some initial housekeeping. Namely, we specify the URL for the `PreFectClient` as well as the storage location of run artifacts where we retrieve the benchmark results from. 36 | 37 | ```python 38 | --8<-- "examples/streamlit/streamlit.py:15:17" 39 | ``` 40 | 41 | In this example there is no direct integration of Streamlit with nnbench, but all interactions are passing through Prefect to make use of its orchestration benefits such as caching of tasks. 42 | Another thing to note is that we are working with local instances for easier reproducibility of this example. Adapting it to work with a remote orchestration server and object storage should be straightforward. 43 | 44 | The main interaction of the Streamlit frontend with Prefect takes place in the `run_bms` and `get_bm_artifacts` functions. 45 | 46 | The former searches for a Prefect deployment `"train-and-benchmark/benchmark-runner"` and executes it with the benchmark parameters specified by the user. It returns the `storage_key`, which we use to retrieve the persisted benchmark results. 47 | 48 | ```python 49 | --8<-- "examples/streamlit/streamlit.py:34:38" 50 | ``` 51 | 52 | The `get_bm_artifacts` function gets a storage key and retrieves the corresponding results. As the results are stored in raw bytes, we have some logic to reconstruct the `nnbench.types.BenchmarkRecord` object. We transform the data into Pandas `DataFrame`s, which are later processed by Streamlit to display the results in tables. 53 | 54 | ```python 55 | --8<-- "examples/streamlit/streamlit.py:41:52" 56 | ``` 57 | 58 | ## Running the example 59 | To run the example, we have to do several things. 60 | First, we need to start Prefect using `prefect server start` in the command line. 61 | Next, we need to make the `"train-and-benchmark/benchmark-runner"` deployment available. 62 | We do so by running the corresponding Python file, `python runner.py`. You find that file in the `examples/prefect/src` directory. 63 | If you are recreating this example on your machine, make sure you have the full contents of the `prefect` directory available in addition to the `streamlit_example.py`. For more information, you can look into the [Prefect Guide](prefect.md). 64 | 65 | Now that Prefect is set up, you can launch a local instance of Streamlit with `streamlit run streamlit_example.py`. 66 | 67 | For more information on how to work with Streamlit, visit their [docs](https://docs.streamlit.io/). 68 | -------------------------------------------------------------------------------- /docs/tutorials/streamlit_resources/initial_ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/nnbench/9a4e37cd7fa99178a2dab747bbc5b153ac1cd42a/docs/tutorials/streamlit_resources/initial_ui.png -------------------------------------------------------------------------------- /examples/bq/benchmarks.py: -------------------------------------------------------------------------------- 1 | import nnbench 2 | 3 | 4 | @nnbench.benchmark 5 | def prod(a: int, b: int) -> int: 6 | return a * b 7 | 8 | 9 | @nnbench.benchmark 10 | def sum(a: int, b: int) -> int: 11 | return a + b 12 | -------------------------------------------------------------------------------- /examples/bq/bq.py: -------------------------------------------------------------------------------- 1 | from google.cloud import bigquery 2 | 3 | import nnbench 4 | from nnbench.context import GitEnvironmentInfo 5 | 6 | 7 | def main(): 8 | client = bigquery.Client() 9 | 10 | # TODO: Fill these out with your appropriate resource names. 11 | table_id = ".." 12 | 13 | job_config = bigquery.LoadJobConfig( 14 | autodetect=True, source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON 15 | ) 16 | 17 | benchmarks = nnbench.collect("benchmarks.py") 18 | res = nnbench.run(benchmarks, params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),)) 19 | 20 | load_job = client.load_table_from_json(res.to_json(), table_id, job_config=job_config) 21 | load_job.result() 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /examples/huggingface/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import time 4 | from functools import cache, lru_cache, partial 5 | 6 | import torch 7 | from datasets import Dataset 8 | from torch.nn import Module 9 | from torch.utils.data import DataLoader 10 | from training import tokenize_and_align_labels 11 | from transformers import DataCollatorForTokenClassification 12 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 13 | 14 | import nnbench 15 | 16 | 17 | @cache 18 | def make_dataloader(tokenizer, dataset): 19 | tokenized_dataset = dataset.map( 20 | lambda examples: tokenize_and_align_labels(tokenizer, examples), 21 | batched=True, 22 | remove_columns=dataset.column_names, 23 | ) 24 | return DataLoader( 25 | tokenized_dataset, 26 | shuffle=False, 27 | collate_fn=DataCollatorForTokenClassification(tokenizer, padding=True), 28 | batch_size=8, 29 | ) 30 | 31 | 32 | @lru_cache 33 | def run_eval_loop(model, dataloader, padding_id=-100): 34 | true_positives = torch.zeros(model.config.num_labels) 35 | false_positives = torch.zeros(model.config.num_labels) 36 | true_negatives = torch.zeros(model.config.num_labels) 37 | false_negatives = torch.zeros(model.config.num_labels) 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | for batch in dataloader: 42 | outputs = model(**batch) 43 | predictions = outputs.logits.argmax(dim=-1) 44 | labels = batch["labels"] 45 | 46 | valid_indices = labels.view(-1) != padding_id 47 | predictions = predictions.view(-1)[valid_indices] 48 | labels = labels.view(-1)[valid_indices] 49 | for idx in range(model.config.num_labels): 50 | tp = ((predictions == idx) & (labels == idx)).sum() 51 | fp = ((predictions == idx) & (labels != idx)).sum() 52 | fn = ((predictions != idx) & (labels == idx)).sum() 53 | tn = ((predictions != idx) & (labels != idx)).sum() 54 | 55 | true_positives[idx] += tp 56 | false_positives[idx] += fp 57 | false_negatives[idx] += fn 58 | true_negatives[idx] += tn 59 | 60 | return true_positives, false_positives, true_negatives, false_negatives 61 | 62 | 63 | parametrize_label = partial( 64 | nnbench.parametrize, 65 | ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC"], 66 | tags=("metric", "per-class"), 67 | )() 68 | 69 | 70 | @nnbench.benchmark(tags=("metric", "aggregate")) 71 | def precision( 72 | model: Module, 73 | tokenizer: PreTrainedTokenizerBase, 74 | valdata: Dataset, 75 | padding_id: int = -100, 76 | ) -> float: 77 | dataloader = make_dataloader(tokenizer, valdata) 78 | tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id) 79 | precision = tp / (tp + fp + 1e-6) 80 | return torch.mean(precision).item() 81 | 82 | 83 | @parametrize_label 84 | def precision_per_class( 85 | class_label: str, 86 | model: Module, 87 | tokenizer: PreTrainedTokenizerBase, 88 | valdata: Dataset, 89 | index_label_mapping: dict[int, str], 90 | padding_id: int = -100, 91 | ) -> float: 92 | dataloader = make_dataloader(tokenizer, valdata) 93 | 94 | tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id) 95 | precision_values = tp / (tp + fp + 1e-6) 96 | for idx, lbl in index_label_mapping.items(): 97 | if lbl == class_label: 98 | return precision_values[idx] 99 | raise ValueError(f" Key {class_label} not in test labels") 100 | 101 | 102 | @nnbench.benchmark(tags=("metric", "aggregate")) 103 | def recall( 104 | model: Module, 105 | tokenizer: PreTrainedTokenizerBase, 106 | valdata: Dataset, 107 | padding_id: int = -100, 108 | ) -> float: 109 | dataloader = make_dataloader(tokenizer, valdata) 110 | 111 | tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id) 112 | recall = tp / (tp + fn + 1e-6) 113 | return torch.mean(recall).item() 114 | 115 | 116 | @parametrize_label 117 | def recall_per_class( 118 | class_label: str, 119 | model: Module, 120 | tokenizer: PreTrainedTokenizerBase, 121 | valdata: Dataset, 122 | index_label_mapping: dict[int, str], 123 | padding_id: int = -100, 124 | ) -> float: 125 | dataloader = make_dataloader(tokenizer, valdata) 126 | 127 | tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id) 128 | recall_values = tp / (tp + fn + 1e-6) 129 | for idx, lbl in index_label_mapping.items(): 130 | if lbl == class_label: 131 | return recall_values[idx] 132 | raise ValueError(f" Key {class_label} not in test labels") 133 | 134 | 135 | @nnbench.benchmark(tags=("metric", "aggregate")) 136 | def f1( 137 | model: Module, 138 | tokenizer: PreTrainedTokenizerBase, 139 | valdata: Dataset, 140 | padding_id: int = -100, 141 | ) -> float: 142 | dataloader = make_dataloader(tokenizer, valdata) 143 | 144 | tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id) 145 | precision = tp / (tp + fp + 1e-6) 146 | recall = tp / (tp + fn + 1e-6) 147 | f1 = 2 * (precision * recall) / (precision + recall + 1e-6) 148 | return torch.mean(f1).item() 149 | 150 | 151 | @parametrize_label 152 | def f1_per_class( 153 | class_label: str, 154 | model: Module, 155 | tokenizer: PreTrainedTokenizerBase, 156 | valdata: Dataset, 157 | index_label_mapping: dict[int, str], 158 | padding_id: int = -100, 159 | ) -> float: 160 | dataloader = make_dataloader(tokenizer, valdata) 161 | 162 | tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id) 163 | precision = tp / (tp + fp + 1e-6) 164 | recall = tp / (tp + fn + 1e-6) 165 | f1_values = 2 * (precision * recall) / (precision + recall + 1e-6) 166 | for idx, lbl in index_label_mapping.items(): 167 | if lbl == class_label: 168 | return f1_values[idx] 169 | raise ValueError(f" Key {class_label} not in test labels") 170 | 171 | 172 | @nnbench.benchmark(tags=("metric", "aggregate")) 173 | def accuracy( 174 | model: Module, 175 | tokenizer: PreTrainedTokenizerBase, 176 | valdata: Dataset, 177 | padding_id: int = -100, 178 | ) -> float: 179 | dataloader = make_dataloader(tokenizer, valdata) 180 | 181 | tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id) 182 | accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-6) 183 | return torch.mean(accuracy).item() 184 | 185 | 186 | @parametrize_label 187 | def accuracy_per_class( 188 | class_label: str, 189 | model: Module, 190 | tokenizer: PreTrainedTokenizerBase, 191 | valdata: Dataset, 192 | index_label_mapping: dict[int, str], 193 | padding_id: int = -100, 194 | ) -> dict[str, float]: 195 | dataloader = make_dataloader(tokenizer, valdata) 196 | 197 | tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id) 198 | accuracy_values = (tp + tn) / (tp + tn + fp + fn + 1e-6) 199 | for idx, lbl in index_label_mapping.items(): 200 | if lbl == class_label: 201 | return accuracy_values[idx] 202 | raise ValueError(f" Key {class_label} not in test labels") 203 | 204 | 205 | @nnbench.benchmark(tags=("config",)) 206 | def model_configuration(model: Module) -> dict: 207 | model.eval() 208 | config = model.config.to_dict() 209 | return config 210 | 211 | 212 | @nnbench.benchmark(tags=("model-meta", "inference-time")) 213 | def avg_inference_time_ns( 214 | model: Module, 215 | tokenizer: PreTrainedTokenizerBase, 216 | valdata: Dataset, 217 | avg_n: int = 100, 218 | ) -> float: 219 | dataloader = make_dataloader(tokenizer, valdata) 220 | 221 | start_time = time.perf_counter() 222 | model.eval() 223 | num_datapoints = 0 224 | with torch.no_grad(): 225 | for batch in dataloader: 226 | if num_datapoints >= avg_n: 227 | break 228 | num_datapoints += len(batch) 229 | _ = model(**batch) 230 | end_time = time.perf_counter() 231 | 232 | total_time = end_time - start_time 233 | average_time = total_time / num_datapoints 234 | return average_time 235 | 236 | 237 | @nnbench.benchmark(tags=("model-meta", "size-on-disk")) 238 | def model_size_mb(model: Module) -> float: 239 | model.eval() 240 | with tempfile.NamedTemporaryFile() as tmp: 241 | torch.save(model.state_dict(), tmp.name) 242 | tmp.flush() 243 | tmp.seek(0, os.SEEK_END) 244 | tmp_size = tmp.tell() 245 | size_mb = tmp_size / (1024 * 1024) 246 | return size_mb 247 | -------------------------------------------------------------------------------- /examples/huggingface/runner.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | from datasets import Dataset, load_dataset 4 | from transformers import ( 5 | AutoModelForTokenClassification, 6 | AutoTokenizer, 7 | ) 8 | 9 | import nnbench 10 | from nnbench.reporter import ConsoleReporter 11 | 12 | dataset = load_dataset("conllpp") 13 | path = dataset.cache_files["validation"][0]["filename"] 14 | 15 | 16 | def main() -> None: 17 | model = AutoModelForTokenClassification.from_pretrained( 18 | "dslim/distilbert-NER", use_safetensors=True 19 | ) 20 | tokenizer = AutoTokenizer.from_pretrained("dslim/distilbert-NER") 21 | valdata = Dataset.from_file(path) 22 | label_names: Iterable[str] = valdata.features["ner_tags"].feature.names 23 | index_label_mapping = {i: label for i, label in enumerate(label_names)} 24 | 25 | params = { 26 | "model": model, 27 | "tokenizer": tokenizer, 28 | "valdata": valdata, 29 | "index_label_mapping": index_label_mapping, 30 | } 31 | 32 | benchmarks = nnbench.collect("benchmark.py", tags=("per-class",)) 33 | reporter = ConsoleReporter() 34 | result = nnbench.run(benchmarks, params=params) 35 | reporter.write(result) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /examples/huggingface/training.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from datasets import load_dataset 5 | from torch.optim import AdamW 6 | from torch.utils.data import DataLoader 7 | from tqdm.auto import tqdm 8 | from transformers import ( 9 | AutoModelForTokenClassification, 10 | AutoTokenizer, 11 | BatchEncoding, 12 | DataCollatorForTokenClassification, 13 | PreTrainedTokenizer, 14 | get_scheduler, 15 | ) 16 | 17 | OUTPUT_DIR = "artifacts" 18 | 19 | 20 | def align_labels_with_tokens(labels: list[int], word_ids: list[int]) -> list[int]: 21 | new_labels = [] 22 | current_word = None 23 | for word_id in word_ids: 24 | if word_id != current_word: 25 | # Start of a new word! 26 | current_word = word_id 27 | label = -100 if word_id is None else labels[word_id] 28 | new_labels.append(label) 29 | elif word_id is None: 30 | # Special token 31 | new_labels.append(-100) 32 | else: 33 | # Same word as previous token 34 | label = labels[word_id] 35 | # If the label is B-XXX we change it to I-XXX 36 | if label % 2 == 1: 37 | label += 1 38 | new_labels.append(label) 39 | 40 | return new_labels 41 | 42 | 43 | def tokenize_and_align_labels( 44 | tokenizer: PreTrainedTokenizer, examples: dict[str, Any] 45 | ) -> BatchEncoding: 46 | tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True) 47 | 48 | labels = [] 49 | for i, label in enumerate(examples["ner_tags"]): 50 | word_ids = tokenized_inputs.word_ids(batch_index=i) 51 | labels.append(align_labels_with_tokens(label, word_ids)) 52 | 53 | tokenized_inputs["labels"] = labels 54 | return tokenized_inputs 55 | 56 | 57 | def train_model(tokenizer_name: str, model_name: str, output_dir: str) -> None: 58 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 59 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 60 | model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=9).to(device) 61 | 62 | dataset = load_dataset("conllpp", split="train") 63 | tokenized_datasets = dataset.map( 64 | lambda examples: tokenize_and_align_labels(tokenizer, examples), 65 | batched=True, 66 | remove_columns=dataset.column_names, 67 | ) 68 | 69 | optimizer = AdamW(model.parameters(), lr=2e-5) 70 | 71 | data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True) 72 | 73 | train_dataloader = DataLoader( 74 | tokenized_datasets, 75 | shuffle=True, 76 | collate_fn=data_collator, 77 | batch_size=8, 78 | ) 79 | 80 | num_train_epochs = 3 81 | num_update_steps_per_epoch = len(train_dataloader) 82 | num_training_steps = num_train_epochs * num_update_steps_per_epoch 83 | 84 | lr_scheduler = get_scheduler( 85 | "linear", 86 | optimizer=optimizer, 87 | num_warmup_steps=0, 88 | num_training_steps=num_training_steps, 89 | ) 90 | 91 | progress_bar = tqdm(range(num_training_steps)) 92 | for epoch in range(num_train_epochs): 93 | model.train() 94 | for batch in train_dataloader: 95 | batch = {k: v.to(device) for k, v in batch.items()} 96 | outputs = model(**batch) 97 | loss = outputs.loss 98 | loss.backward() 99 | 100 | optimizer.step() 101 | lr_scheduler.step() 102 | optimizer.zero_grad() 103 | progress_bar.update(1) 104 | 105 | model_save_path = f"{output_dir}/{model_name}" 106 | model.save_pretrained(model_save_path) 107 | tokenizer.save_pretrained(model_save_path) 108 | 109 | 110 | if __name__ == "__main__": 111 | tokenizers_and_models = [ 112 | ("distilbert-base-uncased", "distilbert-base-uncased"), 113 | ("bert-base-uncased", "bert-base-uncased"), 114 | ] 115 | for t, m in tokenizers_and_models: 116 | train_model(t, m, OUTPUT_DIR) 117 | -------------------------------------------------------------------------------- /examples/mnist/benchmarks.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from mnist import ArrayMapping, ConvNet 4 | 5 | import nnbench 6 | 7 | 8 | @nnbench.benchmark 9 | def accuracy(params: ArrayMapping, data: ArrayMapping) -> float: 10 | x_test, y_test = data["x_test"], data["y_test"] 11 | 12 | cn = ConvNet() 13 | y_pred = cn.apply({"params": params}, x_test) 14 | return jnp.mean(jnp.argmax(y_pred, -1) == y_test).item() 15 | 16 | 17 | @nnbench.benchmark(name="Model size (MB)") 18 | def modelsize(params: ArrayMapping) -> float: 19 | nbytes = sum(x.size * x.dtype.itemsize for x in jax.tree_util.tree_leaves(params)) 20 | return nbytes / 1e6 21 | -------------------------------------------------------------------------------- /examples/mnist/mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | JAX MNIST example with nnbench. 3 | 4 | Demonstrates the use of nnbench to collect and log metrics right after training. 5 | 6 | Source: https://github.com/google/flax/blob/main/examples/mnist 7 | """ 8 | 9 | import random 10 | from collections.abc import Mapping 11 | from dataclasses import dataclass 12 | from pathlib import Path 13 | 14 | import flax.linen as nn 15 | import fsspec 16 | import jax 17 | import jax.numpy as jnp 18 | import jax.random as jr 19 | import numpy as np 20 | import optax 21 | from flax.training.train_state import TrainState 22 | 23 | import nnbench 24 | from nnbench.reporter import FileReporter 25 | 26 | HERE = Path(__file__).parent 27 | 28 | ArrayMapping = dict[str, jax.Array | np.ndarray] 29 | 30 | INPUT_SHAPE = (28, 28, 1) # H x W x C (= 1, BW grayscale images) 31 | NUM_CLASSES = 10 32 | BATCH_SIZE = 128 33 | EPOCHS = 1 34 | LEARNING_RATE = 0.1 35 | MOMENTUM = 0.9 36 | 37 | 38 | @dataclass(frozen=True) 39 | class MNISTTestParameters(nnbench.Parameters): 40 | params: Mapping[str, jax.Array] 41 | data: ArrayMapping 42 | 43 | 44 | class ConvNet(nn.Module): 45 | @nn.compact 46 | def __call__(self, x): 47 | x = nn.Conv(features=32, kernel_size=(3, 3))(x) 48 | x = nn.relu(x) 49 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 50 | x = nn.Conv(features=64, kernel_size=(3, 3))(x) 51 | x = nn.relu(x) 52 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 53 | x = x.reshape(x.shape[0], -1) # flatten 54 | x = nn.Dense(features=256)(x) 55 | x = nn.relu(x) 56 | x = nn.Dense(features=NUM_CLASSES)(x) 57 | return x 58 | 59 | 60 | @jax.jit 61 | def apply_model(state, images, labels): 62 | """Computes gradients, loss and accuracy for a single batch.""" 63 | 64 | def loss_fn(params): 65 | logits = state.apply_fn({"params": params}, images) 66 | one_hot = jax.nn.one_hot(labels, 10) 67 | loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) 68 | return loss, logits 69 | 70 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 71 | (loss, logits), grads = grad_fn(state.params) 72 | accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) 73 | return grads, loss, accuracy 74 | 75 | 76 | @jax.jit 77 | def update_model(state, grads): 78 | return state.apply_gradients(grads=grads) 79 | 80 | 81 | def create_train_state(rng): 82 | """Creates initial `TrainState`.""" 83 | convnet = ConvNet() 84 | params = convnet.init(rng, jnp.ones([1, *INPUT_SHAPE]))["params"] 85 | tx = optax.sgd(learning_rate=LEARNING_RATE, momentum=MOMENTUM) 86 | return TrainState.create(apply_fn=convnet.apply, params=params, tx=tx) 87 | 88 | 89 | def load_mnist() -> ArrayMapping: 90 | """ 91 | Load MNIST dataset using fsspec. 92 | 93 | Returns 94 | ------- 95 | ArrayMapping 96 | Versioned dataset as numpy arrays, split into training and test data. 97 | """ 98 | 99 | if Path(HERE / "mnist.npz").exists(): 100 | data = np.load(HERE / "mnist.npz") 101 | return dict(data) 102 | 103 | mnist: ArrayMapping = {} 104 | 105 | baseurl = "https://storage.googleapis.com/cvdf-datasets/mnist/" 106 | 107 | for key, file in [ 108 | ("x_train", "train-images-idx3-ubyte.gz"), 109 | ("x_test", "t10k-images-idx3-ubyte.gz"), 110 | ("y_train", "train-labels-idx1-ubyte.gz"), 111 | ("y_test", "t10k-labels-idx1-ubyte.gz"), 112 | ]: 113 | with fsspec.open(baseurl + file, compression="gzip") as f: 114 | if key.startswith("x"): 115 | mnist[key] = np.frombuffer(f.read(), np.uint8, offset=16).reshape((-1, 28, 28)) 116 | else: 117 | mnist[key] = np.frombuffer(f.read(), np.uint8, offset=8) 118 | 119 | # save the data locally after download. 120 | np.savez_compressed(HERE / "mnist.npz", **mnist) 121 | 122 | return mnist 123 | 124 | 125 | def train_epoch(state, train_ds, train_labels, batch_size, rng): 126 | """Train for a single epoch.""" 127 | train_ds_size = len(train_ds) 128 | steps_per_epoch = train_ds_size // batch_size 129 | 130 | perms = jax.random.permutation(rng, len(train_ds)) 131 | # skip incomplete batch to avoid a recompile of apply_model 132 | perms = perms[: steps_per_epoch * batch_size] 133 | perms = perms.reshape((steps_per_epoch, batch_size)) 134 | 135 | epoch_loss = [] 136 | epoch_accuracy = [] 137 | 138 | for perm in perms: 139 | batch_images = train_ds[perm, ...] 140 | batch_labels = train_labels[perm, ...] 141 | grads, loss, accuracy = apply_model(state, batch_images, batch_labels) 142 | state = update_model(state, grads) 143 | epoch_loss.append(loss) 144 | epoch_accuracy.append(accuracy) 145 | 146 | train_loss = np.mean(epoch_loss) 147 | train_accuracy = np.mean(epoch_accuracy) 148 | return state, train_loss, train_accuracy 149 | 150 | 151 | def preprocess(data: ArrayMapping) -> ArrayMapping: 152 | """ 153 | Expand dimensions of images. 154 | 155 | Parameters 156 | ---------- 157 | data: ArrayMapping 158 | Raw input dataset, as a compressed NumPy array collection. 159 | 160 | Returns 161 | ------- 162 | ArrayMapping 163 | Dataset with expanded dimensions. 164 | """ 165 | 166 | data["x_train"] = jnp.float32(data["x_train"]) / 255.0 167 | data["y_train"] = jnp.float32(data["y_train"]) 168 | data["x_test"] = jnp.float32(data["x_test"]) / 255.0 169 | data["y_test"] = jnp.float32(data["y_test"]) 170 | 171 | # add a fake channel axis to make sure images have shape (28, 28, 1) 172 | if not data["x_train"].shape[-1] == 1: 173 | data["x_train"] = jnp.expand_dims(data["x_train"], -1) 174 | data["x_test"] = jnp.expand_dims(data["x_test"], -1) 175 | 176 | return data 177 | 178 | 179 | def train(data: ArrayMapping) -> tuple[TrainState, ArrayMapping]: 180 | """Train a ConvNet model on the preprocessed data.""" 181 | 182 | x_train, y_train = data["x_train"], data["y_train"] 183 | x_test, y_test = data["x_test"], data["y_test"] 184 | 185 | train_perm = np.random.permutation(len(x_train)) 186 | train_perm = train_perm[: int(0.5 * len(x_train))] 187 | train_data, train_labels = x_train[train_perm, ...], y_train[train_perm, ...] 188 | 189 | test_perm = np.random.permutation(len(x_test)) 190 | test_perm = test_perm[: int(0.5 * len(x_test))] 191 | test_data, test_labels = x_test[test_perm, ...], y_test[test_perm, ...] 192 | 193 | rng = jr.PRNGKey(random.randint(0, 1000)) 194 | rng, init_rng = jr.split(rng) 195 | state = create_train_state(init_rng) 196 | 197 | for epoch in range(EPOCHS): 198 | rng, input_rng = jax.random.split(rng) 199 | state, train_loss, train_accuracy = train_epoch( 200 | state, train_data, train_labels, BATCH_SIZE, input_rng 201 | ) 202 | 203 | data = { 204 | "x_train": train_data, 205 | "y_train": train_labels, 206 | "x_test": test_data, 207 | "y_test": test_labels, 208 | } 209 | 210 | return state, data 211 | 212 | 213 | def mnist_jax(): 214 | """Load MNIST data and train a simple ConvNet model.""" 215 | mnist = load_mnist() 216 | mnist = preprocess(mnist) 217 | state, data = train(mnist) 218 | 219 | # the nnbench portion. 220 | benchmarks = nnbench.collect(HERE) 221 | reporter = FileReporter() 222 | params = MNISTTestParameters(params=state.params, data=data) 223 | result = nnbench.run(benchmarks, name="nnbench-mnist-run", params=params) 224 | reporter.write(result, "mnist-result.json") 225 | 226 | 227 | if __name__ == "__main__": 228 | mnist_jax() 229 | -------------------------------------------------------------------------------- /examples/prefect/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools-scm[toml]>=7.1"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "nnbench-prefect-example" 7 | requires-python = ">= 3.10" 8 | version = "0.1.0" 9 | description = "Integration example of Prefect with nnbench." 10 | readme = "docs/guides/prefect.md" 11 | license = { text = "Apache-2.0" } 12 | dependencies = [ 13 | "prefect", 14 | "numpy", 15 | "scikit-learn", 16 | "nnbench@git+https://github.com/aai-institute/nnbench.git" 17 | ] 18 | maintainers = [{name="Max Mynter", email="m.mynter@appliedai-institute.de"}] 19 | authors = [{ name = "appliedAI Initiative", email = "info+oss@appliedai.de" }] 20 | -------------------------------------------------------------------------------- /examples/prefect/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/nnbench/9a4e37cd7fa99178a2dab747bbc5b153ac1cd42a/examples/prefect/src/__init__.py -------------------------------------------------------------------------------- /examples/prefect/src/benchmark.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import sys 3 | import time 4 | 5 | import numpy as np 6 | from sklearn import base, metrics 7 | 8 | import nnbench 9 | 10 | 11 | @nnbench.benchmark(tags=("metric",)) 12 | def mae(model: base.BaseEstimator, X_test: np.ndarray, y_test: np.ndarray) -> float: 13 | y_pred = model.predict(X_test) 14 | return metrics.mean_absolute_error(y_true=y_test, y_pred=y_pred) 15 | 16 | 17 | @nnbench.benchmark(tags=("metric",)) 18 | def mse(model: base.BaseEstimator, X_test: np.ndarray, y_test: np.ndarray) -> float: 19 | y_pred = model.predict(X_test) 20 | return metrics.mean_squared_error(y_true=y_test, y_pred=y_pred) 21 | 22 | 23 | @nnbench.benchmark(name="Model size (bytes)", tags=("model-meta",)) 24 | def modelsize(model: base.BaseEstimator) -> int: 25 | model_bytes = pickle.dumps(model) 26 | return sys.getsizeof(model_bytes) 27 | 28 | 29 | @nnbench.benchmark(name="Inference time (s)", tags=("model-meta",)) 30 | def inference_time(model: base.BaseEstimator, X: np.ndarray, n_iter: int = 100) -> float: 31 | start = time.perf_counter() 32 | for i in range(n_iter): 33 | _ = model.predict(X) 34 | end = time.perf_counter() 35 | return (end - start) / n_iter 36 | -------------------------------------------------------------------------------- /examples/prefect/src/runner.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | import numpy as np 5 | import training 6 | from prefect import flow, get_run_logger, task 7 | from prefect.artifacts import create_table_artifact 8 | from sklearn import base 9 | 10 | import nnbench 11 | from nnbench import context, reporter 12 | 13 | dir_path = os.path.dirname(__file__) 14 | 15 | 16 | class PrefectReporter(reporter.BenchmarkReporter): 17 | def __init__(self): 18 | self.logger = get_run_logger() 19 | 20 | async def write( 21 | self, 22 | result: nnbench.BenchmarkResult, 23 | key: str, 24 | description: str = "Benchmark and Context", 25 | ) -> None: 26 | await create_table_artifact( 27 | key=key, 28 | table=result.to_json(), 29 | description=description, 30 | ) 31 | 32 | 33 | @task 34 | def run_metric_benchmarks( 35 | model: base.BaseEstimator, X_test: np.ndarray, y_test: np.ndarray 36 | ) -> nnbench.BenchmarkResult: 37 | benchmarks = nnbench.collect(os.path.join(dir_path, "benchmark.py"), tags=("metric",)) 38 | results = nnbench.run( 39 | benchmarks, 40 | params={"model": model, "X_test": X_test, "y_test": y_test}, 41 | ) 42 | return results 43 | 44 | 45 | @task 46 | def run_metadata_benchmarks(model: base.BaseEstimator, X: np.ndarray) -> nnbench.BenchmarkResult: 47 | benchmarks = nnbench.collect(os.path.join(dir_path, "benchmark.py"), tags=("model-meta",)) 48 | result = nnbench.run( 49 | benchmarks, 50 | params={"model": model, "X": X}, 51 | ) 52 | return result 53 | 54 | 55 | @flow(persist_result=True) 56 | async def train_and_benchmark( 57 | data_params: dict[str, int | float] | None = None, 58 | ) -> tuple[nnbench.BenchmarkResult, ...]: 59 | if data_params is None: 60 | data_params = {} 61 | 62 | reporter = PrefectReporter() 63 | 64 | regressor_and_test_data: tuple[ 65 | base.BaseEstimator, np.ndarray, np.ndarray 66 | ] = await training.prepare_regressor_and_test_data(data_params=data_params) # type: ignore 67 | 68 | model = regressor_and_test_data[0] 69 | X_test = regressor_and_test_data[1] 70 | y_test = regressor_and_test_data[2] 71 | 72 | metadata_results: nnbench.BenchmarkResult = run_metadata_benchmarks(model=model, X=X_test) 73 | 74 | metadata_results.context.update(data_params) 75 | metadata_results.context.update(context.PythonInfo()()) 76 | 77 | await reporter.write( 78 | result=metadata_results, key="model-attributes", description="Model Attributes" 79 | ) 80 | 81 | metric_results: nnbench.BenchmarkResult = run_metric_benchmarks( 82 | model=model, X_test=X_test, y_test=y_test 83 | ) 84 | 85 | metric_results.context.update(data_params) 86 | metric_results.context.update(context.PythonInfo()()) 87 | await reporter.write(metric_results, key="model-performance", description="Model Performance") 88 | return metadata_results, metric_results 89 | 90 | 91 | if __name__ == "__main__": 92 | asyncio.run(train_and_benchmark.serve(name="benchmark-runner")) 93 | -------------------------------------------------------------------------------- /examples/prefect/src/training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from prefect import flow, task 3 | from sklearn import base 4 | from sklearn.datasets import make_regression 5 | from sklearn.linear_model import LinearRegression 6 | from sklearn.model_selection import train_test_split 7 | 8 | 9 | @task 10 | def make_regression_data( 11 | random_state: int, n_samples: int = 100, n_features: int = 1, noise: float = 0.2 12 | ) -> tuple[np.ndarray, np.ndarray]: 13 | X, y = make_regression( 14 | n_samples=n_samples, n_features=n_features, noise=noise, random_state=random_state 15 | ) 16 | return X, y 17 | 18 | 19 | @task 20 | def make_train_test_split( 21 | X: np.ndarray, y: np.ndarray, random_state: int, test_size: float = 0.2 22 | ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 23 | X_train, X_test, y_train, y_test = train_test_split( 24 | X, y, test_size=test_size, random_state=random_state 25 | ) 26 | return X_train, y_train, X_test, y_test 27 | 28 | 29 | @task 30 | def train_linear_regression(X: np.ndarray, y: np.ndarray) -> base.BaseEstimator: 31 | model = LinearRegression() 32 | model.fit(X, y) 33 | return model 34 | 35 | 36 | @flow 37 | def prepare_regression_data( 38 | random_state: int = 42, n_samples: int = 100, n_features: int = 1, noise: float = 0.2 39 | ) -> tuple[np.ndarray, ...]: 40 | X, y = make_regression_data( 41 | random_state=random_state, n_samples=n_samples, n_features=n_features, noise=noise 42 | ) 43 | X_train, y_train, X_test, y_test = make_train_test_split(X=X, y=y, random_state=random_state) 44 | return X_train, y_train, X_test, y_test 45 | 46 | 47 | @flow 48 | async def prepare_regressor_and_test_data( 49 | data_params: dict[str, int | float] | None = None, 50 | ) -> tuple[base.BaseEstimator, np.ndarray, np.ndarray]: 51 | if data_params is None: 52 | data_params = {} 53 | X_train, y_train, X_test, y_test = prepare_regression_data(**data_params) 54 | model = train_linear_regression(X=X_train, y=y_train) 55 | return model, X_test, y_test 56 | -------------------------------------------------------------------------------- /examples/streamlit/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools-scm[toml]>=7.1"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "nnbench-streamlit-example" 7 | requires-python = ">= 3.10" 8 | version = "0.1.0" 9 | description = "Integration example of Streamlit with Prefect and nnbench." 10 | readme = "docs/guides/streamlit.md" 11 | license = { text = "Apache-2.0" } 12 | dependencies = [ 13 | "prefect", 14 | "numpy", 15 | "scikit-learn", 16 | "streamlit", 17 | "nnbench@git+https://github.com/aai-institute/nnbench.git" 18 | ] 19 | maintainers = [{name="Max Mynter", email="m.mynter@appliedai-institute.de"}] 20 | authors = [{ name = "appliedAI Initiative", email = "info+oss@appliedai.de" }] 21 | -------------------------------------------------------------------------------- /examples/streamlit/streamlit_example.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import pickle 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | import pandas as pd 8 | import prefect 9 | import streamlit as st 10 | from prefect import deployments 11 | from prefect.results import PersistedResultBlob 12 | 13 | import nnbench 14 | 15 | LOCAL_PREFECT_URL = "http://127.0.0.1:4200" 16 | LOCAL_PREFECT_PERSISTENCE_FOLDER = Path.home() / ".prefect" / "storage" 17 | prefect_client = prefect.PrefectClient(api=LOCAL_PREFECT_URL) 18 | 19 | if "benchmarks" not in st.session_state: 20 | st.session_state["benchmarks"] = [] 21 | 22 | 23 | def setup_ui(): 24 | st.title("Streamlit & Prefect Demo") 25 | bm_params = { 26 | "random_state": st.number_input("Random State", value=42, step=1), 27 | "n_features": st.number_input("N Features", value=2, step=1), 28 | "n_samples": st.number_input("N Samples", value=100, step=5), 29 | "noise": st.number_input("Noise", value=0.2, format="%f"), 30 | } 31 | return bm_params 32 | 33 | 34 | async def run_bms(params: dict[str, Any]) -> str: 35 | result = await deployments.run_deployment( 36 | "train-and-benchmark/benchmark-runner", parameters={"data_params": params} 37 | ) 38 | return result.state.result().storage_key 39 | 40 | 41 | def get_bm_artifacts(storage_key: str) -> None: 42 | blob_path = LOCAL_PREFECT_PERSISTENCE_FOLDER / storage_key 43 | blob = PersistedResultBlob.parse_raw(blob_path.read_bytes()) 44 | bm_results: tuple[nnbench.BenchmarkResult, ...] = pickle.loads(base64.b64decode(blob.data)) 45 | 46 | bms = [pd.DataFrame(result.benchmarks) for result in bm_results] 47 | for df in bms: 48 | df["value"] = df["value"].apply(lambda x: f"{x:.2e}") 49 | cxs = [pd.DataFrame([result.context.data]) for result in bm_results] 50 | 51 | display_data = [bms + [cxs[0]]] # Only need context once 52 | st.session_state["benchmarks"].extend(display_data) 53 | 54 | 55 | if __name__ == "__main__": 56 | bm_params = setup_ui() 57 | if st.button("Run Benchmarks"): 58 | storage_key = asyncio.run(run_bms(params=bm_params)) 59 | get_bm_artifacts(storage_key) 60 | st.write("Benchmark Results") 61 | for i, benchmark in reversed(list(enumerate(st.session_state["benchmarks"]))): 62 | with st.expander(f"Benchmark Run {i + 1}"): 63 | meta, metric, ctx = benchmark 64 | st.write("Model Attributes") 65 | st.table(meta) 66 | st.write("Model Metrics") 67 | st.table(metric) 68 | st.write("Context Configuration") 69 | st.table(ctx) 70 | -------------------------------------------------------------------------------- /examples/zenml/README.md: -------------------------------------------------------------------------------- 1 | # An example on benchmarking models with ZenML 2 | 3 | This example contains a ZenML pipeline training a random forest classifier on the Iris dataset, including an evaluation step that collects and runs a benchmark suite on the newly trained model. 4 | It logs the results to the step directly as metadata, where the data scientist can immediately inspect them in the ZenML dashboard, or process them in scripts with the client's model metadata APIs. 5 | -------------------------------------------------------------------------------- /examples/zenml/benchmarks.py: -------------------------------------------------------------------------------- 1 | from pipeline import ArrayMapping 2 | from sklearn.base import BaseEstimator 3 | from sklearn.metrics import accuracy_score 4 | 5 | import nnbench 6 | 7 | 8 | @nnbench.benchmark 9 | def accuracy(model: BaseEstimator, data: ArrayMapping) -> float: 10 | X_test, y_test = data["X_test"], data["y_test"] 11 | y_pred = model.predict(X_test) 12 | accuracy = accuracy_score(y_test, y_pred) 13 | return accuracy 14 | -------------------------------------------------------------------------------- /examples/zenml/pipeline.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = ">=3.10" 3 | # dependencies = [ 4 | # "scikit-learn", 5 | # "zenml", 6 | # "nnbench", 7 | # ] 8 | # /// 9 | 10 | """ 11 | The scikit-learn random forest example with nnbench, as a ZenML pipeline. 12 | 13 | Demonstrates the use of nnbench to collect and log metrics right after training. 14 | """ 15 | 16 | from dataclasses import dataclass 17 | from pathlib import Path 18 | 19 | import numpy as np 20 | from sklearn.datasets import load_iris 21 | from sklearn.ensemble import RandomForestClassifier 22 | from sklearn.model_selection import train_test_split 23 | from zenml import log_metadata, pipeline, step 24 | 25 | import nnbench 26 | 27 | HERE = Path(__file__).parent 28 | 29 | ArrayMapping = dict[str, np.ndarray] 30 | 31 | 32 | MAX_DEPTH = 5 33 | N_ESTIMATORS = 100 34 | RANDOM_STATE = 42 35 | 36 | 37 | @dataclass(frozen=True) 38 | class MNISTTestParameters(nnbench.Parameters): 39 | model: RandomForestClassifier 40 | data: ArrayMapping 41 | 42 | 43 | @step 44 | def load_iris_dataset() -> ArrayMapping: 45 | """ 46 | Load and split Iris data from scikit-learn. 47 | 48 | Returns 49 | ------- 50 | ArrayMapping 51 | Iris dataset as NumPy arrays, split into training and test data. 52 | """ 53 | 54 | data = load_iris() 55 | X, y = data.data, data.target 56 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) 57 | 58 | iris = { 59 | "X_train": X_train, 60 | "X_test": X_test, 61 | "y_train": y_train, 62 | "y_test": y_test, 63 | } 64 | 65 | return iris 66 | 67 | 68 | @step 69 | def train_model( 70 | data: ArrayMapping, 71 | n_estimators: int = N_ESTIMATORS, 72 | max_depth: int = MAX_DEPTH, 73 | random_state: int = RANDOM_STATE, 74 | ) -> RandomForestClassifier: 75 | model = RandomForestClassifier( 76 | n_estimators=n_estimators, max_depth=max_depth, random_state=random_state 77 | ) 78 | X_train, y_train = data["X_train"], data["y_train"] 79 | model.fit(X_train, y_train) 80 | return model 81 | 82 | 83 | @step 84 | def benchmark_model(model: RandomForestClassifier, data: ArrayMapping) -> None: 85 | """Evaluate the model and log metrics in nnbench.""" 86 | 87 | # the nnbench portion. 88 | benchmarks = nnbench.collect(HERE) 89 | params = MNISTTestParameters(model=model, data=data) 90 | result = nnbench.run(benchmarks, name="nnbench-iris-run", params=params) 91 | 92 | # Log metrics to the step. 93 | log_metadata(metadata=result.to_json()) 94 | 95 | 96 | @pipeline 97 | def mnist_jax(): 98 | """Load Iris data, train, and benchmark a simple random forest model.""" 99 | iris = load_iris_dataset() 100 | model = train_model(iris) 101 | benchmark_model(model, iris) 102 | 103 | 104 | if __name__ == "__main__": 105 | mnist_jax() 106 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: nnbench 2 | site_dir: public/docs 3 | site_url: https://aai-institute.github.io/nnbench 4 | repo_url: https://github.com/aai-institute/nnbench 5 | edit_uri: edit/main/docs/ 6 | 7 | # To validate all internal links exist. Does not work in ipynb files 8 | strict: true 9 | validation: 10 | omitted_files: warn 11 | absolute_links: warn 12 | unrecognized_links: warn 13 | 14 | copyright: Copyright © 2024 appliedAI Institute for Europe gGmbH
The appliedAI Institute for Europe gGmbH is supported by the KI-Stiftung Heilbronn gGmbH. 15 | 16 | nav: 17 | - Home: index.md 18 | - quickstart.md 19 | - Command-line interface (CLI): 20 | - cli/index.md 21 | - cli/cli.md 22 | - cli/pyproject.md 23 | - cli/fixtures.md 24 | - cli/comparisons.md 25 | - User Guide: 26 | - guides/index.md 27 | - guides/benchmarks.md 28 | - guides/customization.md 29 | - guides/organization.md 30 | - guides/runners.md 31 | - Examples: 32 | - tutorials/index.md 33 | - tutorials/huggingface.md 34 | - tutorials/mnist.md 35 | - tutorials/prefect.md 36 | - tutorials/streamlit.md 37 | - tutorials/bq.md 38 | - tutorials/duckdb.md 39 | - API Reference: reference/ 40 | - Contributing: CONTRIBUTING.md 41 | 42 | # Rebuild docs in `mkdocs serve` for changes in source code 43 | watch: 44 | - src/ 45 | 46 | plugins: 47 | - callouts 48 | - gen-files: 49 | scripts: 50 | - docs/_scripts/gen_api_ref_pages.py 51 | - literate-nav: 52 | nav_file: SUMMARY.md 53 | - section-index 54 | - mkdocstrings: 55 | handlers: 56 | python: 57 | paths: [src] 58 | options: 59 | docstring_style: numpy 60 | docstring_section_style: spacy 61 | line_length: 100 62 | show_bases: true 63 | show_if_no_docstring: true 64 | members_order: source 65 | separate_signature: true 66 | show_signature_annotations: true 67 | signature_crossrefs: true 68 | merge_init_into_class: false 69 | filters: ["!^_{1,2}"] 70 | - mike: 71 | canonical_version: latest 72 | - privacy 73 | - search: 74 | - include_dir_to_nav: 75 | file_pattern: '.*\.md$' 76 | 77 | markdown_extensions: 78 | - neoteroi.cards # https://www.neoteroi.dev/mkdocs-plugins/cards/ 79 | # python-markdown extensions: https://python-markdown.github.io/extensions/ 80 | - admonition 81 | - attr_list 82 | - sane_lists 83 | - toc: 84 | permalink: true 85 | toc_depth: 3 86 | # pymdown-extensions: https://facelessuser.github.io/pymdown-extensions/ 87 | - pymdownx.details 88 | - pymdownx.emoji: 89 | emoji_index: !!python/name:material.extensions.emoji.twemoji 90 | emoji_generator: !!python/name:material.extensions.emoji.to_svg 91 | - pymdownx.highlight: 92 | anchor_linenums: true 93 | line_spans: __span 94 | pygments_lang_class: true 95 | - pymdownx.inlinehilite 96 | - pymdownx.snippets: 97 | url_download: true 98 | - pymdownx.superfences 99 | - pymdownx.tabbed: 100 | alternate_style: true 101 | 102 | theme: 103 | name: "material" 104 | custom_dir: docs/_theme_overrides 105 | logo: _images/aai-logo-cropped.png 106 | favicon: _images/aai-favicon.png 107 | font: 108 | text: IBM Plex Sans # Arial replacement 109 | code: Source Code Pro 110 | icon: 111 | logo: _images/aai-favicon.png 112 | repo: fontawesome/brands/github 113 | features: 114 | - content.tabs.link 115 | - content.code.copy 116 | - content.code.annotate 117 | - content.action.edit 118 | palette: 119 | # Palette toggle for light mode 120 | - scheme: aai-light 121 | toggle: 122 | icon: material/brightness-7 123 | name: Switch to dark mode 124 | 125 | # Palette toggle for dark mode 126 | - scheme: slate 127 | toggle: 128 | icon: material/brightness-4 129 | name: Switch to light mode 130 | 131 | extra: 132 | copyright_link: https://appliedai-institute.de 133 | homepage: https://aai-institute.github.io/nnbench 134 | generator: false 135 | pre_release: !ENV [DOCS_PRERELEASE, false] 136 | version: 137 | provider: mike 138 | default: latest 139 | social: 140 | - icon: fontawesome/brands/github 141 | link: https://github.com/aai-institute/nnbench 142 | - icon: fontawesome/brands/python 143 | link: https://pypi.org/project/nnbench 144 | - icon: fontawesome/brands/linkedin 145 | link: https://www.linkedin.com/company/appliedai-institute-for-europe-ggmbh/ 146 | - icon: fontawesome/solid/section 147 | link: https://appliedai-institute.de/impressum 148 | name: Impressum / Imprint 149 | 150 | extra_css: 151 | - _styles/extra.css 152 | - _styles/theme.css 153 | - _styles/neoteroi-mkdocs.css 154 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "nnbench" 7 | description = "A small framework for benchmarking machine learning models." 8 | keywords = ["Benchmarking", "Machine Learning"] 9 | requires-python = ">=3.10" 10 | license = { text = "Apache-2.0" } 11 | authors = [{ name = "appliedAI Initiative", email = "info+oss@appliedai.de" }] 12 | maintainers = [ 13 | { name = "Nicholas Junge", email = "n.junge@appliedai-institute.de" }, 14 | { name = "Max Mynter", email = "m.mynter@appliedai-institute.de" }, 15 | { name = "Adrian Rumpold", email = "a.rumpold@appliedai-institute.de" }, 16 | ] 17 | classifiers = [ 18 | "Development Status :: 4 - Beta", 19 | "Intended Audience :: Developers", 20 | "Intended Audience :: Science/Research", 21 | "License :: OSI Approved :: Apache Software License", 22 | "Operating System :: OS Independent", 23 | "Programming Language :: Python :: 3", 24 | "Topic :: Scientific/Engineering", 25 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 26 | "Topic :: Scientific/Engineering :: Information Analysis", 27 | "Topic :: Software Development :: Libraries :: Python Modules", 28 | "Topic :: System :: Benchmark", 29 | "Topic :: Utilities", 30 | "Typing :: Typed", 31 | ] 32 | 33 | dependencies = [ 34 | "rich", 35 | "tomli >= 1.1.0 ; python_version < '3.11'", 36 | "typing-extensions; python_version < '3.11'", 37 | ] 38 | 39 | dynamic = ["readme", "version"] 40 | 41 | [project.urls] 42 | Homepage = "https://github.com/aai-institute/nnbench" 43 | Repository = "https://github.com/aai-institute/nnbench.git" 44 | Issues = "https://github.com/aai-institute/nnbench/issues" 45 | Discussions = "https://github.com/aai-institute/nnbench/discussions" 46 | 47 | [dependency-groups] 48 | dev = [ 49 | "build>=1.0.0", 50 | "fsspec", 51 | "numpy>=2.2.1", 52 | "pre-commit>=3.3.3", 53 | "psutil", 54 | "pyarrow", 55 | "pytest>=7.4.0", 56 | "pytest-memray; platform_system != 'Windows'", 57 | ] 58 | docs = [ 59 | "black", 60 | "docstring-parser", 61 | "mike", 62 | "mkdocs", 63 | "mkdocs-callouts", 64 | "mkdocs-gen-files", 65 | "mkdocs-literate-nav", 66 | "mkdocs-section-index", 67 | "mkdocs-material", 68 | "mkdocstrings[python]", 69 | "mkdocs-include-dir-to-nav", 70 | "neoteroi-mkdocs", 71 | ] 72 | 73 | [project.scripts] 74 | nnbench = "nnbench.cli:main" 75 | 76 | [tool.setuptools] 77 | package-dir = { "" = "src" } 78 | 79 | [tool.setuptools.dynamic] 80 | readme = { file = "README.md", content-type = "text/markdown" } 81 | version = { attr = "nnbench.__version__" } 82 | 83 | [tool.setuptools.packages.find] 84 | where = ["src"] 85 | 86 | [tool.setuptools.package-data] 87 | nnbench = ["py.typed"] 88 | 89 | [tool.mypy] 90 | allow_redefinition = true 91 | check_untyped_defs = true 92 | disallow_incomplete_defs = true 93 | pretty = true 94 | python_version = "3.10" 95 | strict_optional = false 96 | warn_unreachable = true 97 | 98 | [[tool.mypy.overrides]] 99 | module = ["tabulate", "yaml", "fsspec"] 100 | ignore_missing_imports = true 101 | 102 | [tool.ruff] 103 | # explicitly set src folder for isort to understand first-party imports correctly. 104 | src = ["src"] 105 | line-length = 100 106 | target-version = "py310" 107 | 108 | [tool.ruff.lint] 109 | # Enable pycodestyle errors & warnings (`E`, `W`), Pyflakes (`F`), isort (`I`), 110 | # and pyupgrade (`UP`) by default. 111 | select = ["E", "F", "I", "W", "UP"] 112 | ignore = [ 113 | # Line too long 114 | "E501", 115 | # Allow capitalized variable names 116 | "F841", 117 | ] 118 | 119 | [tool.ruff.lint.per-file-ignores] 120 | # Ignore unused imports in all `__init__.py` files 121 | "__init__.py" = ["F401"] 122 | 123 | [tool.pytest.ini_options] 124 | log_cli = true 125 | log_cli_level = "DEBUG" 126 | -------------------------------------------------------------------------------- /src/nnbench/__init__.py: -------------------------------------------------------------------------------- 1 | """A framework for organizing and running benchmark workloads on machine learning models.""" 2 | 3 | from .core import benchmark, parametrize, product 4 | from .runner import collect, run 5 | from .types import Benchmark, BenchmarkFamily, BenchmarkResult, Parameters 6 | 7 | __version__ = "0.5.0" 8 | -------------------------------------------------------------------------------- /src/nnbench/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from nnbench.cli import main 4 | 5 | if __name__ == "__main__": 6 | sys.exit(main()) 7 | -------------------------------------------------------------------------------- /src/nnbench/compare.py: -------------------------------------------------------------------------------- 1 | """Contains machinery to compare multiple benchmark results side by side.""" 2 | 3 | import operator 4 | from collections.abc import Iterable 5 | from typing import Any, Protocol 6 | 7 | from rich.console import Console 8 | from rich.table import Table 9 | 10 | from nnbench.types import BenchmarkResult 11 | from nnbench.util import collate, flatten 12 | 13 | _MISSING = "N/A" 14 | _STATUS_KEY = "Status" 15 | 16 | 17 | class Comparator(Protocol): 18 | def __call__(self, val1: Any, val2: Any) -> bool: ... 19 | 20 | 21 | class GreaterEqual(Comparator): 22 | def __call__(self, val1: Any, val2: Any) -> bool: 23 | return operator.ge(val1, val2) 24 | 25 | def __str__(self): 26 | return "x ≥ y" 27 | 28 | 29 | class LessEqual(Comparator): 30 | def __call__(self, val1: Any, val2: Any) -> bool: 31 | return operator.le(val1, val2) 32 | 33 | def __str__(self): 34 | return "x ≤ y" 35 | 36 | 37 | class AbsDiffLessEqual(Comparator): 38 | def __init__(self, thresh: float): 39 | self.thresh = thresh 40 | 41 | def __call__(self, val1: float, val2: float) -> bool: 42 | return operator.le(operator.abs(val1 - val2), self.thresh) 43 | 44 | def __str__(self): 45 | return f"|x - y| <= {self.thresh:.2f}" 46 | 47 | 48 | def make_row(result: BenchmarkResult) -> dict[str, Any]: 49 | d = dict() 50 | d["run"] = result.run 51 | for bm in result.benchmarks: 52 | # TODO: Guard against key errors from database queries 53 | name, value = bm["function"], bm["value"] 54 | d[name] = value 55 | # TODO: Check for errors 56 | return d 57 | 58 | 59 | def get_value_by_name(result: BenchmarkResult, name: str, missing: str) -> Any: 60 | """ 61 | Get the value of a metric by name from a benchmark result, or a placeholder 62 | if the metric name is not present in the result. 63 | 64 | If the name is found, but the benchmark did not complete successfully 65 | (i.e. the ``error_occurred`` value is set to ``True``), the returned value 66 | will be set to the value of the ``error_message`` field. 67 | 68 | Parameters 69 | ---------- 70 | result: BenchmarkResult 71 | The benchmark result to extract a metric value from. 72 | name: str 73 | The name of the target metric. 74 | missing: str 75 | A placeholder string to return in the event of a missing metric. 76 | 77 | Returns 78 | ------- 79 | str 80 | A string containing the metric value (or error message) formatted as rich text. 81 | """ 82 | metric_names = [b["function"] for b in result.benchmarks] 83 | if name not in metric_names: 84 | return missing 85 | 86 | res = result.benchmarks[metric_names.index(name)] 87 | if res.get("error_occurred", False): 88 | errmsg = res.get("error_message", "") 89 | return f"[red]ERROR: {errmsg} [/red]" 90 | return res.get("value", missing) 91 | 92 | 93 | class AbstractComparison: 94 | def compare2(self, name: str, val1: Any, val2: Any) -> bool: 95 | """A subroutine to compare two values of a metric ``name`` against each other.""" 96 | raise NotImplementedError 97 | 98 | def render(self) -> None: 99 | """A method to render a previously computed comparison to a stream.""" 100 | raise NotImplementedError 101 | 102 | @property 103 | def success(self) -> bool: 104 | """ 105 | Indicates whether a comparison has been succesful based on the criteria 106 | expressed by the chosen set of comparators. 107 | """ 108 | raise NotImplementedError 109 | 110 | 111 | class TabularComparison(AbstractComparison): 112 | def __init__( 113 | self, 114 | results: Iterable[BenchmarkResult], 115 | comparators: dict[str, Comparator] | None = None, 116 | placeholder: str = _MISSING, 117 | contextvals: list[str] | None = None, 118 | ): 119 | """ 120 | Initialize a tabular comparison class, rendering the result to a rich table. 121 | 122 | Parameters 123 | ---------- 124 | results: Iterable[BenchmarkResult] 125 | The benchmark results to compare. 126 | comparators: dict[str, Comparator] | None 127 | A mapping from benchmark functions to comparators, i.e. a function 128 | comparing two results and returning a boolean indicating a favourable or 129 | unfavourable comparison. 130 | placeholder: str 131 | A placeholder string to show in the event of a missing metric. 132 | contextvals: list[str] | None 133 | A list of context values to display in the comparison table. 134 | Supply nested context values via dotted syntax. 135 | """ 136 | self.placeholder = placeholder 137 | self.comparators = comparators or {} 138 | self.contextvals = contextvals or [] 139 | self.results: tuple[BenchmarkResult, ...] = tuple(collate(results)) 140 | self.data: list[dict[str, Any]] = [make_row(rec) for rec in self.results] 141 | self.metrics: list[str] = [] 142 | self._success: bool = True 143 | 144 | self.display_names: dict[str, str] = {} 145 | for res in self.results: 146 | for bm in res.benchmarks: 147 | name, func = bm["name"], bm["function"] 148 | if func not in self.display_names: 149 | self.display_names[func] = name 150 | if func not in self.metrics: 151 | self.metrics.append(func) 152 | 153 | if len(self.data) < 2: 154 | raise ValueError("must give at least two results to compare") 155 | 156 | def compare2(self, name: str, val1: Any, val2: Any) -> bool: 157 | """ 158 | Compare two values of a metric across runs. 159 | 160 | A comparison is a function taking two values and returning a boolean 161 | indicating whether val2 compares favorably (the ``True`` case) 162 | or unfavorably to val1 (the ``False`` case). 163 | 164 | This method should generally be overwritten by child classes. 165 | 166 | Parameters 167 | ---------- 168 | name: str 169 | Name of the metric to compare. 170 | val1: Any 171 | Value of the metric in the first benchmark result. 172 | val2: Any 173 | Value of the metric in the second benchmark result. 174 | 175 | Returns 176 | ------- 177 | bool 178 | The comparison between the two values. 179 | """ 180 | if any(v == self.placeholder for v in (val1, val2)): 181 | return False 182 | return self.comparators[name](val1, val2) 183 | 184 | def format_value(self, name: str, val: Any) -> str: 185 | if val == self.placeholder: 186 | return self.placeholder 187 | if name == "accuracy": 188 | return f"{val:.2%}" 189 | else: 190 | return f"{val:.2f}" 191 | 192 | def render(self) -> None: 193 | c = Console() 194 | t = Table() 195 | 196 | has_comparable_metrics = set(self.metrics) & self.comparators.keys() 197 | if has_comparable_metrics: 198 | c.print("Comparison strategy: All vs. first") 199 | c.print("Comparisons:") 200 | for k, v in self.comparators.items(): 201 | c.print(f" {k}: {v}") 202 | else: 203 | print(f"warning: no comparators found for metrics {', '.join(self.metrics)}") 204 | 205 | # TODO: Support parameter prints 206 | rows: list[list[str]] = [] 207 | columns: list[str] = ["Run Name"] + list(self.display_names.values()) 208 | columns += self.contextvals 209 | if has_comparable_metrics: 210 | columns += [_STATUS_KEY] 211 | 212 | for i, d in enumerate(self.data): 213 | row = [d["run"]] 214 | status = "" 215 | for metric in self.metrics: 216 | val = d.get(metric, self.placeholder) 217 | sval = self.format_value(metric, val) 218 | comparator = self.comparators.get(metric, None) 219 | if i and comparator is not None: 220 | cd = self.data[0] 221 | compval = cd.get(metric, self.placeholder) 222 | success = self.compare2(metric, val, compval) 223 | self._success &= success 224 | status += ":white_check_mark:" if success else ":x:" 225 | sval += " (vs. " + self.format_value(metric, compval) + ")" 226 | row += [sval] 227 | 228 | ctx = flatten(self.results[i].context) 229 | for cval in self.contextvals: 230 | row.append(ctx.get(cval, _MISSING)) 231 | 232 | if _STATUS_KEY in columns: 233 | row += [status] 234 | rows.append(row) 235 | 236 | for column in columns: 237 | t.add_column(column) 238 | for row in rows: 239 | t.add_row(*row) 240 | 241 | c.print(t) 242 | 243 | @property 244 | def success(self) -> bool: 245 | return self._success 246 | -------------------------------------------------------------------------------- /src/nnbench/config.py: -------------------------------------------------------------------------------- 1 | """Utilities for parsing a ``[tool.nnbench]`` config block out of a pyproject.toml file.""" 2 | 3 | import importlib 4 | import logging 5 | import os 6 | import sys 7 | from collections.abc import Callable 8 | from dataclasses import dataclass, field 9 | from pathlib import Path 10 | from typing import Any 11 | 12 | if sys.version_info >= (3, 11): 13 | from typing import Self 14 | 15 | import tomllib 16 | else: 17 | import tomli as tomllib 18 | from typing_extensions import Self 19 | 20 | logger = logging.getLogger("nnbench.config") 21 | 22 | DEFAULT_JSONIFIER = "nnbench.runner.jsonify" 23 | 24 | 25 | @dataclass 26 | class ContextProviderDef: 27 | """ 28 | A POD struct representing a custom context provider definition in a 29 | pyproject.toml table. 30 | """ 31 | 32 | # TODO: Extend this def to a generic typedef, reusable by context 33 | # providers, IOs, and comparisons (with a type enum). 34 | 35 | name: str 36 | """Name under which the provider should be registered by nnbench.""" 37 | classpath: str 38 | """Full path to the class or callable returning the context dict.""" 39 | arguments: dict[str, Any] = field(default_factory=dict) 40 | """ 41 | Arguments needed to instantiate the context provider class, 42 | given as key-value pairs in the table. 43 | If the class path points to a function, no arguments may be given.""" 44 | 45 | 46 | @dataclass(frozen=True) 47 | class NNBenchConfig: 48 | log_level: str 49 | """Log level to use for the ``nnbench`` module root logger.""" 50 | context: list[ContextProviderDef] 51 | """A list of context provider definitions found in pyproject.toml.""" 52 | # TODO: Move this down one level to [tool.nnbench.run] 53 | jsonifier: str | Callable[[dict[str, Any]], dict[str, Any]] = DEFAULT_JSONIFIER 54 | 55 | @classmethod 56 | def from_toml(cls, d: dict[str, Any]) -> Self: 57 | """ 58 | Returns an nnbench CLI config by processing fields obtained from 59 | parsing a [tool.nnbench] block in a pyproject.toml file. 60 | 61 | Parameters 62 | ---------- 63 | d: dict[str, Any] 64 | Mapping containing the [tool.nnbench] table contents, 65 | as obtained by ``tomllib.load()``. 66 | 67 | Returns 68 | ------- 69 | Self 70 | An nnbench config instance with the values from pyproject.toml, 71 | and defaults for values that were not set explicitly. 72 | """ 73 | log_level = d.get("log-level", "NOTSET") 74 | provider_map = d.get("context", {}) 75 | jsonifier = d.get("jsonifier", DEFAULT_JSONIFIER) 76 | context = [ContextProviderDef(**cpd) for cpd in provider_map.values()] 77 | return cls(log_level=log_level, context=context, jsonifier=jsonifier) 78 | 79 | 80 | def locate_pyproject(stop: os.PathLike[str] = Path.home()) -> os.PathLike[str] | None: 81 | """ 82 | Locate a pyproject.toml file by walking up from the current directory, 83 | and checking for file existence, stopping at ``stop`` (by default, the 84 | current user home directory). 85 | 86 | If no pyproject.toml file can be found at any level, returns None. 87 | 88 | Returns 89 | ------- 90 | os.PathLike[str] | None 91 | The path to pyproject.toml. 92 | """ 93 | cwd = Path.cwd() 94 | for p in (cwd, *cwd.parents): 95 | if (pyproject_cand := (p / "pyproject.toml")).exists(): 96 | return pyproject_cand 97 | if p == stop: 98 | break 99 | logger.debug(f"could not locate pyproject.toml in directory {cwd}") 100 | return None 101 | 102 | 103 | def parse_nnbench_config(pyproject_path: str | os.PathLike[str] | None = None) -> NNBenchConfig: 104 | """ 105 | Load an nnbench config from a given pyproject.toml file. 106 | 107 | If no path to the pyproject.toml file is given, an attempt at autodiscovery 108 | will be made. If that is unsuccessful, an empty config is returned. 109 | 110 | Parameters 111 | ---------- 112 | pyproject_path: str | os.PathLike[str] | None 113 | Path to the current project's pyproject.toml file, optional. 114 | 115 | Returns 116 | ------- 117 | NNBenchConfig 118 | The loaded config if found, or a default config. 119 | """ 120 | pyproject_path = pyproject_path or locate_pyproject() 121 | if pyproject_path is None: 122 | # pyproject.toml could not be found, so return an empty config. 123 | return NNBenchConfig.from_toml({}) 124 | 125 | with open(pyproject_path, "rb") as fp: 126 | pyproject = tomllib.load(fp) 127 | return NNBenchConfig.from_toml(pyproject.get("tool", {}).get("nnbench", {})) 128 | 129 | 130 | def import_(resource: str) -> Any: 131 | # If the current directory is not on sys.path, insert it in front. 132 | if "" not in sys.path and "." not in sys.path: 133 | sys.path.insert(0, "") 134 | 135 | # NB: This assumes that every resource is a top-level member of the 136 | # target module, and not nested in a class or other construct. 137 | modname, classname = resource.rsplit(".", 1) 138 | klass = getattr(importlib.import_module(modname), classname) 139 | return klass 140 | -------------------------------------------------------------------------------- /src/nnbench/context.py: -------------------------------------------------------------------------------- 1 | """Utilities for collecting context key-value pairs as metadata in benchmark runs.""" 2 | 3 | import logging 4 | import platform 5 | import sys 6 | from collections.abc import Callable, Iterable 7 | from typing import Any, Literal 8 | 9 | logger = logging.getLogger("nnbench.context") 10 | 11 | Context = dict[str, Any] 12 | ContextProvider = Callable[[], Context] 13 | """A function providing a dictionary of context values.""" 14 | 15 | 16 | class PythonInfo: 17 | """ 18 | A context helper returning version info for requested installed packages. 19 | 20 | If a requested package is not installed, an empty string is returned instead. 21 | 22 | Parameters 23 | ---------- 24 | packages: Iterable[str] 25 | Names of the requested packages under which they exist in the current environment. 26 | For packages installed through ``pip``, this equals the PyPI package name. 27 | """ 28 | 29 | key = "python" 30 | 31 | def __init__(self, packages: Iterable[str] = ()): 32 | self.packages = tuple(packages) 33 | 34 | def __call__(self) -> dict[str, Any]: 35 | from importlib.metadata import PackageNotFoundError, version 36 | 37 | result: dict[str, Any] = dict() 38 | 39 | result["version"] = platform.python_version() 40 | result["implementation"] = platform.python_implementation() 41 | buildno, buildtime = platform.python_build() 42 | result["buildno"] = buildno 43 | result["buildtime"] = buildtime 44 | 45 | packages: dict[str, str] = {} 46 | for pkg in self.packages: 47 | try: 48 | packages[pkg] = version(pkg) 49 | except PackageNotFoundError: 50 | packages[pkg] = "" 51 | 52 | result["packages"] = packages 53 | return {self.key: result} 54 | 55 | 56 | class GitEnvironmentInfo: 57 | """ 58 | A context helper providing the current git commit, latest tag, and upstream repository name. 59 | 60 | Parameters 61 | ---------- 62 | remote: str 63 | Remote name for which to provide info, by default ``"origin"``. 64 | """ 65 | 66 | key = "git" 67 | 68 | def __init__(self, remote: str = "origin"): 69 | self.remote = remote 70 | 71 | def __call__(self) -> dict[str, dict[str, Any]]: 72 | import subprocess 73 | 74 | def git_subprocess(args: list[str]) -> subprocess.CompletedProcess: 75 | if platform.system() == "Windows": 76 | git = "git.exe" 77 | else: 78 | git = "git" 79 | 80 | return subprocess.run( 81 | [git, *args], 82 | capture_output=True, 83 | encoding="utf-8", 84 | ) 85 | 86 | result: dict[str, Any] = { 87 | "commit": "", 88 | "provider": "", 89 | "repository": "", 90 | "tag": "", 91 | "dirty": None, 92 | } 93 | 94 | # first, check if inside a repo. 95 | p = git_subprocess(["rev-parse", "--is-inside-work-tree"]) 96 | # if not, return empty info. 97 | if p.returncode: 98 | return {"git": result} 99 | 100 | # secondly: get the current commit. 101 | p = git_subprocess(["rev-parse", "HEAD"]) 102 | if not p.returncode: 103 | result["commit"] = p.stdout.strip() 104 | 105 | # thirdly, get the latest tag, without a short commit SHA attached. 106 | p = git_subprocess(["describe", "--tags", "--abbrev=0"]) 107 | if not p.returncode: 108 | result["tag"] = p.stdout.strip() 109 | 110 | # and finally, get the remote repo name pointed to by the given remote. 111 | p = git_subprocess(["remote", "get-url", self.remote]) 112 | if not p.returncode: 113 | remotename: str = p.stdout.strip() 114 | if "@" in remotename: 115 | # it's an SSH remote. 116 | prefix, sep = "git@", ":" 117 | else: 118 | # it is HTTPS. 119 | prefix, sep = "https://", "/" 120 | 121 | remotename = remotename.removeprefix(prefix) 122 | provider, reponame = remotename.split(sep, 1) 123 | 124 | result["provider"] = provider 125 | result["repository"] = reponame.removesuffix(".git") 126 | 127 | p = git_subprocess(["status", "--porcelain"]) 128 | if not p.returncode: 129 | result["dirty"] = bool(p.stdout.strip()) 130 | 131 | return {"git": result} 132 | 133 | 134 | class CPUInfo: 135 | """ 136 | A context helper providing information about the host machine's CPU 137 | capabilities, operating system, and amount of memory. 138 | 139 | Parameters 140 | ---------- 141 | memunit: Literal["kB", "MB", "GB"] 142 | The unit to display memory size in (either "kB" for kilobytes, 143 | "MB" for Megabytes, or "GB" for Gigabytes). 144 | frequnit: Literal["kHz", "MHz", "GHz"] 145 | The unit to display CPU clock speeds in (either "kHz" for kilohertz, 146 | "MHz" for Megahertz, or "GHz" for Gigahertz). 147 | """ 148 | 149 | key = "cpu" 150 | 151 | def __init__( 152 | self, 153 | memunit: Literal["kB", "MB", "GB"] = "MB", 154 | frequnit: Literal["kHz", "MHz", "GHz"] = "MHz", 155 | ): 156 | self.memunit = memunit 157 | self.frequnit = frequnit 158 | self.conversion_table: dict[str, float] = {"k": 1e3, "M": 1e6, "G": 1e9} 159 | 160 | def __call__(self) -> dict[str, Any]: 161 | try: 162 | import psutil 163 | except ModuleNotFoundError: 164 | raise ModuleNotFoundError( 165 | f"context provider {self.__class__.__name__}() needs `psutil` installed. " 166 | f"To install, run `{sys.executable} -m pip install --upgrade psutil`." 167 | ) 168 | 169 | result: dict[str, Any] = dict() 170 | 171 | # first, the platform info. 172 | result["architecture"] = platform.machine() 173 | result["bitness"] = platform.architecture()[0] 174 | result["processor"] = platform.processor() 175 | result["system"] = platform.system() 176 | result["system-version"] = platform.release() 177 | 178 | try: 179 | # The CPU frequency is not available on some ARM devices 180 | freq_struct = psutil.cpu_freq() 181 | result["min_frequency"] = float(freq_struct.min) 182 | result["max_frequency"] = float(freq_struct.max) 183 | freq_conversion = self.conversion_table[self.frequnit[0]] 184 | # result is in MHz, so we convert to Hz and apply the conversion factor. 185 | result["frequency"] = freq_struct.current * 1e6 / freq_conversion 186 | except RuntimeError: 187 | result["frequency"] = 0.0 188 | result["min_frequency"] = 0.0 189 | result["max_frequency"] = 0.0 190 | 191 | result["frequency_unit"] = self.frequnit 192 | result["num_cpus"] = psutil.cpu_count(logical=False) 193 | result["num_logical_cpus"] = psutil.cpu_count() 194 | 195 | mem_struct = psutil.virtual_memory() 196 | mem_conversion = self.conversion_table[self.memunit[0]] 197 | # result is in bytes, so no need for base conversion. 198 | result["total_memory"] = mem_struct.total / mem_conversion 199 | result["memory_unit"] = self.memunit 200 | return {self.key: result} 201 | 202 | 203 | builtin_providers: dict[str, ContextProvider] = { 204 | "cpu": CPUInfo(), 205 | "git": GitEnvironmentInfo(), 206 | "python": PythonInfo(), 207 | } 208 | 209 | 210 | def register_context_provider( 211 | name: str, typ: type[ContextProvider] | ContextProvider, kwargs: Any 212 | ) -> None: 213 | logger.debug(f"Registering context provider {name!r}") 214 | 215 | if isinstance(typ, type): 216 | # classes can be instantiated with arguments, 217 | # while functions cannot. 218 | builtin_providers[name] = typ(**kwargs) 219 | else: 220 | builtin_providers[name] = typ 221 | -------------------------------------------------------------------------------- /src/nnbench/core.py: -------------------------------------------------------------------------------- 1 | """Data model, registration, and parametrization facilities for defining benchmarks.""" 2 | 3 | import itertools 4 | from collections.abc import Callable, Iterable 5 | from typing import Any, overload 6 | 7 | from nnbench.types import Benchmark, BenchmarkFamily, NoOp 8 | 9 | 10 | def _default_namegen(fn: Callable, **kwargs: Any) -> str: 11 | return fn.__name__ + "_" + "_".join(f"{k}={v}" for k, v in kwargs.items()) 12 | 13 | 14 | # Overloads for the ``benchmark`` decorator. 15 | # Case #1: Bare application without parentheses 16 | # @nnbench.benchmark 17 | # def foo() -> int: 18 | # return 0 19 | @overload 20 | def benchmark( 21 | func: None = None, 22 | name: str = "", 23 | setUp: Callable[..., None] = NoOp, 24 | tearDown: Callable[..., None] = NoOp, 25 | tags: tuple[str, ...] = (), 26 | ) -> Callable[[Callable], Benchmark]: ... 27 | 28 | 29 | # Case #2: Application with arguments 30 | # @nnbench.benchmark(name="My foo experiment", tags=("hello", "world")) 31 | # def foo() -> int: 32 | # return 0 33 | @overload 34 | def benchmark( 35 | func: Callable[..., Any], 36 | name: str = "", 37 | setUp: Callable[..., None] = NoOp, 38 | tearDown: Callable[..., None] = NoOp, 39 | tags: tuple[str, ...] = (), 40 | ) -> Benchmark: ... 41 | 42 | 43 | def benchmark( 44 | func: Callable[..., Any] | None = None, 45 | name: str = "", 46 | setUp: Callable[..., None] = NoOp, 47 | tearDown: Callable[..., None] = NoOp, 48 | tags: tuple[str, ...] = (), 49 | ) -> Benchmark | Callable[[Callable], Benchmark]: 50 | """ 51 | Define a benchmark from a function. 52 | 53 | The resulting benchmark can either be completely (i.e., the resulting function takes no 54 | more arguments) or incompletely parametrized. In the latter case, the remaining free 55 | parameters need to be passed in the calls to `nnbench.run()`. 56 | 57 | Parameters 58 | ---------- 59 | func: Callable[..., Any] | None 60 | The function to benchmark. This slot only exists to allow application of the decorator 61 | without parentheses, you should never fill it explicitly. 62 | name: str 63 | A display name to give to the benchmark. Useful in summaries and reports. 64 | setUp: Callable[..., None] 65 | A setup hook to run before the benchmark. 66 | tearDown: Callable[..., None] 67 | A teardown hook to run after the benchmark. 68 | tags: tuple[str, ...] 69 | Additional tags to attach for bookkeeping and selective filtering during runs. 70 | 71 | Returns 72 | ------- 73 | Benchmark | Callable[[Callable], Benchmark] 74 | The resulting benchmark (if no arguments were given), or a parametrized decorator 75 | returning the benchmark. 76 | """ 77 | 78 | def decorator(fun: Callable) -> Benchmark: 79 | return Benchmark(fun, name=name, setUp=setUp, tearDown=tearDown, tags=tags) 80 | 81 | if func is not None: 82 | return decorator(func) 83 | else: 84 | return decorator 85 | 86 | 87 | def parametrize( 88 | parameters: Iterable[dict[str, Any]], 89 | setUp: Callable[..., None] = NoOp, 90 | tearDown: Callable[..., None] = NoOp, 91 | namegen: Callable[..., str] = _default_namegen, 92 | tags: tuple[str, ...] = (), 93 | ) -> Callable[[Callable], BenchmarkFamily]: 94 | """ 95 | Define a family of benchmarks over a function with varying parameters. 96 | 97 | The resulting benchmarks can either be completely (i.e., the resulting function takes no 98 | more arguments) or incompletely parametrized. In the latter case, the remaining free 99 | parameters need to be passed in the call to `nnbench.run()`. 100 | 101 | Parameters 102 | ---------- 103 | parameters: Iterable[dict[str, Any]] 104 | The different sets of parameters defining the benchmark family. 105 | setUp: Callable[..., None] 106 | A setup hook to run before each of the benchmarks. 107 | tearDown: Callable[..., None] 108 | A teardown hook to run after each of the benchmarks. 109 | namegen: Callable[..., str] 110 | A function taking the benchmark function and given parameters that generates a unique 111 | custom name for the benchmark. The default name generated is the benchmark function's name 112 | followed by the keyword arguments in ``key=value`` format separated by underscores. 113 | tags: tuple[str, ...] 114 | Additional tags to attach for bookkeeping and selective filtering during runs. 115 | 116 | Returns 117 | ------- 118 | Callable[[Callable], BenchmarkFamily] 119 | A parametrized decorator returning the benchmark family. 120 | """ 121 | 122 | def decorator(fn: Callable) -> BenchmarkFamily: 123 | return BenchmarkFamily( 124 | fn, 125 | parameters, 126 | name=namegen, 127 | setUp=setUp, 128 | tearDown=tearDown, 129 | tags=tags, 130 | ) 131 | 132 | return decorator 133 | 134 | 135 | def product( 136 | setUp: Callable[..., None] = NoOp, 137 | tearDown: Callable[..., None] = NoOp, 138 | namegen: Callable[..., str] = _default_namegen, 139 | tags: tuple[str, ...] = (), 140 | **iterables: Iterable, 141 | ) -> Callable[[Callable], BenchmarkFamily]: 142 | """ 143 | Define a family of benchmarks over a cartesian product of one or more iterables. 144 | 145 | The resulting benchmarks can either be completely (i.e., the resulting function takes no 146 | more arguments) or incompletely parametrized. In the latter case, the remaining free 147 | parameters need to be passed in the call to `nnbench.run()`. 148 | 149 | Parameters 150 | ---------- 151 | setUp: Callable[..., None] 152 | A setup hook to run before each of the benchmarks. 153 | tearDown: Callable[..., None] 154 | A teardown hook to run after each of the benchmarks. 155 | namegen: Callable[..., str] 156 | A function taking the benchmark function and given parameters that generates a unique 157 | custom name for the benchmark. The default name generated is the benchmark function's name 158 | followed by the keyword arguments in ``key=value`` format separated by underscores. 159 | tags: tuple[str, ...] 160 | Additional tags to attach for bookkeeping and selective filtering during runs. 161 | **iterables: Iterable 162 | The iterables parametrizing the benchmarks. 163 | 164 | Returns 165 | ------- 166 | Callable[[Callable], BenchmarkFamily] 167 | A parametrized decorator returning the benchmark family. 168 | """ 169 | 170 | def decorator(fn: Callable) -> BenchmarkFamily: 171 | names, values = iterables.keys(), iterables.values() 172 | 173 | # NB: This line forces the exhaustion of all input iterables by nature of the 174 | # cartesian product (values need to be persisted to be accessed multiple times). 175 | parameters = (dict(zip(names, vals)) for vals in itertools.product(*values)) 176 | 177 | return BenchmarkFamily( 178 | fn, 179 | parameters, 180 | name=namegen, 181 | setUp=setUp, 182 | tearDown=tearDown, 183 | tags=tags, 184 | ) 185 | 186 | return decorator 187 | -------------------------------------------------------------------------------- /src/nnbench/fixtures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collect values ('fixtures') by name for benchmark runs from certain files, 3 | similarly to pytest and its ``conftest.py``. 4 | """ 5 | 6 | import inspect 7 | import os 8 | from collections.abc import Callable, Iterable 9 | from pathlib import Path 10 | from types import ModuleType 11 | from typing import Any 12 | 13 | from nnbench.types import Benchmark, Interface 14 | from nnbench.util import import_file_as_module 15 | 16 | 17 | def get_transitive_closure(mod: ModuleType, name: str) -> tuple[list[Callable], list[Interface]]: 18 | fixture = getattr(mod, name) 19 | if not callable(fixture): 20 | raise ValueError(f"fixture input {name!r} needs to be a callable") 21 | closure: list[Callable] = [fixture] 22 | interfaces: list[Interface] = [] 23 | 24 | def recursive_closure_collection(fn, _closure): 25 | _if = Interface.from_callable(fn, {}) 26 | interfaces.append(_if) 27 | # if the fixture itself takes arguments, 28 | # resolve all of them within the module. 29 | for closure_name in _if.names: 30 | _closure_obj = getattr(mod, closure_name, None) 31 | if _closure_obj is None: 32 | raise ImportError(f"fixture {name!r}: missing closure value {closure_name!r}") 33 | if not callable(_closure_obj): 34 | raise ValueError(f"input {name!r} to fixture {fn} needs to be a callable") 35 | _closure.append(_closure_obj) 36 | recursive_closure_collection(_closure_obj, _closure) 37 | 38 | recursive_closure_collection(fixture, closure) 39 | return closure, interfaces 40 | 41 | 42 | class FixtureManager: 43 | """ 44 | A lean class responsible for resolving parameter values (aka 'fixtures') 45 | of benchmarks from provider functions. 46 | 47 | To resolve a benchmark parameter (in ``FixtureManager.resolve()``), the class 48 | does the following: 49 | 50 | 1. Obtain the path to the file containing the benchmark, as 51 | the ``__file__`` attribute of the benchmark function's origin module. 52 | 53 | 2. Look for a `conf.py` file in the same directory. 54 | 55 | 3. Import the `conf.py` module, look for a function named the same as 56 | the benchmark parameter. 57 | 58 | 4. If necessary, resolve any named inputs to the function **within** 59 | the module scope. 60 | 61 | 5. If no function member is found, and the benchmark file is not in `root`, 62 | fall back to the parent directory, repeat steps 2-5, until `root` is reached. 63 | 64 | 6. If no `conf.py` contains any function matching the name, throw an 65 | error. 66 | """ 67 | 68 | def __init__(self, root: str | os.PathLike[str]) -> None: 69 | self.root = Path(root) 70 | self.cache: dict[Path, dict[str, Any]] = {} 71 | """ 72 | Cache architecture: 73 | key: directory 74 | value: key-value mapping of fixture name -> fixture value within directory. 75 | """ 76 | 77 | def collect(self, mod: ModuleType, names: Iterable[str]) -> dict[str, Any]: 78 | """ 79 | Given a module containing fixtures (contents of a ``conf.py`` file imported 80 | as a module), and a list of required fixture names (for a 81 | selected benchmark), collect values, computing transitive closures in the 82 | process (i.e., all inputs required to compute the set of fixtures). 83 | 84 | Parameters 85 | ---------- 86 | mod: ModuleType 87 | The module to import fixture values from. 88 | names: Iterable[str] 89 | Names of fixture values to compute and use in the invoking benchmark. 90 | 91 | Returns 92 | ------- 93 | dict[str, Any] 94 | The mapping of fixture names to their values. 95 | """ 96 | res: dict[str, Any] = {} 97 | for name in names: 98 | fixture_cand = getattr(mod, name, None) 99 | if fixture_cand is None: 100 | continue 101 | else: 102 | closure, interfaces = get_transitive_closure(mod, name) 103 | # easy case first, fixture without arguments - just call the function. 104 | if len(closure) == 1: 105 | (fn,) = closure 106 | res[name] = fn() 107 | else: 108 | # compute the closure in reverse to get the fixture value. 109 | # the last fixture takes no arguments, otherwise this would 110 | # be an infinite loop. 111 | idx = -1 112 | _temp_res: dict[str, Any] = {} 113 | for iface in reversed(interfaces): 114 | iface_names = iface.names 115 | # each fixture can take values in the respective closure, 116 | # but only values that have already been computed. 117 | kwargs = {k: v for k, v in _temp_res.items() if k in iface_names} 118 | fn = closure[idx] 119 | _temp_res[iface.funcname] = fn(**kwargs) 120 | idx -= 1 121 | assert name in _temp_res, f"internal error computing fixture {name!r}" 122 | res[name] = _temp_res[name] 123 | return res 124 | 125 | def resolve(self, bm: Benchmark) -> dict[str, Any]: 126 | """ 127 | Resolve fixture values for a benchmark. 128 | 129 | Fixtures will be resolved only for benchmark inputs that do not have a 130 | default value in place in the interface. 131 | 132 | Fixtures need to be functions in a ``conf.py`` module in the benchmark 133 | directory structure, and must *exactly* match input parameters by name. 134 | 135 | Parameters 136 | ---------- 137 | bm: Benchmark 138 | The benchmark to resolve fixtures for. 139 | 140 | Returns 141 | ------- 142 | dict[str, Any] 143 | The mapping of fixture values to use for the given benchmark. 144 | """ 145 | fixturevals: dict[str, Any] = {} 146 | # first, get the candidate fixture names, aka the benchmark param names. 147 | # We limit ourselves to names that do not have a default value. 148 | names = [ 149 | n 150 | for n, d in zip(bm.interface.names, bm.interface.defaults) 151 | if d is inspect.Parameter.empty 152 | ] 153 | nameset, fixtureset = set(names), set() 154 | # then, load the benchmark function's origin module, 155 | # should be fast as it's a sys.modules lookup. 156 | # Each user module loaded via spec_from_file_location *should* have its 157 | # __file__ attribute set, so inspect.getsourcefile can find it. 158 | bm_origin_module = inspect.getmodule(bm.fn) 159 | sourcefile = inspect.getsourcefile(bm_origin_module) 160 | 161 | if sourcefile is None: 162 | raise ValueError( 163 | "during fixture collection: " 164 | f"could not locate origin module for benchmark {bm.name!r}()" 165 | ) 166 | 167 | # then, look for a `conf.py` file in the benchmark file's directory, 168 | # and all parents up to "root" (i.e., the path in `nnbench run `) 169 | bm_file = Path(sourcefile) 170 | for p in bm_file.parents: 171 | conf_candidate = p / "conf.py" 172 | if conf_candidate.exists(): 173 | bm_dir_cache: dict[str, Any] | None = self.cache.get(p, None) 174 | if bm_dir_cache is None: 175 | mod = import_file_as_module(conf_candidate) 176 | # contains fixture values for the benchmark that could be resolved 177 | # on the current directory level. 178 | res = self.collect(mod, names) 179 | # hydrate the directory cache with the found fixture values. 180 | # some might be missing, so we could have to continue traversal. 181 | self.cache[p] = res 182 | fixturevals.update(res) 183 | else: 184 | # at this point, the cache entry might have other fixture values 185 | # that this benchmark may not consume, so we need to filter. 186 | fixturevals.update({k: v for k, v in bm_dir_cache.items() if k in names}) 187 | 188 | fixtureset |= set(fixturevals) 189 | 190 | if p == self.root or nameset == fixtureset: 191 | break 192 | 193 | # TODO: This should not throw an error, inline the typecheck into before benchmark 194 | # execution, then handle it there. 195 | # if fixtureset < nameset: 196 | # missing, *_ = nameset - fixtureset 197 | # raise RuntimeError(f"could not locate fixture {missing!r} for benchmark {bm.name!r}") 198 | 199 | return fixturevals 200 | -------------------------------------------------------------------------------- /src/nnbench/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/nnbench/9a4e37cd7fa99178a2dab747bbc5b153ac1cd42a/src/nnbench/py.typed -------------------------------------------------------------------------------- /src/nnbench/reporter/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | An interface for displaying, writing, or streaming benchmark results to 3 | files, databases, or web services. 4 | """ 5 | 6 | import os 7 | 8 | from nnbench.types import BenchmarkReporter 9 | 10 | from .console import ConsoleReporter 11 | from .file import FileReporter 12 | from .mlflow import MLFlowReporter 13 | from .sqlite import SQLiteReporter 14 | 15 | _known_reporters: dict[str, type[BenchmarkReporter]] = { 16 | "stdout": ConsoleReporter, 17 | "s3": FileReporter, 18 | "gs": FileReporter, 19 | "gcs": FileReporter, 20 | "az": FileReporter, 21 | "lakefs": FileReporter, 22 | "file": FileReporter, 23 | "mlflow": MLFlowReporter, 24 | "sqlite": SQLiteReporter, 25 | } 26 | 27 | 28 | def get_reporter_implementation(uri: str | os.PathLike[str]) -> BenchmarkReporter: 29 | import sys 30 | 31 | if uri is sys.stdout: 32 | proto = "stdout" 33 | else: 34 | from .util import get_protocol 35 | 36 | proto = get_protocol(uri) 37 | try: 38 | return _known_reporters[proto]() 39 | except KeyError: 40 | raise ValueError(f"no benchmark reporter registered for format {proto!r}") from None 41 | 42 | 43 | def register_reporter_implementation( 44 | name: str, klass: type[BenchmarkReporter], clobber: bool = False 45 | ) -> None: 46 | if name in _known_reporters and not clobber: 47 | raise RuntimeError( 48 | f"benchmark reporter {name!r} is already registered " 49 | f"(to force registration, rerun with clobber=True)" 50 | ) 51 | _known_reporters[name] = klass 52 | 53 | 54 | del os 55 | -------------------------------------------------------------------------------- /src/nnbench/reporter/console.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any 4 | 5 | from rich.console import Console 6 | from rich.table import Table 7 | 8 | from nnbench.types import BenchmarkReporter, BenchmarkResult 9 | 10 | _MISSING = "-----" 11 | _STDOUT = "-" 12 | 13 | 14 | def get_value_by_name(result: dict[str, Any]) -> str: 15 | if result.get("error_occurred", False): 16 | errmsg = result.get("error_message", "") 17 | return "[red]ERROR: [/red]" + errmsg 18 | return str(result.get("value", _MISSING)) 19 | 20 | 21 | class ConsoleReporter(BenchmarkReporter): 22 | """ 23 | The base interface for a console reporter class. 24 | 25 | Wraps a ``rich.Console()`` to display values in a rich-text table. 26 | """ 27 | 28 | def __init__(self, *args, **kwargs): 29 | """ 30 | Initialize a console reporter. 31 | 32 | Parameters 33 | ---------- 34 | *args: Any 35 | Positional arguments, unused. 36 | **kwargs: Any 37 | Keyword arguments, forwarded directly to ``rich.Console()``. 38 | """ 39 | super().__init__(*args, **kwargs) 40 | # TODO: Add context manager to register live console prints 41 | self.console = Console(**kwargs) 42 | 43 | def read( 44 | self, path: str | os.PathLike[str], **kwargs: Any 45 | ) -> BenchmarkResult | list[BenchmarkResult]: 46 | raise NotImplementedError 47 | 48 | def write( 49 | self, 50 | result: BenchmarkResult, 51 | path: str | os.PathLike[str] = _STDOUT, 52 | **options: Any, 53 | ) -> None: 54 | """ 55 | Display a benchmark result in the console as a rich-text table. 56 | 57 | Gives a summary of all present context values directly above the table, 58 | as a pretty-printed JSON result. 59 | 60 | By default, displays only the benchmark name, value, execution wall time, 61 | and parameters. 62 | 63 | Parameters 64 | ---------- 65 | result: BenchmarkResult 66 | The benchmark result to display. 67 | path: str | os.PathLike[str] 68 | For compatibility with the `BenchmarkReporter` protocol, unused. 69 | options: Any 70 | Display options used to format the resulting table. 71 | """ 72 | del path 73 | t = Table() 74 | 75 | rows: list[list[str]] = [] 76 | columns: list[str] = ["Benchmark", "Value", "Wall time (ns)", "Parameters"] 77 | 78 | # print context values 79 | print("Context values:") 80 | print(json.dumps(result.context, indent=4)) 81 | 82 | for bm in result.benchmarks: 83 | row = [bm["name"], get_value_by_name(bm), str(bm["time_ns"]), str(bm["parameters"])] 84 | rows.append(row) 85 | 86 | for column in columns: 87 | t.add_column(column) 88 | for row in rows: 89 | t.add_row(*row) 90 | 91 | self.console.print(t, overflow="ellipsis") 92 | -------------------------------------------------------------------------------- /src/nnbench/reporter/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import IO, Any, Literal, cast 4 | 5 | from nnbench.reporter.util import get_extension, get_protocol 6 | from nnbench.types import BenchmarkReporter, BenchmarkResult 7 | 8 | 9 | def make_file_descriptor( 10 | file: str | os.PathLike[str] | IO, 11 | mode: Literal["r", "w", "a", "x", "rb", "wb", "ab", "xb"], 12 | **open_kwargs: Any, 13 | ) -> IO: 14 | if hasattr(file, "read") or hasattr(file, "write"): 15 | return cast(IO, file) 16 | elif isinstance(file, str | os.PathLike): 17 | protocol = get_protocol(file) 18 | fd: IO 19 | if protocol == "file": 20 | fd = open(file, mode, **open_kwargs) 21 | else: 22 | try: 23 | import fsspec 24 | except ImportError: 25 | raise RuntimeError("non-local URIs require the fsspec package") 26 | fs = fsspec.filesystem(protocol) 27 | fd = fs.open(file, mode, **open_kwargs) 28 | return fd 29 | raise TypeError("filename must be a str, bytes, file or PathLike object") 30 | 31 | 32 | class FileReporter(BenchmarkReporter): 33 | @staticmethod 34 | def filter_open_kwds(kwargs: dict[str, Any]) -> dict[str, Any]: 35 | _OPEN_KWDS = ("buffering", "encoding", "errors", "newline", "closefd", "opener") 36 | return {k: v for k, v in kwargs.items() if k in _OPEN_KWDS} 37 | 38 | def from_json(self, path: str | os.PathLike[str], **kwargs: Any) -> list[BenchmarkResult]: 39 | import json 40 | 41 | newline_delimited = Path(path).suffix == ".ndjson" 42 | with make_file_descriptor(path, mode="r") as fd: 43 | if newline_delimited: 44 | benchmarks = [json.loads(line, **kwargs) for line in fd] 45 | return BenchmarkResult.from_records(benchmarks) 46 | else: 47 | benchmarks = json.load(fd, **kwargs) 48 | return [BenchmarkResult.from_json(benchmarks)] 49 | 50 | def to_json( 51 | self, 52 | result: BenchmarkResult, 53 | path: str | os.PathLike[str], 54 | **kwargs: Any, 55 | ) -> None: 56 | import json 57 | 58 | newline_delimited = Path(path).suffix == ".ndjson" 59 | with make_file_descriptor(path, mode="w") as fd: 60 | if newline_delimited: 61 | fd.write("\n".join([json.dumps(r, **kwargs) for r in result.to_records()])) 62 | else: 63 | json.dump(result.to_json(), fd, **kwargs) 64 | 65 | def from_yaml(self, path: str | os.PathLike[str], **kwargs: Any) -> list[BenchmarkResult]: 66 | import yaml 67 | 68 | del kwargs 69 | with make_file_descriptor(path, mode="r") as fd: 70 | bms = yaml.safe_load(fd) 71 | return BenchmarkResult.from_records(bms) 72 | 73 | def to_yaml(self, result: BenchmarkResult, path: str | os.PathLike[str], **kwargs: Any) -> None: 74 | import yaml 75 | 76 | with make_file_descriptor(path, mode="w") as fd: 77 | yaml.safe_dump(result.to_records(), fd, **kwargs) 78 | 79 | def from_parquet(self, path: str | os.PathLike[str], **kwargs: Any) -> list[BenchmarkResult]: 80 | import pyarrow.parquet as pq 81 | 82 | table = pq.read_table(str(path), **kwargs) 83 | benchmarks: list[dict[str, Any]] = table.to_pylist() 84 | return BenchmarkResult.from_records(benchmarks) 85 | 86 | def to_parquet( 87 | self, result: BenchmarkResult, path: str | os.PathLike[str], **kwargs: Any 88 | ) -> None: 89 | import pyarrow.parquet as pq 90 | from pyarrow import Table 91 | 92 | table = Table.from_pylist(result.to_records()) 93 | pq.write_table(table, str(path), **kwargs) 94 | 95 | def read(self, path: str | os.PathLike[str], **kwargs: Any) -> list[BenchmarkResult]: 96 | """ 97 | Reads a benchmark record from the given file path. 98 | 99 | The reading implementation is chosen based on the extension in the ``file`` path. 100 | 101 | Extensions ``json``, ``ndjson``, ``yaml``, and ``parquet`` are supported, 102 | as well as abbreviations ``.yml`` and ``.pq``. 103 | 104 | Parameters 105 | ---------- 106 | path: str | os.PathLike[str] 107 | The path name, or path-like object, to read from. 108 | **kwargs: Any | None 109 | Options to pass to the respective file IO implementation. 110 | 111 | Returns 112 | ------- 113 | list[BenchmarkResult] 114 | The benchmark results contained in the file. 115 | 116 | Raises 117 | ------ 118 | ValueError 119 | If the extension of the given filename is not supported. 120 | """ 121 | 122 | ext = get_extension(path) 123 | 124 | # TODO: Filter open keywords (in methods?) 125 | if ext in (".json", ".ndjson"): 126 | return self.from_json(path, **kwargs) 127 | elif ext in (".yml", ".yaml"): 128 | return self.from_yaml(path, **kwargs) 129 | elif ext in (".parquet", ".pq"): 130 | return self.from_parquet(path, **kwargs) 131 | else: 132 | raise ValueError(f"unsupported benchmark file format {ext!r}") 133 | 134 | def write(self, result: BenchmarkResult, path: str | os.PathLike[str], **kwargs: Any) -> None: 135 | """ 136 | Writes multiple benchmark results to the given file path. 137 | 138 | The writing is chosen based on the extension found on the ``file`` path. 139 | 140 | Extensions ``json``, ``ndjson``, ``yaml``, and ``parquet`` are supported, 141 | as well as abbreviations ``.yml`` and ``.pq``. 142 | 143 | Parameters 144 | ---------- 145 | result: BenchmarkResult 146 | The benchmark result to write to a file. 147 | path: str | os.PathLike[str] 148 | The file name, or path-like object, to write to. 149 | **kwargs: Any | None 150 | Options to pass to the respective file writer implementation. 151 | 152 | Raises 153 | ------ 154 | ValueError 155 | If the extension of the given filename is not supported. 156 | """ 157 | ext = get_extension(path) 158 | 159 | # TODO: Filter open keywords (in methods?) 160 | if ext in (".json", ".ndjson"): 161 | return self.to_json(result, path, **kwargs) 162 | elif ext in (".yml", ".yaml"): 163 | return self.to_yaml(result, path, **kwargs) 164 | elif ext in (".parquet", ".pq"): 165 | return self.to_parquet(result, path, **kwargs) 166 | else: 167 | raise ValueError(f"unsupported benchmark file format {ext!r}") 168 | -------------------------------------------------------------------------------- /src/nnbench/reporter/mlflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import ExitStack 3 | from pathlib import Path 4 | from typing import TYPE_CHECKING, Any 5 | 6 | from nnbench.types import BenchmarkReporter, BenchmarkResult 7 | 8 | if TYPE_CHECKING: 9 | from mlflow import ActiveRun as ActiveRun 10 | 11 | 12 | class MLFlowReporter(BenchmarkReporter): 13 | def __init__(self): 14 | self.stack = ExitStack() 15 | 16 | @staticmethod 17 | def strip_protocol(uri: str | os.PathLike[str]) -> str: 18 | s = str(uri) 19 | if s.startswith("mlflow://"): 20 | return s[9:] 21 | return s 22 | 23 | def get_or_create_run(self, run_name: str, nested: bool = False) -> "ActiveRun": 24 | import mlflow 25 | 26 | existing_runs = mlflow.search_runs( 27 | filter_string=f"attributes.`run_name`={run_name!r}", output_format="list" 28 | ) 29 | if existing_runs: 30 | run_id = existing_runs[0].info.run_id 31 | return mlflow.start_run(run_id=run_id, nested=nested) 32 | else: 33 | return mlflow.start_run(run_name=run_name, nested=nested) 34 | 35 | def read(self, path: str | os.PathLike[str], **kwargs: Any) -> list[BenchmarkResult]: 36 | raise NotImplementedError 37 | 38 | def write( 39 | self, 40 | result: BenchmarkResult, 41 | path: str | os.PathLike[str], 42 | **kwargs: Any, 43 | ) -> None: 44 | import mlflow 45 | 46 | uri = self.strip_protocol(path) 47 | try: 48 | experiment, run_name, *subruns = Path(uri).parts 49 | except ValueError: 50 | raise ValueError(f"expected URI of form /[/...], got {uri!r}") 51 | 52 | # setting experiment removes the need for passing the `experiment_id` kwarg 53 | # in the subsequent API calls. 54 | mlflow.set_experiment(experiment) 55 | 56 | run = self.stack.enter_context(self.get_or_create_run(run_name=run_name)) 57 | for s in subruns: 58 | # reassignment ensures that we log into the max-depth subrun specified. 59 | run = self.stack.enter_context(self.get_or_create_run(run_name=s, nested=True)) 60 | 61 | run_id = run.info.run_id 62 | timestamp = result.timestamp 63 | mlflow.log_dict(result.context, f"context-{result.run}.json", run_id=run_id) 64 | for bm in result.benchmarks: 65 | name, value = bm["name"], bm["value"] 66 | mlflow.log_metric(name, value, timestamp=timestamp, run_id=run_id) 67 | -------------------------------------------------------------------------------- /src/nnbench/reporter/sqlite.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sqlite3 3 | from typing import Any 4 | 5 | from nnbench.types import BenchmarkReporter, BenchmarkResult 6 | 7 | # TODO: Add tablename state (f-string) 8 | _DEFAULT_COLS = ("run", "benchmark", "context", "timestamp") 9 | _DEFAULT_CREATION_QUERY = "CREATE TABLE IF NOT EXISTS nnbench(" + ", ".join(_DEFAULT_COLS) + ")" 10 | _DEFAULT_INSERT_QUERY = "INSERT INTO nnbench VALUES(:run, :benchmark, :context, :timestamp)" 11 | _DEFAULT_READ_QUERY = """SELECT * FROM nnbench""" 12 | 13 | 14 | class SQLiteReporter(BenchmarkReporter): 15 | @staticmethod 16 | def strip_protocol(uri: str | os.PathLike[str]) -> str: 17 | s = str(uri) 18 | if s.startswith("sqlite://"): 19 | return s[9:] 20 | return s 21 | 22 | def read( 23 | self, 24 | path: str | os.PathLike[str], 25 | query: str = _DEFAULT_READ_QUERY, 26 | **kwargs: Any, 27 | ) -> list[BenchmarkResult]: 28 | path = self.strip_protocol(path) 29 | # query: str | None = options.pop("query", _DEFAULT_READ_QUERY) 30 | if query is None: 31 | raise ValueError(f"need a query to read from SQLite database {path!r}") 32 | 33 | db = f"file:{path}?mode=ro" # open DB in read-only mode 34 | conn = sqlite3.connect(db, uri=True) 35 | conn.row_factory = sqlite3.Row 36 | cursor = conn.cursor() 37 | cursor.execute(query) 38 | records = [dict(r) for r in cursor.fetchall()] 39 | conn.close() 40 | return BenchmarkResult.from_records(records) 41 | 42 | def write( 43 | self, 44 | result: BenchmarkResult, 45 | path: str | os.PathLike[str], 46 | query: str = _DEFAULT_INSERT_QUERY, 47 | **kwargs: Any, 48 | ) -> None: 49 | path = self.strip_protocol(path) 50 | 51 | conn = sqlite3.connect(path) 52 | cursor = conn.cursor() 53 | 54 | # TODO: Guard by exists_ok state 55 | cursor.execute(_DEFAULT_CREATION_QUERY) 56 | cursor.executemany(query, result.to_records()) 57 | conn.commit() 58 | -------------------------------------------------------------------------------- /src/nnbench/reporter/util.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import re 4 | from pathlib import Path 5 | from typing import IO, Any 6 | 7 | 8 | def nullcols(_benchmarks: list[dict[str, Any]]) -> set[str]: 9 | """ 10 | Extracts columns that only contain false-ish data from a list of benchmarks. 11 | 12 | Since this data is most often not interesting, the result of this 13 | can be used to filter out these columns from the benchmark dictionaries. 14 | 15 | Parameters 16 | ---------- 17 | _benchmarks: list[dict[str, Any]] 18 | The benchmarks to filter. 19 | 20 | Returns 21 | ------- 22 | set[str] 23 | Set of the columns (key names) that only contain false-ish values 24 | across all benchmarks. 25 | """ 26 | nulls: dict[str, bool] = collections.defaultdict(bool) 27 | for bm in _benchmarks: 28 | for k, v in bm.items(): 29 | nulls[k] = nulls[k] or bool(v) 30 | return set(k for k, v in nulls.items() if not v) 31 | 32 | 33 | def get_protocol(url: str | os.PathLike[str]) -> str: 34 | url = str(url) 35 | parts = re.split(r"(::|://)", url, maxsplit=1) 36 | if len(parts) > 1: 37 | return parts[0] 38 | return "file" 39 | 40 | 41 | def get_extension(f: str | os.PathLike[str] | IO) -> str: 42 | """ 43 | Given a path or file-like object, returns file extension 44 | (can be the empty string, if the file has no extension). 45 | """ 46 | if isinstance(f, str | os.PathLike): 47 | return Path(f).suffix 48 | else: 49 | return Path(f.name).suffix 50 | -------------------------------------------------------------------------------- /src/nnbench/runner.py: -------------------------------------------------------------------------------- 1 | """The abstract benchmark runner interface, which can be overridden for custom benchmark workloads.""" 2 | 3 | import collections 4 | import inspect 5 | import logging 6 | import os 7 | import platform 8 | import sys 9 | import time 10 | import uuid 11 | from collections.abc import Callable, Iterable 12 | from dataclasses import asdict 13 | from pathlib import Path 14 | from typing import Any 15 | 16 | from nnbench.context import Context, ContextProvider 17 | from nnbench.fixtures import FixtureManager 18 | from nnbench.types import Benchmark, BenchmarkFamily, BenchmarkResult, Parameters, State 19 | from nnbench.util import ( 20 | all_python_files, 21 | collate, 22 | exists_module, 23 | import_file_as_module, 24 | qualname, 25 | timer, 26 | ) 27 | 28 | Benchmarkable = Benchmark | BenchmarkFamily 29 | 30 | logger = logging.getLogger("nnbench.runner") 31 | 32 | 33 | def jsonify( 34 | params: dict[str, Any], repr_hooks: dict[type, Callable] | None = None 35 | ) -> dict[str, Any]: 36 | """ 37 | Construct a JSON representation of benchmark parameters. 38 | 39 | This is necessary to break reference cycles from the parameters to the results, 40 | which prevent garbage collection of memory-intensive values. 41 | 42 | Parameters 43 | ---------- 44 | params: dict[str, Any] 45 | Benchmark parameters to compute a JSON representation of. 46 | repr_hooks: dict[type, Callable] | None 47 | A dictionary mapping parameter types to functions returning a JSON representation 48 | of an instance of the type. Allows fine-grained control to achieve lossless, 49 | reproducible serialization of input parameter information. 50 | 51 | Returns 52 | ------- 53 | dict[str, Any] 54 | A JSON-serializable representation of the benchmark input parameters. 55 | """ 56 | repr_hooks = repr_hooks or {} 57 | natives = (float, int, str, bool, bytes, complex) 58 | json_params: dict[str, Any] = {} 59 | 60 | def _jsonify(val): 61 | vtype = type(val) 62 | if vtype in repr_hooks: 63 | return repr_hooks[vtype](val) 64 | if isinstance(val, natives): 65 | return val 66 | elif hasattr(val, "to_json"): 67 | try: 68 | return val.to_json() 69 | except TypeError: 70 | # if to_json() needs arguments, we're SOL. 71 | pass 72 | 73 | return repr(val) 74 | 75 | for k, v in params.items(): 76 | if isinstance(v, tuple | list | set | frozenset): 77 | container_type = type(v) 78 | json_params[k] = container_type(map(_jsonify, v)) 79 | elif isinstance(v, dict): 80 | json_params[k] = jsonify(v) 81 | else: 82 | json_params[k] = _jsonify(v) 83 | return json_params 84 | 85 | 86 | def collect( 87 | path_or_module: str | os.PathLike[str], tags: tuple[str, ...] = () 88 | ) -> list[Benchmark | BenchmarkFamily]: 89 | # TODO: functools.cache this guy 90 | """ 91 | Discover benchmarks in a module or source file. 92 | 93 | Parameters 94 | ---------- 95 | path_or_module: str | os.PathLike[str] 96 | Name or path of the module to discover benchmarks in. Can also be a directory, 97 | in which case benchmarks are collected from the Python files therein. 98 | tags: tuple[str, ...] 99 | Tags to filter for when collecting benchmarks. Only benchmarks containing either of 100 | these tags are collected. 101 | 102 | Raises 103 | ------ 104 | ValueError 105 | If the given path is not a Python file, directory, or module name. 106 | """ 107 | benchmarks: list[Benchmark] = [] 108 | ppath = Path(path_or_module) 109 | if ppath.is_dir(): 110 | pythonpaths = all_python_files(ppath) 111 | for py in pythonpaths: 112 | benchmarks.extend(collect(py, tags)) 113 | return benchmarks 114 | elif ppath.is_file(): 115 | logger.debug(f"Collecting benchmarks from file {ppath}.") 116 | module = import_file_as_module(path_or_module) 117 | elif exists_module(path_or_module): 118 | module = sys.modules[str(path_or_module)] 119 | else: 120 | raise ValueError( 121 | f"expected a module name, Python file, or directory, got {str(path_or_module)!r}" 122 | ) 123 | 124 | # iterate through the module dict members to register 125 | for k, v in module.__dict__.items(): 126 | if k.startswith("__") and k.endswith("__"): 127 | # dunder names are ignored. 128 | continue 129 | elif isinstance(v, Benchmarkable): 130 | if not tags or set(tags) & set(v.tags): 131 | benchmarks.append(v) 132 | elif isinstance(v, list | tuple | set | frozenset): 133 | for bm in v: 134 | if isinstance(bm, Benchmarkable): 135 | if not tags or set(tags) & set(bm.tags): 136 | benchmarks.append(bm) 137 | return benchmarks 138 | 139 | 140 | def run( 141 | benchmarks: Benchmark | BenchmarkFamily | Iterable[Benchmark | BenchmarkFamily], 142 | name: str | None = None, 143 | params: dict[str, Any] | Parameters | None = None, 144 | context: Context | Iterable[ContextProvider] = (), 145 | jsonifier: Callable[[dict[str, Any]], dict[str, Any]] = jsonify, 146 | ) -> BenchmarkResult: 147 | """ 148 | Run a previously collected benchmark workload. 149 | 150 | Parameters 151 | ---------- 152 | benchmarks: Benchmark | BenchmarkFamily | Iterable[Benchmark | BenchmarkFamily] 153 | A benchmark, family of benchmarks, or collection of discovered benchmarks to run. 154 | name: str | None 155 | A name for the currently started run. If None, a name will be automatically generated. 156 | params: dict[str, Any] | Parameters | None 157 | Parameters to use for the benchmark run. Names have to match positional and keyword 158 | argument names of the benchmark functions. 159 | context: Iterable[ContextProvider] 160 | Additional context to log with the benchmarks in the output JSON result. Useful for 161 | obtaining environment information and configuration, like CPU/GPU hardware info, 162 | ML model metadata, and more. 163 | jsonifier: Callable[[dict[str, Any], dict[str, Any]]] 164 | A function constructing a string representation from the input parameters. 165 | Defaults to ``nnbench.runner.jsonify_params()``. Must produce a dictionary containing 166 | only JSON-serializable values. 167 | 168 | Returns 169 | ------- 170 | BenchmarkResult 171 | A JSON output representing the benchmark results. Has three top-level keys, 172 | "name" giving the benchmark run name, "context" holding the context information, 173 | and "benchmarks", holding an array with the benchmark results. 174 | """ 175 | 176 | _run = name or "nnbench-" + platform.node() + "-" + uuid.uuid1().hex[:8] 177 | 178 | family_sizes: dict[str, Any] = collections.defaultdict(int) 179 | family_indices: dict[str, Any] = collections.defaultdict(int) 180 | 181 | if isinstance(context, dict): 182 | ctx = context 183 | else: 184 | ctx = dict() 185 | for provider in context: 186 | val = provider() 187 | duplicates = set(ctx.keys()) & set(val.keys()) 188 | if duplicates: 189 | dupe, *_ = duplicates 190 | raise ValueError(f"got multiple values for context key {dupe!r}") 191 | ctx.update(val) 192 | 193 | if isinstance(benchmarks, Benchmarkable): 194 | benchmarks = [benchmarks] 195 | 196 | if isinstance(params, Parameters): 197 | dparams = asdict(params) 198 | else: 199 | dparams = params or {} 200 | 201 | results: list[dict[str, Any]] = [] 202 | timestamp = int(time.time()) 203 | 204 | for benchmark in collate(benchmarks): 205 | bm_family = benchmark.interface.funcname 206 | state = State( 207 | name=benchmark.name, 208 | family=bm_family, 209 | family_size=family_sizes[bm_family], 210 | family_index=family_indices[bm_family], 211 | ) 212 | family_indices[bm_family] += 1 213 | 214 | # Assemble benchmark parameters. First grab all defaults from the interface, 215 | bmparams = { 216 | name: val 217 | for name, _, val in benchmark.interface.variables 218 | if val is not inspect.Parameter.empty 219 | } 220 | # ... then hydrate with the appropriate subset of input parameters. 221 | bmparams |= {k: v for k, v in dparams.items() if k in benchmark.interface.names} 222 | # If any arguments are still unresolved, go look them up as fixtures. 223 | if set(bmparams) < set(benchmark.interface.names): 224 | # TODO: This breaks for a module name (like __main__). 225 | # Since that only means that we cannot resolve fixtures when benchmarking 226 | # a module name (which makes sense), and we can always pass extra 227 | # parameters in the module case, fixing this is not as urgent. 228 | mod = benchmark.__module__ 229 | file = sys.modules[mod].__file__ 230 | p = Path(file).parent 231 | fm = FixtureManager(p) 232 | bmparams |= fm.resolve(benchmark) 233 | 234 | res: dict[str, Any] = { 235 | "name": benchmark.name, 236 | "function": qualname(benchmark.fn), 237 | "description": benchmark.fn.__doc__ or "", 238 | "error_occurred": False, 239 | "error_message": "", 240 | "parameters": jsonifier(bmparams), 241 | } 242 | try: 243 | benchmark.setUp(state, bmparams) 244 | with timer(res): 245 | res["value"] = benchmark.fn(**bmparams) 246 | except Exception as e: 247 | res["error_occurred"] = True 248 | res["error_message"] = str(e) 249 | finally: 250 | benchmark.tearDown(state, bmparams) 251 | results.append(res) 252 | 253 | return BenchmarkResult( 254 | run=_run, 255 | context=ctx, 256 | benchmarks=results, 257 | timestamp=timestamp, 258 | ) 259 | -------------------------------------------------------------------------------- /src/nnbench/util.py: -------------------------------------------------------------------------------- 1 | """Various utilities related to benchmark collection, filtering, and more.""" 2 | 3 | import contextlib 4 | import importlib.util 5 | import itertools 6 | import os 7 | import sys 8 | import time 9 | from collections.abc import Callable, Generator, Iterable 10 | from importlib.machinery import ModuleSpec 11 | from pathlib import Path 12 | from types import ModuleType 13 | from typing import Any, TypeVar 14 | 15 | T = TypeVar("T") 16 | 17 | 18 | def collate(_its: Iterable[T | Iterable[T]]) -> Generator[T, None, None]: 19 | for _it in _its: 20 | if isinstance(_it, Iterable): 21 | yield from _it 22 | else: 23 | yield _it 24 | 25 | 26 | def flatten(d: dict[str, Any], sep: str = ".", prefix: str = "") -> dict: 27 | """ 28 | Given a nested dictionary and a separator, returns another dictionary 29 | of depth 1, containing values under nested keys joined by the separator. 30 | 31 | Parameters 32 | ---------- 33 | d: dict[str, Any] 34 | A dictionary to be flattened. All nested dictionaries must contain 35 | string keys only. 36 | sep: str 37 | The separator string to join keys on. 38 | prefix: str 39 | A prefix to apply to keys when calling ``flatten()`` recursively. 40 | You shouldn't need to pass this yourself. 41 | 42 | Returns 43 | ------- 44 | dict[str, Any] 45 | The flattened dictionary. 46 | 47 | Examples 48 | -------- 49 | >>> flatten({"a": 1, "b": {"c": 2}}) 50 | {"a": 1, "b.c": 2} 51 | """ 52 | d_flat = {} 53 | for k, v in d.items(): 54 | new_key = prefix + sep + k if prefix else k 55 | if isinstance(v, dict): 56 | d_flat.update(flatten(v, sep=sep, prefix=new_key)) 57 | else: 58 | d_flat[new_key] = v 59 | return d_flat 60 | 61 | 62 | def unflatten(d: dict[str, Any], sep: str = ".") -> dict[str, Any]: 63 | """ 64 | Unflatten a previously flattened dictionary. 65 | 66 | Any key that does not contain the separator is passed through unchanged. 67 | 68 | This is, as the name suggests, the inverse operation to ``nnbench.util.flatten()``. 69 | 70 | Parameters 71 | ---------- 72 | d: dict[str, Any] 73 | The dictionary to unflatten. 74 | sep: str 75 | The separator to split keys on, introducing dictionary nesting. 76 | 77 | Returns 78 | ------- 79 | dict[str, Any] 80 | 81 | Examples 82 | -------- 83 | >>> unflatten({"a": 1, "b.c": 2}) 84 | {"a": 1, "b": {"c": 2}} 85 | 86 | >>> d = {"a": 1, "b": {"c": 2}} 87 | >>> unflatten(flatten(d)) == d 88 | True 89 | """ 90 | sorted_keys = sorted(d.keys()) 91 | unflattened = {} 92 | for prefix, keys in itertools.groupby(sorted_keys, key=lambda key: key.split(sep, 1)[0]): 93 | key_group = list(keys) 94 | if len(key_group) == 1 and sep not in key_group[0]: 95 | unflattened[prefix] = d[prefix] 96 | else: 97 | nested_dict = {key.split(sep, 1)[1]: d[key] for key in key_group} 98 | unflattened[prefix] = unflatten(nested_dict, sep=sep) 99 | return unflattened 100 | 101 | 102 | def exists_module(name: str | os.PathLike[str]) -> bool: 103 | """Checks if the current interpreter has an available Python module named `name`.""" 104 | name = str(name) 105 | if name in sys.modules: 106 | return True 107 | 108 | root, *parts = name.split(".") 109 | 110 | for part in parts: 111 | spec = importlib.util.find_spec(root) 112 | if spec is None: 113 | return False 114 | root += f".{part}" 115 | 116 | return importlib.util.find_spec(name) is not None 117 | 118 | 119 | def modulename(file: str | os.PathLike[str]) -> str: 120 | """ 121 | Convert a file name to its corresponding Python module name. 122 | 123 | Examples 124 | -------- 125 | >>> modulename("path/to/my/file.py") 126 | "path.to.my.module" 127 | """ 128 | fpath = Path(file).with_suffix("") 129 | if len(fpath.parts) == 1: 130 | return str(fpath) 131 | 132 | filename = fpath.as_posix() 133 | return filename.replace("/", ".") 134 | 135 | 136 | def import_file_as_module(file: str | os.PathLike[str]) -> ModuleType: 137 | """ 138 | Import a Python file as a module using importlib. 139 | 140 | Raises an error if the given path is not a Python file, or if the 141 | module spec could not be constructed. 142 | 143 | Parameters 144 | ---------- 145 | file: str | os.PathLike[str] 146 | The file to import as a Python module. 147 | 148 | Returns 149 | ------- 150 | ModuleType 151 | The imported module, with its file location set as ``__file__``. 152 | 153 | """ 154 | fpath = Path(file) 155 | if not fpath.is_file() or fpath.suffix != ".py": 156 | raise ValueError(f"path {str(file)!r} is not a Python file") 157 | 158 | # TODO: Recomputing this map in a loop can be expensive if many modules are loaded. 159 | modmap = {m.__file__: m for m in sys.modules.values() if getattr(m, "__file__", None)} 160 | spath = str(fpath) 161 | if spath in modmap: 162 | # if the module under "file" has already been loaded, return it, 163 | # otherwise we get nasty type errors in collection. 164 | return modmap[spath] 165 | 166 | modname = modulename(fpath) 167 | if modname in sys.modules: 168 | # return already loaded module 169 | return sys.modules[modname] 170 | 171 | spec: ModuleSpec | None = importlib.util.spec_from_file_location(modname, fpath) 172 | if spec is None: 173 | raise RuntimeError(f"could not import module {fpath}") 174 | 175 | module = importlib.util.module_from_spec(spec) 176 | sys.modules[modname] = module 177 | spec.loader.exec_module(module) 178 | return module 179 | 180 | 181 | def all_python_files(_dir: str | os.PathLike[str]) -> Generator[Path, None, None]: 182 | if sys.version_info >= (3, 12): 183 | pathgen = Path(_dir).walk(top_down=True) 184 | else: 185 | pathgen = os.walk(_dir, topdown=True) 186 | 187 | for root, dirs, files in pathgen: 188 | proot = Path(root) 189 | # do not descend into potentially large __pycache__ dirs. 190 | if "__pycache__" in dirs: 191 | dirs.remove("__pycache__") 192 | for file in files: 193 | fp = proot / file 194 | if fp.suffix == ".py": 195 | yield fp 196 | 197 | 198 | def qualname(fn: Callable) -> str: 199 | if fn.__name__ == fn.__qualname__: 200 | return fn.__name__ 201 | return f"{fn.__qualname__}.{fn.__name__}" 202 | 203 | 204 | @contextlib.contextmanager 205 | def timer(bm: dict[str, Any]) -> Generator[None, None, None]: 206 | start = time.perf_counter_ns() 207 | try: 208 | yield 209 | finally: 210 | end = time.perf_counter_ns() 211 | bm["time_ns"] = end - start 212 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/nnbench/9a4e37cd7fa99178a2dab747bbc5b153ac1cd42a/tests/__init__.py -------------------------------------------------------------------------------- /tests/benchmarks/argchecks.py: -------------------------------------------------------------------------------- 1 | import nnbench 2 | 3 | 4 | @nnbench.benchmark(tags=("with_default",)) 5 | def add_two(a: int = 0) -> int: 6 | return a + 2 7 | 8 | 9 | @nnbench.benchmark(tags=("duplicate",)) 10 | def triple_int(x: int) -> int: 11 | return x * 3 12 | 13 | 14 | @nnbench.benchmark(tags=("duplicate",)) 15 | def triple_str(x: str) -> str: 16 | return x * 3 17 | 18 | 19 | @nnbench.benchmark(tags=("untyped",)) 20 | def increment(value): 21 | return value + 1 22 | -------------------------------------------------------------------------------- /tests/benchmarks/standard.py: -------------------------------------------------------------------------------- 1 | import nnbench 2 | 3 | 4 | @nnbench.benchmark(tags=("standard", "runner-collect")) 5 | def double(x: int) -> int: 6 | return x * 2 7 | 8 | 9 | @nnbench.benchmark(tags=("standard",)) 10 | def triple(y: int) -> int: 11 | return y * 3 12 | 13 | 14 | @nnbench.benchmark(tags=("standard",)) 15 | def prod(x: int, y: int) -> int: 16 | return x * y 17 | -------------------------------------------------------------------------------- /tests/benchmarks/tags.py: -------------------------------------------------------------------------------- 1 | import nnbench 2 | 3 | 4 | @nnbench.benchmark(tags=("tag1",)) 5 | def subtract(a: int, b: int) -> int: 6 | return a - b 7 | 8 | 9 | @nnbench.benchmark(tags=("tag1", "tag2")) 10 | def decrement(a: int) -> int: 11 | return a - 1 12 | 13 | 14 | @nnbench.benchmark() 15 | def identity(x: int) -> int: 16 | return x 17 | -------------------------------------------------------------------------------- /tests/cli/__init__.py: -------------------------------------------------------------------------------- 1 | DELAY_SECONDS = 10 2 | -------------------------------------------------------------------------------- /tests/cli/benchmarks/a.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from tests.cli import DELAY_SECONDS 4 | 5 | import nnbench 6 | 7 | 8 | @nnbench.benchmark 9 | def add(a: int, b: int) -> int: 10 | time.sleep(DELAY_SECONDS) 11 | return a + b 12 | -------------------------------------------------------------------------------- /tests/cli/benchmarks/b.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from tests.cli import DELAY_SECONDS 4 | 5 | import nnbench 6 | 7 | 8 | @nnbench.benchmark 9 | def mul(a: int, b: int) -> int: 10 | time.sleep(DELAY_SECONDS) 11 | return a * b 12 | -------------------------------------------------------------------------------- /tests/cli/conf.py: -------------------------------------------------------------------------------- 1 | def a() -> int: 2 | return 1 3 | 4 | 5 | def b() -> int: 6 | return 2 7 | -------------------------------------------------------------------------------- /tests/cli/test_parallel_exec.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | from nnbench.cli import main 5 | from tests.cli import DELAY_SECONDS 6 | 7 | 8 | def test_parallel_execution_for_slow_benchmarks(): 9 | """ 10 | Verifies that benchmarks with long execution time finish faster 11 | when using process parallelism. 12 | All discovered benchmarks contain a time.sleep(DELAY_SECONDS) call, 13 | and are just a simple add/mul of two integers, so we expect the 14 | execution time to be slightly above DELAY_SECONDS. 15 | """ 16 | n_jobs = 2 17 | bm_path = Path(__file__).parent / "benchmarks" 18 | start = time.time() 19 | args = ["run", f"{bm_path}", "-j2"] 20 | rc = main(args) 21 | end = time.time() 22 | assert rc == 0, f"running nnbench {' '.join(args)} failed with exit code {rc}" 23 | assert end - start < n_jobs * DELAY_SECONDS 24 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | HERE = Path(__file__).parent 7 | 8 | logger = logging.getLogger("nnbench") 9 | logger.setLevel(logging.DEBUG) 10 | 11 | 12 | @pytest.fixture(scope="session") 13 | def testfolder() -> str: 14 | """A test directory for benchmark collection.""" 15 | return str(HERE / "benchmarks") 16 | 17 | 18 | @pytest.fixture 19 | def local_file(tmp_path: Path) -> Path: 20 | file_path = tmp_path / "test_file.txt" 21 | file_path.write_text("Test content") 22 | return file_path 23 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aai-institute/nnbench/9a4e37cd7fa99178a2dab747bbc5b153ac1cd42a/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/test_benchmark_family_memory_consumption.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import nnbench 5 | 6 | N = 1024 7 | NUM_REPLICAS = 3 8 | NP_MATSIZE_BYTES = N**2 * np.float64().itemsize 9 | # for a matmul, we alloc LHS, RHS, and RESULT. 10 | # math.ceil gives us a cushion of up to 1MiB, which is for other system allocs. 11 | EXPECTED_MEM_USAGE_MB = 3 * NP_MATSIZE_BYTES / 1048576 + 0.5 12 | 13 | 14 | @pytest.mark.limit_memory(f"{EXPECTED_MEM_USAGE_MB}MB") 15 | def test_parametrize_memory_consumption(): 16 | """ 17 | Checks that a benchmark family works with GC in the parametrization case, 18 | and produces the theoretically optimal memory usage pattern for a matmul. 19 | 20 | Note: We do not have a similar "best case" memory guarantee for @nnbench.product, 21 | because the evaluation of the cartesian product via `itertools.product()` forces 22 | the eager exhaustion of all iterables (generators can be used only once). 23 | """ 24 | 25 | @nnbench.parametrize({"b": np.zeros((N, N), dtype=np.int64)} for _ in range(NUM_REPLICAS)) 26 | def matmul(a: np.ndarray, b: np.ndarray) -> np.float64: 27 | return np.dot(a, b).sum() 28 | 29 | a = np.zeros((N, N), dtype=np.int64) 30 | nnbench.run(matmul, params={"a": a}) 31 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | from nnbench.config import NNBenchConfig, parse_nnbench_config 8 | 9 | empty = NNBenchConfig.from_toml({}) 10 | 11 | test_toml = """ 12 | [tool.nnbench] 13 | log-level = "DEBUG" 14 | 15 | [tool.nnbench.context.myctx] 16 | name = "myctx" 17 | classpath = "nnbench.context.PythonInfo" 18 | arguments = { packages = ["rich", "pyyaml"] } 19 | """ 20 | 21 | test_toml_with_unknown_key = ( 22 | test_toml 23 | + """ 24 | 25 | [tool.nnbench.what] 26 | hello = "world" 27 | """ 28 | ) 29 | 30 | 31 | def test_config_load_and_parse(tmp_path: Path) -> None: 32 | tmp_pyproject = tmp_path / "pyproject.toml" 33 | tmp_pyproject.write_text(test_toml) 34 | 35 | cfg = parse_nnbench_config(tmp_pyproject) 36 | assert cfg.log_level == "DEBUG" 37 | assert len(cfg.context) == 1 38 | assert cfg.context[0].name == "myctx" 39 | 40 | 41 | def test_config_load_with_unknown_key(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: 42 | tmp_pyproject = tmp_path / "pyproject.toml" 43 | tmp_pyproject.write_text(test_toml_with_unknown_key) 44 | 45 | # if this doesn't crash, we know that the unknown key does not make it into the config. 46 | cfg = parse_nnbench_config(tmp_pyproject) 47 | assert cfg != empty 48 | 49 | # autodiscovery with no config available should fail. 50 | with caplog.at_level(logging.DEBUG): 51 | tmp_pyproject.unlink() 52 | os.chdir(tmp_path) 53 | cfg = parse_nnbench_config() 54 | assert cfg == empty 55 | assert "could not locate pyproject.toml" in caplog.text 56 | -------------------------------------------------------------------------------- /tests/test_context.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from nnbench.context import CPUInfo, GitEnvironmentInfo, PythonInfo 5 | 6 | 7 | def test_cpu_info_provider() -> None: 8 | """Tests CPU info integrity, along with some assumptions about metrics.""" 9 | 10 | c = CPUInfo() 11 | ctx = c()["cpu"] 12 | for k in [ 13 | "architecture", 14 | "system", 15 | "frequency", 16 | "min_frequency", 17 | "max_frequency", 18 | "frequency_unit", 19 | "memory_unit", 20 | ]: 21 | assert k in ctx 22 | 23 | assert isinstance(ctx["frequency"], float) 24 | assert isinstance(ctx["min_frequency"], float) 25 | assert isinstance(ctx["max_frequency"], float) 26 | assert isinstance(ctx["total_memory"], float) 27 | 28 | 29 | def test_git_info_provider() -> None: 30 | """Tests git provider value integrity, along with some data sanity checks.""" 31 | g = GitEnvironmentInfo() 32 | # git info needs to be collected inside the nnbench repo, otherwise we get no values. 33 | os.chdir(Path(__file__).parent) 34 | ctx = g()["git"] 35 | 36 | # tag is not checked, because that can be empty (e.g. in a shallow repo clone). 37 | for k in ["provider", "repository", "commit"]: 38 | assert k in ctx 39 | assert ctx[k] != "", f"empty value for context {k!r}" 40 | 41 | assert ctx["repository"].split("/")[1] == "nnbench" 42 | assert ctx["provider"] == "github.com" 43 | 44 | 45 | def test_python_info_provider() -> None: 46 | """Tests Python info, along with an example of Python package version scraping.""" 47 | packages = ["rich", "pytest"] 48 | p = PythonInfo(packages=packages) 49 | ctx = p()["python"] 50 | 51 | for k in ["version", "implementation", "packages"]: 52 | assert k in ctx 53 | 54 | assert list(ctx["packages"].keys()) == packages 55 | for v in ctx["packages"].values(): 56 | assert v != "" 57 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | import nnbench 2 | import nnbench.types 3 | from nnbench import benchmark, parametrize, product 4 | 5 | from .test_utils import has_expected_args 6 | 7 | 8 | def test_benchmark_no_args(): 9 | @benchmark 10 | def sample_benchmark() -> str: 11 | return "test" 12 | 13 | assert isinstance(sample_benchmark, nnbench.types.Benchmark) 14 | 15 | 16 | def test_benchmark_with_args(): 17 | @benchmark(name="Test Name", tags=("tag1", "tag2")) 18 | def another_benchmark() -> str: 19 | return "test" 20 | 21 | assert another_benchmark.name == "Test Name" 22 | assert another_benchmark.tags == ("tag1", "tag2") 23 | 24 | 25 | def test_parametrize(): 26 | @parametrize([{"param": 1}, {"param": 2}]) 27 | def parametrized_benchmark(param: int) -> int: 28 | return param 29 | 30 | all_benchmarks = list(parametrized_benchmark) 31 | assert len(all_benchmarks) == 2 32 | assert has_expected_args(all_benchmarks[0].fn, {"param": 1}) 33 | assert all_benchmarks[0].fn(**all_benchmarks[0].params) == 1 34 | assert has_expected_args(all_benchmarks[1].fn, {"param": 2}) 35 | assert all_benchmarks[1].fn(**all_benchmarks[1].params) == 2 36 | 37 | 38 | def test_product(): 39 | @product(iter1=[1, 2], iter2=["a", "b"]) 40 | def product_benchmark(iter1: int, iter2: str) -> tuple[int, str]: 41 | return iter1, iter2 42 | 43 | all_benchmarks = list(product_benchmark) 44 | assert len(all_benchmarks) == 4 45 | assert has_expected_args(all_benchmarks[0].fn, {"iter1": 1, "iter2": "a"}) 46 | assert all_benchmarks[0].fn(**all_benchmarks[0].params) == (1, "a") 47 | assert has_expected_args(all_benchmarks[1].fn, {"iter1": 1, "iter2": "b"}) 48 | assert all_benchmarks[1].fn(**all_benchmarks[1].params) == (1, "b") 49 | assert has_expected_args(all_benchmarks[2].fn, {"iter1": 2, "iter2": "a"}) 50 | assert all_benchmarks[2].fn(**all_benchmarks[2].params) == (2, "a") 51 | assert has_expected_args(all_benchmarks[3].fn, {"iter1": 2, "iter2": "b"}) 52 | assert all_benchmarks[3].fn(**all_benchmarks[3].params) == (2, "b") 53 | -------------------------------------------------------------------------------- /tests/test_file_reporter.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from nnbench.reporter.file import FileReporter 6 | from nnbench.types import BenchmarkResult 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "ext", 11 | ("yaml", "json", "ndjson", "parquet"), 12 | ) 13 | def test_file_reporter_roundtrip(tmp_path: Path, ext: str) -> None: 14 | """Tests data integrity for file reporter roundtrips.""" 15 | 16 | res = BenchmarkResult( 17 | run="my-run", 18 | context={"a": "b", "s": 1, "b.c": 1.0}, 19 | benchmarks=[{"name": "foo", "value": 1}, {"name": "bar", "value": 2}], 20 | timestamp=0, 21 | ) 22 | file = tmp_path / f"result.{ext}" 23 | f = FileReporter() 24 | f.write(res, file) 25 | (res2,) = f.read(file) 26 | 27 | if ext == "csv": 28 | for bm1, bm2 in zip(res.benchmarks, res2.benchmarks): 29 | assert bm1.keys() == bm2.keys() 30 | assert set(bm1.values()) == set(bm2.values()) 31 | else: 32 | assert res2 == res 33 | -------------------------------------------------------------------------------- /tests/test_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | import nnbench 6 | 7 | 8 | def test_runner_collection(testfolder: str) -> None: 9 | benchmarks = nnbench.collect(os.path.join(testfolder, "standard.py"), tags=("runner-collect",)) 10 | assert len(benchmarks) == 1 11 | 12 | benchmarks = nnbench.collect(testfolder, tags=("non-existing-tag",)) 13 | assert len(benchmarks) == 0 14 | 15 | benchmarks = nnbench.collect(testfolder, tags=("runner-collect",)) 16 | assert len(benchmarks) == 1 17 | 18 | 19 | def test_tag_selection(testfolder: str) -> None: 20 | PATH = os.path.join(testfolder, "tags.py") 21 | 22 | assert len(nnbench.collect(PATH)) == 3 23 | assert len(nnbench.collect(PATH, tags=("tag1",))) == 2 24 | assert len(nnbench.collect(PATH, tags=("tag2",))) == 1 25 | 26 | 27 | def test_context_assembly(testfolder: str) -> None: 28 | benchmarks = nnbench.collect(testfolder, tags=("standard",)) 29 | result = nnbench.run( 30 | benchmarks, 31 | params={"x": 1, "y": 1}, 32 | context=[lambda: {"foo": "bar"}], 33 | ) 34 | 35 | assert "foo" in result.context 36 | 37 | 38 | def test_error_on_duplicate_context_keys_in_runner(testfolder: str) -> None: 39 | def duplicate_provider() -> dict[str, str]: 40 | return {"foo": "baz"} 41 | 42 | benchmarks = nnbench.collect(testfolder, tags=("standard",)) 43 | with pytest.raises(ValueError, match="got multiple values for context key 'foo'"): 44 | nnbench.run( 45 | benchmarks, 46 | params={"x": 1, "y": 1}, 47 | context=[lambda: {"foo": "bar"}, duplicate_provider], 48 | ) 49 | 50 | 51 | def test_filter_benchmarks_on_params(testfolder: str) -> None: 52 | @nnbench.benchmark 53 | def prod(a: int, b: int = 1) -> int: 54 | return a * b 55 | 56 | benchmarks = [prod] 57 | rec1 = nnbench.run(benchmarks, params={"a": 1, "b": 2}) 58 | assert rec1.benchmarks[0]["parameters"] == {"a": 1, "b": 2} 59 | # Assert that the defaults are also present if not overridden. 60 | rec2 = nnbench.run(benchmarks, params={"a": 1}) 61 | assert rec2.benchmarks[0]["parameters"] == {"a": 1, "b": 1} 62 | -------------------------------------------------------------------------------- /tests/test_types.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from nnbench.types import Interface 4 | 5 | 6 | def test_interface_with_no_arguments(): 7 | def fn() -> None: 8 | pass 9 | 10 | interface = Interface.from_callable(fn, {}) 11 | assert interface.names == () 12 | assert interface.types == () 13 | assert interface.defaults == () 14 | assert interface.variables == () 15 | assert interface.returntype is type(None) 16 | 17 | 18 | def test_interface_with_multiple_arguments(): 19 | def fn(a: int, b, c: str = "hello", d: float = 10.0) -> None: # type: ignore 20 | pass 21 | 22 | interface = Interface.from_callable(fn, {}) 23 | empty = inspect.Parameter.empty 24 | assert interface.names == ("a", "b", "c", "d") 25 | assert interface.types == (int, empty, str, float) 26 | assert interface.defaults == (empty, empty, "hello", 10.0) 27 | assert interface.variables == ( 28 | ("a", int, empty), 29 | ("b", empty, empty), 30 | ("c", str, "hello"), 31 | ("d", float, 10.0), 32 | ) 33 | assert interface.returntype is type(None) 34 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import pytest 4 | 5 | from nnbench.util import exists_module, modulename 6 | 7 | 8 | @pytest.mark.parametrize("name,expected", [("sys", True), ("yaml", True), ("pipapo", False)]) 9 | def test_ismodule(name: str, expected: bool) -> None: 10 | actual = exists_module(name) 11 | assert expected == actual 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "name,expected", 16 | [("sys", "sys"), ("__main__", "__main__"), ("src/my/module.py", "src.my.module")], 17 | ) 18 | def test_modulename(name: str, expected: str) -> None: 19 | actual = modulename(name) 20 | assert expected == actual 21 | 22 | 23 | def has_expected_args(fn, expected_args): 24 | signature = inspect.signature(fn) 25 | params = signature.parameters 26 | return all(param in params for param in expected_args) 27 | -------------------------------------------------------------------------------- /zizmor.yml: -------------------------------------------------------------------------------- 1 | rules: 2 | unpinned-uses: 3 | config: 4 | policies: 5 | actions/*: ref-pin 6 | astral-sh/setup-uv: ref-pin 7 | pypa/gh-action-pypi-publish: ref-pin 8 | --------------------------------------------------------------------------------