├── .github ├── ISSUE_TEMPLATE │ ├── bug.yml │ ├── config.yml │ ├── feature.yml │ ├── other.yml │ └── roadmap.yml ├── pull_request_template.md └── workflows │ ├── ci-tests-slow.yml │ ├── ci-tests.yml │ ├── ci.yml │ └── publish.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── .vscode ├── launch.json └── settings.json ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── assets ├── .gitignore ├── resolve-demo-assets.sh └── resolve-tests-assets.sh ├── docs ├── Makefile ├── make.bat └── source │ ├── _static │ ├── .DS_Store │ ├── css │ │ ├── custom.css │ │ └── nbsphinx.css │ ├── images │ │ ├── .DS_Store │ │ ├── favicon.ico │ │ ├── lczerolens-logo.png │ │ ├── lczerolens-logo.svg │ │ ├── one.png │ │ └── two.png │ └── switcher.json │ ├── about.rst │ ├── conf.py │ ├── features.rst │ ├── index.rst │ ├── notebooks │ ├── .gitignore │ ├── features │ │ ├── convert-official-weights.ipynb │ │ ├── evaluate-models-on-puzzles.ipynb │ │ ├── move-prediction.ipynb │ │ ├── probe-concepts.ipynb │ │ ├── run-models-on-gpu.ipynb │ │ └── visualise-heatmaps.ipynb │ ├── tutorials │ │ ├── automated-interpretability.ipynb │ │ ├── evidence-of-learned-look-ahead.ipynb │ │ ├── piece-value-estimation-using-lrp.ipynb │ │ └── train-saes.ipynb │ └── walkthrough.ipynb │ ├── start.rst │ └── tutorials.rst ├── pyproject.toml ├── scripts ├── __init__.py ├── constants.py ├── datasets │ ├── __init__.py │ ├── make_lichess_dataset.py │ └── make_tcec_dataset.py ├── lrp │ ├── __init__.py │ └── plane_analysis.py ├── results │ └── .gitignore └── visualisation.py ├── spaces └── .gitignore ├── src └── lczerolens │ ├── __init__.py │ ├── backends.py │ ├── board.py │ ├── concept.py │ ├── concepts │ ├── __init__.py │ ├── material.py │ ├── move.py │ └── threat.py │ ├── constants.py │ ├── lens.py │ ├── lenses │ ├── __init__.py │ ├── activation.py │ ├── composite.py │ ├── gradient.py │ ├── lrp │ │ ├── __init__.py │ │ ├── lens.py │ │ └── rules │ │ │ ├── __init__.py │ │ │ └── epsilon.py │ ├── patching.py │ ├── probing │ │ ├── __init__.py │ │ ├── lens.py │ │ └── probe.py │ └── sae │ │ ├── __init__.py │ │ └── buffer.py │ ├── model.py │ └── play │ ├── __init__.py │ ├── game.py │ ├── puzzle.py │ └── sampling.py ├── tests ├── __init__.py ├── assets │ └── error.ipynb ├── conftest.py ├── integration │ ├── __init__.py │ └── test_notebooks.py └── unit │ ├── __init__.py │ ├── concepts │ ├── __init__.py │ └── test_concepts.py │ ├── conftest.py │ ├── core │ ├── __init__.py │ ├── test_board.py │ ├── test_input_encoding.py │ ├── test_lczero.py │ ├── test_lens.py │ └── test_model.py │ ├── lenses │ ├── __init__.py │ ├── test_activation.py │ ├── test_gradient.py │ └── test_lrp.py │ └── play │ ├── __init__.py │ ├── test_puzzle.py │ └── test_sampling.py └── uv.lock /.github/ISSUE_TEMPLATE/bug.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: File a bug report. 3 | title: "[Bug]: " 4 | labels: ["bug"] 5 | projects: [] 6 | assignees: 7 | - Xmaster6y 8 | body: 9 | - type: markdown 10 | attributes: 11 | value: | 12 | Thanks for taking the time to fill out this bug report! 13 | - type: textarea 14 | id: what-happened 15 | attributes: 16 | label: What happened? 17 | description: Also tell us, what did you expect to happen? 18 | placeholder: Tell us what you see! 19 | value: "A bug happened!" 20 | validations: 21 | required: true 22 | - type: dropdown 23 | id: version 24 | attributes: 25 | label: Version 26 | description: What version of the software are you running? 27 | options: 28 | - "latest" 29 | - "0.3.3" 30 | - "0.3.2" 31 | - "0.3.1" 32 | - "0.3.0" 33 | - "0.2.0" 34 | - "0.1.2" 35 | default: "latest" 36 | validations: 37 | required: true 38 | - type: dropdown 39 | id: environment 40 | attributes: 41 | label: In which environments is the bug happening? 42 | multiple: true 43 | options: 44 | - Colab 45 | - Jupyter Notebook 46 | - Python Script 47 | - Linux 48 | - Windows 49 | - MacOS 50 | - Other 51 | - type: textarea 52 | id: logs 53 | attributes: 54 | label: Relevant log output 55 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. 56 | render: shell 57 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Community Support 4 | url: https://github.com/Xmaster6y/lczerolens/discussions 5 | about: Please ask and answer questions here. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: File a feature request. 3 | title: "[Feature]: " 4 | labels: ["feature"] 5 | projects: [] 6 | assignees: 7 | - Xmaster6y 8 | body: 9 | - type: markdown 10 | attributes: 11 | value: | 12 | Thanks for taking the time to fill out this feature request! 13 | - type: textarea 14 | id: description 15 | attributes: 16 | label: Description 17 | description: A shortline description of the feature request. 18 | placeholder: Tell us what you want! 19 | value: "I want to use this feature!" 20 | validations: 21 | required: true 22 | - type: dropdown 23 | id: category 24 | attributes: 25 | label: Category 26 | description: What category does the feature belong to? 27 | options: 28 | - New Feature 29 | - Improvement 30 | - Bug Fix 31 | - Documentation 32 | - Other 33 | default: 0 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: use-case 38 | attributes: 39 | label: Use Case 40 | description: Please describe the use case for the feature. Use python code to describe the use case. 41 | render: python 42 | - type: textarea 43 | id: tasks 44 | attributes: 45 | label: Tasks 46 | description: List of tasks to implement the feature. 47 | value: | 48 | * [ ] Code is documented 49 | * [ ] Utilities and class tests are written 50 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/other.yml: -------------------------------------------------------------------------------- 1 | name: Other 2 | description: Detail an issue. 3 | title: "[Other]: " 4 | labels: ["other"] 5 | projects: [] 6 | assignees: 7 | - Xmaster6y 8 | body: 9 | - type: markdown 10 | attributes: 11 | value: | 12 | Thanks for taking the time to fill out this issue! 13 | - type: textarea 14 | id: description 15 | attributes: 16 | label: Description 17 | description: A shortline description of the issue. 18 | placeholder: Key insights about the issue. 19 | validations: 20 | required: true 21 | - type: dropdown 22 | id: category 23 | attributes: 24 | label: Category 25 | description: What category does the issue belong to? 26 | options: 27 | - New Feature 28 | - Improvement 29 | - Bug Fix 30 | - Documentation 31 | - Other 32 | default: 4 33 | validations: 34 | required: true 35 | - type: textarea 36 | id: tasks 37 | attributes: 38 | label: Tasks 39 | description: List of tasks to implement the feature. 40 | value: | 41 | * [ ] Code is documented 42 | * [ ] Utilities and class tests are written 43 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/roadmap.yml: -------------------------------------------------------------------------------- 1 | name: Roadmap 2 | description: Detail the roadmap to implement a new feature or solve a bug. 3 | title: "[Roadmap]: " 4 | labels: ["roadmap"] 5 | projects: [] 6 | assignees: 7 | - Xmaster6y 8 | body: 9 | - type: markdown 10 | attributes: 11 | value: | 12 | Thanks for taking the time to fill out this roadmap! 13 | - type: textarea 14 | id: description 15 | attributes: 16 | label: Description 17 | description: A shortline description of the roadmap. 18 | placeholder: Key insights about the roadmap. 19 | validations: 20 | required: true 21 | - type: textarea 22 | id: issues 23 | attributes: 24 | label: Linked Issues 25 | description: List of issues for the roadmap. 26 | value: | 27 | * [ ] #? 28 | validations: 29 | required: true 30 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | Key insights about the PR. 4 | 5 | ## Linked Issues 6 | 7 | - Closes #? 8 | - #? 9 | -------------------------------------------------------------------------------- /.github/workflows/ci-tests-slow.yml: -------------------------------------------------------------------------------- 1 | name: CI (slow) 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | ci: 8 | uses: ./.github/workflows/ci.yml 9 | with: 10 | tests-type: tests-slow 11 | -------------------------------------------------------------------------------- /.github/workflows/ci-tests.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | workflow_dispatch: 6 | push: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | ci: 12 | uses: ./.github/workflows/ci.yml 13 | with: 14 | tests-type: tests 15 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | workflow_call: 3 | inputs: 4 | tests-type: 5 | required: true 6 | type: string 7 | 8 | jobs: 9 | ci: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.10", "3.11", "3.12"] 14 | environment: ci 15 | timeout-minutes: 10 16 | steps: 17 | - uses: actions/checkout@v4 18 | with: 19 | fetch-depth: 0 20 | - name: Install uv 21 | uses: astral-sh/setup-uv@v5 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install dependencies 25 | run: | 26 | uv sync --locked 27 | - name: Run checks 28 | run: | 29 | make checks 30 | - name: Download assets 31 | run: | 32 | make tests-assets 33 | - name: Run tests 34 | run: | 35 | make ${{ inputs.tests-type }} 36 | - name: Upload coverage reports to Codecov 37 | if: matrix.python-version == '3.10' 38 | uses: codecov/codecov-action@v5 39 | with: 40 | token: ${{ secrets.CODECOV_TOKEN }} 41 | - name: Upload test results to Codecov 42 | if: matrix.python-version == '3.10' && ${{ !cancelled() }} 43 | uses: codecov/test-results-action@v1 44 | with: 45 | token: ${{ secrets.CODECOV_TOKEN }} 46 | fail_ci_if_error: true 47 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: [published] 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | environment: publish 12 | permissions: 13 | id-token: write 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Install uv 17 | uses: astral-sh/setup-uv@v5 18 | with: 19 | python-version: "3.11" 20 | - name: Install dependencies 21 | run: | 22 | uv sync --locked 23 | - name: Build package 24 | run: | 25 | uv build 26 | - name: Publish package 27 | uses: pypa/gh-action-pypi-publish@release/v1 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pipenv 85 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 86 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 87 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 88 | # install all needed dependencies. 89 | #Pipfile.lock 90 | 91 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 92 | __pypackages__/ 93 | 94 | # Celery stuff 95 | celerybeat-schedule 96 | celerybeat.pid 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ 127 | 128 | # Pickle files 129 | *.pkl 130 | 131 | # Various files 132 | ignored 133 | debug 134 | *.zip 135 | lc0 136 | !bin/lc0 137 | wandb 138 | **/.DS_Store 139 | junit* 140 | 141 | *secret* 142 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "spaces/lichess-puzzles-leaderboard"] 2 | path = spaces/lichess-puzzles-leaderboard 3 | url = https://huggingface.co/spaces/lczerolens/lichess-puzzles-leaderboard 4 | [submodule "spaces/lczerolens-backends-demo"] 5 | path = spaces/lczerolens-backends-demo 6 | url = https://huggingface.co/spaces/lczerolens/lczerolens-backends-demo 7 | [submodule "spaces/lczerolens-demo"] 8 | path = spaces/lczerolens-demo 9 | url = https://huggingface.co/spaces/Xmaster6y/lczerolens-demo 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-added-large-files 6 | args: ['--maxkb=1024'] 7 | - id: check-yaml 8 | - id: check-json 9 | - id: check-toml 10 | - id: end-of-file-fixer 11 | - id: trailing-whitespace 12 | - id: check-docstring-first 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: v0.9.1 15 | hooks: 16 | - id: ruff 17 | args: [ --fix ] 18 | - id: ruff-format 19 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.11" 7 | commands: 8 | - asdf plugin add uv 9 | - asdf install uv latest 10 | - asdf global uv latest 11 | - uv sync --locked 12 | - make docs 13 | - mkdir -p $READTHEDOCS_OUTPUT 14 | - mv docs/build/html $READTHEDOCS_OUTPUT 15 | 16 | sphinx: 17 | configuration: docs/source/conf.py 18 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "Script Plane Analysis", 6 | "type": "debugpy", 7 | "request": "launch", 8 | "module": "scripts.lrp.plane_analysis", 9 | "console": "integratedTerminal", 10 | "justMyCode": false 11 | } 12 | ] 13 | } 14 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestEnabled": true, 3 | "python.testing.pytestArgs": [ 4 | "tests" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 0.1.0 2 | title: LCZeroLens 3 | type: software 4 | authors: 5 | - name: Yoann Poupart 6 | repository-code: 'https://github.com/Xmaster6y/lczerolens' 7 | repository-artifact: 'https://github.com/Xmaster6y/lczerolens/releases/' 8 | keywords: 9 | - chess 10 | - explainable AI (XAI) 11 | license: MIT 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute? 2 | 3 | ## Guidelines 4 | 5 | The project dependencies are managed using `uv`, see their installation [guide](https://docs.astral.sh/uv/). For even more stability, I recommend using `pyenv` or python `3.9.16`. 6 | 7 | Additionally, to make your life easier, install `make` to use the shortcut commands. 8 | 9 | ## Dev Install 10 | 11 | To install the dependencies: 12 | 13 | ```bash 14 | uv sync 15 | ``` 16 | 17 | Before committing, install `pre-commit`: 18 | 19 | ```bash 20 | uv run pre-commit install 21 | ``` 22 | 23 | To run the checks (`pre-commit` checks): 24 | 25 | ```bash 26 | make checks 27 | ``` 28 | 29 | To run the tests (using `pytest`): 30 | 31 | ```bash 32 | make tests 33 | ``` 34 | 35 | ## Branches 36 | 37 | Make a branch before making a pull request to `develop`. 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yoann Poupart 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: checks 2 | checks: 3 | uv run pre-commit run --all-files 4 | 5 | .PHONY: tests-assets 6 | tests-assets: 7 | bash assets/resolve-tests-assets.sh 8 | 9 | .PHONY: demo-assets 10 | demo-assets: 11 | bash assets/resolve-demo-assets.sh 12 | 13 | .PHONY: tests 14 | tests: 15 | uv run pytest tests --cov=src --cov-report=term-missing --cov-fail-under=50 -s -v --run-fast --run-backends --cov-branch --cov-report=xml --junitxml=junit.xml -o junit_family=legacy 16 | 17 | .PHONY: tests-fast 18 | tests-fast: 19 | uv run pytest tests --cov=src --cov-report=term-missing -s -v --run-fast 20 | 21 | .PHONY: tests-slow 22 | tests-slow: 23 | uv run pytest tests --cov=src --cov-report=term-missing -s -v --run-slow 24 | 25 | .PHONY: tests-backends 26 | tests-backends: 27 | uv run pytest tests --cov=src --cov-report=term-missing -s -v --run-backends 28 | 29 | .PHONY: docs 30 | docs: 31 | cd docs && uv run --group docs make html 32 | 33 | .PHONY: demo 34 | demo: 35 | uv run --group demo gradio spaces/lczerolens-demo/app.py 36 | 37 | .PHONY: demo-backends 38 | demo-backends: 39 | uv run --group demo gradio spaces/lczerolens-backends-demo/app.py 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | logo 2 | 3 | # lczerolens 🔍 4 | 5 | [![lczerolens](https://img.shields.io/pypi/v/lczerolens?color=purple)](https://pypi.org/project/lczerolens/) 6 | [![license](https://img.shields.io/badge/license-MIT-lightgrey.svg)](https://github.com/Xmaster6y/lczerolens/blob/main/LICENSE) 7 | [![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv) 8 | [![python versions](https://img.shields.io/pypi/pyversions/lczerolens.svg)](https://www.python.org/downloads/) 9 | 10 | [![codecov](https://codecov.io/gh/Xmaster6y/lczerolens/graph/badge.svg?token=JKJAWB451A)](https://codecov.io/gh/Xmaster6y/lczerolens) 11 | ![ci-tests](https://github.com/Xmaster6y/lczerolens/actions/workflows/ci-tests.yml/badge.svg) 12 | ![ci-tests-slow](https://github.com/Xmaster6y/lczerolens/actions/workflows/ci-tests-slow.yml/badge.svg) 13 | ![publish](https://github.com/Xmaster6y/lczerolens/actions/workflows/publish.yml/badge.svg) 14 | [![docs](https://readthedocs.org/projects/lczerolens/badge/?version=latest)](https://lczerolens.readthedocs.io/en/latest/?badge=latest) 15 | 16 | 17 | 18 | 19 | Leela Chess Zero (lc0) Lens (`lczerolens`): a set of utilities to make the analysis of Leela Chess Zero networks easy. 20 | 21 | ## Getting Started 22 | 23 | ### Installs 24 | 25 | ```bash 26 | pip install lczerolens 27 | ``` 28 | 29 | Take the `viz` extra to render heatmaps and the `backends` extra to use the `lc0` backends. 30 | 31 | ### Run Models 32 | 33 | Get the best move predicted by a model: 34 | 35 | ```python 36 | from lczerolens import LczeroBoard, LczeroModel 37 | 38 | model = LczeroModel.from_onnx_path("demo/onnx-models/lc0-10-4238.onnx") 39 | board = LczeroBoard() 40 | output = model(board) 41 | legal_indices = board.get_legal_indices() 42 | best_move_idx = output["policy"].gather( 43 | dim=1, 44 | index=legal_indices.unsqueeze(0) 45 | ).argmax(dim=1).item() 46 | print(board.decode_move(legal_indices[best_move_idx])) 47 | ``` 48 | 49 | ### Intervene 50 | 51 | In addition to the built-in lenses, you can easily create your own lenses by subclassing the `Lens` class and overriding the `_intervene` method. E.g. compute an neuron ablation effect: 52 | 53 | ```python 54 | from lczerolens import LczeroBoard, LczeroModel, Lens 55 | 56 | class CustomLens(Lens): 57 | _grad_enabled: bool = False 58 | 59 | def _intervene(self, model: LczeroModel, **kwargs) -> dict: 60 | ablate = kwargs.get("ablate", False) 61 | if ablate: 62 | l5_module = getattr(model, "block5/conv2/relu") 63 | l5_module.output[0, :, 0, 0] = 0 # relative a1 64 | return getattr(model, "output/wdl").output[0,0].save() # win probability 65 | 66 | model = LczeroModel.from_onnx_path("path/to/model.onnx") 67 | lens = CustomLens() 68 | board = LczeroBoard() 69 | 70 | clean_results = lens.analyse(model, board) 71 | corrupted_results = lens.analyse(model, board, ablate=True) 72 | print((corrupted_results - clean_results) / clean_results) 73 | ``` 74 | 75 | ### Features 76 | 77 | - [Visualise Heatmaps](https://lczerolens.readthedocs.io/en/latest/notebooks/features/visualise-heatmaps.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/visualise-heatmaps.ipynb) 78 | - [Probe Concepts](https://lczerolens.readthedocs.io/en/latest/notebooks/features/probe-concepts.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/probe-concepts.ipynb) 79 | - [Move Prediction](https://lczerolens.readthedocs.io/en/latest/notebooks/features/move-prediction.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/move-prediction.ipynb) 80 | - [Run Models on GPU](https://lczerolens.readthedocs.io/en/latest/notebooks/features/run-models-on-gpu.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/run-models-on-gpu.ipynb) 81 | - [Evaluate Models on Puzzles](https://lczerolens.readthedocs.io/en/latest/notebooks/features/evaluate-models-on-puzzles.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/evaluate-models-on-puzzles.ipynb) 82 | - [Convert Official Weights](https://lczerolens.readthedocs.io/en/latest/notebooks/features/convert-official-weights.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/convert-official-weights.ipynb) 83 | 84 | ### Tutorials 85 | 86 | - [Walkthrough](https://lczerolens.readthedocs.io/en/latest/notebooks/walkthrough.html): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/docs/source/notebooks/walkthrough.ipynb) 87 | - [Piece Value Estimation Using LRP](https://lczerolens.readthedocs.io/en/latest/notebooks/tutorials/piece-value-estimation-using-lrp.ipynb): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/piece-value-estimation-using-lrp.ipynb) 88 | - [Evidence of Learned Look-Ahead](https://lczerolens.readthedocs.io/en/latest/notebooks/tutorials/evidence-of-learned-look-ahead.ipynb): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/evidence-of-learned-look-ahead.ipynb) 89 | - [Train SAEs](https://lczerolens.readthedocs.io/en/latest/notebooks/tutorials/train-saes.ipynb): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/train-saes.ipynb) 90 | 91 | ## Demo 92 | 93 | ### Spaces 94 | 95 | - [Lczerolens Demo](https://huggingface.co/spaces/lczerolens/lczerolens-demo) 96 | - [Lczerolens Backends Demo](https://huggingface.co/spaces/lczerolens/lczerolens-backends-demo) 97 | - [Lczerolens Puzzles Leaderboard](https://huggingface.co/spaces/lczerolens/lichess-puzzles-leaderboard) 98 | 99 | ### Local Demo 100 | 101 | Additionally, you can run the gradio demos locally. First you'll need to clone the spaces (after cloning the repo): 102 | 103 | ```bash 104 | git clone https://huggingface.co/spaces/Xmaster6y/lczerolens-demo spaces/lczerolens-demo 105 | ``` 106 | 107 | And optionally the backends demo: 108 | 109 | ```bash 110 | git clone https://huggingface.co/spaces/Xmaster6y/lczerolens-backends-demo spaces/lczerolens-backends-demo 111 | ``` 112 | 113 | And then launch the demo (running on port `8000`): 114 | 115 | ```bash 116 | make demo 117 | ``` 118 | 119 | To test the backends use: 120 | 121 | ```bash 122 | make demo-backends 123 | ``` 124 | 125 | ## Full Documentation 126 | 127 | See the full [documentation](https://lczerolens.readthedocs.io). 128 | 129 | ## Contribute 130 | 131 | See the guidelines in [CONTRIBUTING.md](CONTRIBUTING.md). 132 | -------------------------------------------------------------------------------- /assets/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !resolve-tests-assets.sh 4 | !resolve-demo-assets.sh 5 | -------------------------------------------------------------------------------- /assets/resolve-demo-assets.sh: -------------------------------------------------------------------------------- 1 | make test-assets 2 | 3 | uv run gdown 1CGo_UzFins8pQWQ0JQRrAgIPo57PG11g -O assets/test_stockfish_5000.jsonl 4 | -------------------------------------------------------------------------------- /assets/resolve-tests-assets.sh: -------------------------------------------------------------------------------- 1 | uv run gdown 1Ssl4JanqzQn3p-RoHRDk_aApykl-SukE -O assets/tinygyal-8.pb.gz 2 | uv run gdown 1WzBQV_zn5NnfsG0K8kOion0pvWxXhgKM -O assets/384x30-2022_0108_1903_17_608.pb.gz 3 | uv run gdown 1erxB3tULDURjpPhiPWVGr6X986Q8uE6U -O assets/maia-1100.pb.gz 4 | uv run gdown 1YqqANK-wuZIOmMweuK_oCU7vfPN7G_Z6 -O assets/t1-smolgen-512x15x8h-distilled-swa-3395000.pb.gz 5 | uv run gdown 15-eGN7Hz2NM6aEMRaQrbW3ScxxQpAqa5 -O assets/test_stockfish_10.jsonl 6 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/.DS_Store -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | .button-group { 2 | display: flex; 3 | flex-direction: row; 4 | flex-wrap: nowrap; 5 | justify-content: flex-start; 6 | align-items: center; 7 | align-content: stretch; 8 | gap: 10px; 9 | } 10 | 11 | .hero { 12 | height: 90vh; 13 | display: flex; 14 | flex-direction: row; 15 | flex-wrap: nowrap; 16 | justify-content: center; 17 | align-items: center; 18 | align-content: stretch; 19 | padding-bottom: 20vh; 20 | 21 | overflow: hidden; 22 | } 23 | 24 | html[data-theme="light"] { 25 | --pst-color-secondary: #b676b2; 26 | --pst-color-primary: #4ca4d2; 27 | --pst-color-primary-highlight: #b676b2; 28 | 29 | } 30 | 31 | 32 | .bd-sidebar-primary { 33 | width: fit-content !important; 34 | padding-right: 40px !important 35 | } 36 | 37 | 38 | .features { 39 | height: 60vh; 40 | overflow: hidden; 41 | } 42 | 43 | 44 | .image-container { 45 | margin-top: 50px; 46 | } 47 | 48 | img { 49 | pointer-events: none; 50 | } 51 | 52 | .img-bottom-sizing { 53 | max-height: 10vh; 54 | width: auto 55 | } 56 | 57 | .body-sizing{ 58 | height: 10vh; 59 | } 60 | 61 | .title-bot { 62 | margin-bottom: -10px !important; 63 | line-height: normal !important; 64 | } 65 | 66 | .sub-bot { 67 | margin-bottom: -10px !important; 68 | /* margin-top: -20px !important; */ 69 | line-height: normal !important; 70 | } 71 | 72 | .features-container { 73 | img { 74 | max-width: none; 75 | } 76 | 77 | display: flex; 78 | gap: 20px; 79 | } 80 | 81 | @media only screen and (max-width: 768px) { 82 | 83 | /* Adjust this value based on your breakpoint for mobile */ 84 | .front-container, 85 | .hero { 86 | height: auto; 87 | /* Change from fixed height to auto */ 88 | min-height: 50vh; 89 | /* Adjust this as needed */ 90 | } 91 | 92 | .features-container { 93 | margin-bottom: 20px; 94 | /* Increase bottom margin */ 95 | } 96 | 97 | .hero { 98 | margin-bottom: 30px; 99 | /* Adjust the bottom margin of the main container */ 100 | } 101 | 102 | .features { 103 | height: 110vh; 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /docs/source/_static/css/nbsphinx.css: -------------------------------------------------------------------------------- 1 | div.nboutput.container div.output_area:has(pre) { 2 | max-height: 600px; 3 | } 4 | -------------------------------------------------------------------------------- /docs/source/_static/images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/images/.DS_Store -------------------------------------------------------------------------------- /docs/source/_static/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/images/favicon.ico -------------------------------------------------------------------------------- /docs/source/_static/images/lczerolens-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/images/lczerolens-logo.png -------------------------------------------------------------------------------- /docs/source/_static/images/one.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/images/one.png -------------------------------------------------------------------------------- /docs/source/_static/images/two.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/images/two.png -------------------------------------------------------------------------------- /docs/source/_static/switcher.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "version": "dev", 4 | "url": "https://lczerolens.readthedocs.io/en/latest/" 5 | }, 6 | { 7 | "version": "v0.3.3", 8 | "url": "https://lczerolens.readthedocs.io/en/v0.3.3/" 9 | }, 10 | { 11 | "version": "v0.3.2", 12 | "url": "https://lczerolens.readthedocs.io/en/v0.3.2/" 13 | }, 14 | { 15 | "version": "v0.3.1", 16 | "url": "https://lczerolens.readthedocs.io/en/v0.3.1/" 17 | }, 18 | { 19 | "version": "v0.3.0", 20 | "url": "https://lczerolens.readthedocs.io/en/v0.3.0/" 21 | }, 22 | { 23 | "version": "v0.2.0", 24 | "url": "https://lczerolens.readthedocs.io/en/v0.2.0/" 25 | }, 26 | { 27 | "version": "v0.1.2", 28 | "url": "https://lczerolens.readthedocs.io/en/v0.1.2/" 29 | } 30 | ] 31 | -------------------------------------------------------------------------------- /docs/source/about.rst: -------------------------------------------------------------------------------- 1 | 2 | About lczerolens 3 | ================ 4 | 5 | Goal 6 | ---- 7 | 8 | Provide analysis tools for lc0 networks with ``torch``. 9 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | import os 4 | 5 | import lczerolens 6 | 7 | # Project Information 8 | project = "lczerolens" 9 | copyright = "2024, Yoann Poupart" 10 | author = "Yoann Poupart" 11 | 12 | 13 | # General Configuration 14 | extensions = [ 15 | # 'sphinx.ext.autosectionlabel', 16 | "sphinx.ext.autodoc", # Auto documentation from docstrings 17 | "sphinx.ext.napoleon", # Support for NumPy and Google style docstrings 18 | "sphinx.ext.viewcode", # View code in the browser 19 | "sphinx_copybutton", # Copy button for code blocks 20 | "sphinx_design", # Boostrap design components 21 | "nbsphinx", # Jupyter notebook support 22 | "autoapi.extension", 23 | ] 24 | 25 | templates_path = ["_templates"] 26 | exclude_patterns = [] # type: ignore 27 | fixed_sidebar = True 28 | 29 | 30 | # HTML Output Options 31 | 32 | # See https://sphinx-themes.org/ for more 33 | html_theme = "pydata_sphinx_theme" 34 | html_title = "lczerolens" 35 | html_logo = "_static/images/lczerolens-logo.svg" 36 | html_static_path = ["_static"] 37 | 38 | html_favicon = "_static/images/favicon.ico" 39 | html_show_sourcelink = False 40 | 41 | # Define the json_url for our version switcher. 42 | json_url = "https://lczerolens.readthedocs.io/en/latest/_static/switcher.json" 43 | 44 | 45 | version_match = os.environ.get("READTHEDOCS_VERSION") 46 | release = lczerolens.__version__ 47 | # If READTHEDOCS_VERSION doesn't exist, we're not on RTD 48 | # If it is an integer, we're in a PR build and the version isn't correct. 49 | # If it's "latest" → change to "dev" 50 | if not version_match or version_match.isdigit() or version_match == "latest": 51 | # For local development, infer the version to match from the package. 52 | if "dev" in release or "rc" in release: 53 | version_match = "dev" 54 | # We want to keep the relative reference if we are in dev mode 55 | # but we want the whole url if we are effectively in a released version 56 | json_url = "_static/switcher.json" 57 | else: 58 | version_match = f"v{release}" 59 | elif version_match == "stable": 60 | version_match = f"v{release}" 61 | 62 | html_theme_options = { 63 | "show_nav_level": 2, 64 | "navigation_depth": 2, 65 | "show_toc_level": 2, 66 | "navbar_end": ["theme-switcher", "navbar-icon-links"], 67 | "navbar_align": "left", 68 | "icon_links": [ 69 | { 70 | "name": "GitHub", 71 | "url": "https://github.com/Xmaster6y/lczerolens", 72 | "icon": "fa-brands fa-github", 73 | }, 74 | { 75 | "name": "Discord", 76 | "url": "https://discord.gg/e7vhrTsjnt", 77 | "icon": "fa-brands fa-discord", 78 | }, 79 | { 80 | "name": "PyPI", 81 | "url": "https://pypi.org/project/pydata-sphinx-theme", 82 | "icon": "fa-custom fa-pypi", 83 | }, 84 | ], 85 | "show_version_warning_banner": True, 86 | "navbar_center": ["version-switcher", "navbar-nav"], 87 | "footer_start": ["copyright"], 88 | "footer_center": ["sphinx-version"], 89 | "switcher": { 90 | "json_url": json_url, 91 | "version_match": version_match, 92 | }, 93 | } 94 | html_sidebars = {"about": [], "start": []} 95 | 96 | html_context = {"default_mode": "auto"} 97 | 98 | html_css_files = [ 99 | "css/custom.css", 100 | "css/nbsphinx.css", 101 | ] 102 | 103 | # Nbsphinx 104 | nbsphinx_execute = "auto" 105 | 106 | # Autoapi 107 | autoapi_dirs = ["../../src"] 108 | autoapi_root = "api" 109 | autoapi_keep_files = False 110 | autodoc_typehints = "description" 111 | -------------------------------------------------------------------------------- /docs/source/features.rst: -------------------------------------------------------------------------------- 1 | Features 2 | ========== 3 | 4 | .. raw:: html 5 | 6 | 13 | 14 | 19 | 20 | .. grid:: 2 21 | :gutter: 3 22 | 23 | .. grid-item-card:: 24 | :link: notebooks/features/visualise-heatmaps.ipynb 25 | :class-card: surface 26 | :class-body: surface 27 | 28 | .. raw:: html 29 | 30 |
31 |
32 | 33 |
34 |
35 |
Visualise Heatmaps
36 |

Visualise saliency heatmaps or encodings for a given chess board.

37 |
38 |
39 | 40 | .. grid-item-card:: 41 | :link: notebooks/features/probe-concepts.ipynb 42 | :class-card: surface 43 | :class-body: surface 44 | 45 | .. raw:: html 46 | 47 |
48 |
49 | 50 |
51 |
52 |
Probe Concepts
53 |

Probe the concepts with a dataset.

54 |
55 |
56 | 57 | .. grid-item-card:: 58 | :link: notebooks/features/move-prediction.ipynb 59 | :class-card: surface 60 | :class-body: surface 61 | 62 | .. raw:: html 63 | 64 |
65 |
66 | 67 |
68 |
69 |
Move Prediction
70 |

Make a move prediction for a given chess board.

71 |
72 |
73 | 74 | .. grid-item-card:: 75 | :link: notebooks/features/run-models-on-gpu.ipynb 76 | :class-card: surface 77 | :class-body: surface 78 | 79 | .. raw:: html 80 | 81 |
82 |
83 | 84 |
85 |
86 |
Run Models on GPU
87 |

Take advantage of GPU acceleration.

88 |
89 |
90 | 91 | .. grid-item-card:: 92 | :link: notebooks/features/evaluate-models-on-puzzles.ipynb 93 | :class-card: surface 94 | :class-body: surface 95 | 96 | .. raw:: html 97 | 98 |
99 |
100 | 101 |
102 |
103 |
Evaluate Models on Puzzles
104 |

Evaluate a model on a set of puzzles.

105 |
106 |
107 | 108 | .. grid-item-card:: 109 | :link: notebooks/features/convert-official-weights.ipynb 110 | :class-card: surface 111 | :class-body: surface 112 | 113 | .. raw:: html 114 | 115 |
116 |
117 | 118 |
119 |
120 |
Convert Official Weights
121 |

Convert lc0 networks to onnx.

122 |
123 |
124 | 125 | .. toctree:: 126 | :hidden: 127 | :maxdepth: 2 128 | 129 | notebooks/features/visualise-heatmaps.ipynb 130 | notebooks/features/probe-concepts.ipynb 131 | notebooks/features/move-prediction.ipynb 132 | notebooks/features/run-models-on-gpu.ipynb 133 | notebooks/features/evaluate-models-on-puzzles.ipynb 134 | notebooks/features/convert-official-weights.ipynb 135 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :html_theme.sidebar_secondary.remove: true 2 | :sd_hide_title: 3 | 4 | lczerolens 5 | ========== 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | :hidden: 10 | 11 | start 12 | features 13 | tutorials 14 | api/index 15 | About 16 | 17 | .. grid:: 1 1 2 2 18 | :class-container: hero 19 | :reverse: 20 | 21 | .. grid-item:: 22 | .. div:: 23 | 24 | .. image:: _static/images/lczerolens-logo.svg 25 | :width: 300 26 | :height: 300 27 | 28 | .. grid-item:: 29 | 30 | .. div:: sd-fs-1 sd-font-weight-bold title-bot sd-text-primary image-container 31 | 32 | LczeroLens 33 | 34 | .. div:: sd-fs-4 sd-font-weight-bold sd-my-0 sub-bot image-container 35 | 36 | Interpretability for lc0 networks 37 | 38 | **lczerolens** is a package for interpreting and manipulating the neural networks produce by lc0 39 | 40 | .. div:: button-group 41 | 42 | .. button-ref:: start 43 | :color: primary 44 | :shadow: 45 | 46 | Get Started 47 | 48 | .. button-ref:: tutorials 49 | :color: primary 50 | :outline: 51 | 52 | Tutorials 53 | 54 | .. button-ref:: api/index 55 | :color: primary 56 | :outline: 57 | 58 | API Reference 59 | 60 | 61 | .. div:: sd-fs-1 sd-font-weight-bold sd-text-center sd-text-primary sd-mb-5 62 | 63 | Key Features 64 | 65 | .. grid:: 1 1 2 2 66 | :class-container: features 67 | 68 | .. grid-item:: 69 | 70 | .. div:: features-container 71 | 72 | .. image:: _static/images/one.png 73 | :width: 150 74 | 75 | .. div:: 76 | 77 | **Adaptability** 78 | 79 | Load a network from lc0 (``.pb`` or ``.onnx``) and load it with lczerolens using ``torch``. 80 | 81 | .. grid-item:: 82 | 83 | .. div:: features-container 84 | 85 | .. image:: _static/images/two.png 86 | :width: 150 87 | 88 | .. div:: 89 | 90 | **Interpretability** 91 | 92 | Easily compute saliency maps or aggregated statistics using the pre-built Interpretability methods. 93 | -------------------------------------------------------------------------------- /docs/source/notebooks/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !**/ 3 | !.gitignore 4 | !*.ipynb 5 | *.nbconvert.ipynb 6 | -------------------------------------------------------------------------------- /docs/source/notebooks/features/convert-official-weights.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "-qCMaHT6xuoB" 7 | }, 8 | "source": [ 9 | "# Convert Official Weights\n", 10 | "\n", 11 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/convert-official-weights.ipynb)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "id": "VtZdxVZZx2pL" 18 | }, 19 | "source": [ 20 | "## Setup" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "metadata": { 27 | "id": "iMj257hVxlgJ" 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "MODE = \"local\" # \"colab\" | \"colab-dev\" | \"local\"" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "colab": { 39 | "base_uri": "https://localhost:8080/" 40 | }, 41 | "id": "WtTH-Oz-yotw", 42 | "outputId": "62f302e9-5f63-46ab-82a3-0d77ec054d0d" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "if MODE == \"colab\":\n", 47 | " !pip install -q lczerolens\n", 48 | "elif MODE == \"colab-dev\":\n", 49 | " !rm -r lczerolens\n", 50 | " !git clone https://github.com/Xmaster6y/lczerolens -b main\n", 51 | " !pip install -q ./lczerolens" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "colab": { 59 | "base_uri": "https://localhost:8080/" 60 | }, 61 | "id": "_GdhCPhtyoj3", 62 | "outputId": "0c3dfe61-3dbf-41ee-f6ec-cd524c713daf" 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "!gdown 1erxB3tULDURjpPhiPWVGr6X986Q8uE6U -O leela-network.pb.gz" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "KjLtlNV95WWx" 73 | }, 74 | "source": [ 75 | "To convert a network you'll need to have installed the `lc0` binaries (**takes about 10 minutes**):" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "metadata": { 82 | "colab": { 83 | "base_uri": "https://localhost:8080/" 84 | }, 85 | "id": "IuQ4hpbv5Z-o", 86 | "outputId": "338038e6-3d07-4a43-e9f0-b7797555e9ce" 87 | }, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", 94 | " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", 95 | " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n", 96 | " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", 97 | " Building wheel for lczero_bindings (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" 98 | ] 99 | } 100 | ], 101 | "source": [ 102 | "!pip install -q git+https://github.com/LeelaChessZero/lc0" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": { 108 | "id": "AM3K3PCgx8g0" 109 | }, 110 | "source": [ 111 | "## Convert a Model\n", 112 | "\n", 113 | "You can convert networks to `onnx` using the official `lc0` binaries or\n", 114 | "by using the `backends` module:" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 6, 120 | "metadata": { 121 | "colab": { 122 | "base_uri": "https://localhost:8080/", 123 | "height": 120 124 | }, 125 | "id": "f-6vcmwEyb7n", 126 | "outputId": "4989f19d-27cb-4cb7-bf01-28eaad64f405" 127 | }, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "\n", 134 | "Format\n", 135 | "~~~~~~\n", 136 | " Weights encoding: LINEAR16\n", 137 | " Input: INPUT_CLASSICAL_112_PLANE\n", 138 | " Network: NETWORK_SE_WITH_HEADFORMAT\n", 139 | " Policy: POLICY_CONVOLUTION\n", 140 | " Value: VALUE_WDL\n", 141 | "\n", 142 | "Weights\n", 143 | "~~~~~~~\n", 144 | " Blocks: 6\n", 145 | " SE blocks: 6\n", 146 | " Filters: 64\n", 147 | " Policy: Convolution\n", 148 | " Policy activation: ACTIVATION_DEFAULT\n", 149 | " Value: WDL\n", 150 | " MLH: Absent\n", 151 | "Converting Leela network to the ONNX.\n", 152 | "\n", 153 | "ONNX interface\n", 154 | "~~~~~~~~~~~~~~\n", 155 | " Data type: FLOAT\n", 156 | " Input planes: /input/planes\n", 157 | " Output WDL: /output/wdl\n", 158 | " Output Policy: /output/policy\n", 159 | "Done.\n", 160 | "\n" 161 | ] 162 | } 163 | ], 164 | "source": [ 165 | "from lczerolens import backends\n", 166 | "\n", 167 | "output = backends.convert_to_onnx(\"leela-network.pb.gz\", \"leela-network.onnx\")\n", 168 | "print(output)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": { 174 | "id": "P3RrLQHF5qkI" 175 | }, 176 | "source": [ 177 | "See [Move Prediction](move-prediction.ipynb) to see how to use the converted network." 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": { 183 | "id": "qBzY4spb5nGg" 184 | }, 185 | "source": [ 186 | "## Note\n", 187 | "\n", 188 | "Only the latest networks are supported. To convert older weights, you should build the associated binaries." 189 | ] 190 | } 191 | ], 192 | "metadata": { 193 | "colab": { 194 | "provenance": [] 195 | }, 196 | "kernelspec": { 197 | "display_name": "Python 3", 198 | "name": "python3" 199 | }, 200 | "language_info": { 201 | "name": "python", 202 | "version": "3.11.10" 203 | } 204 | }, 205 | "nbformat": 4, 206 | "nbformat_minor": 0 207 | } 208 | -------------------------------------------------------------------------------- /docs/source/notebooks/tutorials/automated-interpretability.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Automated Interpretability\n", 8 | "\n", 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/automated-interpretability.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Setup" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "MODE = \"local\" # \"colab\" | \"colab-dev\" | \"local\"" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "if MODE == \"colab\":\n", 35 | " !pip install -q lczerolens\n", 36 | "elif MODE == \"colab-dev\":\n", 37 | " !rm -r lczerolens\n", 38 | " !git clone https://github.com/Xmaster6y/lczerolens -b main\n", 39 | " !pip install -q ./lczerolens" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "Downloading...\n", 52 | "From: https://drive.google.com/uc?id=15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X\n", 53 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/tutorials/lc0-19-1876.onnx\n", 54 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:02<00:00, 36.6MB/s]\n", 55 | "Downloading...\n", 56 | "From: https://drive.google.com/uc?id=1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd\n", 57 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/tutorials/lc0-19-4508.onnx\n", 58 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:02<00:00, 35.8MB/s]\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "!gdown 15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X -O lc0-19-1876.onnx\n", 64 | "!gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O lc0-19-4508.onnx" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "## Load a Model\n", 72 | "\n", 73 | "Load a leela network from file (already converted to `onnx`):" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "name": "stderr", 83 | "output_type": "stream", 84 | "text": [ 85 | "/Users/xmaster/Work/lczerolens/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 86 | " from .autonotebook import tqdm as notebook_tqdm\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "from lczerolens import LczeroModel\n", 92 | "\n", 93 | "strong_model = LczeroModel.from_path(\"lc0-19-4508.onnx\")\n", 94 | "weak_model = LczeroModel.from_path(\"lc0-19-1876.onnx\")" 95 | ] 96 | } 97 | ], 98 | "metadata": { 99 | "kernelspec": { 100 | "display_name": ".venv", 101 | "language": "python", 102 | "name": "python3" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": { 106 | "name": "ipython", 107 | "version": 3 108 | }, 109 | "file_extension": ".py", 110 | "mimetype": "text/x-python", 111 | "name": "python", 112 | "nbconvert_exporter": "python", 113 | "pygments_lexer": "ipython3", 114 | "version": "3.11.10" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 2 119 | } 120 | -------------------------------------------------------------------------------- /docs/source/notebooks/tutorials/piece-value-estimation-using-lrp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Piece Value Estimation Using LRP\n", 8 | "\n", 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/piece-value-estimation-using-lrp.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Setup" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "MODE = \"local\" # \"colab\" | \"colab-dev\" | \"local\"" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "if MODE == \"colab\":\n", 35 | " !pip install -q lczerolens\n", 36 | "elif MODE == \"colab-dev\":\n", 37 | " !rm -r lczerolens\n", 38 | " !git clone https://github.com/Xmaster6y/lczerolens -b main\n", 39 | " !pip install -q ./lczerolens" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "Downloading...\n", 52 | "From: https://drive.google.com/uc?id=15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X\n", 53 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/tutorials/lc0-19-1876.onnx\n", 54 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:03<00:00, 26.8MB/s]\n", 55 | "Downloading...\n", 56 | "From: https://drive.google.com/uc?id=1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd\n", 57 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/tutorials/lc0-19-4508.onnx\n", 58 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:03<00:00, 25.6MB/s]\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "!gdown 15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X -O lc0-19-1876.onnx\n", 64 | "!gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O lc0-19-4508.onnx" 65 | ] 66 | } 67 | ], 68 | "metadata": { 69 | "kernelspec": { 70 | "display_name": ".venv", 71 | "language": "python", 72 | "name": "python3" 73 | }, 74 | "language_info": { 75 | "codemirror_mode": { 76 | "name": "ipython", 77 | "version": 3 78 | }, 79 | "file_extension": ".py", 80 | "mimetype": "text/x-python", 81 | "name": "python", 82 | "nbconvert_exporter": "python", 83 | "pygments_lexer": "ipython3", 84 | "version": "3.9.18" 85 | } 86 | }, 87 | "nbformat": 4, 88 | "nbformat_minor": 2 89 | } 90 | -------------------------------------------------------------------------------- /docs/source/notebooks/walkthrough.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Walkthrough\n", 8 | "\n", 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/walkthrough.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Setup" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "MODE = \"local\" # \"colab\" | \"colab-dev\" | \"local\"" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "if MODE == \"colab\":\n", 35 | " !pip install -q lczerolens\n", 36 | "elif MODE == \"colab-dev\":\n", 37 | " !rm -r lczerolens\n", 38 | " !git clone https://github.com/Xmaster6y/lczerolens -b main\n", 39 | " !pip install -q ./lczerolens" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "Downloading...\n", 52 | "From: https://drive.google.com/uc?id=15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X\n", 53 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/lc0-19-1876.onnx\n", 54 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:02<00:00, 32.9MB/s]\n", 55 | "Downloading...\n", 56 | "From: https://drive.google.com/uc?id=1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd\n", 57 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/lc0-19-4508.onnx\n", 58 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:03<00:00, 30.1MB/s]\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "!gdown 15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X -O lc0-19-1876.onnx\n", 64 | "!gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O lc0-19-4508.onnx" 65 | ] 66 | } 67 | ], 68 | "metadata": { 69 | "kernelspec": { 70 | "display_name": ".venv", 71 | "language": "python", 72 | "name": "python3" 73 | }, 74 | "language_info": { 75 | "codemirror_mode": { 76 | "name": "ipython", 77 | "version": 3 78 | }, 79 | "file_extension": ".py", 80 | "mimetype": "text/x-python", 81 | "name": "python", 82 | "nbconvert_exporter": "python", 83 | "pygments_lexer": "ipython3", 84 | "version": "3.11.10" 85 | } 86 | }, 87 | "nbformat": 4, 88 | "nbformat_minor": 2 89 | } 90 | -------------------------------------------------------------------------------- /docs/source/start.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | **lczerolens** is a package for running interpretability methods on lc0 models. 5 | It is designed to be easy to use and to work with the most common interpretability 6 | techniques. 7 | 8 | .. _installation: 9 | 10 | Installation 11 | ------------ 12 | 13 | To get started with lczerolens, install it with ``pip``. 14 | 15 | .. code-block:: console 16 | 17 | pip install lczerolens 18 | 19 | .. note:: 20 | 21 | The dependencies for lczerolens are currently substantial. 22 | It mainly depends on ``torch``, ``nnsight``, ``zennit`` and ``datasets``. 23 | 24 | First Steps 25 | ----------- 26 | 27 | .. grid:: 2 28 | :gutter: 2 29 | 30 | .. grid-item-card:: Features 31 | :link: features 32 | :link-type: doc 33 | 34 | Review the basic features provided by :bdg-primary:`lczerolens`. 35 | 36 | .. grid-item-card:: Walkthrough 37 | :link: notebooks/walkthrough.ipynb 38 | 39 | Walk through the basic functionality of the package. 40 | 41 | .. note:: 42 | 43 | Check out the :bdg-secondary:`walkthrough` to get a better understanding of the package. 44 | 45 | Advanced Features 46 | ----------------- 47 | 48 | .. warning:: 49 | 50 | This following section is under construction, not yet stable nor fully functional. 51 | 52 | .. grid:: 2 53 | :gutter: 2 54 | 55 | .. grid-item-card:: Tutorials 56 | :link: tutorials 57 | :link-type: doc 58 | 59 | See implementations of :bdg-primary:`lczerolens` through common interpretability techniques. 60 | 61 | .. grid-item-card:: API Reference 62 | :link: api/index 63 | :link-type: doc 64 | 65 | See the full API reference for :bdg-primary:`lczerolens` to extend its functionality. 66 | -------------------------------------------------------------------------------- /docs/source/tutorials.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | .. grid:: 2 5 | :gutter: 2 6 | 7 | .. grid-item-card:: Walkthrough 8 | :link: notebooks/walkthrough.ipynb 9 | 10 | :bdg-primary:`Main Features` 11 | 12 | .. grid-item-card:: Evidence of Learned Look-Ahead 13 | :link: notebooks/tutorials/evidence-of-learned-look-ahead.ipynb 14 | 15 | :bdg-primary:`Main Features` 16 | 17 | .. grid-item-card:: Piece Value Estimation Using LRP 18 | :link: notebooks/tutorials/piece-value-estimation-using-lrp.ipynb 19 | 20 | :bdg-primary:`Main Features` 21 | 22 | .. grid-item-card:: Train SAEs 23 | :link: notebooks/tutorials/train-saes.ipynb 24 | 25 | .. toctree:: 26 | :hidden: 27 | :maxdepth: 2 28 | 29 | notebooks/walkthrough.ipynb 30 | notebooks/tutorials/evidence-of-learned-look-ahead.ipynb 31 | notebooks/tutorials/piece-value-estimation-using-lrp.ipynb 32 | notebooks/tutorials/train-saes.ipynb 33 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "lczerolens" 3 | version = "0.3.3-dev" 4 | description = "Interpretability for LeelaChessZero networks." 5 | readme = "README.md" 6 | license = "MIT" 7 | license-files = ["LICENSE"] 8 | authors = [ 9 | {name = "Yoann Poupart", email = "yoann.poupart@ens-lyon.org"}, 10 | ] 11 | requires-python = ">=3.10" 12 | dependencies = [ 13 | "datasets>=3.2.0", 14 | "einops>=0.8.0", 15 | "jsonlines>=4.0.0", 16 | "nnsight>=0.3.7,<0.4.0", 17 | "onnx2torch>=1.5.15", 18 | "python-chess>=1.999", 19 | "scikit-learn>=1.6.1", 20 | "tensordict>=0.6.2", 21 | "typing-extensions>=4.12.2", 22 | ] 23 | classifiers = [ 24 | "Development Status :: 3 - Alpha", 25 | "Intended Audience :: Science/Research", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.10", 28 | "Programming Language :: Python :: 3.11", 29 | "Programming Language :: Python :: 3.12", 30 | ] 31 | 32 | [project.urls] 33 | homepage = "https://lczerolens.readthedocs.io/" 34 | documentation = "https://lczerolens.readthedocs.io/" 35 | source = "https://github.com/Xmaster6y/lczerolens" 36 | issues = "https://github.com/Xmaster6y/lczerolens/issues" 37 | releasenotes = "https://github.com/Xmaster6y/lczerolens/releases" 38 | 39 | [project.optional-dependencies] 40 | viz = [ 41 | "matplotlib>=3.10.0", 42 | ] 43 | backends = [ 44 | "v-lczero-bindings>=0.31.2" 45 | ] 46 | 47 | [dependency-groups] 48 | dev = [ 49 | "gdown>=5.2.0", 50 | "ipykernel>=6.29.5", 51 | "nbconvert>=7.16.5", 52 | "onnxruntime>=1.20.1", 53 | "pre-commit>=4.0.1", 54 | "pytest>=8.3.4", 55 | "pytest-cov>=6.0.0", 56 | "v-lczero-bindings>=0.31.2", 57 | ] 58 | demo = [ 59 | "gradio>=5.12.0", 60 | "matplotlib>=3.10.0", 61 | "v-lczero-bindings>=0.31.2", 62 | ] 63 | docs = [ 64 | "nbsphinx>=0.9.6", 65 | "pandoc>=2.4", 66 | "plotly>=5.24.1", 67 | "pydata-sphinx-theme>=0.16.1", 68 | "sphinx>=8.1.3", 69 | "sphinx-autoapi>=3.6.0", 70 | "sphinx-charts>=0.2.1", 71 | "sphinx-copybutton>=0.5.2", 72 | "sphinx-design>=0.6.1", 73 | ] 74 | scripts = [ 75 | "loguru>=0.7.3", 76 | "matplotlib>=3.10.0", 77 | "pylatex>=1.4.2", 78 | "safetensors>=0.5.2", 79 | "wandb>=0.19.2", 80 | ] 81 | 82 | [build] 83 | target-dir = "build/dist" 84 | 85 | [build-system] 86 | requires = ["setuptools"] 87 | build-backend = "setuptools.build_meta" 88 | 89 | [tool.uv] 90 | default-groups = ["dev"] 91 | 92 | [tool.ruff] 93 | line-length = 119 94 | target-version = "py311" 95 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/constants.py: -------------------------------------------------------------------------------- 1 | """Constants for the scripts.""" 2 | 3 | import os 4 | 5 | # Secrets 6 | HF_TOKEN = os.getenv("HF_TOKEN") 7 | WANDB_API_KEY = os.getenv("WANDB_API_KEY") 8 | -------------------------------------------------------------------------------- /scripts/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/scripts/datasets/__init__.py -------------------------------------------------------------------------------- /scripts/datasets/make_lichess_dataset.py: -------------------------------------------------------------------------------- 1 | """Script to generate the base datasets. 2 | 3 | Run with: 4 | ```bash 5 | uv run python -m scripts.datasets.make_lichess_dataset 6 | ``` 7 | """ 8 | 9 | import argparse 10 | 11 | from datasets import Dataset 12 | from loguru import logger 13 | 14 | from scripts.constants import HF_TOKEN 15 | 16 | 17 | def main(args: argparse.Namespace): 18 | logger.info(f"Loading `{args.source_file}`...") 19 | 20 | dataset = Dataset.from_csv(args.source_file) 21 | logger.info(f"Loaded dataset: {dataset}") 22 | 23 | if args.push_to_hub: 24 | logger.info("Pushing to hub...") 25 | dataset.push_to_hub( 26 | repo_id=args.dataset_name, 27 | token=HF_TOKEN, 28 | ) 29 | 30 | 31 | def parse_args() -> argparse.Namespace: 32 | parser = argparse.ArgumentParser("make-lichess-dataset") 33 | parser.add_argument( 34 | "--source_file", 35 | type=str, 36 | default="./assets/lichess_db_puzzle.csv", 37 | ) 38 | parser.add_argument( 39 | "--dataset_name", 40 | type=str, 41 | default="lczerolens/lichess-puzzles", 42 | ) 43 | parser.add_argument("--push_to_hub", action=argparse.BooleanOptionalAction, default=False) 44 | return parser.parse_args() 45 | 46 | 47 | if __name__ == "__main__": 48 | args = parse_args() 49 | main(args) 50 | -------------------------------------------------------------------------------- /scripts/datasets/make_tcec_dataset.py: -------------------------------------------------------------------------------- 1 | """Script to generate the base datasets. 2 | 3 | Run with: 4 | ```bash 5 | uv run python -m scripts.datasets.make_tcec_dataset 6 | ``` 7 | """ 8 | 9 | import argparse 10 | 11 | from datasets import Dataset 12 | from loguru import logger 13 | 14 | from scripts.constants import HF_TOKEN 15 | 16 | 17 | def main(args: argparse.Namespace): 18 | logger.info(f"Loading `{args.source_file}`...") 19 | 20 | dataset = Dataset.from_json(args.source_file) 21 | logger.info(f"Loaded dataset: {dataset}") 22 | 23 | if args.push_to_hub: 24 | logger.info("Pushing to hub...") 25 | dataset.push_to_hub( 26 | repo_id=args.dataset_name, 27 | token=HF_TOKEN, 28 | ) 29 | 30 | 31 | def parse_args() -> argparse.Namespace: 32 | parser = argparse.ArgumentParser("make-tcec-dataset") 33 | parser.add_argument( 34 | "--source_file", 35 | type=str, 36 | default="./assets/tcec-games.jsonl", 37 | ) 38 | parser.add_argument( 39 | "--dataset_name", 40 | type=str, 41 | default="lczerolens/tcec-games", 42 | ) 43 | parser.add_argument("--push_to_hub", action=argparse.BooleanOptionalAction, default=False) 44 | return parser.parse_args() 45 | 46 | 47 | if __name__ == "__main__": 48 | args = parse_args() 49 | main(args) 50 | -------------------------------------------------------------------------------- /scripts/lrp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/scripts/lrp/__init__.py -------------------------------------------------------------------------------- /scripts/lrp/plane_analysis.py: -------------------------------------------------------------------------------- 1 | """Script to compute the importance of each plane for the model. 2 | 3 | Run with: 4 | ``` 5 | uv run python -m scripts.lrp.plane_analysis 6 | ``` 7 | """ 8 | 9 | import argparse 10 | from loguru import logger 11 | 12 | from datasets import Dataset 13 | from torch.utils.data import DataLoader 14 | import torch 15 | 16 | from lczerolens.encodings import move as move_encoding 17 | from lczerolens.concept import MulticlassConcept 18 | from lczerolens.model import ForceValueFlow, PolicyFlow 19 | from lczerolens import concept, Lens 20 | from scripts import visualisation 21 | 22 | 23 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | 26 | def main(args): 27 | dataset = Dataset.from_json( 28 | "./assets/TCEC_game_collection_random_boards_bestlegal.jsonl", features=MulticlassConcept.features 29 | ) 30 | logger.info(f"Loaded dataset with {len(dataset)} boards.") 31 | if args.target == "policy": 32 | wrapper = PolicyFlow.from_path(f"./assets/{args.model_name}").to(DEVICE) 33 | init_rel_fn = concept.concept_init_rel 34 | 35 | elif args.target == "value": 36 | wrapper = ForceValueFlow.from_path(f"./assets/{args.model_name}").to(DEVICE) 37 | init_rel_fn = None 38 | else: 39 | raise ValueError(f"Target '{args.target}' not supported.") 40 | lens = Lens.from_name("lrp") 41 | if not lens.is_compatible(wrapper): 42 | raise ValueError(f"Lens of type 'lrp' not compatible with model '{args.model_name}'.") 43 | 44 | dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=concept.concept_collate_fn) 45 | 46 | iter_analyse = lens.analyse_batched_boards( 47 | dataloader, 48 | wrapper, 49 | target=None, 50 | return_output=True, 51 | init_rel_fn=init_rel_fn, 52 | ) 53 | all_stats = { 54 | "relative_piece_relevance": [], 55 | "absolute_piece_relevance": [], 56 | "plane_relevance_proportion": [], 57 | "relative_piece_relevance_proportion": [], 58 | "absolute_piece_relevance_proportion": [], 59 | } 60 | n_plotted = 0 61 | for batch in iter_analyse: 62 | batched_relevances, boards, *infos = batch 63 | relevances, outputs = batched_relevances 64 | labels = infos[0] 65 | for rel, out, board, label in zip(relevances, outputs, boards, labels): 66 | max_config_rel = rel[:12].abs().max().item() 67 | if max_config_rel == 0: 68 | continue 69 | if n_plotted < args.plot_first_n: 70 | if board.turn: 71 | heatmap = rel.sum(dim=0).view(64) 72 | else: 73 | heatmap = rel.sum(dim=0).flip(0).view(64) 74 | if args.target == "policy": 75 | move = move_encoding.decode_move(label, (board.turn, not board.turn), board) 76 | else: 77 | move = None 78 | visualisation.render_heatmap( 79 | board, 80 | heatmap, 81 | arrows=[(move.from_square, move.to_square)] if move is not None else None, 82 | normalise="abs", 83 | save_to=f"./scripts/results/{args.target}_heatmap_{n_plotted}.png", 84 | ) 85 | n_plotted += 1 86 | 87 | plane_order = "PNBRQKpnbrqk" 88 | piece_relevance = {} 89 | for i, letter in enumerate(plane_order): 90 | num = (rel[i] != 0).sum().item() 91 | if num == 0: 92 | piece_relevance[letter] = 0 93 | else: 94 | piece_relevance[letter] = rel[i].sum().item() / num 95 | 96 | if args.find_interesting: 97 | if piece_relevance["q"] / max_config_rel > 0.9 and args.target == "value": 98 | if board.turn: 99 | heatmap = rel.sum(dim=0).view(64) 100 | else: 101 | heatmap = rel.sum(dim=0).flip(0).view(64) 102 | if args.target == "policy": 103 | move = move_encoding.decode_move(label, (board.turn, not board.turn), board) 104 | else: 105 | move = None 106 | visualisation.render_heatmap( 107 | board, 108 | heatmap, 109 | arrows=[(move.from_square, move.to_square)] if move is not None else None, 110 | normalise="abs", 111 | save_to=f"./scripts/results/{args.target}_heatmap_{n_plotted}.png", 112 | ) 113 | raise SystemExit 114 | 115 | if any(piece_relevance[k] / max_config_rel > 0.9 for k in "pnbrqk") and args.target == "policy": 116 | if board.turn: 117 | heatmap = rel.sum(dim=0).view(64) 118 | else: 119 | heatmap = rel.sum(dim=0).flip(0).view(64) 120 | if args.target == "policy": 121 | move = move_encoding.decode_move(label, (board.turn, not board.turn), board) 122 | else: 123 | move = None 124 | visualisation.render_heatmap( 125 | board, 126 | heatmap, 127 | arrows=[(move.from_square, move.to_square)] if move is not None else None, 128 | normalise="abs", 129 | save_to=f"./scripts/results/{args.target}_heatmap_{n_plotted}.png", 130 | ) 131 | raise SystemExit 132 | 133 | all_stats["absolute_piece_relevance"].append(piece_relevance) 134 | all_stats["relative_piece_relevance"].append({k: v / max_config_rel for k, v in piece_relevance.items()}) 135 | 136 | total_relevance = rel.abs().sum().item() 137 | clock = board.fullmove_number * 2 - (not board.turn) 138 | proportion = rel.abs().sum(dim=(1, 2)).div(total_relevance).tolist() 139 | all_stats["plane_relevance_proportion"].append({clock: proportion}) 140 | all_stats["relative_piece_relevance_proportion"].append( 141 | {clock: [v / max_config_rel for v in piece_relevance.values()]} 142 | ) 143 | all_stats["absolute_piece_relevance_proportion"].append({clock: proportion[:12]}) 144 | 145 | logger.info(f"Processed {len(all_stats['relative_piece_relevance'])} boards.") 146 | 147 | visualisation.render_boxplot( 148 | all_stats["relative_piece_relevance"], 149 | y_label="Relevance", 150 | title="Relative Relevance", 151 | save_to=f"./scripts/results/{args.target}_piece_relative_relevance.png", 152 | ) 153 | visualisation.render_boxplot( 154 | all_stats["absolute_piece_relevance"], 155 | y_label="Relevance", 156 | title="Absolute Relevance", 157 | save_to=f"./scripts/results/{args.target}_piece_absolute_relevance.png", 158 | ) 159 | visualisation.render_proportion_through_index( 160 | all_stats["plane_relevance_proportion"], 161 | plane_type="Pieces", 162 | y_label="Proportion of relevance", 163 | y_log=True, 164 | max_index=200, 165 | title="Proportion of relevance per piece", 166 | save_to=f"./scripts/results/{args.target}_plane_config_relevance.png", 167 | ) 168 | visualisation.render_proportion_through_index( 169 | all_stats["plane_relevance_proportion"], 170 | plane_type="H0", 171 | y_label="Proportion of relevance", 172 | y_log=True, 173 | max_index=200, 174 | title="Proportion of relevance per plane", 175 | save_to=f"./scripts/results/{args.target}_plane_H0_relevance.png", 176 | ) 177 | visualisation.render_proportion_through_index( 178 | all_stats["plane_relevance_proportion"], 179 | plane_type="Hist", 180 | y_label="Proportion of relevance", 181 | y_log=True, 182 | max_index=200, 183 | title="Proportion of relevance per plane", 184 | save_to=f"./scripts/results/{args.target}_plane_hist_relevance.png", 185 | ) 186 | visualisation.render_proportion_through_index( 187 | all_stats["relative_piece_relevance_proportion"], 188 | plane_type="Pieces", 189 | y_label="Proportion of relevance", 190 | y_log=False, 191 | max_index=200, 192 | title="Proportion of relevance per piece", 193 | save_to=f"./scripts/results/{args.target}_piece_plane_relative_relevance.png", 194 | ) 195 | visualisation.render_proportion_through_index( 196 | all_stats["absolute_piece_relevance_proportion"], 197 | plane_type="Pieces", 198 | y_label="Proportion of relevance", 199 | y_log=False, 200 | max_index=200, 201 | title="Proportion of relevance per piece", 202 | save_to=f"./scripts/results/{args.target}_piece_plane_absolute_relevance.png", 203 | ) 204 | 205 | 206 | def parse_args() -> argparse.Namespace: 207 | parser = argparse.ArgumentParser("plane-importance") 208 | parser.add_argument("--model_name", type=str, default="64x6-2018_0627_1913_08_161.onnx") 209 | parser.add_argument("--target", type=str, default="value") 210 | parser.add_argument("--find_interesting", action=argparse.BooleanOptionalAction, default=False) 211 | parser.add_argument("--batch_size", type=int, default=100) 212 | parser.add_argument("--plot_first_n", type=int, default=5) 213 | return parser.parse_args() 214 | 215 | 216 | if __name__ == "__main__": 217 | args = parse_args() 218 | main(args) 219 | -------------------------------------------------------------------------------- /scripts/results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /scripts/visualisation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualisation utils. 3 | """ 4 | 5 | import chess 6 | import chess.svg 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | from lczerolens.encodings import board as board_encoding 12 | 13 | COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000) 14 | ALPHA = 1.0 15 | NORM = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False) 16 | 17 | 18 | def render_heatmap( 19 | board, 20 | heatmap, 21 | square=None, 22 | vmin=None, 23 | vmax=None, 24 | arrows=None, 25 | normalise="none", 26 | save_to=None, 27 | ): 28 | """ 29 | Render a heatmap on the board. 30 | """ 31 | if normalise == "abs": 32 | a_max = heatmap.abs().max() 33 | if a_max != 0: 34 | heatmap = heatmap / a_max 35 | vmin = -1 36 | vmax = 1 37 | if vmin is None: 38 | vmin = heatmap.min() 39 | if vmax is None: 40 | vmax = heatmap.max() 41 | norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False) 42 | 43 | color_dict = {} 44 | for square_index in range(64): 45 | color = COLOR_MAP(norm(heatmap[square_index])) 46 | color = (*color[:3], ALPHA) 47 | color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True) 48 | fig = plt.figure(figsize=(1, 6)) 49 | ax = plt.gca() 50 | ax.axis("off") 51 | fig.colorbar( 52 | matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP), 53 | ax=ax, 54 | orientation="vertical", 55 | fraction=1.0, 56 | ) 57 | if square is not None: 58 | try: 59 | check = chess.parse_square(square) 60 | except ValueError: 61 | check = None 62 | else: 63 | check = None 64 | if arrows is None: 65 | arrows = [] 66 | 67 | svg_board = chess.svg.board( 68 | board, 69 | check=check, 70 | fill=color_dict, 71 | size=350, 72 | arrows=arrows, 73 | ) 74 | if save_to is not None: 75 | plt.savefig(save_to) 76 | with open(save_to.replace(".png", ".svg"), "w") as f: 77 | f.write(svg_board) 78 | plt.close() 79 | else: 80 | plt.close() 81 | return svg_board, fig 82 | 83 | 84 | def render_boxplot( 85 | data, 86 | filter_null=True, 87 | y_label=None, 88 | title=None, 89 | save_to=None, 90 | ): 91 | labels = data[0].keys() 92 | boxed_data = {label: [] for label in labels} 93 | for d in data: 94 | for label in labels: 95 | v = d.get(label) 96 | if v == 0.0 and filter_null: 97 | continue 98 | boxed_data[label].append(v) 99 | plt.boxplot(boxed_data.values(), notch=True, vert=True, patch_artist=True, labels=labels) 100 | plt.ylabel(y_label) 101 | plt.title(title) 102 | if save_to is not None: 103 | plt.savefig(save_to) 104 | plt.close() 105 | else: 106 | plt.show() 107 | 108 | 109 | def render_proportion_through_index( 110 | data, 111 | plane_type="H0", 112 | max_index=None, 113 | y_log=False, 114 | y_label=None, 115 | title=None, 116 | save_to=None, 117 | ): 118 | if plane_type == "H0": 119 | indexed_data = { 120 | "H0": {}, 121 | "Hist": {}, 122 | "Meta": {}, 123 | } 124 | for d in data: 125 | index, proportion = next(iter(d.items())) 126 | if max_index is not None and index > max_index: 127 | continue 128 | if index not in indexed_data["H0"]: 129 | indexed_data["H0"][index] = [sum(proportion[:13])] 130 | indexed_data["Hist"][index] = [sum(proportion[13:104])] 131 | indexed_data["Meta"][index] = [sum(proportion[104:])] 132 | else: 133 | indexed_data["H0"][index].append(sum(proportion[:13])) 134 | indexed_data["Hist"][index].append(sum(proportion[13:104])) 135 | indexed_data["Meta"][index].append(sum(proportion[104:])) 136 | 137 | elif plane_type == "Hist": 138 | indexed_data = { 139 | "H0": {}, 140 | "H1": {}, 141 | "H2": {}, 142 | "H3": {}, 143 | "H4": {}, 144 | "H5": {}, 145 | "H6": {}, 146 | "H7": {}, 147 | "Castling": {}, 148 | "Remaining": {}, 149 | } 150 | for d in data: 151 | index, proportion = next(iter(d.items())) 152 | if max_index is not None and index > max_index: 153 | continue 154 | if index not in indexed_data["H0"]: 155 | for i in range(8): 156 | indexed_data[f"H{i}"][index] = [sum(proportion[13 * i : 13 * (i + 1)])] 157 | indexed_data["Castling"][index] = [sum(proportion[104:108])] 158 | indexed_data["Remaining"][index] = [sum(proportion[108:])] 159 | else: 160 | for i in range(8): 161 | indexed_data[f"H{i}"][index].append(sum(proportion[13 * i : 13 * (i + 1)])) 162 | indexed_data["Castling"][index].append(sum(proportion[104:108])) 163 | indexed_data["Remaining"][index].append(sum(proportion[108:])) 164 | 165 | elif plane_type == "Pieces": 166 | relative_plane_order = board_encoding.get_plane_order((chess.WHITE, chess.BLACK)) 167 | indexed_data = {letter: {} for letter in relative_plane_order} 168 | for d in data: 169 | index, proportion = next(iter(d.items())) 170 | if max_index is not None and index > max_index: 171 | continue 172 | if index not in indexed_data[relative_plane_order[0]]: 173 | for i, letter in enumerate(relative_plane_order): 174 | indexed_data[letter][index] = [proportion[i]] 175 | else: 176 | for i, letter in enumerate(relative_plane_order): 177 | indexed_data[letter][index].append(proportion[i]) 178 | else: 179 | raise ValueError(f"Invalid plane type: {plane_type}") 180 | 181 | n_curves = len(indexed_data) 182 | for i, (label, curve_data) in enumerate(indexed_data.items()): 183 | indices = sorted(list(curve_data.keys())) 184 | mean_curve = [np.mean(curve_data[idx]) for idx in indices] 185 | std_curve = [np.std(curve_data[idx]) for idx in indices] 186 | c = COLOR_MAP(i / (n_curves - 1)) 187 | plt.plot(indices, mean_curve, label=label, c=c) 188 | lower_bound = np.array(mean_curve) - np.array(std_curve) 189 | upper_bound = np.array(mean_curve) + np.array(std_curve) 190 | plt.fill_between(indices, lower_bound, upper_bound, alpha=0.2, color=c) 191 | if y_log: 192 | plt.yscale("log") 193 | plt.legend() 194 | plt.ylabel(y_label) 195 | plt.title(title) 196 | if save_to is not None: 197 | plt.savefig(save_to) 198 | plt.close() 199 | else: 200 | plt.show() 201 | -------------------------------------------------------------------------------- /spaces/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /src/lczerolens/__init__.py: -------------------------------------------------------------------------------- 1 | """Main module for the lczerolens package.""" 2 | 3 | from importlib.metadata import PackageNotFoundError, version 4 | 5 | from .board import LczeroBoard, InputEncoding 6 | from .model import LczeroModel, Flow 7 | from .lens import Lens 8 | from . import lenses, concepts, play 9 | 10 | try: 11 | __version__ = version("lczerolens") 12 | except PackageNotFoundError: 13 | __version__ = "unknown version" 14 | 15 | __all__ = [ 16 | "LczeroBoard", 17 | "LczeroModel", 18 | "Flow", 19 | "InputEncoding", 20 | "Lens", 21 | "lenses", 22 | "concepts", 23 | "play", 24 | ] 25 | -------------------------------------------------------------------------------- /src/lczerolens/backends.py: -------------------------------------------------------------------------------- 1 | """Utils from the lczero executable and bindings. 2 | 3 | Notes 4 | ---- 5 | The lczero bindings are not installed by default. You can install them by 6 | running `pip install lczerolens[backends]`. 7 | """ 8 | 9 | import subprocess 10 | 11 | import chess 12 | import torch 13 | 14 | from lczerolens.board import LczeroBoard 15 | 16 | try: 17 | from lczero.backends import Backend, GameState 18 | except ImportError as e: 19 | raise ImportError( 20 | "LCZero bindings are required to use the backends, install them with `pip install lczerolens[backends]`." 21 | ) from e 22 | 23 | 24 | def generic_command(args, verbose=False): 25 | """ 26 | Run a generic command. 27 | """ 28 | popen = subprocess.Popen( 29 | ["lc0", *args], 30 | stdout=subprocess.PIPE, 31 | stderr=subprocess.PIPE, 32 | ) 33 | popen.wait() 34 | if popen.returncode != 0: 35 | if verbose: 36 | stderr = f"\n[DEBUG] stderr:\n{popen.stderr.read().decode('utf-8')}" 37 | else: 38 | stderr = "" 39 | raise RuntimeError(f"Could not run `lc0 {' '.join(args)}`." + stderr) 40 | return popen.stdout.read().decode("utf-8") 41 | 42 | 43 | def describenet(path, verbose=False): 44 | """ 45 | Describe the net at the given path. 46 | """ 47 | return generic_command(["describenet", "-w", path], verbose=verbose) 48 | 49 | 50 | def convert_to_onnx(in_path, out_path, verbose=False): 51 | """ 52 | Convert the net at the given path. 53 | """ 54 | return generic_command( 55 | ["leela2onnx", f"--input={in_path}", f"--output={out_path}"], 56 | verbose=verbose, 57 | ) 58 | 59 | 60 | def convert_to_leela(in_path, out_path, verbose=False): 61 | """ 62 | Convert the net at the given path. 63 | """ 64 | return generic_command( 65 | ["onnx2leela", f"--input={in_path}", f"--output={out_path}"], 66 | verbose=verbose, 67 | ) 68 | 69 | 70 | def board_from_backend(lczero_backend: Backend, lczero_game: GameState, planes: int = 112): 71 | """ 72 | Create a board from the lczero backend. 73 | """ 74 | lczero_input = lczero_game.as_input(lczero_backend) 75 | lczero_input_tensor = torch.zeros((112, 64), dtype=torch.float) 76 | for plane in range(planes): 77 | mask_str = f"{lczero_input.mask(plane):b}".zfill(64) 78 | lczero_input_tensor[plane] = torch.tensor( 79 | tuple(map(int, reversed(mask_str))), dtype=torch.float 80 | ) * lczero_input.val(plane) 81 | return lczero_input_tensor.view((112, 8, 8)) 82 | 83 | 84 | def prediction_from_backend( 85 | lczero_backend: Backend, 86 | lczero_game: GameState, 87 | softmax: bool = False, 88 | only_legal: bool = False, 89 | illegal_value: float = 0, 90 | ): 91 | """ 92 | Predicts the move. 93 | """ 94 | filtered_policy = torch.full((1858,), illegal_value, dtype=torch.float) 95 | lczero_input = lczero_game.as_input(lczero_backend) 96 | (lczero_output,) = lczero_backend.evaluate(lczero_input) 97 | if only_legal: 98 | indices = torch.tensor(lczero_game.policy_indices()) 99 | else: 100 | indices = torch.tensor(range(1858)) 101 | if softmax: 102 | policy = torch.tensor(lczero_output.p_softmax(*range(1858)), dtype=torch.float) 103 | else: 104 | policy = torch.tensor(lczero_output.p_raw(*range(1858)), dtype=torch.float) 105 | value = torch.tensor(lczero_output.q()) 106 | filtered_policy[indices] = policy[indices] 107 | return filtered_policy, value 108 | 109 | 110 | def moves_with_castling_swap(lczero_game: GameState, board: LczeroBoard): 111 | """ 112 | Get the moves with castling swap. 113 | """ 114 | lczero_legal_moves = lczero_game.moves() 115 | lczero_policy_indices = list(lczero_game.policy_indices()) 116 | for move in board.legal_moves: 117 | uci_move = move.uci() 118 | if uci_move in lczero_legal_moves: 119 | continue 120 | if board.is_castling(move): 121 | leela_uci_move = uci_move.replace("g", "h").replace("c", "a") 122 | if leela_uci_move in lczero_legal_moves: 123 | lczero_legal_moves.remove(leela_uci_move) 124 | lczero_legal_moves.append(uci_move) 125 | lczero_policy_indices.remove( 126 | LczeroBoard.encode_move( 127 | chess.Move.from_uci(leela_uci_move), 128 | board.turn, 129 | ) 130 | ) 131 | lczero_policy_indices.append(LczeroBoard.encode_move(move, board.turn)) 132 | return lczero_legal_moves, lczero_policy_indices 133 | -------------------------------------------------------------------------------- /src/lczerolens/board.py: -------------------------------------------------------------------------------- 1 | """Board class.""" 2 | 3 | import re 4 | from enum import Enum 5 | from typing import Optional, Generator, Tuple, List, Union, Any 6 | 7 | import chess 8 | import chess.svg 9 | import torch 10 | import io 11 | import numpy as np 12 | 13 | from .constants import INVERTED_POLICY_INDEX, POLICY_INDEX 14 | 15 | 16 | class InputEncoding(int, Enum): 17 | """Input encoding for the board tensor.""" 18 | 19 | INPUT_CLASSICAL_112_PLANE = 0 20 | INPUT_CLASSICAL_112_PLANE_REPEATED = 1 21 | INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED = 2 22 | INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS = 3 23 | 24 | 25 | class LczeroBoard(chess.Board): 26 | """A class for wrapping the LczeroBoard class.""" 27 | 28 | @staticmethod 29 | def get_plane_order(us: bool): 30 | """Get the plane order for the given us view. 31 | 32 | Parameters 33 | ---------- 34 | us : bool 35 | The us_them tuple. 36 | 37 | Returns 38 | ------- 39 | str 40 | The plane order. 41 | """ 42 | plane_orders = {chess.WHITE: "PNBRQK", chess.BLACK: "pnbrqk"} 43 | return plane_orders[us] + plane_orders[not us] 44 | 45 | @staticmethod 46 | def get_piece_index(piece: str, us: bool, plane_order: Optional[str] = None): 47 | """Converts a piece to its index in the plane order. 48 | 49 | Parameters 50 | ---------- 51 | piece : str 52 | The piece to convert. 53 | us : bool 54 | The us_them tuple. 55 | plane_order : Optional[str] 56 | The plane order. 57 | 58 | Returns 59 | ------- 60 | int 61 | The index of the piece in the plane order. 62 | """ 63 | if plane_order is None: 64 | plane_order = LczeroBoard.get_plane_order(us) 65 | return f"{plane_order}0".index(piece) 66 | 67 | def to_config_tensor( 68 | self, 69 | us: Optional[bool] = None, 70 | ): 71 | """Converts a LczeroBoard to a tensor based on the pieces configuration. 72 | 73 | Parameters 74 | ---------- 75 | us : Optional[bool] 76 | The us_them tuple. 77 | 78 | Returns 79 | ------- 80 | torch.Tensor 81 | The 13x8x8 tensor. 82 | """ 83 | if us is None: 84 | us = self.turn 85 | plane_order = LczeroBoard.get_plane_order(us) 86 | 87 | def piece_to_index(piece: str): 88 | return f"{plane_order}0".index(piece) 89 | 90 | fen_board = self.fen().split(" ")[0] 91 | fen_rep = re.sub(r"(\d)", lambda m: "0" * int(m.group(1)), fen_board) 92 | rows = fen_rep.split("/") 93 | rev_rows = rows[::-1] 94 | ordered_fen = "".join(rev_rows) 95 | 96 | config_tensor = torch.zeros((13, 8, 8), dtype=torch.float) 97 | ordinal_board = torch.tensor(tuple(map(piece_to_index, ordered_fen)), dtype=torch.float) 98 | ordinal_board = ordinal_board.reshape((8, 8)).unsqueeze(0) 99 | piece_tensor = torch.tensor(tuple(map(piece_to_index, plane_order)), dtype=torch.float) 100 | piece_tensor = piece_tensor.reshape((12, 1, 1)) 101 | config_tensor[:12] = (ordinal_board == piece_tensor).float() 102 | if self.is_repetition(2): # Might be wrong if the full history is not available 103 | config_tensor[12] = torch.ones((8, 8), dtype=torch.float) 104 | return config_tensor if us == chess.WHITE else config_tensor.flip(1) 105 | 106 | def to_input_tensor( 107 | self, 108 | *, 109 | input_encoding: InputEncoding = InputEncoding.INPUT_CLASSICAL_112_PLANE, 110 | ): 111 | """Create the lc0 input tensor from the history of a game. 112 | 113 | Parameters 114 | ---------- 115 | input_encoding : InputEncoding 116 | The input encoding method. 117 | 118 | Returns 119 | ------- 120 | torch.Tensor 121 | The 112x8x8 tensor. 122 | """ 123 | 124 | input_tensor = torch.zeros((112, 8, 8), dtype=torch.float) 125 | us = self.turn 126 | them = not us 127 | moves = [] 128 | 129 | if ( 130 | input_encoding == InputEncoding.INPUT_CLASSICAL_112_PLANE 131 | or input_encoding == InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED 132 | ): 133 | for i in range(8): 134 | config_tensor = self.to_config_tensor(us) 135 | input_tensor[i * 13 : (i + 1) * 13] = config_tensor 136 | try: 137 | moves.append(self.pop()) 138 | except IndexError: 139 | if input_encoding == InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED: 140 | input_tensor[(i + 1) * 13 : 104] = config_tensor.repeat(7 - i, 1, 1) 141 | break 142 | 143 | elif input_encoding == InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED: 144 | config_tensor = self.to_config_tensor(us) 145 | input_tensor[:104] = config_tensor.repeat(8, 1, 1) 146 | elif input_encoding == InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS: 147 | input_tensor[:13] = self.to_config_tensor(us) 148 | else: 149 | raise ValueError(f"Got unexpected input encoding {input_encoding}") 150 | 151 | # Restore the moves 152 | for move in reversed(moves): 153 | self.push(move) 154 | 155 | if self.has_queenside_castling_rights(us): 156 | input_tensor[104] = torch.ones((8, 8), dtype=torch.float) 157 | if self.has_kingside_castling_rights(us): 158 | input_tensor[105] = torch.ones((8, 8), dtype=torch.float) 159 | if self.has_queenside_castling_rights(them): 160 | input_tensor[106] = torch.ones((8, 8), dtype=torch.float) 161 | if self.has_kingside_castling_rights(them): 162 | input_tensor[107] = torch.ones((8, 8), dtype=torch.float) 163 | if us == chess.BLACK: 164 | input_tensor[108] = torch.ones((8, 8), dtype=torch.float) 165 | input_tensor[109] = torch.ones((8, 8), dtype=torch.float) * self.halfmove_clock 166 | input_tensor[111] = torch.ones((8, 8), dtype=torch.float) 167 | 168 | return input_tensor 169 | 170 | @staticmethod 171 | def encode_move( 172 | move: chess.Move, 173 | us: bool, 174 | ) -> int: 175 | """ 176 | Converts a chess.Move object to an index. 177 | 178 | Parameters 179 | ---------- 180 | move : chess.Move 181 | The chess move to encode. 182 | us : bool 183 | The side to move (True for white, False for black). 184 | 185 | Returns 186 | ------- 187 | int 188 | The encoded move index. 189 | """ 190 | from_square = move.from_square 191 | to_square = move.to_square 192 | 193 | if us == chess.BLACK: 194 | from_square_row = from_square // 8 195 | from_square_col = from_square % 8 196 | from_square = 8 * (7 - from_square_row) + from_square_col 197 | to_square_row = to_square // 8 198 | to_square_col = to_square % 8 199 | to_square = 8 * (7 - to_square_row) + to_square_col 200 | us_uci_move = chess.SQUARE_NAMES[from_square] + chess.SQUARE_NAMES[to_square] 201 | if move.promotion is not None: 202 | if move.promotion == chess.BISHOP: 203 | us_uci_move += "b" 204 | elif move.promotion == chess.ROOK: 205 | us_uci_move += "r" 206 | elif move.promotion == chess.QUEEN: 207 | us_uci_move += "q" 208 | # Knight promotion is the default 209 | return INVERTED_POLICY_INDEX[us_uci_move] 210 | 211 | def decode_move( 212 | self, 213 | index: int, 214 | ) -> chess.Move: 215 | """ 216 | Converts an index to a chess.Move object. 217 | 218 | Parameters 219 | ---------- 220 | index : int 221 | The index to convert. 222 | 223 | Returns 224 | ------- 225 | chess.Move 226 | The chess move. 227 | """ 228 | us = self.turn 229 | us_uci_move = POLICY_INDEX[index] 230 | from_square = chess.SQUARE_NAMES.index(us_uci_move[:2]) 231 | to_square = chess.SQUARE_NAMES.index(us_uci_move[2:4]) 232 | if us == chess.BLACK: 233 | from_square_row = from_square // 8 234 | from_square_col = from_square % 8 235 | from_square = 8 * (7 - from_square_row) + from_square_col 236 | to_square_row = to_square // 8 237 | to_square_col = to_square % 8 238 | to_square = 8 * (7 - to_square_row) + to_square_col 239 | 240 | uci_move = chess.SQUARE_NAMES[from_square] + chess.SQUARE_NAMES[to_square] 241 | from_piece = self.piece_at(from_square) 242 | if from_piece == chess.PAWN and to_square >= 56: # Knight promotion is the default 243 | uci_move += "n" 244 | return chess.Move.from_uci(uci_move) 245 | 246 | def get_legal_indices( 247 | self, 248 | ) -> torch.Tensor: 249 | """ 250 | Gets the legal indices. 251 | 252 | Returns 253 | ------- 254 | torch.Tensor 255 | Tensor containing indices of legal moves. 256 | """ 257 | us = self.turn 258 | return torch.tensor([self.encode_move(move, us) for move in self.legal_moves]) 259 | 260 | def get_next_legal_boards( 261 | self, 262 | n_history: int = 7, 263 | ) -> Generator["LczeroBoard", None, None]: 264 | """ 265 | Gets the next legal boards. 266 | 267 | Parameters 268 | ---------- 269 | n_history : int, optional 270 | Number of previous positions to keep in the move stack, by default 7. 271 | 272 | Returns 273 | ------- 274 | Generator[LczeroBoard, None, None] 275 | Generator yielding board positions after each legal move. 276 | """ 277 | working_board = self.copy(stack=n_history) 278 | for move in working_board.legal_moves: 279 | working_board.push(move) 280 | yield working_board.copy(stack=n_history) 281 | working_board.pop() 282 | 283 | def render_heatmap( 284 | self, 285 | heatmap: Union[torch.Tensor, np.ndarray], 286 | square: Optional[str] = None, 287 | vmin: Optional[float] = None, 288 | vmax: Optional[float] = None, 289 | arrows: Optional[List[Tuple[str, str]]] = None, 290 | normalise: str = "none", 291 | save_to: Optional[str] = None, 292 | cmap_name: str = "RdYlBu_r", 293 | alpha: float = 1.0, 294 | relative_board_view: bool = True, 295 | heatmap_mode: str = "relative_flip", 296 | ) -> Tuple[Optional[str], Any]: 297 | """Render a heatmap on the board. 298 | 299 | Parameters 300 | ---------- 301 | heatmap : torch.Tensor or numpy.ndarray 302 | The heatmap values to visualize on the board (64,) or (8, 8). 303 | square : Optional[str], default=None 304 | Chess square to highlight (e.g. 'e4'). 305 | vmin : Optional[float], default=None 306 | Minimum value for the colormap normalization. 307 | vmax : Optional[float], default=None 308 | Maximum value for the colormap normalization. 309 | arrows : Optional[List[Tuple[str, str]]], default=None 310 | List of arrow tuples (from_square, to_square) to draw on board. 311 | normalise : str, default="none" 312 | Normalization method. Use "abs" for absolute value normalization. 313 | save_to : Optional[str], default=None 314 | Path to save the visualization. If None, returns the figure. 315 | cmap_name : str, default="RdYlBu_r" 316 | Name of matplotlib colormap to use. 317 | alpha : float, default=1.0 318 | Opacity of the heatmap overlay. 319 | relative_board_view : bool, default=True 320 | Whether to use the relative board view. 321 | heatmap_mode : str, default="relative_flip" 322 | Use "relative_flip" if the heatmap corresponds to a relative flip of the board, 323 | "relative_rotation" if it corresponds to a relative rotation of the board, 324 | or "absolute" if it is already in the correct orientation. 325 | 326 | Returns 327 | ------- 328 | Union[Tuple[str, matplotlib.figure.Figure], None] 329 | If save_to is None, returns (SVG string, matplotlib figure). 330 | If save_to is provided, saves files and returns None. 331 | 332 | Raises 333 | ------ 334 | ValueError 335 | If save_to is provided and does not end with `.svg`. 336 | """ 337 | try: 338 | import matplotlib 339 | import matplotlib.pyplot as plt 340 | except ImportError as e: 341 | raise ImportError( 342 | "matplotlib is required to render heatmaps, install it with `pip install lczerolens[viz]`." 343 | ) from e 344 | 345 | if heatmap.ndim > 1: 346 | heatmap = heatmap.view(64) 347 | 348 | if heatmap_mode == "relative_flip": 349 | if not self.turn: 350 | heatmap = heatmap.view(8, 8).flip(0).view(64) 351 | elif heatmap_mode == "relative_rotation": 352 | if not self.turn: 353 | heatmap = heatmap.view(8, 8).flip(1).flip(0).view(64) 354 | elif heatmap_mode == "absolute": 355 | pass 356 | else: 357 | raise ValueError( 358 | f"Got unexpected heatmap_mode {heatmap_mode!r}. " 359 | "Valid options are ['relative_flip', 'relative_rotation', 'absolute']" 360 | ) 361 | 362 | cmap = matplotlib.colormaps[cmap_name].resampled(1000) 363 | 364 | if normalise == "abs": 365 | a_max = heatmap.abs().max() 366 | if a_max != 0: 367 | heatmap = heatmap / a_max 368 | vmin = -1 369 | vmax = 1 370 | if vmin is None: 371 | vmin = heatmap.min() 372 | if vmax is None: 373 | vmax = heatmap.max() 374 | norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False) 375 | 376 | color_dict = {} 377 | for square_index in range(64): 378 | color = cmap(norm(heatmap[square_index])) 379 | color = (*color[:3], alpha) 380 | color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True) 381 | fig = plt.figure(figsize=(1, 4.1)) 382 | ax = plt.gca() 383 | ax.axis("off") 384 | fig.colorbar( 385 | matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap), 386 | ax=ax, 387 | orientation="vertical", 388 | fraction=1.0, 389 | ) 390 | if square is not None: 391 | try: 392 | check = chess.parse_square(square) 393 | except ValueError: 394 | check = None 395 | else: 396 | check = None 397 | if arrows is None: 398 | arrows = [] 399 | 400 | svg_board = chess.svg.board( 401 | self, 402 | orientation=self.turn if relative_board_view else chess.WHITE, 403 | check=check, 404 | fill=color_dict, 405 | size=400, 406 | arrows=arrows, 407 | ) 408 | buffer = io.BytesIO() 409 | fig.savefig(buffer, format="svg") 410 | svg_colorbar = buffer.getvalue().decode("utf-8") 411 | plt.close() 412 | 413 | if save_to is None: 414 | return svg_board, svg_colorbar 415 | elif not save_to.endswith(".svg"): 416 | raise ValueError("only saving to `svg` is supported") 417 | 418 | with open(save_to.replace(".svg", "_board.svg"), "w") as f: 419 | f.write(svg_board) 420 | with open(save_to.replace(".svg", "_colorbar.svg"), "w") as f: 421 | f.write(svg_colorbar) 422 | -------------------------------------------------------------------------------- /src/lczerolens/concept.py: -------------------------------------------------------------------------------- 1 | """Class for concept-based XAI methods.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any 5 | 6 | import torch 7 | from sklearn import metrics 8 | from datasets import Features, Value, Sequence, ClassLabel 9 | 10 | from lczerolens.board import LczeroBoard 11 | 12 | 13 | class Concept(ABC): 14 | """ 15 | Class for concept-based XAI methods. 16 | """ 17 | 18 | @abstractmethod 19 | def compute_label( 20 | self, 21 | board: LczeroBoard, 22 | ) -> Any: 23 | """ 24 | Compute the label for a given model and input. 25 | """ 26 | pass 27 | 28 | @staticmethod 29 | @abstractmethod 30 | def compute_metrics( 31 | predictions, 32 | labels, 33 | ): 34 | """ 35 | Compute the metrics for a given model and input. 36 | """ 37 | pass 38 | 39 | @property 40 | @abstractmethod 41 | def features(self) -> Features: 42 | """ 43 | Return the features for the concept. 44 | """ 45 | pass 46 | 47 | 48 | class BinaryConcept(Concept): 49 | """ 50 | Class for binary concept-based XAI methods. 51 | """ 52 | 53 | features = Features( 54 | { 55 | "gameid": Value("string"), 56 | "moves": Sequence(Value("string")), 57 | "fen": Value("string"), 58 | "label": ClassLabel(num_classes=2), 59 | } 60 | ) 61 | 62 | @staticmethod 63 | def compute_metrics( 64 | predictions, 65 | labels, 66 | ): 67 | """ 68 | Compute the metrics for a given model and input. 69 | """ 70 | return { 71 | "accuracy": metrics.accuracy_score(labels, predictions), 72 | "precision": metrics.precision_score(labels, predictions), 73 | "recall": metrics.recall_score(labels, predictions), 74 | "f1": metrics.f1_score(labels, predictions), 75 | } 76 | 77 | 78 | class NullConcept(BinaryConcept): 79 | """ 80 | Class for binary concept-based XAI methods. 81 | """ 82 | 83 | def compute_label( 84 | self, 85 | board: LczeroBoard, 86 | ) -> Any: 87 | """ 88 | Compute the label for a given model and input. 89 | """ 90 | return 0 91 | 92 | 93 | class OrBinaryConcept(BinaryConcept): 94 | """ 95 | Class for binary concept-based XAI methods. 96 | """ 97 | 98 | def __init__(self, *concepts: BinaryConcept): 99 | for concept in concepts: 100 | if not isinstance(concept, BinaryConcept): 101 | raise ValueError(f"{concept} is not a BinaryConcept") 102 | self.concepts = concepts 103 | 104 | def compute_label( 105 | self, 106 | board: LczeroBoard, 107 | ) -> Any: 108 | """ 109 | Compute the label for a given model and input. 110 | """ 111 | return any(concept.compute_label(board) for concept in self.concepts) 112 | 113 | 114 | class AndBinaryConcept(BinaryConcept): 115 | """ 116 | Class for binary concept-based XAI methods. 117 | """ 118 | 119 | def __init__(self, *concepts: BinaryConcept): 120 | for concept in concepts: 121 | if not isinstance(concept, BinaryConcept): 122 | raise ValueError(f"{concept} is not a BinaryConcept") 123 | self.concepts = concepts 124 | 125 | def compute_label( 126 | self, 127 | board: LczeroBoard, 128 | ) -> Any: 129 | """ 130 | Compute the label for a given model and input. 131 | """ 132 | return all(concept.compute_label(board) for concept in self.concepts) 133 | 134 | 135 | class MulticlassConcept(Concept): 136 | """ 137 | Class for multiclass concept-based XAI methods. 138 | """ 139 | 140 | features = Features( 141 | { 142 | "gameid": Value("string"), 143 | "moves": Sequence(Value("string")), 144 | "fen": Value("string"), 145 | "label": Value("int32"), 146 | } 147 | ) 148 | 149 | @staticmethod 150 | def compute_metrics( 151 | predictions, 152 | labels, 153 | ): 154 | """ 155 | Compute the metrics for a given model and input. 156 | """ 157 | return { 158 | "accuracy": metrics.accuracy_score(labels, predictions), 159 | "precision": metrics.precision_score(labels, predictions, average="weighted"), 160 | "recall": metrics.recall_score(labels, predictions, average="weighted"), 161 | "f1": metrics.f1_score(labels, predictions, average="weighted"), 162 | } 163 | 164 | 165 | class ContinuousConcept(Concept): 166 | """ 167 | Class for continuous concept-based XAI methods. 168 | """ 169 | 170 | features = Features( 171 | { 172 | "gameid": Value("string"), 173 | "moves": Sequence(Value("string")), 174 | "fen": Value("string"), 175 | "label": Value("float32"), 176 | } 177 | ) 178 | 179 | @staticmethod 180 | def compute_metrics( 181 | predictions, 182 | labels, 183 | ): 184 | """ 185 | Compute the metrics for a given model and input. 186 | """ 187 | return { 188 | "rmse": metrics.root_mean_squared_error(labels, predictions), 189 | "mae": metrics.mean_absolute_error(labels, predictions), 190 | "r2": metrics.r2_score(labels, predictions), 191 | } 192 | 193 | 194 | def concept_collate_fn(batch): 195 | boards = [] 196 | labels = [] 197 | for element in batch: 198 | board = LczeroBoard(element["fen"]) 199 | for move in element["moves"]: 200 | board.push_san(move) 201 | boards.append(board) 202 | labels.append(element["label"]) 203 | return boards, labels, batch 204 | 205 | 206 | def concept_init_rel(output, infos): 207 | labels = infos[0] 208 | rel = torch.zeros_like(output) 209 | for i in range(rel.shape[0]): 210 | rel[i, labels[i]] = output[i, labels[i]] 211 | return rel 212 | -------------------------------------------------------------------------------- /src/lczerolens/concepts/__init__.py: -------------------------------------------------------------------------------- 1 | """Concepts module.""" 2 | 3 | from .material import HasMaterialAdvantage, HasPiece 4 | from .move import BestLegalMove, PieceBestLegalMove 5 | from .threat import HasMateThreat, HasThreat 6 | 7 | __all__ = [ 8 | "HasPiece", 9 | "HasThreat", 10 | "HasMateThreat", 11 | "HasMaterialAdvantage", 12 | "BestLegalMove", 13 | "PieceBestLegalMove", 14 | ] 15 | -------------------------------------------------------------------------------- /src/lczerolens/concepts/material.py: -------------------------------------------------------------------------------- 1 | """All concepts related to material.""" 2 | 3 | from typing import Dict, Optional 4 | 5 | import chess 6 | 7 | from lczerolens.board import LczeroBoard 8 | from lczerolens.concept import BinaryConcept 9 | 10 | 11 | class HasPiece(BinaryConcept): 12 | """Class for material concept-based XAI methods.""" 13 | 14 | def __init__( 15 | self, 16 | piece: str, 17 | relative: bool = True, 18 | ): 19 | """Initialize the class.""" 20 | self.piece = chess.Piece.from_symbol(piece) 21 | self.relative = relative 22 | 23 | def compute_label( 24 | self, 25 | board: LczeroBoard, 26 | ) -> int: 27 | """Compute the label for a given model and input.""" 28 | if self.relative: 29 | color = self.piece.color if board.turn else not self.piece.color 30 | else: 31 | color = self.piece.color 32 | squares = board.pieces(self.piece.piece_type, color) 33 | return 1 if len(squares) > 0 else 0 34 | 35 | 36 | class HasMaterialAdvantage(BinaryConcept): 37 | """Class for material concept-based XAI methods. 38 | 39 | Attributes 40 | ---------- 41 | piece_values : Dict[int, int] 42 | The piece values. 43 | """ 44 | 45 | piece_values = { 46 | chess.PAWN: 1, 47 | chess.KNIGHT: 3, 48 | chess.BISHOP: 3, 49 | chess.ROOK: 5, 50 | chess.QUEEN: 9, 51 | chess.KING: 0, 52 | } 53 | 54 | def __init__( 55 | self, 56 | relative: bool = True, 57 | ): 58 | """ 59 | Initialize the class. 60 | """ 61 | self.relative = relative 62 | 63 | def compute_label( 64 | self, 65 | board: LczeroBoard, 66 | piece_values: Optional[Dict[int, int]] = None, 67 | ) -> int: 68 | """ 69 | Compute the label for a given model and input. 70 | """ 71 | if piece_values is None: 72 | piece_values = self.piece_values 73 | if self.relative: 74 | us, them = board.turn, not board.turn 75 | else: 76 | us, them = chess.WHITE, chess.BLACK 77 | our_value = 0 78 | their_value = 0 79 | for piece in range(1, 7): 80 | our_value += len(board.pieces(piece, us)) * piece_values[piece] 81 | their_value += len(board.pieces(piece, them)) * piece_values[piece] 82 | return 1 if our_value > their_value else 0 83 | -------------------------------------------------------------------------------- /src/lczerolens/concepts/move.py: -------------------------------------------------------------------------------- 1 | """All concepts related to move.""" 2 | 3 | import chess 4 | import torch 5 | 6 | from lczerolens.board import LczeroBoard 7 | from lczerolens.model import LczeroModel, PolicyFlow 8 | from lczerolens.concept import BinaryConcept, MulticlassConcept 9 | 10 | 11 | class BestLegalMove(MulticlassConcept): 12 | """Class for move concept-based XAI methods.""" 13 | 14 | def __init__( 15 | self, 16 | model: LczeroModel, 17 | ): 18 | """Initialize the class.""" 19 | self.policy_flow = PolicyFlow(model) 20 | 21 | def compute_label( 22 | self, 23 | board: LczeroBoard, 24 | ) -> int: 25 | """Compute the label for a given model and input.""" 26 | (policy,) = self.policy_flow(board) 27 | policy = torch.softmax(policy.squeeze(0), dim=-1) 28 | 29 | legal_move_indices = [LczeroBoard.encode_move(move, board.turn) for move in board.legal_moves] 30 | sub_index = policy[legal_move_indices].argmax().item() 31 | return legal_move_indices[sub_index] 32 | 33 | 34 | class PieceBestLegalMove(BinaryConcept): 35 | """Class for move concept-based XAI methods.""" 36 | 37 | def __init__( 38 | self, 39 | model: LczeroModel, 40 | piece: str, 41 | ): 42 | """Initialize the class.""" 43 | self.policy_flow = PolicyFlow(model) 44 | self.piece = chess.Piece.from_symbol(piece) 45 | 46 | def compute_label( 47 | self, 48 | board: LczeroBoard, 49 | ) -> int: 50 | """Compute the label for a given model and input.""" 51 | (policy,) = self.policy_flow(board) 52 | policy = torch.softmax(policy.squeeze(0), dim=-1) 53 | 54 | legal_moves = list(board.legal_moves) 55 | legal_move_indices = [LczeroBoard.encode_move(move, board.turn) for move in legal_moves] 56 | sub_index = policy[legal_move_indices].argmax().item() 57 | best_legal_move = legal_moves[sub_index] 58 | if board.piece_at(best_legal_move.from_square) == self.piece: 59 | return 1 60 | return 0 61 | -------------------------------------------------------------------------------- /src/lczerolens/concepts/threat.py: -------------------------------------------------------------------------------- 1 | """All concepts related to threats.""" 2 | 3 | import chess 4 | 5 | from lczerolens.board import LczeroBoard 6 | from lczerolens.concept import BinaryConcept 7 | 8 | 9 | class HasThreat(BinaryConcept): 10 | """ 11 | Class for material concept-based XAI methods. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | piece: str, 17 | relative: bool = True, 18 | ): 19 | """ 20 | Initialize the class. 21 | """ 22 | self.piece = chess.Piece.from_symbol(piece) 23 | self.relative = relative 24 | 25 | def compute_label( 26 | self, 27 | board: LczeroBoard, 28 | ) -> int: 29 | """ 30 | Compute the label for a given model and input. 31 | """ 32 | if self.relative: 33 | color = self.piece.color if board.turn else not self.piece.color 34 | else: 35 | color = self.piece.color 36 | squares = board.pieces(self.piece.piece_type, color) 37 | for square in squares: 38 | if board.is_attacked_by(not color, square): 39 | return 1 40 | return 0 41 | 42 | 43 | class HasMateThreat(BinaryConcept): 44 | """ 45 | Class for material concept-based XAI methods. 46 | """ 47 | 48 | def compute_label( 49 | self, 50 | board: LczeroBoard, 51 | ) -> int: 52 | """ 53 | Compute the label for a given model and input. 54 | """ 55 | for move in board.legal_moves: 56 | board.push(move) 57 | if board.is_checkmate(): 58 | board.pop() 59 | return 1 60 | board.pop() 61 | return 0 62 | -------------------------------------------------------------------------------- /src/lczerolens/lens.py: -------------------------------------------------------------------------------- 1 | """Generic lens class.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Dict, Iterable, Generator, Callable, Type, Union, Optional, Any 5 | 6 | import torch 7 | import re 8 | 9 | from lczerolens.model import LczeroModel 10 | from lczerolens.board import LczeroBoard 11 | 12 | 13 | class Lens(ABC): 14 | """Generic lens class for analysing model activations.""" 15 | 16 | _lens_type: str 17 | _registry: Dict[str, Type["Lens"]] = {} 18 | _grad_enabled: bool = False 19 | 20 | @classmethod 21 | def register(cls, name: str) -> Callable: 22 | """Registers the lens. 23 | 24 | Parameters 25 | ---------- 26 | name : str 27 | The name of the lens. 28 | 29 | Returns 30 | ------- 31 | Callable 32 | The decorator to register the lens. 33 | 34 | Raises 35 | ------ 36 | ValueError 37 | If the lens name is already registered. 38 | """ 39 | 40 | if name in cls._registry: 41 | raise ValueError(f"Lens {name} already registered.") 42 | 43 | def decorator(subclass: Type["Lens"]): 44 | subclass._lens_type = name 45 | cls._registry[name] = subclass 46 | return subclass 47 | 48 | return decorator 49 | 50 | @classmethod 51 | def from_name(cls, name: str, *args, **kwargs) -> "Lens": 52 | """Returns the lens from its name. 53 | 54 | Parameters 55 | ---------- 56 | name : str 57 | The name of the lens. 58 | 59 | Returns 60 | ------- 61 | Lens 62 | The lens instance. 63 | 64 | Raises 65 | ------ 66 | KeyError 67 | If the lens name is not found. 68 | """ 69 | if name not in cls._registry: 70 | raise KeyError(f"Lens {name} not found.") 71 | return cls._registry[name](*args, **kwargs) 72 | 73 | def __init__(self, pattern: Optional[str] = None): 74 | """Initialise the lens. 75 | 76 | Parameters 77 | ---------- 78 | pattern : Optional[str], default=None 79 | The pattern to match the modules. 80 | """ 81 | if pattern is None: 82 | pattern = r"a^" # match nothing by default 83 | self._pattern = pattern 84 | self._reg_exp = re.compile(pattern) 85 | 86 | @property 87 | def pattern(self) -> str: 88 | """The pattern to match the modules.""" 89 | return self._pattern 90 | 91 | @pattern.setter 92 | def pattern(self, pattern: str): 93 | self._pattern = pattern 94 | self._reg_exp = re.compile(pattern) 95 | 96 | def _get_modules(self, model: LczeroModel) -> Generator[tuple[str, Any], None, None]: 97 | """Get the modules to intervene on.""" 98 | for name, module in model.named_modules(): 99 | fixed_name = name.lstrip(". ") # nnsight outputs names with a dot 100 | if self._reg_exp.match(fixed_name): 101 | yield fixed_name, module 102 | 103 | def is_compatible(self, model: LczeroModel) -> bool: 104 | """Returns whether the lens is compatible with the model. 105 | 106 | Parameters 107 | ---------- 108 | model : LczeroModel 109 | The NNsight model. 110 | 111 | Returns 112 | ------- 113 | bool 114 | Whether the lens is compatible with the model. 115 | """ 116 | return isinstance(model, LczeroModel) 117 | 118 | def _ensure_compatible(self, model: LczeroModel): 119 | """Ensure the lens is compatible with the model. 120 | 121 | Parameters 122 | ---------- 123 | model : LczeroModel 124 | The NNsight model. 125 | 126 | Raises 127 | ------ 128 | ValueError 129 | If the lens is not compatible with the model. 130 | """ 131 | if not self.is_compatible(model): 132 | raise ValueError(f"Lens {self._lens_type} is not compatible with model of type {type(model)}.") 133 | 134 | def prepare(self, model: LczeroModel, **kwargs) -> LczeroModel: 135 | """Prepare the model for the lens. 136 | 137 | Parameters 138 | ---------- 139 | model : LczeroModel 140 | The NNsight model. 141 | 142 | Returns 143 | ------- 144 | LczeroModel 145 | The prepared model. 146 | """ 147 | return model 148 | 149 | @abstractmethod 150 | def _intervene(self, model: LczeroModel, **kwargs) -> dict: 151 | """Intervene on the model. 152 | 153 | Parameters 154 | ---------- 155 | model : LczeroModel 156 | The NNsight model. 157 | 158 | Returns 159 | ------- 160 | dict 161 | The intervention results. 162 | """ 163 | pass 164 | 165 | def _trace( 166 | self, 167 | model: LczeroModel, 168 | *inputs: Union[LczeroBoard, torch.Tensor], 169 | model_kwargs: dict, 170 | intervention_kwargs: dict, 171 | ): 172 | """Trace the model and intervene on it. 173 | 174 | Parameters 175 | ---------- 176 | model : LczeroModel 177 | The NNsight model. 178 | inputs : Union[LczeroBoard, torch.Tensor] 179 | The inputs. 180 | model_kwargs : dict 181 | The model kwargs. 182 | intervention_kwargs : dict 183 | The intervention kwargs. 184 | 185 | Returns 186 | ------- 187 | dict 188 | The intervention results. 189 | """ 190 | with model.trace(*inputs, **model_kwargs): 191 | return self._intervene(model, **intervention_kwargs) 192 | 193 | def analyse( 194 | self, 195 | model: LczeroModel, 196 | *inputs: Union[LczeroBoard, torch.Tensor], 197 | **kwargs, 198 | ) -> dict: 199 | """Analyse the input. 200 | 201 | Parameters 202 | ---------- 203 | model : LczeroModel 204 | The NNsight model. 205 | inputs : Union[LczeroBoard, torch.Tensor] 206 | The inputs. 207 | 208 | Returns 209 | ------- 210 | dict 211 | The analysis results. 212 | 213 | Raises 214 | ------ 215 | ValueError 216 | If the lens is not compatible with the model. 217 | """ 218 | if not isinstance(model, LczeroModel): 219 | raise ValueError(f"Model is not a LczeroModel. Got {type(model)}.") 220 | self._ensure_compatible(model) 221 | model_kwargs = kwargs.get("model_kwargs", {}) 222 | prepared_model = self.prepare(model, **kwargs) 223 | with torch.set_grad_enabled(self._grad_enabled): 224 | return self._trace(prepared_model, *inputs, model_kwargs=model_kwargs, intervention_kwargs=kwargs) 225 | 226 | def analyse_batched( 227 | self, 228 | model: LczeroModel, 229 | iter_inputs: Iterable[Union[LczeroBoard, torch.Tensor]], 230 | **kwargs, 231 | ) -> Generator[dict, None, None]: 232 | """Analyse a batches of inputs. 233 | 234 | Parameters 235 | ---------- 236 | model : LczeroModel 237 | The NNsight model. 238 | iter_inputs : Iterable[Tuple[Union[LczeroBoard, torch.Tensor], dict]] 239 | The iterator over the inputs. 240 | 241 | Returns 242 | ------- 243 | Generator[dict, None, None] 244 | The iterator over the statistics. 245 | 246 | Raises 247 | ------ 248 | ValueError 249 | If the lens is not compatible with the model. 250 | """ 251 | self._ensure_compatible(model) 252 | model_kwargs = kwargs.get("model_kwargs", {}) 253 | prepared_model = self.prepare(model, **kwargs) 254 | for inputs, dynamic_intervention_kwargs in iter_inputs: 255 | kwargs.update(dynamic_intervention_kwargs) 256 | yield self._trace(prepared_model, *inputs, model_kwargs=model_kwargs, intervention_kwargs=kwargs) 257 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lenses module. 3 | """ 4 | 5 | from .activation import ActivationLens 6 | from .composite import CompositeLens 7 | from .gradient import GradientLens 8 | from .lrp import LrpLens 9 | from .patching import PatchingLens 10 | from .probing import ProbingLens 11 | 12 | __all__ = [ 13 | "ActivationLens", 14 | "CompositeLens", 15 | "GradientLens", 16 | "LrpLens", 17 | "PatchingLens", 18 | "ProbingLens", 19 | ] 20 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/activation.py: -------------------------------------------------------------------------------- 1 | """Activation lens.""" 2 | 3 | from lczerolens.model import LczeroModel 4 | from lczerolens.lens import Lens 5 | 6 | 7 | @Lens.register("activation") 8 | class ActivationLens(Lens): 9 | """ 10 | Class for activation-based XAI methods. 11 | 12 | Examples 13 | -------- 14 | 15 | .. code-block:: python 16 | 17 | model = LczeroModel.from_path(model_path) 18 | lens = ActivationLens() 19 | board = LczeroBoard() 20 | results = lens.analyse(board, model=model) 21 | """ 22 | 23 | def _intervene( 24 | self, 25 | model: LczeroModel, 26 | **kwargs, 27 | ) -> dict: 28 | save_inputs = kwargs.get("save_inputs", False) 29 | results = {} 30 | for name, module in self._get_modules(model): 31 | if save_inputs: 32 | results[f"{name}_input"] = module.input.save() 33 | results[f"{name}_output"] = module.output.save() 34 | return results 35 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/composite.py: -------------------------------------------------------------------------------- 1 | """Composite lens for XAI.""" 2 | 3 | from typing import List, Dict, Union, Any 4 | 5 | from lczerolens.lens import Lens 6 | from lczerolens.model import LczeroModel 7 | 8 | 9 | @Lens.register("composite") 10 | class CompositeLens(Lens): 11 | """Composite lens for XAI. 12 | 13 | Examples 14 | -------- 15 | 16 | .. code-block:: python 17 | 18 | model = LczeroModel.from_path(model_path) 19 | lens = CompositeLens([ActivationLens(), GradientLens()]) 20 | board = LczeroBoard() 21 | results = lens.analyse(board, model=model) 22 | """ 23 | 24 | def __init__(self, lenses: Union[List[Lens], Dict[str, Lens]], merge_results: bool = True): 25 | self._lens_map = lenses if isinstance(lenses, dict) else {f"lens_{i}": lens for i, lens in enumerate(lenses)} 26 | self.merge_results = merge_results 27 | 28 | def is_compatible(self, model: LczeroModel) -> bool: 29 | return all(lens.is_compatible(model) for lens in self._lens_map.values()) 30 | 31 | def prepare(self, model: LczeroModel, **kwargs) -> LczeroModel: 32 | for lens in self._lens_map.values(): 33 | model = lens.prepare(model, **kwargs) 34 | return model 35 | 36 | def _intervene(self, model: LczeroModel, **kwargs) -> Dict[str, Any]: 37 | results = {name: lens._intervene(model, **kwargs) for name, lens in self._lens_map.items()} 38 | if self.merge_results: 39 | return {k: v for d in results.values() for k, v in d.items()} 40 | return results 41 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/gradient.py: -------------------------------------------------------------------------------- 1 | """Compute Gradient heatmap for a given model and input.""" 2 | 3 | from lczerolens.model import LczeroModel 4 | from lczerolens.lens import Lens 5 | 6 | 7 | @Lens.register("gradient") 8 | class GradientLens(Lens): 9 | """Class for gradient-based XAI methods.""" 10 | 11 | _grad_enabled: bool = True 12 | 13 | def __init__(self, *, input_requires_grad: bool = True, **kwargs): 14 | self.input_requires_grad = input_requires_grad 15 | super().__init__(**kwargs) 16 | 17 | def _intervene(self, model: LczeroModel, **kwargs) -> dict: 18 | init_target = kwargs.get("init_target", lambda model: model.output["value"]) 19 | init_gradient = kwargs.get("init_gradient", lambda model: None) 20 | 21 | results = {} 22 | if self.input_requires_grad: 23 | model.input.requires_grad_(self.input_requires_grad) 24 | results["input_grad"] = model.input.grad.save() 25 | for name, module in self._get_modules(model): 26 | results[f"{name}_output_grad"] = module.output.grad.save() 27 | target = init_target(model) 28 | target.backward(gradient=init_gradient(model)) 29 | return results 30 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/lrp/__init__.py: -------------------------------------------------------------------------------- 1 | from .lens import LrpLens 2 | 3 | __all__ = ["LrpLens"] 4 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/lrp/lens.py: -------------------------------------------------------------------------------- 1 | """Compute LRP heatmap for a given model and input.""" 2 | 3 | from typing import Any 4 | 5 | 6 | from lczerolens.model import LczeroModel 7 | from lczerolens.lens import Lens 8 | 9 | 10 | @Lens.register("lrp") 11 | class LrpLens(Lens): 12 | """Class for wrapping the LCZero models.""" 13 | 14 | def _intervene( 15 | self, 16 | model: LczeroModel, 17 | **kwargs, 18 | ) -> Any: 19 | # TODO: Refactor this logic 20 | pass 21 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/lrp/rules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/src/lczerolens/lenses/lrp/rules/__init__.py -------------------------------------------------------------------------------- /src/lczerolens/lenses/lrp/rules/epsilon.py: -------------------------------------------------------------------------------- 1 | """Function classes to apply the LRP rules to the layers of the network. 2 | 3 | Classes 4 | ------- 5 | ElementwiseMultiplyUniform 6 | Distribute the relevance 100% to the input 7 | SoftmaxEpsilon 8 | Softmax with epsilon. 9 | MatrixMultiplicationEpsilon 10 | Matrix multiplication with epsilon. 11 | """ 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | 17 | 18 | def stabilize(tensor, epsilon=1e-6): 19 | return tensor + epsilon * ((-1) ** (tensor < 0)) 20 | 21 | 22 | class AddEpsilonFunction(Function): 23 | @staticmethod 24 | def forward(ctx, input_a, input_b, epsilon=1e-6): 25 | output = input_a + input_b 26 | ctx.save_for_backward(input_a, input_b, output, torch.tensor(epsilon)) 27 | return output 28 | 29 | @staticmethod 30 | def backward(ctx, *grad_output): 31 | input_a, input_b, output, epsilon = ctx.saved_tensors 32 | out_relevance = grad_output[0] / stabilize(output, epsilon) 33 | return out_relevance * input_a, out_relevance * input_b, None 34 | 35 | 36 | class AddEpsilon(torch.nn.Module): 37 | def __init__(self, epsilon=1e-6): 38 | super().__init__() 39 | self.epsilon = epsilon 40 | 41 | def forward(self, x, y): 42 | return AddEpsilonFunction.apply(x, y, self.epsilon) 43 | 44 | 45 | class MatMulEpsilonFunction(Function): 46 | @staticmethod 47 | def forward(ctx, input, param, epsilon=1e-6): 48 | output = torch.matmul(input, param) 49 | ctx.save_for_backward(input, param, output, torch.tensor(epsilon)) 50 | 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, *grad_outputs): 55 | input, param, output, epsilon = ctx.saved_tensors 56 | out_relevance = grad_outputs[0] 57 | 58 | out_relevance = out_relevance / stabilize(output, epsilon) 59 | relevance = (out_relevance @ param.T) * input 60 | return relevance, None, None 61 | 62 | 63 | class MatMulEpsilon(torch.nn.Module): 64 | def __init__(self, epsilon=1e-6): 65 | super().__init__() 66 | self.epsilon = epsilon 67 | 68 | def forward(self, x, y): 69 | return MatMulEpsilonFunction.apply(x, y, self.epsilon) 70 | 71 | 72 | class BilinearMatMulEpsilonFunction(Function): 73 | @staticmethod 74 | def forward(ctx, input_a, input_b, epsilon=1e-6): 75 | outputs = torch.matmul(input_a, input_b) 76 | ctx.save_for_backward(input_a, input_b, outputs, torch.tensor(epsilon)) 77 | 78 | return outputs 79 | 80 | @staticmethod 81 | def backward(ctx, *grad_outputs): 82 | input_a, input_b, outputs, epsilon = ctx.saved_tensors 83 | out_relevance = grad_outputs[0] 84 | 85 | out_relevance = out_relevance / stabilize(2 * outputs, epsilon) 86 | 87 | relevance_a = torch.matmul(out_relevance, input_b.permute(0, 1, -1, -2)) * input_a 88 | relevance_b = torch.matmul(input_a.permute(0, 1, -1, -2), out_relevance) * input_b 89 | 90 | return relevance_a, relevance_b, None 91 | 92 | 93 | class BilinearMatMulEpsilon(torch.nn.Module): 94 | def __init__(self, epsilon=1e-6): 95 | super().__init__() 96 | self.epsilon = epsilon 97 | 98 | def forward(self, x, y): 99 | return BilinearMatMulEpsilonFunction.apply(x, y, self.epsilon) 100 | 101 | 102 | class MulUniformFunction(Function): 103 | @staticmethod 104 | def forward(ctx, input_a, input_b): 105 | return input_a * input_b 106 | 107 | @staticmethod 108 | def backward(ctx, *grad_outputs): 109 | relevance = grad_outputs[0] * 0.5 110 | 111 | return relevance, relevance 112 | 113 | 114 | class MulUniform(torch.nn.Module): 115 | def forward(self, x, y): 116 | return MulUniformFunction.apply(x, y) 117 | 118 | 119 | class SoftmaxEpsilonFunction(Function): 120 | @staticmethod 121 | def forward(ctx, inputs, dim): 122 | outputs = F.softmax(inputs, dim=dim) 123 | ctx.save_for_backward(inputs, outputs) 124 | 125 | return outputs 126 | 127 | @staticmethod 128 | def backward(ctx, *grad_outputs): 129 | inputs, output = ctx.saved_tensors 130 | 131 | relevance = (grad_outputs[0] - (output * grad_outputs[0].sum(-1, keepdim=True))) * inputs 132 | 133 | return (relevance, None) 134 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/patching.py: -------------------------------------------------------------------------------- 1 | """Patching lens.""" 2 | 3 | from typing import Callable 4 | 5 | from lczerolens.model import LczeroModel 6 | from lczerolens.lens import Lens 7 | 8 | 9 | @Lens.register("patching") 10 | class PatchingLens(Lens): 11 | """ 12 | Class for activation-based XAI methods. 13 | 14 | Examples 15 | -------- 16 | 17 | .. code-block:: python 18 | 19 | model = LczeroModel.from_path(model_path) 20 | lens = PatchingLens() 21 | board = LczeroBoard() 22 | patch_fn = lambda n, m, *kwargs: pass 23 | results = lens.analyse(board, model=model) 24 | """ 25 | 26 | def __init__(self, patch_fn: Callable, **kwargs): 27 | self._patch_fn = patch_fn 28 | super().__init__(**kwargs) 29 | 30 | def _intervene( 31 | self, 32 | model: LczeroModel, 33 | **kwargs, 34 | ) -> dict: 35 | for name, module in self._get_modules(model): 36 | self._patch_fn(name, module, **kwargs) 37 | return {} 38 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/probing/__init__.py: -------------------------------------------------------------------------------- 1 | from .lens import ProbingLens 2 | 3 | __all__ = ["ProbingLens"] 4 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/probing/lens.py: -------------------------------------------------------------------------------- 1 | """Probing lens.""" 2 | 3 | from typing import Callable 4 | 5 | from lczerolens.model import LczeroModel 6 | from lczerolens.lens import Lens 7 | 8 | 9 | @Lens.register("probing") 10 | class ProbingLens(Lens): 11 | """ 12 | Class for probing-based XAI methods. 13 | 14 | Examples 15 | -------- 16 | 17 | .. code-block:: python 18 | 19 | model = LczeroModel.from_path(model_path) 20 | lens = ProbingLens(probe) 21 | board = LczeroBoard() 22 | results = lens.analyse(board, model=model) 23 | """ 24 | 25 | def __init__(self, probe_fn: Callable, **kwargs): 26 | self._probe_fn = probe_fn 27 | super().__init__(**kwargs) 28 | 29 | def _intervene( 30 | self, 31 | model: LczeroModel, 32 | **kwargs, 33 | ) -> dict: 34 | return {name: self._probe_fn(module.output.save()) for name, module in self._get_modules(model)} 35 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/probing/probe.py: -------------------------------------------------------------------------------- 1 | """Probing lens for XAI.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any 5 | 6 | import einops 7 | import torch 8 | 9 | 10 | EPS = 1e-6 11 | 12 | 13 | class Probe(ABC): 14 | """Abstract class for probes.""" 15 | 16 | def __init__(self): 17 | self._trained = False 18 | 19 | @abstractmethod 20 | def train( 21 | self, 22 | activations: torch.Tensor, 23 | labels: Any, 24 | **kwargs, 25 | ): 26 | """Train the probe.""" 27 | pass 28 | 29 | @abstractmethod 30 | def predict(self, activations: torch.Tensor, **kwargs): 31 | """Predict with the probe.""" 32 | pass 33 | 34 | 35 | class SignalCav(Probe): 36 | """Signal CAV probe.""" 37 | 38 | def train( 39 | self, 40 | activations: torch.Tensor, 41 | labels: torch.Tensor, 42 | **kwargs, 43 | ): 44 | if len(activations) != len(labels): 45 | raise ValueError("Number of activations and labels must match") 46 | if len(activations.shape) != 2: 47 | raise ValueError("Activations must a batch of tensors") 48 | if len(labels.shape) != 2: 49 | raise ValueError("Labels must a batch of tensors") 50 | 51 | mean_activation = activations.mean(dim=1, keepdim=True) 52 | mean_label = labels.mean(dim=1, keepdim=True) 53 | scaled_activations = activations - mean_activation 54 | scaled_labels = labels - mean_label 55 | cav = einops.einsum(scaled_activations, scaled_labels, "b a, b d -> a d") 56 | self._h = cav / (cav.norm(dim=0, keepdim=True) + EPS) 57 | 58 | def predict(self, activations: torch.Tensor, **kwargs): 59 | if not self._trained: 60 | raise ValueError("Probe not trained") 61 | 62 | if len(activations.shape) != 2: 63 | raise ValueError("Activations must a batch of tensors") 64 | 65 | dot_prod = einops.einsum(activations, self._h, "b a, a d -> b d") 66 | return dot_prod / (activations.norm(dim=1, keepdim=True) + EPS) 67 | -------------------------------------------------------------------------------- /src/lczerolens/lenses/sae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/src/lczerolens/lenses/sae/__init__.py -------------------------------------------------------------------------------- /src/lczerolens/lenses/sae/buffer.py: -------------------------------------------------------------------------------- 1 | """Activation lens for XAI.""" 2 | 3 | from typing import Any, Optional, Callable 4 | from dataclasses import dataclass 5 | 6 | import torch 7 | from datasets import Dataset 8 | from torch.utils.data import DataLoader, TensorDataset 9 | 10 | from lczerolens.model import LczeroModel 11 | 12 | 13 | @dataclass 14 | class ActivationBuffer: 15 | model: LczeroModel 16 | dataset: Dataset 17 | compute_fn: Callable[[Any, LczeroModel], torch.Tensor] 18 | n_batches_in_buffer: int = 10 19 | compute_batch_size: int = 64 20 | train_batch_size: int = 2048 21 | dataloader_kwargs: Optional[dict] = None 22 | logger: Optional[Callable] = None 23 | 24 | def __post_init__(self): 25 | if self.dataloader_kwargs is None: 26 | self.dataloader_kwargs = {} 27 | self._buffer = [] 28 | self._remainder = None 29 | self._make_dataloader_it() 30 | 31 | def _make_dataloader_it(self): 32 | self._dataloader_it = iter( 33 | DataLoader(self.dataset, batch_size=self.compute_batch_size, **self.dataloader_kwargs) 34 | ) 35 | 36 | @torch.no_grad() 37 | def _fill_buffer(self): 38 | if self.logger is not None: 39 | self.logger.info("Computing activations...") 40 | self._buffer = [] 41 | while len(self._buffer) < self.n_batches_in_buffer: 42 | try: 43 | next_batch = next(self._dataloader_it) 44 | except StopIteration: 45 | break 46 | activations = self.compute_fn(next_batch, self.model) 47 | self._buffer.append(activations.to("cpu")) 48 | if not self._buffer: 49 | raise StopIteration 50 | 51 | def _make_activations_it(self): 52 | if self._remainder is not None: 53 | self._buffer.append(self._remainder) 54 | self._remainder = None 55 | activations_ds = TensorDataset(torch.cat(self._buffer, dim=0)) 56 | if self.logger is not None: 57 | self.logger.info(f"Activations dataset of size {len(activations_ds)}") 58 | 59 | self._activations_it = iter( 60 | DataLoader( 61 | activations_ds, 62 | batch_size=self.train_batch_size, 63 | shuffle=True, 64 | ) 65 | ) 66 | 67 | def __iter__(self): 68 | self._make_dataloader_it() 69 | self._fill_buffer() 70 | self._make_activations_it() 71 | self._remainder = None 72 | return self 73 | 74 | def __next__(self): 75 | try: 76 | activations = next(self._activations_it)[0] 77 | if activations.shape[0] < self.train_batch_size: 78 | self._remainder = activations 79 | self._fill_buffer() 80 | self._make_activations_it() 81 | activations = next(self._activations_it)[0] 82 | return activations 83 | except StopIteration: 84 | try: 85 | self._fill_buffer() 86 | self._make_activations_it() 87 | self.__next__() 88 | except StopIteration as e: 89 | if self._remainder is not None: 90 | activations = self._remainder 91 | self._remainder = None 92 | return activations 93 | raise StopIteration from e 94 | raise StopIteration 95 | -------------------------------------------------------------------------------- /src/lczerolens/model.py: -------------------------------------------------------------------------------- 1 | """Class for wrapping the LCZero models.""" 2 | 3 | import os 4 | from typing import Dict, Type, Any, Tuple, Union 5 | 6 | import torch 7 | from onnx2torch import convert 8 | from onnx2torch.utils.safe_shape_inference import safe_shape_inference 9 | from tensordict import TensorDict 10 | from torch import nn 11 | from nnsight import NNsight 12 | from contextlib import contextmanager 13 | 14 | from lczerolens.board import InputEncoding, LczeroBoard 15 | 16 | 17 | class LczeroModel(NNsight): 18 | """Class for wrapping the LCZero models.""" 19 | 20 | def trace( 21 | self, 22 | *inputs: Any, 23 | **kwargs: Dict[str, Any], 24 | ): 25 | kwargs["scan"] = False 26 | kwargs["validate"] = False 27 | return super().trace(*inputs, **kwargs) 28 | 29 | def _execute(self, *prepared_inputs: torch.Tensor, **kwargs) -> Any: 30 | kwargs.pop("input_encoding", None) 31 | kwargs.pop("input_requires_grad", None) 32 | with self._ensure_proper_forward(): 33 | return super()._execute(*prepared_inputs, **kwargs) 34 | 35 | def _prepare_inputs(self, *inputs: Union[LczeroBoard, torch.Tensor], **kwargs) -> Tuple[Tuple[Any], int]: 36 | input_encoding = kwargs.pop("input_encoding", InputEncoding.INPUT_CLASSICAL_112_PLANE) 37 | input_requires_grad = kwargs.pop("input_requires_grad", False) 38 | 39 | if len(inputs) == 1 and isinstance(inputs[0], torch.Tensor): 40 | return inputs, len(inputs[0]) 41 | for board in inputs: 42 | if not isinstance(board, LczeroBoard): 43 | raise ValueError(f"Got invalid input type {type(board)}.") 44 | 45 | tensor_list = [board.to_input_tensor(input_encoding=input_encoding).unsqueeze(0) for board in inputs] 46 | batched_tensor = torch.cat(tensor_list, dim=0) 47 | if input_requires_grad: 48 | batched_tensor.requires_grad = True 49 | batched_tensor = batched_tensor.to(self.device) 50 | 51 | return (batched_tensor,), len(inputs) 52 | 53 | def __call__(self, *inputs, **kwargs): 54 | prepared_inputs, _ = self._prepare_inputs(*inputs, **kwargs) 55 | return self._execute(*prepared_inputs, **kwargs) 56 | 57 | def __getattr__(self, key): 58 | if self._envoy._tracer is None: 59 | return getattr(self._model, key) 60 | return super().__getattr__(key) 61 | 62 | def __setattr__(self, key, value): 63 | if ( 64 | (key not in ("_model", "_model_key")) 65 | and (isinstance(value, torch.nn.Module)) 66 | and (self._envoy._tracer is None) 67 | ): 68 | setattr(self._model, key, value) 69 | else: 70 | super().__setattr__(key, value) 71 | 72 | @property 73 | def device(self): 74 | """Returns the device.""" 75 | return next(self.parameters()).device 76 | 77 | @device.setter 78 | def device(self, device: torch.device): 79 | """Sets the device.""" 80 | self.to(device) 81 | 82 | @classmethod 83 | def from_path(cls, model_path: str) -> "LczeroModel": 84 | """Creates a wrapper from a model path. 85 | 86 | Parameters 87 | ---------- 88 | model_path : str 89 | Path to the model file (.onnx or .pt) 90 | 91 | Returns 92 | ------- 93 | LczeroModel 94 | The wrapped model instance 95 | 96 | Raises 97 | ------ 98 | NotImplementedError 99 | If the model file extension is not supported 100 | """ 101 | if model_path.endswith(".onnx"): 102 | return cls.from_onnx_path(model_path) 103 | elif model_path.endswith(".pt"): 104 | return cls.from_torch_path(model_path) 105 | else: 106 | raise NotImplementedError(f"Model path {model_path} is not supported.") 107 | 108 | @classmethod 109 | def from_onnx_path(cls, onnx_model_path: str, check: bool = True) -> "LczeroModel": 110 | """Builds a model from an ONNX file path. 111 | 112 | Parameters 113 | ---------- 114 | onnx_model_path : str 115 | Path to the ONNX model file 116 | check : bool, optional 117 | Whether to perform shape inference check, by default True 118 | 119 | Returns 120 | ------- 121 | LczeroModel 122 | The wrapped model instance 123 | 124 | Raises 125 | ------ 126 | FileNotFoundError 127 | If the model file does not exist 128 | ValueError 129 | If the model could not be loaded 130 | """ 131 | if not os.path.exists(onnx_model_path): 132 | raise FileNotFoundError(f"Model path {onnx_model_path} does not exist.") 133 | try: 134 | if check: 135 | onnx_model = safe_shape_inference(onnx_model_path) 136 | onnx_torch_model = convert(onnx_model) 137 | return cls(onnx_torch_model) 138 | except Exception as e: 139 | raise ValueError(f"Could not load model at {onnx_model_path}.") from e 140 | 141 | @classmethod 142 | def from_torch_path(cls, torch_model_path: str) -> "LczeroModel": 143 | """Builds a model from a PyTorch file path. 144 | 145 | Parameters 146 | ---------- 147 | torch_model_path : str 148 | Path to the PyTorch model file 149 | 150 | Returns 151 | ------- 152 | LczeroModel 153 | The wrapped model instance 154 | 155 | Raises 156 | ------ 157 | FileNotFoundError 158 | If the model file does not exist 159 | ValueError 160 | If the model could not be loaded or is not a valid model type 161 | """ 162 | if not os.path.exists(torch_model_path): 163 | raise FileNotFoundError(f"Model path {torch_model_path} does not exist.") 164 | try: 165 | torch_model = torch.load(torch_model_path) 166 | except Exception as e: 167 | raise ValueError(f"Could not load model at {torch_model_path}.") from e 168 | if isinstance(torch_model, LczeroModel): 169 | return torch_model 170 | elif isinstance(torch_model, nn.Module): 171 | return cls(torch_model) 172 | else: 173 | raise ValueError(f"Could not load model at {torch_model_path}.") 174 | 175 | @contextmanager 176 | def _ensure_proper_forward(self): 177 | old_forward = self._model.forward 178 | 179 | output_node = list(self._model.graph.nodes)[-1] 180 | output_names = [n.name.replace("output_", "") for n in output_node.all_input_nodes] 181 | 182 | def td_forward(x): 183 | old_out = old_forward(x) 184 | return TensorDict( 185 | {name: old_out[i] for i, name in enumerate(output_names)}, 186 | batch_size=x.shape[0], 187 | ) 188 | 189 | self._model.forward = td_forward 190 | yield 191 | self._model.forward = old_forward 192 | 193 | 194 | class Flow(LczeroModel): 195 | """Base class for isolating a flow.""" 196 | 197 | _flow_type: str 198 | _registry: Dict[str, Type["Flow"]] = {} 199 | 200 | def __init__( 201 | self, 202 | model_key, 203 | *args, 204 | **kwargs, 205 | ): 206 | if isinstance(model_key, LczeroModel): 207 | raise ValueError("Use the `from_model` classmethod to create a flow.") 208 | if not self.is_compatible(model_key): 209 | raise ValueError(f"The model does not have a {self._flow_type} head.") 210 | super().__init__(model_key, *args, **kwargs) 211 | 212 | @classmethod 213 | def register(cls, name: str): 214 | """Registers the flow. 215 | 216 | Parameters 217 | ---------- 218 | name : str 219 | The name of the flow to register. 220 | 221 | Returns 222 | ------- 223 | Callable 224 | Decorator function that registers the flow subclass. 225 | 226 | Raises 227 | ------ 228 | ValueError 229 | If the flow name is already registered. 230 | """ 231 | 232 | if name in cls._registry: 233 | raise ValueError(f"Flow {name} already registered.") 234 | 235 | def decorator(subclass): 236 | cls._registry[name] = subclass 237 | subclass._flow_type = name 238 | return subclass 239 | 240 | return decorator 241 | 242 | @classmethod 243 | def from_name(cls, name: str, *args, **kwargs) -> "Flow": 244 | """Returns the flow from its name. 245 | 246 | Parameters 247 | ---------- 248 | name : str 249 | The name of the flow to instantiate. 250 | *args 251 | Positional arguments passed to flow constructor. 252 | **kwargs 253 | Keyword arguments passed to flow constructor. 254 | 255 | Returns 256 | ------- 257 | Flow 258 | The instantiated flow. 259 | 260 | Raises 261 | ------ 262 | KeyError 263 | If the flow name is not found. 264 | """ 265 | if name not in cls._registry: 266 | raise KeyError(f"Flow {name} not found.") 267 | return cls._registry[name](*args, **kwargs) 268 | 269 | @classmethod 270 | def from_model(cls, name: str, model: LczeroModel, *args, **kwargs) -> "Flow": 271 | """Returns the flow from a model. 272 | 273 | Parameters 274 | ---------- 275 | name : str 276 | The name of the flow to instantiate. 277 | model : LczeroModel 278 | The model to create the flow from. 279 | *args 280 | Positional arguments passed to flow constructor. 281 | **kwargs 282 | Keyword arguments passed to flow constructor. 283 | 284 | Returns 285 | ------- 286 | Flow 287 | The instantiated flow. 288 | 289 | Raises 290 | ------ 291 | KeyError 292 | If the flow name is not found. 293 | """ 294 | if name not in cls._registry: 295 | raise KeyError(f"Flow {name} not found.") 296 | flow_class = cls._registry[name] 297 | return flow_class(model._model, *args, **kwargs) 298 | 299 | @classmethod 300 | def is_compatible(cls, model: nn.Module) -> bool: 301 | """Checks if the model is compatible with this flow. 302 | 303 | Parameters 304 | ---------- 305 | model : nn.Module 306 | The model to check compatibility with. 307 | 308 | Returns 309 | ------- 310 | bool 311 | Whether the model is compatible with this flow. 312 | """ 313 | return hasattr(model, cls._flow_type) or hasattr(model, f"output/{cls._flow_type}") 314 | 315 | @contextmanager 316 | def _ensure_proper_forward(self): 317 | """Rewrites the forward function to return the flow output.""" 318 | flow_type = getattr(self, "_flow_type", None) 319 | if flow_type is None: 320 | return 321 | 322 | with super()._ensure_proper_forward(): 323 | old_forward = self._model.forward 324 | 325 | def flow_forward(*inputs, **kwargs): 326 | out = old_forward(*inputs, **kwargs) 327 | return out[flow_type] 328 | 329 | self._model.forward = flow_forward 330 | yield 331 | self._model.forward = old_forward 332 | 333 | 334 | @Flow.register("policy") 335 | class PolicyFlow(Flow): 336 | """Class for isolating the policy flow.""" 337 | 338 | 339 | @Flow.register("value") 340 | class ValueFlow(Flow): 341 | """Class for isolating the value flow.""" 342 | 343 | 344 | @Flow.register("wdl") 345 | class WdlFlow(Flow): 346 | """Class for isolating the WDL flow.""" 347 | 348 | 349 | @Flow.register("mlh") 350 | class MlhFlow(Flow): 351 | """Class for isolating the MLH flow.""" 352 | 353 | 354 | @Flow.register("force_value") 355 | class ForceValueFlow(Flow): 356 | """Class for forcing and isolating the value flow.""" 357 | 358 | @classmethod 359 | def is_compatible(cls, model: nn.Module): 360 | return ValueFlow.is_compatible(model) or WdlFlow.is_compatible(model) 361 | 362 | @contextmanager 363 | def _ensure_proper_forward(self): 364 | flow_type = getattr(self, "_flow_type", None) 365 | if flow_type is None: 366 | return 367 | 368 | with LczeroModel._ensure_proper_forward(self): 369 | old_forward = self._model.forward 370 | 371 | def flow_forward(*inputs, **kwargs): 372 | out = old_forward(*inputs, **kwargs) 373 | if "value" in out.keys(): 374 | return out["value"] 375 | return out["wdl"] @ torch.tensor([1.0, 0.0, -1.0], device=out.device) 376 | 377 | self._model.forward = flow_forward 378 | yield 379 | self._model.forward = old_forward 380 | -------------------------------------------------------------------------------- /src/lczerolens/play/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Play module for the lczerolens package. 3 | """ 4 | 5 | from .sampling import ModelSampler, Sampler, RandomSampler, PolicySampler 6 | from .game import Game 7 | from .puzzle import Puzzle 8 | from . import sampling, game, puzzle 9 | 10 | __all__ = ["ModelSampler", "Sampler", "RandomSampler", "PolicySampler", "Game", "Puzzle", "sampling", "game", "puzzle"] 11 | -------------------------------------------------------------------------------- /src/lczerolens/play/game.py: -------------------------------------------------------------------------------- 1 | """Preproces functions for chess games.""" 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | from datasets import Features, Value, Sequence 7 | 8 | from lczerolens.board import LczeroBoard 9 | 10 | GAME_DATASET_FEATURES = Features( 11 | { 12 | "gameid": Value("string"), 13 | "moves": Value("string"), 14 | } 15 | ) 16 | 17 | BOARD_DATASET_FEATURES = Features( 18 | { 19 | "gameid": Value("string"), 20 | "moves": Sequence(Value("string")), 21 | "fen": Value("string"), 22 | } 23 | ) 24 | 25 | 26 | @dataclass 27 | class Game: 28 | gameid: str 29 | moves: List[str] 30 | book_exit: Optional[int] = None 31 | 32 | @classmethod 33 | def from_dict(cls, obj: Dict[str, str]) -> "Game": 34 | if "moves" not in obj: 35 | ValueError("The dict should contain `moves`.") 36 | if "gameid" not in obj: 37 | ValueError("The dict should contain `gameid`.") 38 | *pre, post = obj["moves"].split("{ Book exit }") 39 | if pre: 40 | if len(pre) > 1: 41 | raise ValueError("More than one book exit") 42 | (pre,) = pre 43 | parsed_pre_moves = [m for m in pre.split() if not m.endswith(".")] 44 | book_exit = len(parsed_pre_moves) 45 | else: 46 | parsed_pre_moves = [] 47 | book_exit = None 48 | parsed_moves = parsed_pre_moves + [m for m in post.split() if not m.endswith(".")] 49 | return cls( 50 | gameid=obj["gameid"], 51 | moves=parsed_moves, 52 | book_exit=book_exit, 53 | ) 54 | 55 | def to_boards( 56 | self, 57 | n_history: int = 0, 58 | skip_book_exit: bool = False, 59 | skip_first_n: int = 0, 60 | output_dict=True, 61 | ) -> List[Union[Dict[str, Any], LczeroBoard]]: 62 | working_board = LczeroBoard() 63 | if skip_first_n > 0 or (skip_book_exit and (self.book_exit is not None)): 64 | boards = [] 65 | else: 66 | if output_dict: 67 | boards = [ 68 | { 69 | "fen": working_board.fen(), 70 | "moves": [], 71 | "gameid": self.gameid, 72 | } 73 | ] 74 | else: 75 | boards = [working_board.copy(stack=n_history)] 76 | 77 | for i, move in enumerate(self.moves[:-1]): # skip the last move as it can be over 78 | working_board.push_san(move) 79 | if (i < skip_first_n) or (skip_book_exit and (self.book_exit is not None) and (i < self.book_exit)): 80 | continue 81 | if output_dict: 82 | save_board = working_board.copy(stack=n_history) 83 | boards.append( 84 | { 85 | "fen": save_board.root().fen(), 86 | "moves": [move.uci() for move in save_board.move_stack], 87 | "gameid": self.gameid, 88 | } 89 | ) 90 | else: 91 | boards.append(working_board.copy(stack=n_history)) 92 | return boards 93 | 94 | @staticmethod 95 | def board_collate_fn(batch): 96 | boards = [] 97 | for element in batch: 98 | board = LczeroBoard(element["fen"]) 99 | for move in element["moves"]: 100 | board.push_san(move) 101 | boards.append(board) 102 | return boards, {} 103 | -------------------------------------------------------------------------------- /src/lczerolens/play/puzzle.py: -------------------------------------------------------------------------------- 1 | """Preproces functions for chess puzzles.""" 2 | 3 | from dataclasses import dataclass 4 | from typing import Dict, List, Union, Tuple, Optional, Iterable 5 | 6 | import chess 7 | import torch 8 | from datasets import Features, Value 9 | from itertools import tee, chain 10 | 11 | from lczerolens.board import LczeroBoard 12 | from .sampling import Sampler 13 | 14 | 15 | PUZZLE_DATASET_FEATURES = Features( 16 | { 17 | "PuzzleId": Value("string"), 18 | "FEN": Value("string"), 19 | "Moves": Value("string"), 20 | "Rating": Value("int64"), 21 | "RatingDeviation": Value("int64"), 22 | "Popularity": Value("int64"), 23 | "NbPlays": Value("int64"), 24 | "Themes": Value("string"), 25 | "GameUrl": Value("string"), 26 | "OpeningTags": Value("string"), 27 | } 28 | ) 29 | 30 | 31 | @dataclass 32 | class Puzzle: 33 | puzzle_id: str 34 | fen: str 35 | initial_move: chess.Move 36 | moves: List[chess.Move] 37 | rating: int 38 | rating_deviation: int 39 | popularity: int 40 | nb_plays: int 41 | themes: List[str] 42 | game_url: str 43 | opening_tags: List[str] 44 | 45 | @classmethod 46 | def from_dict(cls, obj: Dict[str, Union[str, int, None]]) -> "Puzzle": 47 | uci_moves = obj["Moves"].split() 48 | moves = [chess.Move.from_uci(uci_move) for uci_move in uci_moves] 49 | return cls( 50 | puzzle_id=obj["PuzzleId"], 51 | fen=obj["FEN"], 52 | initial_move=moves[0], 53 | moves=moves[1:], 54 | rating=obj["Rating"], 55 | rating_deviation=obj["RatingDeviation"], 56 | popularity=obj["Popularity"], 57 | nb_plays=obj["NbPlays"], 58 | themes=obj["Themes"].split() if obj["Themes"] is not None else [], 59 | game_url=obj["GameUrl"], 60 | opening_tags=obj["OpeningTags"].split() if obj["OpeningTags"] is not None else [], 61 | ) 62 | 63 | def __len__(self) -> int: 64 | return len(self.moves) 65 | 66 | @property 67 | def initial_board(self) -> LczeroBoard: 68 | board = LczeroBoard(self.fen) 69 | board.push(self.initial_move) 70 | return board 71 | 72 | def board_move_generator(self, all_moves: bool = False) -> Iterable[Tuple[LczeroBoard, chess.Move]]: 73 | board = self.initial_board 74 | initial_turn = board.turn 75 | for move in self.moves: 76 | if not all_moves and board.turn != initial_turn: 77 | board.push(move) 78 | continue 79 | yield board.copy(), move 80 | board.push(move) 81 | 82 | @classmethod 83 | def evaluate_multiple( 84 | cls, 85 | puzzles: Iterable["Puzzle"], 86 | sampler: Sampler, 87 | all_moves: bool = False, 88 | compute_metrics: bool = True, 89 | **kwargs, 90 | ) -> Union[Iterable[Dict[str, float]], Iterable[Tuple[torch.Tensor, torch.Tensor, chess.Move]]]: 91 | metric_puzzles, board_move_puzzles = tee(puzzles) 92 | board_move_generator = chain.from_iterable( 93 | puzzle.board_move_generator(all_moves) for puzzle in board_move_puzzles 94 | ) 95 | 96 | def board_generator(): 97 | for board, _ in board_move_generator: 98 | yield board 99 | 100 | util_boards, move_boards = tee(board_generator()) 101 | 102 | def metric_inputs_generator(): 103 | util_gen = sampler.get_utilities(util_boards, **kwargs) 104 | for board, (utility, legal_indices, _) in zip(move_boards, util_gen): 105 | predicted_move = sampler.choose_move(board, utility, legal_indices) 106 | yield utility, legal_indices, predicted_move 107 | 108 | if compute_metrics: 109 | return cls.compute_metrics(metric_puzzles, metric_inputs_generator(), all_moves=all_moves) 110 | else: 111 | return metric_inputs_generator() 112 | 113 | def evaluate(self, sampler: Sampler, all_moves: bool = False, **kwargs) -> Tuple[float, Optional[float]]: 114 | return next(iter(self.evaluate_multiple([self], sampler, all_moves, **kwargs))) 115 | 116 | @staticmethod 117 | def compute_metrics( 118 | puzzles: Iterable["Puzzle"], 119 | inputs: Iterable[Tuple[torch.Tensor, torch.Tensor, chess.Move]], 120 | all_moves: bool = False, 121 | ) -> Iterable[Dict[str, float]]: 122 | iter_inputs = iter(inputs) 123 | for puzzle in puzzles: 124 | total = len(puzzle) if all_moves else (len(puzzle) + 1) // 2 125 | metrics = {"score": 0.0, "perplexity": 1.0, "normalized_perplexity": 1.0} 126 | for board, move in puzzle.board_move_generator(all_moves=all_moves): 127 | utility, legal_indices, predicted_move = next(iter_inputs) 128 | index = LczeroBoard.encode_move(move, board.turn) 129 | probs = torch.softmax(utility, dim=0) 130 | move_prob = probs[legal_indices == index].item() 131 | metrics["perplexity"] *= move_prob ** (-1 / total) 132 | metrics["normalized_perplexity"] *= (len(legal_indices) * move_prob) ** (-1 / total) 133 | if predicted_move == move: 134 | metrics["score"] += 1 135 | metrics["score"] /= total 136 | yield metrics 137 | 138 | def _repr_svg_(self) -> str: 139 | return self.initial_board._repr_svg_() 140 | -------------------------------------------------------------------------------- /src/lczerolens/play/sampling.py: -------------------------------------------------------------------------------- 1 | """Classes for playing.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from dataclasses import dataclass 5 | from typing import Optional, Callable, Tuple, Dict, Iterable 6 | 7 | import chess 8 | import torch 9 | from torch.distributions import Categorical 10 | from itertools import tee 11 | 12 | from lczerolens.model import LczeroModel 13 | from lczerolens.board import LczeroBoard 14 | 15 | 16 | class Sampler(ABC): 17 | @abstractmethod 18 | def get_utilities( 19 | self, boards: Iterable[LczeroBoard], **kwargs 20 | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]]: 21 | """Get the utility of the board. 22 | 23 | Parameters 24 | ---------- 25 | boards : Iterable[LczeroBoard] 26 | The boards to evaluate. 27 | 28 | Returns 29 | ------- 30 | Iterable[Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]] 31 | The iterable over utilities, legal indices, and log dictionaries. 32 | """ 33 | pass 34 | 35 | def choose_move(self, board: LczeroBoard, utility: torch.Tensor, legal_indices: torch.Tensor) -> chess.Move: 36 | """Choose the next moves. 37 | 38 | Parameters 39 | ---------- 40 | board : LczeroBoard 41 | The board. 42 | utility : torch.Tensor 43 | The utility of the board. 44 | legal_indices : torch.Tensor 45 | The legal indices. 46 | 47 | Returns 48 | ------- 49 | Iterable[chess.Move] 50 | The iterable over the moves. 51 | """ 52 | m = Categorical(logits=utility) 53 | idx = m.sample() 54 | return board.decode_move(legal_indices[idx]) 55 | 56 | def get_next_moves(self, boards: Iterable[LczeroBoard], **kwargs) -> Iterable[Tuple[chess.Move, Dict[str, float]]]: 57 | """Get the next move. 58 | 59 | Parameters 60 | ---------- 61 | boards : Iterable[LczeroBoard] 62 | The boards to evaluate. 63 | 64 | Returns 65 | ------- 66 | Iterable[Tuple[chess.Move, Dict[str, float]]] 67 | The iterable over the moves and log dictionaries. 68 | """ 69 | util_boards, move_boards = tee(boards) 70 | for board, (utility, legal_indices, to_log) in zip(move_boards, self.get_utilities(util_boards, **kwargs)): 71 | predicted_move = self.choose_move(board, utility, legal_indices) 72 | yield predicted_move, to_log 73 | 74 | 75 | class RandomSampler(Sampler): 76 | def get_utilities( 77 | self, boards: Iterable[LczeroBoard], **kwargs 78 | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]]: 79 | for board in boards: 80 | legal_indices = board.get_legal_indices() 81 | utilities = torch.ones_like(legal_indices, dtype=torch.float32) 82 | yield utilities, legal_indices, {} 83 | 84 | 85 | @dataclass 86 | class ModelSampler(Sampler): 87 | model: LczeroModel 88 | use_argmax: bool = True 89 | alpha: float = 1.0 90 | beta: float = 1.0 91 | gamma: float = 1.0 92 | draw_score: float = 0.0 93 | m_max: float = 0.0345 94 | m_slope: float = 0.0027 95 | k_0: float = 0.0 96 | k_1: float = 1.6521 97 | k_2: float = -0.6521 98 | q_threshold: float = 0.8 99 | 100 | def choose_move(self, board: LczeroBoard, utility: torch.Tensor, legal_indices: torch.Tensor) -> chess.Move: 101 | if self.use_argmax: 102 | idx = utility.argmax() 103 | return board.decode_move(legal_indices[idx]) 104 | return super().choose_move(board, utility, legal_indices) 105 | 106 | @torch.no_grad() 107 | def get_utilities( 108 | self, boards: Iterable[LczeroBoard], **kwargs 109 | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]]: 110 | batch_size = kwargs.pop("batch_size", -1) 111 | callback = kwargs.pop("callback", None) 112 | 113 | for legal_indices, batch_stats in self._get_batched_stats(boards, batch_size, **kwargs): 114 | to_log = {} 115 | utility = 0 116 | q_values = self._get_q_values(batch_stats, to_log) 117 | utility += self.alpha * q_values 118 | utility += self.beta * self._get_m_values(batch_stats, q_values, to_log) 119 | utility += self.gamma * self._get_p_values(batch_stats, legal_indices, to_log) 120 | to_log["max_utility"] = utility.max().item() 121 | 122 | self._use_callback(callback, batch_stats, to_log) 123 | 124 | yield utility, legal_indices, to_log 125 | 126 | def _get_batched_stats(self, boards, batch_size, use_next_boards=True, **kwargs): 127 | next_batch = [] 128 | next_legal_indices = [] 129 | 130 | def generator(next_batch, next_legal_indices): 131 | all_stats = self.model(*next_batch, **kwargs) 132 | offset = 0 133 | for legal_indices in next_legal_indices: 134 | n_boards = legal_indices.shape[0] + 1 if use_next_boards else 1 135 | batch_stats = all_stats[offset : offset + n_boards] 136 | offset += n_boards 137 | yield legal_indices, batch_stats 138 | 139 | for board in boards: 140 | legal_indices = board.get_legal_indices() 141 | next_boards = list(board.get_next_legal_boards()) if use_next_boards else [] 142 | if len(next_batch) + len(next_boards) + 1 > batch_size and batch_size != -1: 143 | yield from generator(next_batch, next_legal_indices) 144 | next_batch = [] 145 | next_legal_indices = [] 146 | next_batch.extend([board] + next_boards) 147 | next_legal_indices.append(legal_indices) 148 | if next_batch: 149 | yield from generator(next_batch, next_legal_indices) 150 | 151 | def _get_q_values(self, batch_stats, to_log): 152 | if "value" in batch_stats.keys(): 153 | to_log["value"] = batch_stats["value"][0].item() 154 | return batch_stats["value"][1:, 0] 155 | elif "wdl" in batch_stats.keys(): 156 | to_log["wdl_w"] = batch_stats["wdl"][0][0].item() 157 | to_log["wdl_d"] = batch_stats["wdl"][0][1].item() 158 | to_log["wdl_l"] = batch_stats["wdl"][0][2].item() 159 | scores = torch.tensor([1, self.draw_score, -1]) 160 | return batch_stats["wdl"][1:] @ scores 161 | return torch.zeros(batch_stats.batch_size[0] - 1) 162 | 163 | def _get_m_values(self, batch_stats, q_values, to_log): 164 | if "mlh" in batch_stats.keys(): 165 | to_log["mlh"] = batch_stats["mlh"][0].item() 166 | delta_m_values = self.m_slope * (batch_stats["mlh"][1:, 0] - batch_stats["mlh"][0, 0]) 167 | delta_m_values.clamp_(-self.m_max, self.m_max) 168 | scaled_q_values = torch.relu(q_values.abs() - self.q_threshold) / (1 - self.q_threshold) 169 | poly_q_values = self.k_0 + self.k_1 * scaled_q_values + self.k_2 * scaled_q_values**2 170 | return -q_values.sign() * delta_m_values * poly_q_values 171 | return torch.zeros(batch_stats.batch_size[0] - 1) 172 | 173 | def _get_p_values( 174 | self, 175 | batch_stats, 176 | legal_indices, 177 | to_log, 178 | ): 179 | if "policy" in batch_stats.keys(): 180 | legal_policy = batch_stats["policy"][0].gather(0, legal_indices) 181 | to_log["max_legal_policy"] = legal_policy.max().item() 182 | return legal_policy 183 | return torch.zeros_like(legal_indices) 184 | 185 | def _use_callback(self, callback, batch_stats, to_log): 186 | if callback is not None: 187 | to_log_update = callback(batch_stats, to_log) 188 | if not isinstance(to_log_update, dict): 189 | raise ValueError("Callback must return a dictionary.") 190 | to_log |= to_log_update 191 | 192 | 193 | @dataclass 194 | class PolicySampler(ModelSampler): 195 | use_suboptimal: bool = False 196 | 197 | @torch.no_grad() 198 | def get_utilities( 199 | self, boards: Iterable[LczeroBoard], **kwargs 200 | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]]: 201 | batch_size = kwargs.pop("batch_size", -1) 202 | callback = kwargs.pop("callback", None) 203 | 204 | to_log = {} 205 | for legal_indices, batch_stats in self._get_batched_stats(boards, batch_size, use_next_boards=False, **kwargs): 206 | legal_policy = batch_stats["policy"][0].gather(0, legal_indices.to(batch_stats["policy"].device)) 207 | if self.use_suboptimal: 208 | idx = legal_policy.argmax() 209 | legal_policy[idx] = torch.tensor(-1e3) 210 | 211 | self._use_callback(callback, batch_stats, to_log) 212 | 213 | yield legal_policy, legal_indices, to_log 214 | 215 | 216 | @dataclass 217 | class SelfPlay: 218 | """A class for generating games.""" 219 | 220 | white: Sampler 221 | black: Sampler 222 | 223 | def play( 224 | self, 225 | board: Optional[LczeroBoard] = None, 226 | max_moves: int = 100, 227 | to_play: chess.Color = chess.WHITE, 228 | report_fn: Optional[Callable[[dict, chess.Color], None]] = None, 229 | white_kwargs: Optional[Dict] = None, 230 | black_kwargs: Optional[Dict] = None, 231 | ): 232 | """ 233 | Plays a game. 234 | """ 235 | if board is None: 236 | board = LczeroBoard() 237 | if white_kwargs is None: 238 | white_kwargs = {} 239 | if black_kwargs is None: 240 | black_kwargs = {} 241 | game = [] 242 | if to_play == chess.BLACK: 243 | move, _ = next(iter(self.black.get_next_moves([board], **black_kwargs))) 244 | board.push(move) 245 | game.append(move) 246 | for _ in range(max_moves): 247 | if board.is_game_over() or len(game) >= max_moves: 248 | break 249 | move, to_log = next(iter(self.white.get_next_moves([board], **white_kwargs))) 250 | if report_fn is not None: 251 | report_fn(to_log, board.turn) 252 | board.push(move) 253 | game.append(move) 254 | 255 | if board.is_game_over() or len(game) >= max_moves: 256 | break 257 | move, to_log = next(iter(self.black.get_next_moves([board], **black_kwargs))) 258 | if report_fn is not None: 259 | report_fn(to_log, board.turn) 260 | board.push(move) 261 | game.append(move) 262 | if board.is_game_over() or len(game) >= max_moves: 263 | break 264 | return game, board 265 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/__init__.py -------------------------------------------------------------------------------- /tests/assets/error.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "assert False" 10 | ] 11 | } 12 | ], 13 | "metadata": { 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "nbformat": 4, 19 | "nbformat_minor": 2 20 | } 21 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | File to test the encodings for the Leela Chess Zero engine. 3 | """ 4 | 5 | import onnxruntime as ort 6 | import pytest 7 | from lczero.backends import Backend, Weights 8 | 9 | from lczerolens import LczeroModel 10 | from lczerolens import backends as lczero_utils 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def tiny_lczero_backend(): 15 | lczero_weights = Weights("assets/tinygyal-8.pb.gz") 16 | yield Backend(weights=lczero_weights) 17 | 18 | 19 | @pytest.fixture(scope="session") 20 | def tiny_ensure_network(): 21 | lczero_utils.convert_to_onnx("assets/tinygyal-8.pb.gz", "assets/tinygyal-8.onnx") 22 | yield 23 | 24 | 25 | @pytest.fixture(scope="session") 26 | def tiny_model(tiny_ensure_network): 27 | yield LczeroModel.from_path("assets/tinygyal-8.onnx") 28 | 29 | 30 | @pytest.fixture(scope="session") 31 | def tiny_senet_ort(tiny_ensure_network): 32 | senet_ort = ort.InferenceSession("assets/tinygyal-8.onnx") 33 | yield senet_ort 34 | 35 | 36 | @pytest.fixture(scope="class") 37 | def maia_ensure_network(): 38 | lczero_utils.convert_to_onnx("assets/maia-1100.pb.gz", "assets/maia-1100.onnx") 39 | yield 40 | 41 | 42 | @pytest.fixture(scope="class") 43 | def maia_model(maia_ensure_network): 44 | yield LczeroModel.from_path("assets/maia-1100.onnx") 45 | 46 | 47 | @pytest.fixture(scope="class") 48 | def maia_senet_ort(maia_ensure_network): 49 | senet_ort = ort.InferenceSession("assets/maia-1100.onnx") 50 | yield senet_ort 51 | 52 | 53 | @pytest.fixture(scope="class") 54 | def winner_ensure_network(): 55 | lczero_utils.convert_to_onnx( 56 | "assets/384x30-2022_0108_1903_17_608.pb.gz", 57 | "assets/384x30-2022_0108_1903_17_608.onnx", 58 | ) 59 | yield 60 | 61 | 62 | @pytest.fixture(scope="class") 63 | def winner_model(winner_ensure_network): 64 | yield LczeroModel.from_path("assets/384x30-2022_0108_1903_17_608.onnx") 65 | 66 | 67 | @pytest.fixture(scope="class") 68 | def winner_senet_ort(winner_ensure_network): 69 | yield ort.InferenceSession("assets/384x30-2022_0108_1903_17_608.onnx") 70 | 71 | 72 | def pytest_addoption(parser): 73 | parser.addoption("--run-slow", action="store_true", default=False, help="run slow tests") 74 | parser.addoption("--run-fast", action="store_true", default=False, help="run fast tests") 75 | parser.addoption("--run-backends", action="store_true", default=False, help="run backends tests") 76 | 77 | 78 | def pytest_configure(config): 79 | config.addinivalue_line("markers", "slow: mark test as slow to run") 80 | config.addinivalue_line("markers", "backends: mark test as backends test") 81 | 82 | 83 | def pytest_collection_modifyitems(config, items): 84 | run_slow = config.getoption("--run-slow") 85 | run_fast = config.getoption("--run-fast") 86 | run_backends = config.getoption("--run-backends") 87 | 88 | skip_slow = pytest.mark.skip(reason="--run-slow not given in cli: skipping slow tests") 89 | skip_fast = pytest.mark.skip(reason="--run-fast not given in cli: skipping fast tests") 90 | skip_backends = pytest.mark.skip(reason="--run-backends not given in cli: skipping backends tests") 91 | 92 | for item in items: 93 | if "slow" in item.keywords and not run_slow: 94 | item.add_marker(skip_slow) 95 | if "fast" in item.keywords and not run_fast: 96 | item.add_marker(skip_fast) 97 | if "backends" in item.keywords and not run_backends: 98 | item.add_marker(skip_backends) 99 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/test_notebooks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integration tests using the notebooks. 3 | """ 4 | 5 | import subprocess 6 | import pytest 7 | 8 | NOTEBOOKS = [ 9 | "docs/source/notebooks/features/visualise-heatmaps.ipynb", 10 | "docs/source/notebooks/features/probe-concepts.ipynb", 11 | "docs/source/notebooks/features/convert-official-weights.ipynb", 12 | "docs/source/notebooks/features/move-prediction.ipynb", 13 | "docs/source/notebooks/tutorials/piece-value-estimation-using-lrp.ipynb", 14 | "docs/source/notebooks/walkthrough.ipynb", 15 | ] 16 | 17 | 18 | def run_notebook(notebook): 19 | result = subprocess.run( 20 | ["uv", "run", "jupyter", "nbconvert", "--to", "notebook", "--execute", notebook], 21 | stderr=subprocess.PIPE, 22 | ) 23 | if result.returncode != 0: 24 | raise subprocess.CalledProcessError(result.returncode, result.args, result.stderr) 25 | 26 | 27 | class TestNotebooks: 28 | def test_error_notebook(self): 29 | with pytest.raises(subprocess.CalledProcessError): 30 | run_notebook("tests/assets/error.ipynb") 31 | 32 | @pytest.mark.slow 33 | @pytest.mark.parametrize("notebook", NOTEBOOKS) 34 | def test_notebook(self, notebook): 35 | run_notebook(notebook) 36 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/concepts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/unit/concepts/__init__.py -------------------------------------------------------------------------------- /tests/unit/concepts/test_concepts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test cases for the concept module. 3 | """ 4 | 5 | from lczerolens.concept import BinaryConcept, AndBinaryConcept 6 | from lczerolens.concepts import ( 7 | HasPiece, 8 | HasThreat, 9 | ) 10 | from lczerolens import LczeroBoard 11 | 12 | 13 | class TestBinaryConcept: 14 | """ 15 | Test cases for the BinaryConcept class. 16 | """ 17 | 18 | def test_compute_metrics(self): 19 | """ 20 | Test the compute_metrics method. 21 | """ 22 | predictions = [0, 1, 0, 1] 23 | labels = [0, 1, 1, 1] 24 | metrics = BinaryConcept.compute_metrics(predictions, labels) 25 | assert metrics["accuracy"] == 0.75 26 | assert metrics["precision"] == 1.0 27 | assert metrics["recall"] == 0.6666666666666666 28 | 29 | def test_compute_label(self): 30 | """ 31 | Test the compute_label method. 32 | """ 33 | concept = AndBinaryConcept(HasPiece("p"), HasPiece("n")) 34 | assert concept.compute_label(LczeroBoard("8/8/8/8/8/8/8/8 w - - 0 1")) == 0 35 | assert concept.compute_label(LczeroBoard("8/p7/8/8/8/8/8/8 w - - 0 1")) == 0 36 | assert concept.compute_label(LczeroBoard("8/pn6/8/8/8/8/8/8 w - - 0 1")) == 1 37 | 38 | def test_relative_threat(self): 39 | """ 40 | Test the relative threat concept. 41 | """ 42 | concept = HasThreat("p", relative=True) # Is an enemy pawn threatened? 43 | assert concept.compute_label(LczeroBoard("8/8/8/8/8/8/8/8 w - - 0 1")) == 0 44 | assert concept.compute_label(LczeroBoard("R7/8/8/8/8/8/p7/8 w - - 0 1")) == 1 45 | assert concept.compute_label(LczeroBoard("R7/8/8/8/8/8/p7/8 b - - 0 1")) == 0 46 | -------------------------------------------------------------------------------- /tests/unit/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | File to test the encodings for the Leela Chess Zero engine. 3 | """ 4 | 5 | import random 6 | import chess 7 | import pytest 8 | 9 | from lczerolens import LczeroBoard 10 | 11 | 12 | @pytest.fixture(scope="module") 13 | def random_move_board_list(): 14 | board = LczeroBoard() 15 | seed = 42 16 | random.seed(seed) 17 | move_list = [] 18 | board_list = [board.copy()] 19 | for _ in range(20): 20 | move = random.choice(list(board.legal_moves)) 21 | move_list.append(move) 22 | board.push(move) 23 | board_list.append(board.copy(stack=8)) 24 | return move_list, board_list 25 | 26 | 27 | @pytest.fixture(scope="module") 28 | def repetition_move_board_list(): 29 | board = LczeroBoard() 30 | move_list = [] 31 | board_list = [board.copy()] 32 | for uci_move in ("b1a3", "b8c6", "a3b1", "c6b8") * 4: 33 | move = chess.Move.from_uci(uci_move) 34 | move_list.append(move) 35 | board.push(move) 36 | board_list.append(board.copy(stack=True)) # Full stack is needed for repetition detection 37 | return move_list, board_list 38 | 39 | 40 | @pytest.fixture(scope="module") 41 | def long_move_board_list(): 42 | board = LczeroBoard() 43 | seed = 6 44 | random.seed(seed) 45 | move_list = [] 46 | board_list = [board.copy()] 47 | for _ in range(80): 48 | move = random.choice(list(board.legal_moves)) 49 | move_list.append(move) 50 | board.push(move) 51 | board_list.append(board.copy(stack=8)) 52 | return move_list, board_list 53 | -------------------------------------------------------------------------------- /tests/unit/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/unit/core/__init__.py -------------------------------------------------------------------------------- /tests/unit/core/test_board.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the board utils. 3 | """ 4 | 5 | from typing import List, Tuple 6 | import pytest 7 | import chess 8 | from lczero.backends import GameState 9 | 10 | from lczerolens import backends as lczero_utils 11 | from lczerolens import LczeroBoard 12 | 13 | 14 | @pytest.mark.backends 15 | class TestWithBackend: 16 | def test_board_to_config_tensor( 17 | self, random_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend 18 | ): 19 | """ 20 | Test that the board to tensor function works. 21 | """ 22 | move_list, board_list = random_move_board_list 23 | for i, board in enumerate(board_list): 24 | board_tensor = board.to_config_tensor() 25 | uci_moves = [move.uci() for move in move_list[:i]] 26 | lczero_game = GameState(moves=uci_moves) 27 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game, planes=13) 28 | assert (board_tensor == lczero_input_tensor[:13]).all() 29 | 30 | def test_board_to_input_tensor( 31 | self, random_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend 32 | ): 33 | """ 34 | Test that the board to tensor function works. 35 | """ 36 | move_list, board_list = random_move_board_list 37 | for i, board in enumerate(board_list): 38 | board_tensor = board.to_input_tensor() 39 | uci_moves = [move.uci() for move in move_list[:i]] 40 | lczero_game = GameState(moves=uci_moves) 41 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game) 42 | # assert (board_tensor == lczero_input_tensor).all() 43 | for plane in range(112): 44 | assert (board_tensor[plane] == lczero_input_tensor[plane]).all() 45 | 46 | 47 | @pytest.mark.backends 48 | class TestRepetition: 49 | def test_board_to_config_tensor( 50 | self, repetition_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend 51 | ): 52 | """ 53 | Test that the board to tensor function works. 54 | """ 55 | move_list, board_list = repetition_move_board_list 56 | for i, board in enumerate(board_list): 57 | uci_moves = [move.uci() for move in move_list[:i]] 58 | board_tensor = board.to_config_tensor() 59 | lczero_game = GameState(moves=uci_moves) 60 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game, planes=13) 61 | assert (board_tensor == lczero_input_tensor[:13]).all() 62 | 63 | def test_board_to_input_tensor( 64 | self, repetition_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend 65 | ): 66 | """ 67 | Test that the board to tensor function works. 68 | """ 69 | move_list, board_list = repetition_move_board_list 70 | for i, board in enumerate(board_list): 71 | uci_moves = [move.uci() for move in move_list[:i]] 72 | board_tensor = board.to_input_tensor() 73 | lczero_game = GameState(moves=uci_moves) 74 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game) 75 | assert (board_tensor == lczero_input_tensor).all() 76 | 77 | 78 | @pytest.mark.backends 79 | class TestLong: 80 | def test_board_to_config_tensor( 81 | self, long_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend 82 | ): 83 | """ 84 | Test that the board to tensor function works. 85 | """ 86 | move_list, board_list = long_move_board_list 87 | for i, board in enumerate(board_list): 88 | uci_moves = [move.uci() for move in move_list[:i]] 89 | board_tensor = board.to_config_tensor() 90 | lczero_game = GameState(moves=uci_moves) 91 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game, planes=13) 92 | assert (board_tensor == lczero_input_tensor[:13]).all() 93 | 94 | def test_board_to_input_tensor( 95 | self, long_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend 96 | ): 97 | """ 98 | Test that the board to tensor function works. 99 | """ 100 | move_list, board_list = long_move_board_list 101 | for i, board in enumerate(board_list): 102 | uci_moves = [move.uci() for move in move_list[:i]] 103 | board_tensor = board.to_input_tensor() 104 | lczero_game = GameState(moves=uci_moves) 105 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game) 106 | assert (board_tensor == lczero_input_tensor).all() 107 | 108 | 109 | class TestStability: 110 | def test_encode_decode(self, random_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]]): 111 | """ 112 | Test that encoding and decoding a move is the identity. 113 | """ 114 | us, them = chess.WHITE, chess.BLACK 115 | for move, board in zip(*random_move_board_list): 116 | encoded_move = LczeroBoard.encode_move(move, us) 117 | decoded_move = board.decode_move(encoded_move) 118 | assert move == decoded_move 119 | us, them = them, us 120 | 121 | 122 | @pytest.mark.backends 123 | class TestBackend: 124 | def test_encode_decode_random(self, random_move_board_list): 125 | """ 126 | Test that encoding and decoding a move corresponds to the backend. 127 | """ 128 | move_list, board_list = random_move_board_list 129 | for i, board in enumerate(board_list): 130 | lczero_game = GameState(moves=[move.uci() for move in move_list[:i]]) 131 | legal_moves = [move.uci() for move in board.legal_moves] 132 | ( 133 | lczero_legal_moves, 134 | lczero_policy_indices, 135 | ) = lczero_utils.moves_with_castling_swap(lczero_game, board) 136 | assert len(legal_moves) == len(lczero_legal_moves) 137 | assert set(legal_moves) == set(lczero_legal_moves) 138 | policy_indices = [LczeroBoard.encode_move(move, board.turn) for move in board.legal_moves] 139 | assert len(lczero_policy_indices) == len(policy_indices) 140 | assert set(lczero_policy_indices) == set(policy_indices) 141 | 142 | def test_encode_decode_long(self, long_move_board_list): 143 | """ 144 | Test that encoding and decoding a move corresponds to the backend. 145 | """ 146 | move_list, board_list = long_move_board_list 147 | for i, board in enumerate(board_list): 148 | lczero_game = GameState(moves=[move.uci() for move in move_list[:i]]) 149 | legal_moves = [move.uci() for move in board.legal_moves] 150 | ( 151 | lczero_legal_moves, 152 | lczero_policy_indices, 153 | ) = lczero_utils.moves_with_castling_swap(lczero_game, board) 154 | assert len(legal_moves) == len(lczero_legal_moves) 155 | assert set(legal_moves) == set(lczero_legal_moves) 156 | policy_indices = [LczeroBoard.encode_move(move, board.turn) for move in board.legal_moves] 157 | assert len(lczero_policy_indices) == len(policy_indices) 158 | assert set(lczero_policy_indices) == set(policy_indices) 159 | -------------------------------------------------------------------------------- /tests/unit/core/test_input_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the input encoding. 3 | """ 4 | 5 | import pytest 6 | import chess 7 | 8 | from lczerolens import LczeroBoard, InputEncoding 9 | 10 | 11 | class TestInputEncoding: 12 | @pytest.mark.parametrize( 13 | "input_encoding_expected", 14 | [ 15 | (InputEncoding.INPUT_CLASSICAL_112_PLANE, 32), 16 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED, 32 * 8), 17 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED, 32 * 8), 18 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS, 32), 19 | ], 20 | ) 21 | def test_sum_initial_config_planes(self, input_encoding_expected): 22 | """ 23 | Test the sum of the config planes for the initial board. 24 | """ 25 | input_encoding, expected = input_encoding_expected 26 | board = LczeroBoard() 27 | board_tensor = board.to_input_tensor(input_encoding=input_encoding) 28 | assert board_tensor[:104].sum() == expected 29 | 30 | @pytest.mark.parametrize( 31 | "input_encoding_expected", 32 | [ 33 | (InputEncoding.INPUT_CLASSICAL_112_PLANE, 31 + 32 * 4), 34 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED, 31 + 32 * 7), 35 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED, 31 * 8), 36 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS, 31), 37 | ], 38 | ) 39 | def test_sum_qga_config_planes(self, input_encoding_expected): 40 | """ 41 | Test the sum of the config planes for the queen's gambit accepted. 42 | """ 43 | input_encoding, expected = input_encoding_expected 44 | board = LczeroBoard() 45 | moves = [ 46 | chess.Move.from_uci("d2d4"), 47 | chess.Move.from_uci("d7d5"), 48 | chess.Move.from_uci("c2c4"), 49 | chess.Move.from_uci("d5c4"), 50 | ] 51 | for move in moves: 52 | board.push(move) 53 | board_tensor = board.to_input_tensor(input_encoding=input_encoding) 54 | assert board_tensor[:104].sum() == expected 55 | -------------------------------------------------------------------------------- /tests/unit/core/test_lczero.py: -------------------------------------------------------------------------------- 1 | """ 2 | LCZero utils tests. 3 | """ 4 | 5 | import torch 6 | import pytest 7 | from lczero.backends import GameState 8 | 9 | from lczerolens import backends as lczero_utils 10 | 11 | 12 | class TestExecution: 13 | def test_describenet(self): 14 | """ 15 | Test that the describenet function works. 16 | """ 17 | description = lczero_utils.describenet("assets/tinygyal-8.pb.gz") 18 | assert isinstance(description, str) 19 | assert "Minimal Lc0 version:" in description 20 | 21 | def test_convertnet(self): 22 | """ 23 | Test that the convert_to_onnx function works. 24 | """ 25 | conversion = lczero_utils.convert_to_onnx("assets/tinygyal-8.pb.gz", "assets/tinygyal-8.onnx") 26 | assert isinstance(conversion, str) 27 | assert "INPUT_CLASSICAL_112_PLANE" in conversion 28 | 29 | def test_generic_command(self): 30 | """ 31 | Test that the generic command function works. 32 | """ 33 | generic_command = lczero_utils.generic_command(["--help"]) 34 | assert isinstance(generic_command, str) 35 | assert "Usage: lc0" in generic_command 36 | 37 | @pytest.mark.backends 38 | def test_board_from_backend(self, tiny_lczero_backend): 39 | """ 40 | Test that the board from backend function works. 41 | """ 42 | lczero_game = GameState() 43 | lczero_board_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game) 44 | assert lczero_board_tensor.shape == (112, 8, 8) 45 | 46 | @pytest.mark.backends 47 | def test_prediction_from_backend(self, tiny_lczero_backend): 48 | """ 49 | Test that the prediction from backend function works. 50 | """ 51 | lczero_game = GameState() 52 | lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game) 53 | assert lczero_policy.shape == (1858,) 54 | assert (lczero_value >= -1) and (lczero_value <= 1) 55 | lczero_policy_softmax, _ = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game, softmax=True) 56 | assert lczero_policy_softmax.shape == (1858,) 57 | assert (lczero_policy_softmax >= 0).all() and (lczero_policy_softmax <= 1).all() 58 | assert torch.softmax(lczero_policy, dim=0).allclose(lczero_policy_softmax, atol=1e-4) 59 | -------------------------------------------------------------------------------- /tests/unit/core/test_lens.py: -------------------------------------------------------------------------------- 1 | """Lens tests.""" 2 | 3 | from typing import Any 4 | import pytest 5 | 6 | from lczerolens import Lens 7 | from lczerolens.model import LczeroModel 8 | 9 | 10 | @Lens.register("test_lens") 11 | class TestLens(Lens): 12 | """Test lens.""" 13 | 14 | def is_compatible(self, model: LczeroModel) -> bool: 15 | return True 16 | 17 | def _intervene(self, model: LczeroModel, **kwargs) -> Any: 18 | pass 19 | 20 | 21 | class TestLensRegistry: 22 | def test_lens_registry_duplicate(self): 23 | """Test that registering a lens with an existing name raises an error.""" 24 | with pytest.raises(ValueError, match="Lens .* already registered"): 25 | 26 | @Lens.register("test_lens") 27 | class DuplicateLens(Lens): 28 | """Duplicate lens.""" 29 | 30 | def is_compatible(self, model: LczeroModel) -> bool: 31 | return True 32 | 33 | def analyse(self, *inputs, **kwargs) -> Any: 34 | pass 35 | 36 | def test_lens_registry_missing(self): 37 | """Test that instantiating a non-registered lens raises an error.""" 38 | with pytest.raises(KeyError, match="Lens .* not found"): 39 | Lens.from_name("non_existent_lens") 40 | 41 | def test_lens_type(self): 42 | """Test that the lens type is correct.""" 43 | assert TestLens._lens_type == "test_lens" 44 | -------------------------------------------------------------------------------- /tests/unit/core/test_model.py: -------------------------------------------------------------------------------- 1 | """Model tests.""" 2 | 3 | import pytest 4 | import torch 5 | from lczero.backends import GameState 6 | 7 | from lczerolens import Flow, LczeroBoard 8 | from lczerolens import backends as lczero_utils 9 | 10 | 11 | @pytest.mark.backends 12 | class TestModel: 13 | def test_model_prediction(self, tiny_lczero_backend, tiny_model): 14 | """Test that the model prediction works.""" 15 | board = LczeroBoard() 16 | (out,) = tiny_model(board) 17 | policy = out["policy"] 18 | value = out["value"] 19 | lczero_game = GameState() 20 | lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game) 21 | assert torch.allclose(policy, lczero_policy, atol=1e-4) 22 | assert torch.allclose(value, lczero_value, atol=1e-4) 23 | 24 | def test_model_prediction_random(self, tiny_lczero_backend, tiny_model, random_move_board_list): 25 | """Test that the model prediction works.""" 26 | move_list, board_list = random_move_board_list 27 | for i, board in enumerate(board_list): 28 | (out,) = tiny_model(board) 29 | policy = out["policy"] 30 | value = out["value"] 31 | lczero_game = GameState(moves=[move.uci() for move in move_list[:i]]) 32 | lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game) 33 | assert torch.allclose(policy, lczero_policy, atol=1e-4) 34 | assert torch.allclose(value, lczero_value, atol=1e-4) 35 | 36 | def test_model_prediction_repetition(self, tiny_lczero_backend, tiny_model, repetition_move_board_list): 37 | """Test that the model prediction works.""" 38 | move_list, board_list = repetition_move_board_list 39 | for i, board in enumerate(board_list): 40 | (out,) = tiny_model(board) 41 | policy = out["policy"] 42 | value = out["value"] 43 | lczero_game = GameState(moves=[move.uci() for move in move_list[:i]]) 44 | lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game) 45 | assert torch.allclose(policy, lczero_policy, atol=1e-4) 46 | assert torch.allclose(value, lczero_value, atol=1e-4) 47 | 48 | def test_model_prediction_long(self, tiny_lczero_backend, tiny_model, long_move_board_list): 49 | """Test that the model prediction works.""" 50 | move_list, board_list = long_move_board_list 51 | for i, board in enumerate(board_list): 52 | (out,) = tiny_model(board) 53 | policy = out["policy"] 54 | value = out["value"] 55 | lczero_game = GameState(moves=[move.uci() for move in move_list[:i]]) 56 | lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game) 57 | assert torch.allclose(policy, lczero_policy, atol=1e-4) 58 | assert torch.allclose(value, lczero_value, atol=1e-4) 59 | 60 | 61 | class TestFlows: 62 | def test_policy_flow(self, tiny_model): 63 | """Test that the policy flow works.""" 64 | policy_flow = Flow.from_model("policy", tiny_model) 65 | board = LczeroBoard() 66 | (policy,) = policy_flow(board) 67 | model_policy = tiny_model(board)["policy"][0] 68 | assert torch.allclose(policy, model_policy) 69 | 70 | def test_value_flow(self, tiny_model): 71 | """Test that the value flow works.""" 72 | value_flow = Flow.from_model("value", tiny_model) 73 | board = LczeroBoard() 74 | (value,) = value_flow(board) 75 | model_value = tiny_model(board)["value"][0] 76 | assert torch.allclose(value, model_value) 77 | 78 | def test_wdl_flow(self, winner_model): 79 | """Test that the wdl flow works.""" 80 | wdl_flow = Flow.from_model("wdl", winner_model) 81 | board = LczeroBoard() 82 | (wdl,) = wdl_flow(board) 83 | model_wdl = winner_model(board)["wdl"][0] 84 | assert torch.allclose(wdl, model_wdl) 85 | 86 | def test_mlh_flow(self, winner_model): 87 | """Test that the mlh flow works.""" 88 | mlh_flow = Flow.from_model("mlh", winner_model) 89 | board = LczeroBoard() 90 | (mlh,) = mlh_flow(board) 91 | model_mlh = winner_model(board)["mlh"][0] 92 | assert torch.allclose(mlh, model_mlh) 93 | 94 | def test_force_value_flow_value(self, tiny_model): 95 | """Test that the force value flow works.""" 96 | force_value_flow = Flow.from_model("force_value", tiny_model) 97 | board = LczeroBoard() 98 | (value,) = force_value_flow(board) 99 | model_value = tiny_model(board)["value"][0] 100 | assert torch.allclose(value, model_value) 101 | 102 | def test_force_value_flow_wdl(self, winner_model): 103 | """Test that the force value flow works.""" 104 | force_value_flow = Flow.from_model("force_value", winner_model) 105 | board = LczeroBoard() 106 | (wdl,) = force_value_flow(board) 107 | model_wdl = winner_model(board)["wdl"][0] 108 | model_value = model_wdl @ torch.tensor([1.0, 0.0, -1.0], device=model_wdl.device) 109 | assert torch.allclose(wdl, model_value) 110 | 111 | def test_incompatible_flows(self, tiny_model, winner_model): 112 | """Test that the flows raise an error * 113 | when the model is incompatible. 114 | """ 115 | with pytest.raises(ValueError): 116 | Flow.from_model("value", winner_model) 117 | with pytest.raises(ValueError): 118 | Flow.from_model("wdl", tiny_model) 119 | with pytest.raises(ValueError): 120 | Flow.from_model("mlh", tiny_model) 121 | 122 | with pytest.raises(ValueError): 123 | Flow._registry["value"](tiny_model) 124 | 125 | 126 | @Flow.register("test_flow") 127 | class TestFlow(Flow): 128 | """Test flow.""" 129 | 130 | 131 | class TestFlowRegistry: 132 | def test_flow_registry_duplicate(self): 133 | """Test that registering a flow with an existing name raises an error.""" 134 | with pytest.raises(ValueError, match="Flow .* already registered"): 135 | 136 | @Flow.register("test_flow") 137 | class DuplicateFlow(Flow): 138 | """Duplicate flow.""" 139 | 140 | def test_flow_registry_missing(self, tiny_model): 141 | """Test that instantiating a non-registered flow raises an error.""" 142 | with pytest.raises(KeyError, match="Flow .* not found"): 143 | Flow.from_model("non_existent_flow", tiny_model) 144 | 145 | def test_flow_type(self): 146 | """Test that the flow type is correct.""" 147 | assert TestFlow._flow_type == "test_flow" 148 | -------------------------------------------------------------------------------- /tests/unit/lenses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/unit/lenses/__init__.py -------------------------------------------------------------------------------- /tests/unit/lenses/test_activation.py: -------------------------------------------------------------------------------- 1 | """Activation lens tests.""" 2 | 3 | from lczerolens import Lens 4 | from lczerolens.lenses import ActivationLens 5 | from lczerolens.board import LczeroBoard 6 | 7 | 8 | class TestLens: 9 | def test_is_compatible(self, tiny_model): 10 | lens = Lens.from_name("activation") 11 | assert isinstance(lens, ActivationLens) 12 | assert lens.is_compatible(tiny_model) 13 | 14 | def test_analyse_board(self, tiny_model): 15 | lens = ActivationLens(pattern=r".*") 16 | board = LczeroBoard() 17 | results = lens.analyse(tiny_model, board) 18 | 19 | assert len(results) > 0 20 | for key in results: 21 | assert key.endswith("_output") 22 | 23 | def test_analyse_with_inputs(self, tiny_model): 24 | lens = ActivationLens(pattern=r".*") 25 | board = LczeroBoard() 26 | results = lens.analyse(tiny_model, board, save_inputs=True) 27 | 28 | input_keys = [k for k in results if k.endswith("_input")] 29 | output_keys = [k for k in results if k.endswith("_output")] 30 | 31 | assert len(input_keys) > 0 32 | assert len(output_keys) > 0 33 | 34 | def test_analyse_specific_modules(self, tiny_model): 35 | lens = ActivationLens(pattern=r".*conv.*") 36 | board = LczeroBoard() 37 | results = lens.analyse(tiny_model, board) 38 | 39 | assert len(results) > 0 40 | for key in results: 41 | module_name = key.replace("_output", "") 42 | assert "conv" in module_name 43 | -------------------------------------------------------------------------------- /tests/unit/lenses/test_gradient.py: -------------------------------------------------------------------------------- 1 | """Gradient lens tests.""" 2 | 3 | from lczerolens import Lens 4 | from lczerolens.lenses import GradientLens 5 | from lczerolens.board import LczeroBoard 6 | 7 | 8 | class TestLens: 9 | def test_is_compatible(self, tiny_model): 10 | lens = Lens.from_name("gradient") 11 | assert isinstance(lens, GradientLens) 12 | assert lens.is_compatible(tiny_model) 13 | 14 | def test_analyse_board(self, tiny_model): 15 | lens = GradientLens() 16 | board = LczeroBoard() 17 | results = lens.analyse(tiny_model, board) 18 | 19 | assert "input_grad" in results 20 | 21 | def test_analyse_without_input_grad(self, tiny_model): 22 | lens = GradientLens(input_requires_grad=False) 23 | board = LczeroBoard() 24 | results = lens.analyse(tiny_model, board) 25 | 26 | assert "input_grad" not in results 27 | 28 | def test_analyse_specific_modules(self, tiny_model): 29 | lens = GradientLens(pattern=r".*conv.*relu", input_requires_grad=False) 30 | board = LczeroBoard() 31 | results = lens.analyse(tiny_model, board) 32 | 33 | assert len(results) > 0 34 | for key in results: 35 | module_name = key.replace("_output_grad", "") 36 | assert "conv" in module_name 37 | -------------------------------------------------------------------------------- /tests/unit/lenses/test_lrp.py: -------------------------------------------------------------------------------- 1 | """LRP lens tests.""" 2 | 3 | from lczerolens import Lens 4 | from lczerolens.lenses import LrpLens 5 | 6 | 7 | class TestLens: 8 | def test_is_compatible(self, tiny_model): 9 | lens = Lens.from_name("lrp") 10 | assert isinstance(lens, LrpLens) 11 | assert lens.is_compatible(tiny_model) 12 | -------------------------------------------------------------------------------- /tests/unit/play/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/unit/play/__init__.py -------------------------------------------------------------------------------- /tests/unit/play/test_puzzle.py: -------------------------------------------------------------------------------- 1 | """Puzzle tests.""" 2 | 3 | import pytest 4 | 5 | 6 | from lczerolens.play.puzzle import Puzzle 7 | from lczerolens.play.sampling import RandomSampler, PolicySampler 8 | 9 | 10 | @pytest.fixture 11 | def opening_puzzle(): 12 | return { 13 | "PuzzleId": "1", 14 | "FEN": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", 15 | "Moves": "e2e4 e7e5 d2d4 d7d5", 16 | "Rating": 1000, 17 | "RatingDeviation": 100, 18 | "Popularity": 1000, 19 | "NbPlays": 1000, 20 | "Themes": "Opening", 21 | "GameUrl": "https://lichess.org/training/1", 22 | "OpeningTags": "Ruy Lopez", 23 | } 24 | 25 | 26 | @pytest.fixture 27 | def easy_puzzle(): 28 | return { 29 | "PuzzleId": "00008", 30 | "FEN": "r6k/pp2r2p/4Rp1Q/3p4/8/1N1P2R1/PqP2bPP/7K b - - 0 24", 31 | "Moves": "f2g3 e6e7 b2b1 b3c1 b1c1 h6c1", 32 | "Rating": 1913, 33 | "RatingDeviation": 75, 34 | "Popularity": 94, 35 | "NbPlays": 6230, 36 | "Themes": "crushing hangingPiece long middlegame", 37 | "GameUrl": "https://lichess.org/787zsVup/black#47", 38 | "OpeningTags": None, 39 | } 40 | 41 | 42 | class TestPuzzle: 43 | def test_puzzle_creation(self, opening_puzzle): 44 | """Test puzzle creation.""" 45 | puzzle = Puzzle.from_dict(opening_puzzle) 46 | assert len(puzzle) == 3 47 | assert puzzle.rating == 1000 48 | assert puzzle.rating_deviation == 100 49 | assert puzzle.popularity == 1000 50 | assert puzzle.nb_plays == 1000 51 | assert puzzle.themes == ["Opening"] 52 | assert puzzle.game_url == "https://lichess.org/training/1" 53 | assert puzzle.opening_tags == ["Ruy", "Lopez"] 54 | 55 | def test_puzzle_use(self, opening_puzzle): 56 | """Test puzzle use.""" 57 | puzzle = Puzzle.from_dict(opening_puzzle) 58 | assert len(list(puzzle.board_move_generator())) == 2 59 | assert len(list(puzzle.board_move_generator(all_moves=True))) == 3 60 | 61 | 62 | class TestRandomSampler: 63 | def test_puzzle_evaluation(self, opening_puzzle): 64 | """Test puzzle evaluation.""" 65 | puzzle = Puzzle.from_dict(opening_puzzle) 66 | sampler = RandomSampler() 67 | metrics = puzzle.evaluate(sampler) 68 | assert metrics["score"] != 1.0 69 | assert abs(metrics["perplexity"] - (20.0 * 30) ** 0.5) < 1e-3 70 | 71 | def test_puzzle_multiple_evaluation_len(self, easy_puzzle): 72 | """Test puzzle evaluation.""" 73 | puzzles = [Puzzle.from_dict(easy_puzzle) for _ in range(10)] 74 | sampler = RandomSampler() 75 | all_results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=True, compute_metrics=False) 76 | assert len(list(all_results)) == 10 * 5 77 | results = Puzzle.evaluate_multiple(puzzles, sampler, compute_metrics=False) 78 | assert len(list(results)) == 10 * 3 79 | 80 | def test_puzzle_multiple_evaluation(self, easy_puzzle): 81 | """Test puzzle evaluation.""" 82 | puzzles = [Puzzle.from_dict(easy_puzzle) for _ in range(10)] 83 | sampler = RandomSampler() 84 | all_results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=True) 85 | assert len(list(all_results)) == 10 86 | results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=False) 87 | assert len(list(results)) == 10 88 | 89 | def test_puzzle_multiple_evaluation_batch_size(self, easy_puzzle): 90 | """Test puzzle evaluation.""" 91 | puzzles = [Puzzle.from_dict(easy_puzzle) for _ in range(10)] 92 | sampler = RandomSampler() 93 | all_results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=True, batch_size=5) 94 | assert len(list(all_results)) == 10 95 | results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=False, batch_size=5) 96 | assert len(list(results)) == 10 97 | 98 | 99 | class TestPolicySampler: 100 | def test_puzzle_evaluation(self, easy_puzzle, winner_model): 101 | """Test puzzle evaluation.""" 102 | puzzle = Puzzle.from_dict(easy_puzzle) 103 | sampler = PolicySampler(model=winner_model, use_argmax=True) 104 | metrics = puzzle.evaluate(sampler, all_moves=True) 105 | assert metrics["score"] > 0.0 106 | assert metrics["perplexity"] < 15.0 107 | 108 | def test_puzzle_multiple_evaluation(self, easy_puzzle, tiny_model): 109 | """Test puzzle evaluation.""" 110 | puzzles = [Puzzle.from_dict(easy_puzzle) for _ in range(10)] 111 | sampler = PolicySampler(model=tiny_model, use_argmax=False) 112 | all_results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=True) 113 | assert len(list(all_results)) == 10 114 | results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=False) 115 | assert len(list(results)) == 10 116 | -------------------------------------------------------------------------------- /tests/unit/play/test_sampling.py: -------------------------------------------------------------------------------- 1 | """Sampling tests.""" 2 | 3 | from lczerolens.play.sampling import ModelSampler, SelfPlay, PolicySampler, RandomSampler 4 | from lczerolens.board import LczeroBoard 5 | 6 | 7 | class TestRandomSampler: 8 | def test_get_utilities(self): 9 | """Test get_utilities method.""" 10 | board = LczeroBoard() 11 | sampler = RandomSampler() 12 | utility, _, _ = next(iter(sampler.get_utilities([board, board]))) 13 | assert utility.shape[0] == 20 14 | 15 | 16 | class TestModelSampler: 17 | def test_get_utilities_tiny(self, tiny_model): 18 | """Test get_utilities method.""" 19 | board = LczeroBoard() 20 | sampler = ModelSampler(tiny_model, use_argmax=False) 21 | utility, _, _ = next(iter(sampler.get_utilities([board, board]))) 22 | assert utility.shape[0] == 20 23 | 24 | def test_get_utilities_winner(self, winner_model): 25 | """Test get_utilities method.""" 26 | board = LczeroBoard() 27 | sampler = ModelSampler(winner_model, use_argmax=False) 28 | utility, _, _ = next(iter(sampler.get_utilities([board, board]))) 29 | assert utility.shape[0] == 20 30 | 31 | def test_policy_sampler_tiny(self, tiny_model): 32 | """Test policy_sampler method.""" 33 | board = LczeroBoard() 34 | sampler = PolicySampler(tiny_model, use_argmax=False) 35 | utility, _, _ = next(iter(sampler.get_utilities([board, board]))) 36 | assert utility.shape[0] == 20 37 | 38 | 39 | class TestSelfPlay: 40 | def test_play(self, tiny_model, winner_model): 41 | """Test play method.""" 42 | board = LczeroBoard() 43 | white = ModelSampler(tiny_model, use_argmax=False) 44 | black = ModelSampler(winner_model, use_argmax=False) 45 | self_play = SelfPlay(white=white, black=black) 46 | logs = [] 47 | 48 | def report_fn(log, to_play): 49 | logs.append((log, to_play)) 50 | 51 | game, board = self_play.play(board=board, max_moves=10, report_fn=report_fn) 52 | 53 | assert len(game) == len(logs) == 10 54 | 55 | 56 | class TestBatchedPolicySampler: 57 | def test_batched_policy_sampler_ag(self, tiny_model): 58 | """Test batched_policy_sampler method.""" 59 | boards = [LczeroBoard() for _ in range(10)] 60 | 61 | sampler_ag = PolicySampler(tiny_model, use_argmax=True) 62 | moves = sampler_ag.get_next_moves(boards) 63 | assert len(list(moves)) == 10 64 | assert all([move == moves[0] for move in moves]) 65 | 66 | def test_batched_policy_sampler_no_ag(self, tiny_model): 67 | """Test batched_policy_sampler method.""" 68 | boards = [LczeroBoard() for _ in range(10)] 69 | 70 | sampler_no_ag = PolicySampler(tiny_model, use_argmax=False) 71 | moves = sampler_no_ag.get_next_moves(boards) 72 | assert len(list(moves)) == 10 73 | 74 | def test_batched_policy_sampler_no_ag_sub(self, tiny_model): 75 | """Test batched_policy_sampler method.""" 76 | boards = [LczeroBoard() for _ in range(10)] 77 | 78 | sampler_no_ag = PolicySampler(tiny_model, use_argmax=False, use_suboptimal=True) 79 | moves = sampler_no_ag.get_next_moves(boards) 80 | assert len(list(moves)) == 10 81 | --------------------------------------------------------------------------------