├── .github
├── ISSUE_TEMPLATE
│ ├── bug.yml
│ ├── config.yml
│ ├── feature.yml
│ ├── other.yml
│ └── roadmap.yml
├── pull_request_template.md
└── workflows
│ ├── ci-tests-slow.yml
│ ├── ci-tests.yml
│ ├── ci.yml
│ └── publish.yml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── .vscode
├── launch.json
└── settings.json
├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── Makefile
├── README.md
├── assets
├── .gitignore
├── resolve-demo-assets.sh
└── resolve-tests-assets.sh
├── docs
├── Makefile
├── make.bat
└── source
│ ├── _static
│ ├── .DS_Store
│ ├── css
│ │ ├── custom.css
│ │ └── nbsphinx.css
│ ├── images
│ │ ├── .DS_Store
│ │ ├── favicon.ico
│ │ ├── lczerolens-logo.png
│ │ ├── lczerolens-logo.svg
│ │ ├── one.png
│ │ └── two.png
│ └── switcher.json
│ ├── about.rst
│ ├── conf.py
│ ├── features.rst
│ ├── index.rst
│ ├── notebooks
│ ├── .gitignore
│ ├── features
│ │ ├── convert-official-weights.ipynb
│ │ ├── evaluate-models-on-puzzles.ipynb
│ │ ├── move-prediction.ipynb
│ │ ├── probe-concepts.ipynb
│ │ ├── run-models-on-gpu.ipynb
│ │ └── visualise-heatmaps.ipynb
│ ├── tutorials
│ │ ├── automated-interpretability.ipynb
│ │ ├── evidence-of-learned-look-ahead.ipynb
│ │ ├── piece-value-estimation-using-lrp.ipynb
│ │ └── train-saes.ipynb
│ └── walkthrough.ipynb
│ ├── start.rst
│ └── tutorials.rst
├── pyproject.toml
├── scripts
├── __init__.py
├── constants.py
├── datasets
│ ├── __init__.py
│ ├── make_lichess_dataset.py
│ └── make_tcec_dataset.py
├── lrp
│ ├── __init__.py
│ └── plane_analysis.py
├── results
│ └── .gitignore
└── visualisation.py
├── spaces
└── .gitignore
├── src
└── lczerolens
│ ├── __init__.py
│ ├── backends.py
│ ├── board.py
│ ├── concept.py
│ ├── concepts
│ ├── __init__.py
│ ├── material.py
│ ├── move.py
│ └── threat.py
│ ├── constants.py
│ ├── lens.py
│ ├── lenses
│ ├── __init__.py
│ ├── activation.py
│ ├── composite.py
│ ├── gradient.py
│ ├── lrp
│ │ ├── __init__.py
│ │ ├── lens.py
│ │ └── rules
│ │ │ ├── __init__.py
│ │ │ └── epsilon.py
│ ├── patching.py
│ ├── probing
│ │ ├── __init__.py
│ │ ├── lens.py
│ │ └── probe.py
│ └── sae
│ │ ├── __init__.py
│ │ └── buffer.py
│ ├── model.py
│ └── play
│ ├── __init__.py
│ ├── game.py
│ ├── puzzle.py
│ └── sampling.py
├── tests
├── __init__.py
├── assets
│ └── error.ipynb
├── conftest.py
├── integration
│ ├── __init__.py
│ └── test_notebooks.py
└── unit
│ ├── __init__.py
│ ├── concepts
│ ├── __init__.py
│ └── test_concepts.py
│ ├── conftest.py
│ ├── core
│ ├── __init__.py
│ ├── test_board.py
│ ├── test_input_encoding.py
│ ├── test_lczero.py
│ ├── test_lens.py
│ └── test_model.py
│ ├── lenses
│ ├── __init__.py
│ ├── test_activation.py
│ ├── test_gradient.py
│ └── test_lrp.py
│ └── play
│ ├── __init__.py
│ ├── test_puzzle.py
│ └── test_sampling.py
└── uv.lock
/.github/ISSUE_TEMPLATE/bug.yml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: File a bug report.
3 | title: "[Bug]: "
4 | labels: ["bug"]
5 | projects: []
6 | assignees:
7 | - Xmaster6y
8 | body:
9 | - type: markdown
10 | attributes:
11 | value: |
12 | Thanks for taking the time to fill out this bug report!
13 | - type: textarea
14 | id: what-happened
15 | attributes:
16 | label: What happened?
17 | description: Also tell us, what did you expect to happen?
18 | placeholder: Tell us what you see!
19 | value: "A bug happened!"
20 | validations:
21 | required: true
22 | - type: dropdown
23 | id: version
24 | attributes:
25 | label: Version
26 | description: What version of the software are you running?
27 | options:
28 | - "latest"
29 | - "0.3.3"
30 | - "0.3.2"
31 | - "0.3.1"
32 | - "0.3.0"
33 | - "0.2.0"
34 | - "0.1.2"
35 | default: "latest"
36 | validations:
37 | required: true
38 | - type: dropdown
39 | id: environment
40 | attributes:
41 | label: In which environments is the bug happening?
42 | multiple: true
43 | options:
44 | - Colab
45 | - Jupyter Notebook
46 | - Python Script
47 | - Linux
48 | - Windows
49 | - MacOS
50 | - Other
51 | - type: textarea
52 | id: logs
53 | attributes:
54 | label: Relevant log output
55 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks.
56 | render: shell
57 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Community Support
4 | url: https://github.com/Xmaster6y/lczerolens/discussions
5 | about: Please ask and answer questions here.
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature.yml:
--------------------------------------------------------------------------------
1 | name: Feature Request
2 | description: File a feature request.
3 | title: "[Feature]: "
4 | labels: ["feature"]
5 | projects: []
6 | assignees:
7 | - Xmaster6y
8 | body:
9 | - type: markdown
10 | attributes:
11 | value: |
12 | Thanks for taking the time to fill out this feature request!
13 | - type: textarea
14 | id: description
15 | attributes:
16 | label: Description
17 | description: A shortline description of the feature request.
18 | placeholder: Tell us what you want!
19 | value: "I want to use this feature!"
20 | validations:
21 | required: true
22 | - type: dropdown
23 | id: category
24 | attributes:
25 | label: Category
26 | description: What category does the feature belong to?
27 | options:
28 | - New Feature
29 | - Improvement
30 | - Bug Fix
31 | - Documentation
32 | - Other
33 | default: 0
34 | validations:
35 | required: true
36 | - type: textarea
37 | id: use-case
38 | attributes:
39 | label: Use Case
40 | description: Please describe the use case for the feature. Use python code to describe the use case.
41 | render: python
42 | - type: textarea
43 | id: tasks
44 | attributes:
45 | label: Tasks
46 | description: List of tasks to implement the feature.
47 | value: |
48 | * [ ] Code is documented
49 | * [ ] Utilities and class tests are written
50 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/other.yml:
--------------------------------------------------------------------------------
1 | name: Other
2 | description: Detail an issue.
3 | title: "[Other]: "
4 | labels: ["other"]
5 | projects: []
6 | assignees:
7 | - Xmaster6y
8 | body:
9 | - type: markdown
10 | attributes:
11 | value: |
12 | Thanks for taking the time to fill out this issue!
13 | - type: textarea
14 | id: description
15 | attributes:
16 | label: Description
17 | description: A shortline description of the issue.
18 | placeholder: Key insights about the issue.
19 | validations:
20 | required: true
21 | - type: dropdown
22 | id: category
23 | attributes:
24 | label: Category
25 | description: What category does the issue belong to?
26 | options:
27 | - New Feature
28 | - Improvement
29 | - Bug Fix
30 | - Documentation
31 | - Other
32 | default: 4
33 | validations:
34 | required: true
35 | - type: textarea
36 | id: tasks
37 | attributes:
38 | label: Tasks
39 | description: List of tasks to implement the feature.
40 | value: |
41 | * [ ] Code is documented
42 | * [ ] Utilities and class tests are written
43 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/roadmap.yml:
--------------------------------------------------------------------------------
1 | name: Roadmap
2 | description: Detail the roadmap to implement a new feature or solve a bug.
3 | title: "[Roadmap]: "
4 | labels: ["roadmap"]
5 | projects: []
6 | assignees:
7 | - Xmaster6y
8 | body:
9 | - type: markdown
10 | attributes:
11 | value: |
12 | Thanks for taking the time to fill out this roadmap!
13 | - type: textarea
14 | id: description
15 | attributes:
16 | label: Description
17 | description: A shortline description of the roadmap.
18 | placeholder: Key insights about the roadmap.
19 | validations:
20 | required: true
21 | - type: textarea
22 | id: issues
23 | attributes:
24 | label: Linked Issues
25 | description: List of issues for the roadmap.
26 | value: |
27 | * [ ] #?
28 | validations:
29 | required: true
30 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## What does this PR do?
2 |
3 | Key insights about the PR.
4 |
5 | ## Linked Issues
6 |
7 | - Closes #?
8 | - #?
9 |
--------------------------------------------------------------------------------
/.github/workflows/ci-tests-slow.yml:
--------------------------------------------------------------------------------
1 | name: CI (slow)
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | jobs:
7 | ci:
8 | uses: ./.github/workflows/ci.yml
9 | with:
10 | tests-type: tests-slow
11 |
--------------------------------------------------------------------------------
/.github/workflows/ci-tests.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | pull_request:
5 | workflow_dispatch:
6 | push:
7 | branches:
8 | - main
9 |
10 | jobs:
11 | ci:
12 | uses: ./.github/workflows/ci.yml
13 | with:
14 | tests-type: tests
15 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | on:
2 | workflow_call:
3 | inputs:
4 | tests-type:
5 | required: true
6 | type: string
7 |
8 | jobs:
9 | ci:
10 | runs-on: ubuntu-latest
11 | strategy:
12 | matrix:
13 | python-version: ["3.10", "3.11", "3.12"]
14 | environment: ci
15 | timeout-minutes: 10
16 | steps:
17 | - uses: actions/checkout@v4
18 | with:
19 | fetch-depth: 0
20 | - name: Install uv
21 | uses: astral-sh/setup-uv@v5
22 | with:
23 | python-version: ${{ matrix.python-version }}
24 | - name: Install dependencies
25 | run: |
26 | uv sync --locked
27 | - name: Run checks
28 | run: |
29 | make checks
30 | - name: Download assets
31 | run: |
32 | make tests-assets
33 | - name: Run tests
34 | run: |
35 | make ${{ inputs.tests-type }}
36 | - name: Upload coverage reports to Codecov
37 | if: matrix.python-version == '3.10'
38 | uses: codecov/codecov-action@v5
39 | with:
40 | token: ${{ secrets.CODECOV_TOKEN }}
41 | - name: Upload test results to Codecov
42 | if: matrix.python-version == '3.10' && ${{ !cancelled() }}
43 | uses: codecov/test-results-action@v1
44 | with:
45 | token: ${{ secrets.CODECOV_TOKEN }}
46 | fail_ci_if_error: true
47 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | workflow_dispatch:
5 | release:
6 | types: [published]
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | environment: publish
12 | permissions:
13 | id-token: write
14 | steps:
15 | - uses: actions/checkout@v4
16 | - name: Install uv
17 | uses: astral-sh/setup-uv@v5
18 | with:
19 | python-version: "3.11"
20 | - name: Install dependencies
21 | run: |
22 | uv sync --locked
23 | - name: Build package
24 | run: |
25 | uv build
26 | - name: Publish package
27 | uses: pypa/gh-action-pypi-publish@release/v1
28 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pipenv
85 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
86 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
87 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
88 | # install all needed dependencies.
89 | #Pipfile.lock
90 |
91 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
92 | __pypackages__/
93 |
94 | # Celery stuff
95 | celerybeat-schedule
96 | celerybeat.pid
97 |
98 | # SageMath parsed files
99 | *.sage.py
100 |
101 | # Environments
102 | .env
103 | .venv
104 | env/
105 | venv/
106 | ENV/
107 | env.bak/
108 | venv.bak/
109 |
110 | # Spyder project settings
111 | .spyderproject
112 | .spyproject
113 |
114 | # Rope project settings
115 | .ropeproject
116 |
117 | # mkdocs documentation
118 | /site
119 |
120 | # mypy
121 | .mypy_cache/
122 | .dmypy.json
123 | dmypy.json
124 |
125 | # Pyre type checker
126 | .pyre/
127 |
128 | # Pickle files
129 | *.pkl
130 |
131 | # Various files
132 | ignored
133 | debug
134 | *.zip
135 | lc0
136 | !bin/lc0
137 | wandb
138 | **/.DS_Store
139 | junit*
140 |
141 | *secret*
142 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "spaces/lichess-puzzles-leaderboard"]
2 | path = spaces/lichess-puzzles-leaderboard
3 | url = https://huggingface.co/spaces/lczerolens/lichess-puzzles-leaderboard
4 | [submodule "spaces/lczerolens-backends-demo"]
5 | path = spaces/lczerolens-backends-demo
6 | url = https://huggingface.co/spaces/lczerolens/lczerolens-backends-demo
7 | [submodule "spaces/lczerolens-demo"]
8 | path = spaces/lczerolens-demo
9 | url = https://huggingface.co/spaces/Xmaster6y/lczerolens-demo
10 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: check-added-large-files
6 | args: ['--maxkb=1024']
7 | - id: check-yaml
8 | - id: check-json
9 | - id: check-toml
10 | - id: end-of-file-fixer
11 | - id: trailing-whitespace
12 | - id: check-docstring-first
13 | - repo: https://github.com/astral-sh/ruff-pre-commit
14 | rev: v0.9.1
15 | hooks:
16 | - id: ruff
17 | args: [ --fix ]
18 | - id: ruff-format
19 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-22.04
5 | tools:
6 | python: "3.11"
7 | commands:
8 | - asdf plugin add uv
9 | - asdf install uv latest
10 | - asdf global uv latest
11 | - uv sync --locked
12 | - make docs
13 | - mkdir -p $READTHEDOCS_OUTPUT
14 | - mv docs/build/html $READTHEDOCS_OUTPUT
15 |
16 | sphinx:
17 | configuration: docs/source/conf.py
18 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.2.0",
3 | "configurations": [
4 | {
5 | "name": "Script Plane Analysis",
6 | "type": "debugpy",
7 | "request": "launch",
8 | "module": "scripts.lrp.plane_analysis",
9 | "console": "integratedTerminal",
10 | "justMyCode": false
11 | }
12 | ]
13 | }
14 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.pytestEnabled": true,
3 | "python.testing.pytestArgs": [
4 | "tests"
5 | ]
6 | }
7 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 0.1.0
2 | title: LCZeroLens
3 | type: software
4 | authors:
5 | - name: Yoann Poupart
6 | repository-code: 'https://github.com/Xmaster6y/lczerolens'
7 | repository-artifact: 'https://github.com/Xmaster6y/lczerolens/releases/'
8 | keywords:
9 | - chess
10 | - explainable AI (XAI)
11 | license: MIT
12 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute?
2 |
3 | ## Guidelines
4 |
5 | The project dependencies are managed using `uv`, see their installation [guide](https://docs.astral.sh/uv/). For even more stability, I recommend using `pyenv` or python `3.9.16`.
6 |
7 | Additionally, to make your life easier, install `make` to use the shortcut commands.
8 |
9 | ## Dev Install
10 |
11 | To install the dependencies:
12 |
13 | ```bash
14 | uv sync
15 | ```
16 |
17 | Before committing, install `pre-commit`:
18 |
19 | ```bash
20 | uv run pre-commit install
21 | ```
22 |
23 | To run the checks (`pre-commit` checks):
24 |
25 | ```bash
26 | make checks
27 | ```
28 |
29 | To run the tests (using `pytest`):
30 |
31 | ```bash
32 | make tests
33 | ```
34 |
35 | ## Branches
36 |
37 | Make a branch before making a pull request to `develop`.
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Yoann Poupart
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: checks
2 | checks:
3 | uv run pre-commit run --all-files
4 |
5 | .PHONY: tests-assets
6 | tests-assets:
7 | bash assets/resolve-tests-assets.sh
8 |
9 | .PHONY: demo-assets
10 | demo-assets:
11 | bash assets/resolve-demo-assets.sh
12 |
13 | .PHONY: tests
14 | tests:
15 | uv run pytest tests --cov=src --cov-report=term-missing --cov-fail-under=50 -s -v --run-fast --run-backends --cov-branch --cov-report=xml --junitxml=junit.xml -o junit_family=legacy
16 |
17 | .PHONY: tests-fast
18 | tests-fast:
19 | uv run pytest tests --cov=src --cov-report=term-missing -s -v --run-fast
20 |
21 | .PHONY: tests-slow
22 | tests-slow:
23 | uv run pytest tests --cov=src --cov-report=term-missing -s -v --run-slow
24 |
25 | .PHONY: tests-backends
26 | tests-backends:
27 | uv run pytest tests --cov=src --cov-report=term-missing -s -v --run-backends
28 |
29 | .PHONY: docs
30 | docs:
31 | cd docs && uv run --group docs make html
32 |
33 | .PHONY: demo
34 | demo:
35 | uv run --group demo gradio spaces/lczerolens-demo/app.py
36 |
37 | .PHONY: demo-backends
38 | demo-backends:
39 | uv run --group demo gradio spaces/lczerolens-backends-demo/app.py
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # lczerolens 🔍
4 |
5 | [](https://pypi.org/project/lczerolens/)
6 | [](https://github.com/Xmaster6y/lczerolens/blob/main/LICENSE)
7 | [](https://github.com/astral-sh/uv)
8 | [](https://www.python.org/downloads/)
9 |
10 | [](https://codecov.io/gh/Xmaster6y/lczerolens)
11 | 
12 | 
13 | 
14 | [](https://lczerolens.readthedocs.io/en/latest/?badge=latest)
15 |
16 |
17 |
18 |
19 | Leela Chess Zero (lc0) Lens (`lczerolens`): a set of utilities to make the analysis of Leela Chess Zero networks easy.
20 |
21 | ## Getting Started
22 |
23 | ### Installs
24 |
25 | ```bash
26 | pip install lczerolens
27 | ```
28 |
29 | Take the `viz` extra to render heatmaps and the `backends` extra to use the `lc0` backends.
30 |
31 | ### Run Models
32 |
33 | Get the best move predicted by a model:
34 |
35 | ```python
36 | from lczerolens import LczeroBoard, LczeroModel
37 |
38 | model = LczeroModel.from_onnx_path("demo/onnx-models/lc0-10-4238.onnx")
39 | board = LczeroBoard()
40 | output = model(board)
41 | legal_indices = board.get_legal_indices()
42 | best_move_idx = output["policy"].gather(
43 | dim=1,
44 | index=legal_indices.unsqueeze(0)
45 | ).argmax(dim=1).item()
46 | print(board.decode_move(legal_indices[best_move_idx]))
47 | ```
48 |
49 | ### Intervene
50 |
51 | In addition to the built-in lenses, you can easily create your own lenses by subclassing the `Lens` class and overriding the `_intervene` method. E.g. compute an neuron ablation effect:
52 |
53 | ```python
54 | from lczerolens import LczeroBoard, LczeroModel, Lens
55 |
56 | class CustomLens(Lens):
57 | _grad_enabled: bool = False
58 |
59 | def _intervene(self, model: LczeroModel, **kwargs) -> dict:
60 | ablate = kwargs.get("ablate", False)
61 | if ablate:
62 | l5_module = getattr(model, "block5/conv2/relu")
63 | l5_module.output[0, :, 0, 0] = 0 # relative a1
64 | return getattr(model, "output/wdl").output[0,0].save() # win probability
65 |
66 | model = LczeroModel.from_onnx_path("path/to/model.onnx")
67 | lens = CustomLens()
68 | board = LczeroBoard()
69 |
70 | clean_results = lens.analyse(model, board)
71 | corrupted_results = lens.analyse(model, board, ablate=True)
72 | print((corrupted_results - clean_results) / clean_results)
73 | ```
74 |
75 | ### Features
76 |
77 | - [Visualise Heatmaps](https://lczerolens.readthedocs.io/en/latest/notebooks/features/visualise-heatmaps.html): [](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/visualise-heatmaps.ipynb)
78 | - [Probe Concepts](https://lczerolens.readthedocs.io/en/latest/notebooks/features/probe-concepts.html): [](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/probe-concepts.ipynb)
79 | - [Move Prediction](https://lczerolens.readthedocs.io/en/latest/notebooks/features/move-prediction.html): [](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/move-prediction.ipynb)
80 | - [Run Models on GPU](https://lczerolens.readthedocs.io/en/latest/notebooks/features/run-models-on-gpu.html): [](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/run-models-on-gpu.ipynb)
81 | - [Evaluate Models on Puzzles](https://lczerolens.readthedocs.io/en/latest/notebooks/features/evaluate-models-on-puzzles.html): [](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/evaluate-models-on-puzzles.ipynb)
82 | - [Convert Official Weights](https://lczerolens.readthedocs.io/en/latest/notebooks/features/convert-official-weights.html): [](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/convert-official-weights.ipynb)
83 |
84 | ### Tutorials
85 |
86 | - [Walkthrough](https://lczerolens.readthedocs.io/en/latest/notebooks/walkthrough.html): [](https://colab.research.google.com/github/Xmaster6y/docs/source/notebooks/walkthrough.ipynb)
87 | - [Piece Value Estimation Using LRP](https://lczerolens.readthedocs.io/en/latest/notebooks/tutorials/piece-value-estimation-using-lrp.ipynb): [](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/piece-value-estimation-using-lrp.ipynb)
88 | - [Evidence of Learned Look-Ahead](https://lczerolens.readthedocs.io/en/latest/notebooks/tutorials/evidence-of-learned-look-ahead.ipynb): [](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/evidence-of-learned-look-ahead.ipynb)
89 | - [Train SAEs](https://lczerolens.readthedocs.io/en/latest/notebooks/tutorials/train-saes.ipynb): [](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/train-saes.ipynb)
90 |
91 | ## Demo
92 |
93 | ### Spaces
94 |
95 | - [Lczerolens Demo](https://huggingface.co/spaces/lczerolens/lczerolens-demo)
96 | - [Lczerolens Backends Demo](https://huggingface.co/spaces/lczerolens/lczerolens-backends-demo)
97 | - [Lczerolens Puzzles Leaderboard](https://huggingface.co/spaces/lczerolens/lichess-puzzles-leaderboard)
98 |
99 | ### Local Demo
100 |
101 | Additionally, you can run the gradio demos locally. First you'll need to clone the spaces (after cloning the repo):
102 |
103 | ```bash
104 | git clone https://huggingface.co/spaces/Xmaster6y/lczerolens-demo spaces/lczerolens-demo
105 | ```
106 |
107 | And optionally the backends demo:
108 |
109 | ```bash
110 | git clone https://huggingface.co/spaces/Xmaster6y/lczerolens-backends-demo spaces/lczerolens-backends-demo
111 | ```
112 |
113 | And then launch the demo (running on port `8000`):
114 |
115 | ```bash
116 | make demo
117 | ```
118 |
119 | To test the backends use:
120 |
121 | ```bash
122 | make demo-backends
123 | ```
124 |
125 | ## Full Documentation
126 |
127 | See the full [documentation](https://lczerolens.readthedocs.io).
128 |
129 | ## Contribute
130 |
131 | See the guidelines in [CONTRIBUTING.md](CONTRIBUTING.md).
132 |
--------------------------------------------------------------------------------
/assets/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 | !resolve-tests-assets.sh
4 | !resolve-demo-assets.sh
5 |
--------------------------------------------------------------------------------
/assets/resolve-demo-assets.sh:
--------------------------------------------------------------------------------
1 | make test-assets
2 |
3 | uv run gdown 1CGo_UzFins8pQWQ0JQRrAgIPo57PG11g -O assets/test_stockfish_5000.jsonl
4 |
--------------------------------------------------------------------------------
/assets/resolve-tests-assets.sh:
--------------------------------------------------------------------------------
1 | uv run gdown 1Ssl4JanqzQn3p-RoHRDk_aApykl-SukE -O assets/tinygyal-8.pb.gz
2 | uv run gdown 1WzBQV_zn5NnfsG0K8kOion0pvWxXhgKM -O assets/384x30-2022_0108_1903_17_608.pb.gz
3 | uv run gdown 1erxB3tULDURjpPhiPWVGr6X986Q8uE6U -O assets/maia-1100.pb.gz
4 | uv run gdown 1YqqANK-wuZIOmMweuK_oCU7vfPN7G_Z6 -O assets/t1-smolgen-512x15x8h-distilled-swa-3395000.pb.gz
5 | uv run gdown 15-eGN7Hz2NM6aEMRaQrbW3ScxxQpAqa5 -O assets/test_stockfish_10.jsonl
6 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/.DS_Store
--------------------------------------------------------------------------------
/docs/source/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | .button-group {
2 | display: flex;
3 | flex-direction: row;
4 | flex-wrap: nowrap;
5 | justify-content: flex-start;
6 | align-items: center;
7 | align-content: stretch;
8 | gap: 10px;
9 | }
10 |
11 | .hero {
12 | height: 90vh;
13 | display: flex;
14 | flex-direction: row;
15 | flex-wrap: nowrap;
16 | justify-content: center;
17 | align-items: center;
18 | align-content: stretch;
19 | padding-bottom: 20vh;
20 |
21 | overflow: hidden;
22 | }
23 |
24 | html[data-theme="light"] {
25 | --pst-color-secondary: #b676b2;
26 | --pst-color-primary: #4ca4d2;
27 | --pst-color-primary-highlight: #b676b2;
28 |
29 | }
30 |
31 |
32 | .bd-sidebar-primary {
33 | width: fit-content !important;
34 | padding-right: 40px !important
35 | }
36 |
37 |
38 | .features {
39 | height: 60vh;
40 | overflow: hidden;
41 | }
42 |
43 |
44 | .image-container {
45 | margin-top: 50px;
46 | }
47 |
48 | img {
49 | pointer-events: none;
50 | }
51 |
52 | .img-bottom-sizing {
53 | max-height: 10vh;
54 | width: auto
55 | }
56 |
57 | .body-sizing{
58 | height: 10vh;
59 | }
60 |
61 | .title-bot {
62 | margin-bottom: -10px !important;
63 | line-height: normal !important;
64 | }
65 |
66 | .sub-bot {
67 | margin-bottom: -10px !important;
68 | /* margin-top: -20px !important; */
69 | line-height: normal !important;
70 | }
71 |
72 | .features-container {
73 | img {
74 | max-width: none;
75 | }
76 |
77 | display: flex;
78 | gap: 20px;
79 | }
80 |
81 | @media only screen and (max-width: 768px) {
82 |
83 | /* Adjust this value based on your breakpoint for mobile */
84 | .front-container,
85 | .hero {
86 | height: auto;
87 | /* Change from fixed height to auto */
88 | min-height: 50vh;
89 | /* Adjust this as needed */
90 | }
91 |
92 | .features-container {
93 | margin-bottom: 20px;
94 | /* Increase bottom margin */
95 | }
96 |
97 | .hero {
98 | margin-bottom: 30px;
99 | /* Adjust the bottom margin of the main container */
100 | }
101 |
102 | .features {
103 | height: 110vh;
104 | }
105 | }
106 |
--------------------------------------------------------------------------------
/docs/source/_static/css/nbsphinx.css:
--------------------------------------------------------------------------------
1 | div.nboutput.container div.output_area:has(pre) {
2 | max-height: 600px;
3 | }
4 |
--------------------------------------------------------------------------------
/docs/source/_static/images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/images/.DS_Store
--------------------------------------------------------------------------------
/docs/source/_static/images/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/images/favicon.ico
--------------------------------------------------------------------------------
/docs/source/_static/images/lczerolens-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/images/lczerolens-logo.png
--------------------------------------------------------------------------------
/docs/source/_static/images/one.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/images/one.png
--------------------------------------------------------------------------------
/docs/source/_static/images/two.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/docs/source/_static/images/two.png
--------------------------------------------------------------------------------
/docs/source/_static/switcher.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "version": "dev",
4 | "url": "https://lczerolens.readthedocs.io/en/latest/"
5 | },
6 | {
7 | "version": "v0.3.3",
8 | "url": "https://lczerolens.readthedocs.io/en/v0.3.3/"
9 | },
10 | {
11 | "version": "v0.3.2",
12 | "url": "https://lczerolens.readthedocs.io/en/v0.3.2/"
13 | },
14 | {
15 | "version": "v0.3.1",
16 | "url": "https://lczerolens.readthedocs.io/en/v0.3.1/"
17 | },
18 | {
19 | "version": "v0.3.0",
20 | "url": "https://lczerolens.readthedocs.io/en/v0.3.0/"
21 | },
22 | {
23 | "version": "v0.2.0",
24 | "url": "https://lczerolens.readthedocs.io/en/v0.2.0/"
25 | },
26 | {
27 | "version": "v0.1.2",
28 | "url": "https://lczerolens.readthedocs.io/en/v0.1.2/"
29 | }
30 | ]
31 |
--------------------------------------------------------------------------------
/docs/source/about.rst:
--------------------------------------------------------------------------------
1 |
2 | About lczerolens
3 | ================
4 |
5 | Goal
6 | ----
7 |
8 | Provide analysis tools for lc0 networks with ``torch``.
9 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 |
3 | import os
4 |
5 | import lczerolens
6 |
7 | # Project Information
8 | project = "lczerolens"
9 | copyright = "2024, Yoann Poupart"
10 | author = "Yoann Poupart"
11 |
12 |
13 | # General Configuration
14 | extensions = [
15 | # 'sphinx.ext.autosectionlabel',
16 | "sphinx.ext.autodoc", # Auto documentation from docstrings
17 | "sphinx.ext.napoleon", # Support for NumPy and Google style docstrings
18 | "sphinx.ext.viewcode", # View code in the browser
19 | "sphinx_copybutton", # Copy button for code blocks
20 | "sphinx_design", # Boostrap design components
21 | "nbsphinx", # Jupyter notebook support
22 | "autoapi.extension",
23 | ]
24 |
25 | templates_path = ["_templates"]
26 | exclude_patterns = [] # type: ignore
27 | fixed_sidebar = True
28 |
29 |
30 | # HTML Output Options
31 |
32 | # See https://sphinx-themes.org/ for more
33 | html_theme = "pydata_sphinx_theme"
34 | html_title = "lczerolens"
35 | html_logo = "_static/images/lczerolens-logo.svg"
36 | html_static_path = ["_static"]
37 |
38 | html_favicon = "_static/images/favicon.ico"
39 | html_show_sourcelink = False
40 |
41 | # Define the json_url for our version switcher.
42 | json_url = "https://lczerolens.readthedocs.io/en/latest/_static/switcher.json"
43 |
44 |
45 | version_match = os.environ.get("READTHEDOCS_VERSION")
46 | release = lczerolens.__version__
47 | # If READTHEDOCS_VERSION doesn't exist, we're not on RTD
48 | # If it is an integer, we're in a PR build and the version isn't correct.
49 | # If it's "latest" → change to "dev"
50 | if not version_match or version_match.isdigit() or version_match == "latest":
51 | # For local development, infer the version to match from the package.
52 | if "dev" in release or "rc" in release:
53 | version_match = "dev"
54 | # We want to keep the relative reference if we are in dev mode
55 | # but we want the whole url if we are effectively in a released version
56 | json_url = "_static/switcher.json"
57 | else:
58 | version_match = f"v{release}"
59 | elif version_match == "stable":
60 | version_match = f"v{release}"
61 |
62 | html_theme_options = {
63 | "show_nav_level": 2,
64 | "navigation_depth": 2,
65 | "show_toc_level": 2,
66 | "navbar_end": ["theme-switcher", "navbar-icon-links"],
67 | "navbar_align": "left",
68 | "icon_links": [
69 | {
70 | "name": "GitHub",
71 | "url": "https://github.com/Xmaster6y/lczerolens",
72 | "icon": "fa-brands fa-github",
73 | },
74 | {
75 | "name": "Discord",
76 | "url": "https://discord.gg/e7vhrTsjnt",
77 | "icon": "fa-brands fa-discord",
78 | },
79 | {
80 | "name": "PyPI",
81 | "url": "https://pypi.org/project/pydata-sphinx-theme",
82 | "icon": "fa-custom fa-pypi",
83 | },
84 | ],
85 | "show_version_warning_banner": True,
86 | "navbar_center": ["version-switcher", "navbar-nav"],
87 | "footer_start": ["copyright"],
88 | "footer_center": ["sphinx-version"],
89 | "switcher": {
90 | "json_url": json_url,
91 | "version_match": version_match,
92 | },
93 | }
94 | html_sidebars = {"about": [], "start": []}
95 |
96 | html_context = {"default_mode": "auto"}
97 |
98 | html_css_files = [
99 | "css/custom.css",
100 | "css/nbsphinx.css",
101 | ]
102 |
103 | # Nbsphinx
104 | nbsphinx_execute = "auto"
105 |
106 | # Autoapi
107 | autoapi_dirs = ["../../src"]
108 | autoapi_root = "api"
109 | autoapi_keep_files = False
110 | autodoc_typehints = "description"
111 |
--------------------------------------------------------------------------------
/docs/source/features.rst:
--------------------------------------------------------------------------------
1 | Features
2 | ==========
3 |
4 | .. raw:: html
5 |
6 |
13 |
14 |
19 |
20 | .. grid:: 2
21 | :gutter: 3
22 |
23 | .. grid-item-card::
24 | :link: notebooks/features/visualise-heatmaps.ipynb
25 | :class-card: surface
26 | :class-body: surface
27 |
28 | .. raw:: html
29 |
30 |
31 |
32 |
33 |
34 |
35 |
Visualise Heatmaps
36 |
Visualise saliency heatmaps or encodings for a given chess board.
37 |
38 |
39 |
40 | .. grid-item-card::
41 | :link: notebooks/features/probe-concepts.ipynb
42 | :class-card: surface
43 | :class-body: surface
44 |
45 | .. raw:: html
46 |
47 |
48 |
49 |
50 |
51 |
52 |
Probe Concepts
53 |
Probe the concepts with a dataset.
54 |
55 |
56 |
57 | .. grid-item-card::
58 | :link: notebooks/features/move-prediction.ipynb
59 | :class-card: surface
60 | :class-body: surface
61 |
62 | .. raw:: html
63 |
64 |
65 |
66 |
67 |
68 |
69 |
Move Prediction
70 |
Make a move prediction for a given chess board.
71 |
72 |
73 |
74 | .. grid-item-card::
75 | :link: notebooks/features/run-models-on-gpu.ipynb
76 | :class-card: surface
77 | :class-body: surface
78 |
79 | .. raw:: html
80 |
81 |
82 |
83 |
84 |
85 |
86 |
Run Models on GPU
87 |
Take advantage of GPU acceleration.
88 |
89 |
90 |
91 | .. grid-item-card::
92 | :link: notebooks/features/evaluate-models-on-puzzles.ipynb
93 | :class-card: surface
94 | :class-body: surface
95 |
96 | .. raw:: html
97 |
98 |
99 |
100 |
101 |
102 |
103 |
Evaluate Models on Puzzles
104 |
Evaluate a model on a set of puzzles.
105 |
106 |
107 |
108 | .. grid-item-card::
109 | :link: notebooks/features/convert-official-weights.ipynb
110 | :class-card: surface
111 | :class-body: surface
112 |
113 | .. raw:: html
114 |
115 |
116 |
117 |
118 |
119 |
120 |
Convert Official Weights
121 |
Convert lc0 networks to onnx.
122 |
123 |
124 |
125 | .. toctree::
126 | :hidden:
127 | :maxdepth: 2
128 |
129 | notebooks/features/visualise-heatmaps.ipynb
130 | notebooks/features/probe-concepts.ipynb
131 | notebooks/features/move-prediction.ipynb
132 | notebooks/features/run-models-on-gpu.ipynb
133 | notebooks/features/evaluate-models-on-puzzles.ipynb
134 | notebooks/features/convert-official-weights.ipynb
135 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | :html_theme.sidebar_secondary.remove: true
2 | :sd_hide_title:
3 |
4 | lczerolens
5 | ==========
6 |
7 | .. toctree::
8 | :maxdepth: 1
9 | :hidden:
10 |
11 | start
12 | features
13 | tutorials
14 | api/index
15 | About
16 |
17 | .. grid:: 1 1 2 2
18 | :class-container: hero
19 | :reverse:
20 |
21 | .. grid-item::
22 | .. div::
23 |
24 | .. image:: _static/images/lczerolens-logo.svg
25 | :width: 300
26 | :height: 300
27 |
28 | .. grid-item::
29 |
30 | .. div:: sd-fs-1 sd-font-weight-bold title-bot sd-text-primary image-container
31 |
32 | LczeroLens
33 |
34 | .. div:: sd-fs-4 sd-font-weight-bold sd-my-0 sub-bot image-container
35 |
36 | Interpretability for lc0 networks
37 |
38 | **lczerolens** is a package for interpreting and manipulating the neural networks produce by lc0
39 |
40 | .. div:: button-group
41 |
42 | .. button-ref:: start
43 | :color: primary
44 | :shadow:
45 |
46 | Get Started
47 |
48 | .. button-ref:: tutorials
49 | :color: primary
50 | :outline:
51 |
52 | Tutorials
53 |
54 | .. button-ref:: api/index
55 | :color: primary
56 | :outline:
57 |
58 | API Reference
59 |
60 |
61 | .. div:: sd-fs-1 sd-font-weight-bold sd-text-center sd-text-primary sd-mb-5
62 |
63 | Key Features
64 |
65 | .. grid:: 1 1 2 2
66 | :class-container: features
67 |
68 | .. grid-item::
69 |
70 | .. div:: features-container
71 |
72 | .. image:: _static/images/one.png
73 | :width: 150
74 |
75 | .. div::
76 |
77 | **Adaptability**
78 |
79 | Load a network from lc0 (``.pb`` or ``.onnx``) and load it with lczerolens using ``torch``.
80 |
81 | .. grid-item::
82 |
83 | .. div:: features-container
84 |
85 | .. image:: _static/images/two.png
86 | :width: 150
87 |
88 | .. div::
89 |
90 | **Interpretability**
91 |
92 | Easily compute saliency maps or aggregated statistics using the pre-built Interpretability methods.
93 |
--------------------------------------------------------------------------------
/docs/source/notebooks/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !**/
3 | !.gitignore
4 | !*.ipynb
5 | *.nbconvert.ipynb
6 |
--------------------------------------------------------------------------------
/docs/source/notebooks/features/convert-official-weights.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "-qCMaHT6xuoB"
7 | },
8 | "source": [
9 | "# Convert Official Weights\n",
10 | "\n",
11 | "[](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/features/convert-official-weights.ipynb)"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {
17 | "id": "VtZdxVZZx2pL"
18 | },
19 | "source": [
20 | "## Setup"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 1,
26 | "metadata": {
27 | "id": "iMj257hVxlgJ"
28 | },
29 | "outputs": [],
30 | "source": [
31 | "MODE = \"local\" # \"colab\" | \"colab-dev\" | \"local\""
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {
38 | "colab": {
39 | "base_uri": "https://localhost:8080/"
40 | },
41 | "id": "WtTH-Oz-yotw",
42 | "outputId": "62f302e9-5f63-46ab-82a3-0d77ec054d0d"
43 | },
44 | "outputs": [],
45 | "source": [
46 | "if MODE == \"colab\":\n",
47 | " !pip install -q lczerolens\n",
48 | "elif MODE == \"colab-dev\":\n",
49 | " !rm -r lczerolens\n",
50 | " !git clone https://github.com/Xmaster6y/lczerolens -b main\n",
51 | " !pip install -q ./lczerolens"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {
58 | "colab": {
59 | "base_uri": "https://localhost:8080/"
60 | },
61 | "id": "_GdhCPhtyoj3",
62 | "outputId": "0c3dfe61-3dbf-41ee-f6ec-cd524c713daf"
63 | },
64 | "outputs": [],
65 | "source": [
66 | "!gdown 1erxB3tULDURjpPhiPWVGr6X986Q8uE6U -O leela-network.pb.gz"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "metadata": {
72 | "id": "KjLtlNV95WWx"
73 | },
74 | "source": [
75 | "To convert a network you'll need to have installed the `lc0` binaries (**takes about 10 minutes**):"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 4,
81 | "metadata": {
82 | "colab": {
83 | "base_uri": "https://localhost:8080/"
84 | },
85 | "id": "IuQ4hpbv5Z-o",
86 | "outputId": "338038e6-3d07-4a43-e9f0-b7797555e9ce"
87 | },
88 | "outputs": [
89 | {
90 | "name": "stdout",
91 | "output_type": "stream",
92 | "text": [
93 | " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
94 | " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
95 | " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
96 | " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
97 | " Building wheel for lczero_bindings (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
98 | ]
99 | }
100 | ],
101 | "source": [
102 | "!pip install -q git+https://github.com/LeelaChessZero/lc0"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "metadata": {
108 | "id": "AM3K3PCgx8g0"
109 | },
110 | "source": [
111 | "## Convert a Model\n",
112 | "\n",
113 | "You can convert networks to `onnx` using the official `lc0` binaries or\n",
114 | "by using the `backends` module:"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 6,
120 | "metadata": {
121 | "colab": {
122 | "base_uri": "https://localhost:8080/",
123 | "height": 120
124 | },
125 | "id": "f-6vcmwEyb7n",
126 | "outputId": "4989f19d-27cb-4cb7-bf01-28eaad64f405"
127 | },
128 | "outputs": [
129 | {
130 | "name": "stdout",
131 | "output_type": "stream",
132 | "text": [
133 | "\n",
134 | "Format\n",
135 | "~~~~~~\n",
136 | " Weights encoding: LINEAR16\n",
137 | " Input: INPUT_CLASSICAL_112_PLANE\n",
138 | " Network: NETWORK_SE_WITH_HEADFORMAT\n",
139 | " Policy: POLICY_CONVOLUTION\n",
140 | " Value: VALUE_WDL\n",
141 | "\n",
142 | "Weights\n",
143 | "~~~~~~~\n",
144 | " Blocks: 6\n",
145 | " SE blocks: 6\n",
146 | " Filters: 64\n",
147 | " Policy: Convolution\n",
148 | " Policy activation: ACTIVATION_DEFAULT\n",
149 | " Value: WDL\n",
150 | " MLH: Absent\n",
151 | "Converting Leela network to the ONNX.\n",
152 | "\n",
153 | "ONNX interface\n",
154 | "~~~~~~~~~~~~~~\n",
155 | " Data type: FLOAT\n",
156 | " Input planes: /input/planes\n",
157 | " Output WDL: /output/wdl\n",
158 | " Output Policy: /output/policy\n",
159 | "Done.\n",
160 | "\n"
161 | ]
162 | }
163 | ],
164 | "source": [
165 | "from lczerolens import backends\n",
166 | "\n",
167 | "output = backends.convert_to_onnx(\"leela-network.pb.gz\", \"leela-network.onnx\")\n",
168 | "print(output)"
169 | ]
170 | },
171 | {
172 | "cell_type": "markdown",
173 | "metadata": {
174 | "id": "P3RrLQHF5qkI"
175 | },
176 | "source": [
177 | "See [Move Prediction](move-prediction.ipynb) to see how to use the converted network."
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {
183 | "id": "qBzY4spb5nGg"
184 | },
185 | "source": [
186 | "## Note\n",
187 | "\n",
188 | "Only the latest networks are supported. To convert older weights, you should build the associated binaries."
189 | ]
190 | }
191 | ],
192 | "metadata": {
193 | "colab": {
194 | "provenance": []
195 | },
196 | "kernelspec": {
197 | "display_name": "Python 3",
198 | "name": "python3"
199 | },
200 | "language_info": {
201 | "name": "python",
202 | "version": "3.11.10"
203 | }
204 | },
205 | "nbformat": 4,
206 | "nbformat_minor": 0
207 | }
208 |
--------------------------------------------------------------------------------
/docs/source/notebooks/tutorials/automated-interpretability.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Automated Interpretability\n",
8 | "\n",
9 | "[](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/automated-interpretability.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Setup"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 1,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "MODE = \"local\" # \"colab\" | \"colab-dev\" | \"local\""
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "if MODE == \"colab\":\n",
35 | " !pip install -q lczerolens\n",
36 | "elif MODE == \"colab-dev\":\n",
37 | " !rm -r lczerolens\n",
38 | " !git clone https://github.com/Xmaster6y/lczerolens -b main\n",
39 | " !pip install -q ./lczerolens"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "name": "stdout",
49 | "output_type": "stream",
50 | "text": [
51 | "Downloading...\n",
52 | "From: https://drive.google.com/uc?id=15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X\n",
53 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/tutorials/lc0-19-1876.onnx\n",
54 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:02<00:00, 36.6MB/s]\n",
55 | "Downloading...\n",
56 | "From: https://drive.google.com/uc?id=1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd\n",
57 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/tutorials/lc0-19-4508.onnx\n",
58 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:02<00:00, 35.8MB/s]\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "!gdown 15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X -O lc0-19-1876.onnx\n",
64 | "!gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O lc0-19-4508.onnx"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "metadata": {},
70 | "source": [
71 | "## Load a Model\n",
72 | "\n",
73 | "Load a leela network from file (already converted to `onnx`):"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 4,
79 | "metadata": {},
80 | "outputs": [
81 | {
82 | "name": "stderr",
83 | "output_type": "stream",
84 | "text": [
85 | "/Users/xmaster/Work/lczerolens/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
86 | " from .autonotebook import tqdm as notebook_tqdm\n"
87 | ]
88 | }
89 | ],
90 | "source": [
91 | "from lczerolens import LczeroModel\n",
92 | "\n",
93 | "strong_model = LczeroModel.from_path(\"lc0-19-4508.onnx\")\n",
94 | "weak_model = LczeroModel.from_path(\"lc0-19-1876.onnx\")"
95 | ]
96 | }
97 | ],
98 | "metadata": {
99 | "kernelspec": {
100 | "display_name": ".venv",
101 | "language": "python",
102 | "name": "python3"
103 | },
104 | "language_info": {
105 | "codemirror_mode": {
106 | "name": "ipython",
107 | "version": 3
108 | },
109 | "file_extension": ".py",
110 | "mimetype": "text/x-python",
111 | "name": "python",
112 | "nbconvert_exporter": "python",
113 | "pygments_lexer": "ipython3",
114 | "version": "3.11.10"
115 | }
116 | },
117 | "nbformat": 4,
118 | "nbformat_minor": 2
119 | }
120 |
--------------------------------------------------------------------------------
/docs/source/notebooks/tutorials/piece-value-estimation-using-lrp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Piece Value Estimation Using LRP\n",
8 | "\n",
9 | "[](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/tutorials/piece-value-estimation-using-lrp.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Setup"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 1,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "MODE = \"local\" # \"colab\" | \"colab-dev\" | \"local\""
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "if MODE == \"colab\":\n",
35 | " !pip install -q lczerolens\n",
36 | "elif MODE == \"colab-dev\":\n",
37 | " !rm -r lczerolens\n",
38 | " !git clone https://github.com/Xmaster6y/lczerolens -b main\n",
39 | " !pip install -q ./lczerolens"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "name": "stdout",
49 | "output_type": "stream",
50 | "text": [
51 | "Downloading...\n",
52 | "From: https://drive.google.com/uc?id=15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X\n",
53 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/tutorials/lc0-19-1876.onnx\n",
54 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:03<00:00, 26.8MB/s]\n",
55 | "Downloading...\n",
56 | "From: https://drive.google.com/uc?id=1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd\n",
57 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/tutorials/lc0-19-4508.onnx\n",
58 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:03<00:00, 25.6MB/s]\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "!gdown 15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X -O lc0-19-1876.onnx\n",
64 | "!gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O lc0-19-4508.onnx"
65 | ]
66 | }
67 | ],
68 | "metadata": {
69 | "kernelspec": {
70 | "display_name": ".venv",
71 | "language": "python",
72 | "name": "python3"
73 | },
74 | "language_info": {
75 | "codemirror_mode": {
76 | "name": "ipython",
77 | "version": 3
78 | },
79 | "file_extension": ".py",
80 | "mimetype": "text/x-python",
81 | "name": "python",
82 | "nbconvert_exporter": "python",
83 | "pygments_lexer": "ipython3",
84 | "version": "3.9.18"
85 | }
86 | },
87 | "nbformat": 4,
88 | "nbformat_minor": 2
89 | }
90 |
--------------------------------------------------------------------------------
/docs/source/notebooks/walkthrough.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Walkthrough\n",
8 | "\n",
9 | "[](https://colab.research.google.com/github/Xmaster6y/lczerolens/blob/main/docs/source/notebooks/walkthrough.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Setup"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 1,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "MODE = \"local\" # \"colab\" | \"colab-dev\" | \"local\""
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "if MODE == \"colab\":\n",
35 | " !pip install -q lczerolens\n",
36 | "elif MODE == \"colab-dev\":\n",
37 | " !rm -r lczerolens\n",
38 | " !git clone https://github.com/Xmaster6y/lczerolens -b main\n",
39 | " !pip install -q ./lczerolens"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "name": "stdout",
49 | "output_type": "stream",
50 | "text": [
51 | "Downloading...\n",
52 | "From: https://drive.google.com/uc?id=15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X\n",
53 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/lc0-19-1876.onnx\n",
54 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:02<00:00, 32.9MB/s]\n",
55 | "Downloading...\n",
56 | "From: https://drive.google.com/uc?id=1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd\n",
57 | "To: /Users/xmaster/Work/lczerolens/docs/source/notebooks/lc0-19-4508.onnx\n",
58 | "100%|██████████████████████████████████████| 97.1M/97.1M [00:03<00:00, 30.1MB/s]\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "!gdown 15__7FHvIR5-JbJvDg2eGUhIPZpkYyM7X -O lc0-19-1876.onnx\n",
64 | "!gdown 1CvMyX3KuYxCJUKz9kOb9VX8zIkfISALd -O lc0-19-4508.onnx"
65 | ]
66 | }
67 | ],
68 | "metadata": {
69 | "kernelspec": {
70 | "display_name": ".venv",
71 | "language": "python",
72 | "name": "python3"
73 | },
74 | "language_info": {
75 | "codemirror_mode": {
76 | "name": "ipython",
77 | "version": 3
78 | },
79 | "file_extension": ".py",
80 | "mimetype": "text/x-python",
81 | "name": "python",
82 | "nbconvert_exporter": "python",
83 | "pygments_lexer": "ipython3",
84 | "version": "3.11.10"
85 | }
86 | },
87 | "nbformat": 4,
88 | "nbformat_minor": 2
89 | }
90 |
--------------------------------------------------------------------------------
/docs/source/start.rst:
--------------------------------------------------------------------------------
1 | Getting Started
2 | ===============
3 |
4 | **lczerolens** is a package for running interpretability methods on lc0 models.
5 | It is designed to be easy to use and to work with the most common interpretability
6 | techniques.
7 |
8 | .. _installation:
9 |
10 | Installation
11 | ------------
12 |
13 | To get started with lczerolens, install it with ``pip``.
14 |
15 | .. code-block:: console
16 |
17 | pip install lczerolens
18 |
19 | .. note::
20 |
21 | The dependencies for lczerolens are currently substantial.
22 | It mainly depends on ``torch``, ``nnsight``, ``zennit`` and ``datasets``.
23 |
24 | First Steps
25 | -----------
26 |
27 | .. grid:: 2
28 | :gutter: 2
29 |
30 | .. grid-item-card:: Features
31 | :link: features
32 | :link-type: doc
33 |
34 | Review the basic features provided by :bdg-primary:`lczerolens`.
35 |
36 | .. grid-item-card:: Walkthrough
37 | :link: notebooks/walkthrough.ipynb
38 |
39 | Walk through the basic functionality of the package.
40 |
41 | .. note::
42 |
43 | Check out the :bdg-secondary:`walkthrough` to get a better understanding of the package.
44 |
45 | Advanced Features
46 | -----------------
47 |
48 | .. warning::
49 |
50 | This following section is under construction, not yet stable nor fully functional.
51 |
52 | .. grid:: 2
53 | :gutter: 2
54 |
55 | .. grid-item-card:: Tutorials
56 | :link: tutorials
57 | :link-type: doc
58 |
59 | See implementations of :bdg-primary:`lczerolens` through common interpretability techniques.
60 |
61 | .. grid-item-card:: API Reference
62 | :link: api/index
63 | :link-type: doc
64 |
65 | See the full API reference for :bdg-primary:`lczerolens` to extend its functionality.
66 |
--------------------------------------------------------------------------------
/docs/source/tutorials.rst:
--------------------------------------------------------------------------------
1 | Tutorials
2 | =========
3 |
4 | .. grid:: 2
5 | :gutter: 2
6 |
7 | .. grid-item-card:: Walkthrough
8 | :link: notebooks/walkthrough.ipynb
9 |
10 | :bdg-primary:`Main Features`
11 |
12 | .. grid-item-card:: Evidence of Learned Look-Ahead
13 | :link: notebooks/tutorials/evidence-of-learned-look-ahead.ipynb
14 |
15 | :bdg-primary:`Main Features`
16 |
17 | .. grid-item-card:: Piece Value Estimation Using LRP
18 | :link: notebooks/tutorials/piece-value-estimation-using-lrp.ipynb
19 |
20 | :bdg-primary:`Main Features`
21 |
22 | .. grid-item-card:: Train SAEs
23 | :link: notebooks/tutorials/train-saes.ipynb
24 |
25 | .. toctree::
26 | :hidden:
27 | :maxdepth: 2
28 |
29 | notebooks/walkthrough.ipynb
30 | notebooks/tutorials/evidence-of-learned-look-ahead.ipynb
31 | notebooks/tutorials/piece-value-estimation-using-lrp.ipynb
32 | notebooks/tutorials/train-saes.ipynb
33 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "lczerolens"
3 | version = "0.3.3-dev"
4 | description = "Interpretability for LeelaChessZero networks."
5 | readme = "README.md"
6 | license = "MIT"
7 | license-files = ["LICENSE"]
8 | authors = [
9 | {name = "Yoann Poupart", email = "yoann.poupart@ens-lyon.org"},
10 | ]
11 | requires-python = ">=3.10"
12 | dependencies = [
13 | "datasets>=3.2.0",
14 | "einops>=0.8.0",
15 | "jsonlines>=4.0.0",
16 | "nnsight>=0.3.7,<0.4.0",
17 | "onnx2torch>=1.5.15",
18 | "python-chess>=1.999",
19 | "scikit-learn>=1.6.1",
20 | "tensordict>=0.6.2",
21 | "typing-extensions>=4.12.2",
22 | ]
23 | classifiers = [
24 | "Development Status :: 3 - Alpha",
25 | "Intended Audience :: Science/Research",
26 | "Programming Language :: Python :: 3",
27 | "Programming Language :: Python :: 3.10",
28 | "Programming Language :: Python :: 3.11",
29 | "Programming Language :: Python :: 3.12",
30 | ]
31 |
32 | [project.urls]
33 | homepage = "https://lczerolens.readthedocs.io/"
34 | documentation = "https://lczerolens.readthedocs.io/"
35 | source = "https://github.com/Xmaster6y/lczerolens"
36 | issues = "https://github.com/Xmaster6y/lczerolens/issues"
37 | releasenotes = "https://github.com/Xmaster6y/lczerolens/releases"
38 |
39 | [project.optional-dependencies]
40 | viz = [
41 | "matplotlib>=3.10.0",
42 | ]
43 | backends = [
44 | "v-lczero-bindings>=0.31.2"
45 | ]
46 |
47 | [dependency-groups]
48 | dev = [
49 | "gdown>=5.2.0",
50 | "ipykernel>=6.29.5",
51 | "nbconvert>=7.16.5",
52 | "onnxruntime>=1.20.1",
53 | "pre-commit>=4.0.1",
54 | "pytest>=8.3.4",
55 | "pytest-cov>=6.0.0",
56 | "v-lczero-bindings>=0.31.2",
57 | ]
58 | demo = [
59 | "gradio>=5.12.0",
60 | "matplotlib>=3.10.0",
61 | "v-lczero-bindings>=0.31.2",
62 | ]
63 | docs = [
64 | "nbsphinx>=0.9.6",
65 | "pandoc>=2.4",
66 | "plotly>=5.24.1",
67 | "pydata-sphinx-theme>=0.16.1",
68 | "sphinx>=8.1.3",
69 | "sphinx-autoapi>=3.6.0",
70 | "sphinx-charts>=0.2.1",
71 | "sphinx-copybutton>=0.5.2",
72 | "sphinx-design>=0.6.1",
73 | ]
74 | scripts = [
75 | "loguru>=0.7.3",
76 | "matplotlib>=3.10.0",
77 | "pylatex>=1.4.2",
78 | "safetensors>=0.5.2",
79 | "wandb>=0.19.2",
80 | ]
81 |
82 | [build]
83 | target-dir = "build/dist"
84 |
85 | [build-system]
86 | requires = ["setuptools"]
87 | build-backend = "setuptools.build_meta"
88 |
89 | [tool.uv]
90 | default-groups = ["dev"]
91 |
92 | [tool.ruff]
93 | line-length = 119
94 | target-version = "py311"
95 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/constants.py:
--------------------------------------------------------------------------------
1 | """Constants for the scripts."""
2 |
3 | import os
4 |
5 | # Secrets
6 | HF_TOKEN = os.getenv("HF_TOKEN")
7 | WANDB_API_KEY = os.getenv("WANDB_API_KEY")
8 |
--------------------------------------------------------------------------------
/scripts/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/scripts/datasets/__init__.py
--------------------------------------------------------------------------------
/scripts/datasets/make_lichess_dataset.py:
--------------------------------------------------------------------------------
1 | """Script to generate the base datasets.
2 |
3 | Run with:
4 | ```bash
5 | uv run python -m scripts.datasets.make_lichess_dataset
6 | ```
7 | """
8 |
9 | import argparse
10 |
11 | from datasets import Dataset
12 | from loguru import logger
13 |
14 | from scripts.constants import HF_TOKEN
15 |
16 |
17 | def main(args: argparse.Namespace):
18 | logger.info(f"Loading `{args.source_file}`...")
19 |
20 | dataset = Dataset.from_csv(args.source_file)
21 | logger.info(f"Loaded dataset: {dataset}")
22 |
23 | if args.push_to_hub:
24 | logger.info("Pushing to hub...")
25 | dataset.push_to_hub(
26 | repo_id=args.dataset_name,
27 | token=HF_TOKEN,
28 | )
29 |
30 |
31 | def parse_args() -> argparse.Namespace:
32 | parser = argparse.ArgumentParser("make-lichess-dataset")
33 | parser.add_argument(
34 | "--source_file",
35 | type=str,
36 | default="./assets/lichess_db_puzzle.csv",
37 | )
38 | parser.add_argument(
39 | "--dataset_name",
40 | type=str,
41 | default="lczerolens/lichess-puzzles",
42 | )
43 | parser.add_argument("--push_to_hub", action=argparse.BooleanOptionalAction, default=False)
44 | return parser.parse_args()
45 |
46 |
47 | if __name__ == "__main__":
48 | args = parse_args()
49 | main(args)
50 |
--------------------------------------------------------------------------------
/scripts/datasets/make_tcec_dataset.py:
--------------------------------------------------------------------------------
1 | """Script to generate the base datasets.
2 |
3 | Run with:
4 | ```bash
5 | uv run python -m scripts.datasets.make_tcec_dataset
6 | ```
7 | """
8 |
9 | import argparse
10 |
11 | from datasets import Dataset
12 | from loguru import logger
13 |
14 | from scripts.constants import HF_TOKEN
15 |
16 |
17 | def main(args: argparse.Namespace):
18 | logger.info(f"Loading `{args.source_file}`...")
19 |
20 | dataset = Dataset.from_json(args.source_file)
21 | logger.info(f"Loaded dataset: {dataset}")
22 |
23 | if args.push_to_hub:
24 | logger.info("Pushing to hub...")
25 | dataset.push_to_hub(
26 | repo_id=args.dataset_name,
27 | token=HF_TOKEN,
28 | )
29 |
30 |
31 | def parse_args() -> argparse.Namespace:
32 | parser = argparse.ArgumentParser("make-tcec-dataset")
33 | parser.add_argument(
34 | "--source_file",
35 | type=str,
36 | default="./assets/tcec-games.jsonl",
37 | )
38 | parser.add_argument(
39 | "--dataset_name",
40 | type=str,
41 | default="lczerolens/tcec-games",
42 | )
43 | parser.add_argument("--push_to_hub", action=argparse.BooleanOptionalAction, default=False)
44 | return parser.parse_args()
45 |
46 |
47 | if __name__ == "__main__":
48 | args = parse_args()
49 | main(args)
50 |
--------------------------------------------------------------------------------
/scripts/lrp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/scripts/lrp/__init__.py
--------------------------------------------------------------------------------
/scripts/lrp/plane_analysis.py:
--------------------------------------------------------------------------------
1 | """Script to compute the importance of each plane for the model.
2 |
3 | Run with:
4 | ```
5 | uv run python -m scripts.lrp.plane_analysis
6 | ```
7 | """
8 |
9 | import argparse
10 | from loguru import logger
11 |
12 | from datasets import Dataset
13 | from torch.utils.data import DataLoader
14 | import torch
15 |
16 | from lczerolens.encodings import move as move_encoding
17 | from lczerolens.concept import MulticlassConcept
18 | from lczerolens.model import ForceValueFlow, PolicyFlow
19 | from lczerolens import concept, Lens
20 | from scripts import visualisation
21 |
22 |
23 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 |
25 |
26 | def main(args):
27 | dataset = Dataset.from_json(
28 | "./assets/TCEC_game_collection_random_boards_bestlegal.jsonl", features=MulticlassConcept.features
29 | )
30 | logger.info(f"Loaded dataset with {len(dataset)} boards.")
31 | if args.target == "policy":
32 | wrapper = PolicyFlow.from_path(f"./assets/{args.model_name}").to(DEVICE)
33 | init_rel_fn = concept.concept_init_rel
34 |
35 | elif args.target == "value":
36 | wrapper = ForceValueFlow.from_path(f"./assets/{args.model_name}").to(DEVICE)
37 | init_rel_fn = None
38 | else:
39 | raise ValueError(f"Target '{args.target}' not supported.")
40 | lens = Lens.from_name("lrp")
41 | if not lens.is_compatible(wrapper):
42 | raise ValueError(f"Lens of type 'lrp' not compatible with model '{args.model_name}'.")
43 |
44 | dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=concept.concept_collate_fn)
45 |
46 | iter_analyse = lens.analyse_batched_boards(
47 | dataloader,
48 | wrapper,
49 | target=None,
50 | return_output=True,
51 | init_rel_fn=init_rel_fn,
52 | )
53 | all_stats = {
54 | "relative_piece_relevance": [],
55 | "absolute_piece_relevance": [],
56 | "plane_relevance_proportion": [],
57 | "relative_piece_relevance_proportion": [],
58 | "absolute_piece_relevance_proportion": [],
59 | }
60 | n_plotted = 0
61 | for batch in iter_analyse:
62 | batched_relevances, boards, *infos = batch
63 | relevances, outputs = batched_relevances
64 | labels = infos[0]
65 | for rel, out, board, label in zip(relevances, outputs, boards, labels):
66 | max_config_rel = rel[:12].abs().max().item()
67 | if max_config_rel == 0:
68 | continue
69 | if n_plotted < args.plot_first_n:
70 | if board.turn:
71 | heatmap = rel.sum(dim=0).view(64)
72 | else:
73 | heatmap = rel.sum(dim=0).flip(0).view(64)
74 | if args.target == "policy":
75 | move = move_encoding.decode_move(label, (board.turn, not board.turn), board)
76 | else:
77 | move = None
78 | visualisation.render_heatmap(
79 | board,
80 | heatmap,
81 | arrows=[(move.from_square, move.to_square)] if move is not None else None,
82 | normalise="abs",
83 | save_to=f"./scripts/results/{args.target}_heatmap_{n_plotted}.png",
84 | )
85 | n_plotted += 1
86 |
87 | plane_order = "PNBRQKpnbrqk"
88 | piece_relevance = {}
89 | for i, letter in enumerate(plane_order):
90 | num = (rel[i] != 0).sum().item()
91 | if num == 0:
92 | piece_relevance[letter] = 0
93 | else:
94 | piece_relevance[letter] = rel[i].sum().item() / num
95 |
96 | if args.find_interesting:
97 | if piece_relevance["q"] / max_config_rel > 0.9 and args.target == "value":
98 | if board.turn:
99 | heatmap = rel.sum(dim=0).view(64)
100 | else:
101 | heatmap = rel.sum(dim=0).flip(0).view(64)
102 | if args.target == "policy":
103 | move = move_encoding.decode_move(label, (board.turn, not board.turn), board)
104 | else:
105 | move = None
106 | visualisation.render_heatmap(
107 | board,
108 | heatmap,
109 | arrows=[(move.from_square, move.to_square)] if move is not None else None,
110 | normalise="abs",
111 | save_to=f"./scripts/results/{args.target}_heatmap_{n_plotted}.png",
112 | )
113 | raise SystemExit
114 |
115 | if any(piece_relevance[k] / max_config_rel > 0.9 for k in "pnbrqk") and args.target == "policy":
116 | if board.turn:
117 | heatmap = rel.sum(dim=0).view(64)
118 | else:
119 | heatmap = rel.sum(dim=0).flip(0).view(64)
120 | if args.target == "policy":
121 | move = move_encoding.decode_move(label, (board.turn, not board.turn), board)
122 | else:
123 | move = None
124 | visualisation.render_heatmap(
125 | board,
126 | heatmap,
127 | arrows=[(move.from_square, move.to_square)] if move is not None else None,
128 | normalise="abs",
129 | save_to=f"./scripts/results/{args.target}_heatmap_{n_plotted}.png",
130 | )
131 | raise SystemExit
132 |
133 | all_stats["absolute_piece_relevance"].append(piece_relevance)
134 | all_stats["relative_piece_relevance"].append({k: v / max_config_rel for k, v in piece_relevance.items()})
135 |
136 | total_relevance = rel.abs().sum().item()
137 | clock = board.fullmove_number * 2 - (not board.turn)
138 | proportion = rel.abs().sum(dim=(1, 2)).div(total_relevance).tolist()
139 | all_stats["plane_relevance_proportion"].append({clock: proportion})
140 | all_stats["relative_piece_relevance_proportion"].append(
141 | {clock: [v / max_config_rel for v in piece_relevance.values()]}
142 | )
143 | all_stats["absolute_piece_relevance_proportion"].append({clock: proportion[:12]})
144 |
145 | logger.info(f"Processed {len(all_stats['relative_piece_relevance'])} boards.")
146 |
147 | visualisation.render_boxplot(
148 | all_stats["relative_piece_relevance"],
149 | y_label="Relevance",
150 | title="Relative Relevance",
151 | save_to=f"./scripts/results/{args.target}_piece_relative_relevance.png",
152 | )
153 | visualisation.render_boxplot(
154 | all_stats["absolute_piece_relevance"],
155 | y_label="Relevance",
156 | title="Absolute Relevance",
157 | save_to=f"./scripts/results/{args.target}_piece_absolute_relevance.png",
158 | )
159 | visualisation.render_proportion_through_index(
160 | all_stats["plane_relevance_proportion"],
161 | plane_type="Pieces",
162 | y_label="Proportion of relevance",
163 | y_log=True,
164 | max_index=200,
165 | title="Proportion of relevance per piece",
166 | save_to=f"./scripts/results/{args.target}_plane_config_relevance.png",
167 | )
168 | visualisation.render_proportion_through_index(
169 | all_stats["plane_relevance_proportion"],
170 | plane_type="H0",
171 | y_label="Proportion of relevance",
172 | y_log=True,
173 | max_index=200,
174 | title="Proportion of relevance per plane",
175 | save_to=f"./scripts/results/{args.target}_plane_H0_relevance.png",
176 | )
177 | visualisation.render_proportion_through_index(
178 | all_stats["plane_relevance_proportion"],
179 | plane_type="Hist",
180 | y_label="Proportion of relevance",
181 | y_log=True,
182 | max_index=200,
183 | title="Proportion of relevance per plane",
184 | save_to=f"./scripts/results/{args.target}_plane_hist_relevance.png",
185 | )
186 | visualisation.render_proportion_through_index(
187 | all_stats["relative_piece_relevance_proportion"],
188 | plane_type="Pieces",
189 | y_label="Proportion of relevance",
190 | y_log=False,
191 | max_index=200,
192 | title="Proportion of relevance per piece",
193 | save_to=f"./scripts/results/{args.target}_piece_plane_relative_relevance.png",
194 | )
195 | visualisation.render_proportion_through_index(
196 | all_stats["absolute_piece_relevance_proportion"],
197 | plane_type="Pieces",
198 | y_label="Proportion of relevance",
199 | y_log=False,
200 | max_index=200,
201 | title="Proportion of relevance per piece",
202 | save_to=f"./scripts/results/{args.target}_piece_plane_absolute_relevance.png",
203 | )
204 |
205 |
206 | def parse_args() -> argparse.Namespace:
207 | parser = argparse.ArgumentParser("plane-importance")
208 | parser.add_argument("--model_name", type=str, default="64x6-2018_0627_1913_08_161.onnx")
209 | parser.add_argument("--target", type=str, default="value")
210 | parser.add_argument("--find_interesting", action=argparse.BooleanOptionalAction, default=False)
211 | parser.add_argument("--batch_size", type=int, default=100)
212 | parser.add_argument("--plot_first_n", type=int, default=5)
213 | return parser.parse_args()
214 |
215 |
216 | if __name__ == "__main__":
217 | args = parse_args()
218 | main(args)
219 |
--------------------------------------------------------------------------------
/scripts/results/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/scripts/visualisation.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualisation utils.
3 | """
4 |
5 | import chess
6 | import chess.svg
7 | import matplotlib
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 |
11 | from lczerolens.encodings import board as board_encoding
12 |
13 | COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000)
14 | ALPHA = 1.0
15 | NORM = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False)
16 |
17 |
18 | def render_heatmap(
19 | board,
20 | heatmap,
21 | square=None,
22 | vmin=None,
23 | vmax=None,
24 | arrows=None,
25 | normalise="none",
26 | save_to=None,
27 | ):
28 | """
29 | Render a heatmap on the board.
30 | """
31 | if normalise == "abs":
32 | a_max = heatmap.abs().max()
33 | if a_max != 0:
34 | heatmap = heatmap / a_max
35 | vmin = -1
36 | vmax = 1
37 | if vmin is None:
38 | vmin = heatmap.min()
39 | if vmax is None:
40 | vmax = heatmap.max()
41 | norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
42 |
43 | color_dict = {}
44 | for square_index in range(64):
45 | color = COLOR_MAP(norm(heatmap[square_index]))
46 | color = (*color[:3], ALPHA)
47 | color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
48 | fig = plt.figure(figsize=(1, 6))
49 | ax = plt.gca()
50 | ax.axis("off")
51 | fig.colorbar(
52 | matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP),
53 | ax=ax,
54 | orientation="vertical",
55 | fraction=1.0,
56 | )
57 | if square is not None:
58 | try:
59 | check = chess.parse_square(square)
60 | except ValueError:
61 | check = None
62 | else:
63 | check = None
64 | if arrows is None:
65 | arrows = []
66 |
67 | svg_board = chess.svg.board(
68 | board,
69 | check=check,
70 | fill=color_dict,
71 | size=350,
72 | arrows=arrows,
73 | )
74 | if save_to is not None:
75 | plt.savefig(save_to)
76 | with open(save_to.replace(".png", ".svg"), "w") as f:
77 | f.write(svg_board)
78 | plt.close()
79 | else:
80 | plt.close()
81 | return svg_board, fig
82 |
83 |
84 | def render_boxplot(
85 | data,
86 | filter_null=True,
87 | y_label=None,
88 | title=None,
89 | save_to=None,
90 | ):
91 | labels = data[0].keys()
92 | boxed_data = {label: [] for label in labels}
93 | for d in data:
94 | for label in labels:
95 | v = d.get(label)
96 | if v == 0.0 and filter_null:
97 | continue
98 | boxed_data[label].append(v)
99 | plt.boxplot(boxed_data.values(), notch=True, vert=True, patch_artist=True, labels=labels)
100 | plt.ylabel(y_label)
101 | plt.title(title)
102 | if save_to is not None:
103 | plt.savefig(save_to)
104 | plt.close()
105 | else:
106 | plt.show()
107 |
108 |
109 | def render_proportion_through_index(
110 | data,
111 | plane_type="H0",
112 | max_index=None,
113 | y_log=False,
114 | y_label=None,
115 | title=None,
116 | save_to=None,
117 | ):
118 | if plane_type == "H0":
119 | indexed_data = {
120 | "H0": {},
121 | "Hist": {},
122 | "Meta": {},
123 | }
124 | for d in data:
125 | index, proportion = next(iter(d.items()))
126 | if max_index is not None and index > max_index:
127 | continue
128 | if index not in indexed_data["H0"]:
129 | indexed_data["H0"][index] = [sum(proportion[:13])]
130 | indexed_data["Hist"][index] = [sum(proportion[13:104])]
131 | indexed_data["Meta"][index] = [sum(proportion[104:])]
132 | else:
133 | indexed_data["H0"][index].append(sum(proportion[:13]))
134 | indexed_data["Hist"][index].append(sum(proportion[13:104]))
135 | indexed_data["Meta"][index].append(sum(proportion[104:]))
136 |
137 | elif plane_type == "Hist":
138 | indexed_data = {
139 | "H0": {},
140 | "H1": {},
141 | "H2": {},
142 | "H3": {},
143 | "H4": {},
144 | "H5": {},
145 | "H6": {},
146 | "H7": {},
147 | "Castling": {},
148 | "Remaining": {},
149 | }
150 | for d in data:
151 | index, proportion = next(iter(d.items()))
152 | if max_index is not None and index > max_index:
153 | continue
154 | if index not in indexed_data["H0"]:
155 | for i in range(8):
156 | indexed_data[f"H{i}"][index] = [sum(proportion[13 * i : 13 * (i + 1)])]
157 | indexed_data["Castling"][index] = [sum(proportion[104:108])]
158 | indexed_data["Remaining"][index] = [sum(proportion[108:])]
159 | else:
160 | for i in range(8):
161 | indexed_data[f"H{i}"][index].append(sum(proportion[13 * i : 13 * (i + 1)]))
162 | indexed_data["Castling"][index].append(sum(proportion[104:108]))
163 | indexed_data["Remaining"][index].append(sum(proportion[108:]))
164 |
165 | elif plane_type == "Pieces":
166 | relative_plane_order = board_encoding.get_plane_order((chess.WHITE, chess.BLACK))
167 | indexed_data = {letter: {} for letter in relative_plane_order}
168 | for d in data:
169 | index, proportion = next(iter(d.items()))
170 | if max_index is not None and index > max_index:
171 | continue
172 | if index not in indexed_data[relative_plane_order[0]]:
173 | for i, letter in enumerate(relative_plane_order):
174 | indexed_data[letter][index] = [proportion[i]]
175 | else:
176 | for i, letter in enumerate(relative_plane_order):
177 | indexed_data[letter][index].append(proportion[i])
178 | else:
179 | raise ValueError(f"Invalid plane type: {plane_type}")
180 |
181 | n_curves = len(indexed_data)
182 | for i, (label, curve_data) in enumerate(indexed_data.items()):
183 | indices = sorted(list(curve_data.keys()))
184 | mean_curve = [np.mean(curve_data[idx]) for idx in indices]
185 | std_curve = [np.std(curve_data[idx]) for idx in indices]
186 | c = COLOR_MAP(i / (n_curves - 1))
187 | plt.plot(indices, mean_curve, label=label, c=c)
188 | lower_bound = np.array(mean_curve) - np.array(std_curve)
189 | upper_bound = np.array(mean_curve) + np.array(std_curve)
190 | plt.fill_between(indices, lower_bound, upper_bound, alpha=0.2, color=c)
191 | if y_log:
192 | plt.yscale("log")
193 | plt.legend()
194 | plt.ylabel(y_label)
195 | plt.title(title)
196 | if save_to is not None:
197 | plt.savefig(save_to)
198 | plt.close()
199 | else:
200 | plt.show()
201 |
--------------------------------------------------------------------------------
/spaces/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/src/lczerolens/__init__.py:
--------------------------------------------------------------------------------
1 | """Main module for the lczerolens package."""
2 |
3 | from importlib.metadata import PackageNotFoundError, version
4 |
5 | from .board import LczeroBoard, InputEncoding
6 | from .model import LczeroModel, Flow
7 | from .lens import Lens
8 | from . import lenses, concepts, play
9 |
10 | try:
11 | __version__ = version("lczerolens")
12 | except PackageNotFoundError:
13 | __version__ = "unknown version"
14 |
15 | __all__ = [
16 | "LczeroBoard",
17 | "LczeroModel",
18 | "Flow",
19 | "InputEncoding",
20 | "Lens",
21 | "lenses",
22 | "concepts",
23 | "play",
24 | ]
25 |
--------------------------------------------------------------------------------
/src/lczerolens/backends.py:
--------------------------------------------------------------------------------
1 | """Utils from the lczero executable and bindings.
2 |
3 | Notes
4 | ----
5 | The lczero bindings are not installed by default. You can install them by
6 | running `pip install lczerolens[backends]`.
7 | """
8 |
9 | import subprocess
10 |
11 | import chess
12 | import torch
13 |
14 | from lczerolens.board import LczeroBoard
15 |
16 | try:
17 | from lczero.backends import Backend, GameState
18 | except ImportError as e:
19 | raise ImportError(
20 | "LCZero bindings are required to use the backends, install them with `pip install lczerolens[backends]`."
21 | ) from e
22 |
23 |
24 | def generic_command(args, verbose=False):
25 | """
26 | Run a generic command.
27 | """
28 | popen = subprocess.Popen(
29 | ["lc0", *args],
30 | stdout=subprocess.PIPE,
31 | stderr=subprocess.PIPE,
32 | )
33 | popen.wait()
34 | if popen.returncode != 0:
35 | if verbose:
36 | stderr = f"\n[DEBUG] stderr:\n{popen.stderr.read().decode('utf-8')}"
37 | else:
38 | stderr = ""
39 | raise RuntimeError(f"Could not run `lc0 {' '.join(args)}`." + stderr)
40 | return popen.stdout.read().decode("utf-8")
41 |
42 |
43 | def describenet(path, verbose=False):
44 | """
45 | Describe the net at the given path.
46 | """
47 | return generic_command(["describenet", "-w", path], verbose=verbose)
48 |
49 |
50 | def convert_to_onnx(in_path, out_path, verbose=False):
51 | """
52 | Convert the net at the given path.
53 | """
54 | return generic_command(
55 | ["leela2onnx", f"--input={in_path}", f"--output={out_path}"],
56 | verbose=verbose,
57 | )
58 |
59 |
60 | def convert_to_leela(in_path, out_path, verbose=False):
61 | """
62 | Convert the net at the given path.
63 | """
64 | return generic_command(
65 | ["onnx2leela", f"--input={in_path}", f"--output={out_path}"],
66 | verbose=verbose,
67 | )
68 |
69 |
70 | def board_from_backend(lczero_backend: Backend, lczero_game: GameState, planes: int = 112):
71 | """
72 | Create a board from the lczero backend.
73 | """
74 | lczero_input = lczero_game.as_input(lczero_backend)
75 | lczero_input_tensor = torch.zeros((112, 64), dtype=torch.float)
76 | for plane in range(planes):
77 | mask_str = f"{lczero_input.mask(plane):b}".zfill(64)
78 | lczero_input_tensor[plane] = torch.tensor(
79 | tuple(map(int, reversed(mask_str))), dtype=torch.float
80 | ) * lczero_input.val(plane)
81 | return lczero_input_tensor.view((112, 8, 8))
82 |
83 |
84 | def prediction_from_backend(
85 | lczero_backend: Backend,
86 | lczero_game: GameState,
87 | softmax: bool = False,
88 | only_legal: bool = False,
89 | illegal_value: float = 0,
90 | ):
91 | """
92 | Predicts the move.
93 | """
94 | filtered_policy = torch.full((1858,), illegal_value, dtype=torch.float)
95 | lczero_input = lczero_game.as_input(lczero_backend)
96 | (lczero_output,) = lczero_backend.evaluate(lczero_input)
97 | if only_legal:
98 | indices = torch.tensor(lczero_game.policy_indices())
99 | else:
100 | indices = torch.tensor(range(1858))
101 | if softmax:
102 | policy = torch.tensor(lczero_output.p_softmax(*range(1858)), dtype=torch.float)
103 | else:
104 | policy = torch.tensor(lczero_output.p_raw(*range(1858)), dtype=torch.float)
105 | value = torch.tensor(lczero_output.q())
106 | filtered_policy[indices] = policy[indices]
107 | return filtered_policy, value
108 |
109 |
110 | def moves_with_castling_swap(lczero_game: GameState, board: LczeroBoard):
111 | """
112 | Get the moves with castling swap.
113 | """
114 | lczero_legal_moves = lczero_game.moves()
115 | lczero_policy_indices = list(lczero_game.policy_indices())
116 | for move in board.legal_moves:
117 | uci_move = move.uci()
118 | if uci_move in lczero_legal_moves:
119 | continue
120 | if board.is_castling(move):
121 | leela_uci_move = uci_move.replace("g", "h").replace("c", "a")
122 | if leela_uci_move in lczero_legal_moves:
123 | lczero_legal_moves.remove(leela_uci_move)
124 | lczero_legal_moves.append(uci_move)
125 | lczero_policy_indices.remove(
126 | LczeroBoard.encode_move(
127 | chess.Move.from_uci(leela_uci_move),
128 | board.turn,
129 | )
130 | )
131 | lczero_policy_indices.append(LczeroBoard.encode_move(move, board.turn))
132 | return lczero_legal_moves, lczero_policy_indices
133 |
--------------------------------------------------------------------------------
/src/lczerolens/board.py:
--------------------------------------------------------------------------------
1 | """Board class."""
2 |
3 | import re
4 | from enum import Enum
5 | from typing import Optional, Generator, Tuple, List, Union, Any
6 |
7 | import chess
8 | import chess.svg
9 | import torch
10 | import io
11 | import numpy as np
12 |
13 | from .constants import INVERTED_POLICY_INDEX, POLICY_INDEX
14 |
15 |
16 | class InputEncoding(int, Enum):
17 | """Input encoding for the board tensor."""
18 |
19 | INPUT_CLASSICAL_112_PLANE = 0
20 | INPUT_CLASSICAL_112_PLANE_REPEATED = 1
21 | INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED = 2
22 | INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS = 3
23 |
24 |
25 | class LczeroBoard(chess.Board):
26 | """A class for wrapping the LczeroBoard class."""
27 |
28 | @staticmethod
29 | def get_plane_order(us: bool):
30 | """Get the plane order for the given us view.
31 |
32 | Parameters
33 | ----------
34 | us : bool
35 | The us_them tuple.
36 |
37 | Returns
38 | -------
39 | str
40 | The plane order.
41 | """
42 | plane_orders = {chess.WHITE: "PNBRQK", chess.BLACK: "pnbrqk"}
43 | return plane_orders[us] + plane_orders[not us]
44 |
45 | @staticmethod
46 | def get_piece_index(piece: str, us: bool, plane_order: Optional[str] = None):
47 | """Converts a piece to its index in the plane order.
48 |
49 | Parameters
50 | ----------
51 | piece : str
52 | The piece to convert.
53 | us : bool
54 | The us_them tuple.
55 | plane_order : Optional[str]
56 | The plane order.
57 |
58 | Returns
59 | -------
60 | int
61 | The index of the piece in the plane order.
62 | """
63 | if plane_order is None:
64 | plane_order = LczeroBoard.get_plane_order(us)
65 | return f"{plane_order}0".index(piece)
66 |
67 | def to_config_tensor(
68 | self,
69 | us: Optional[bool] = None,
70 | ):
71 | """Converts a LczeroBoard to a tensor based on the pieces configuration.
72 |
73 | Parameters
74 | ----------
75 | us : Optional[bool]
76 | The us_them tuple.
77 |
78 | Returns
79 | -------
80 | torch.Tensor
81 | The 13x8x8 tensor.
82 | """
83 | if us is None:
84 | us = self.turn
85 | plane_order = LczeroBoard.get_plane_order(us)
86 |
87 | def piece_to_index(piece: str):
88 | return f"{plane_order}0".index(piece)
89 |
90 | fen_board = self.fen().split(" ")[0]
91 | fen_rep = re.sub(r"(\d)", lambda m: "0" * int(m.group(1)), fen_board)
92 | rows = fen_rep.split("/")
93 | rev_rows = rows[::-1]
94 | ordered_fen = "".join(rev_rows)
95 |
96 | config_tensor = torch.zeros((13, 8, 8), dtype=torch.float)
97 | ordinal_board = torch.tensor(tuple(map(piece_to_index, ordered_fen)), dtype=torch.float)
98 | ordinal_board = ordinal_board.reshape((8, 8)).unsqueeze(0)
99 | piece_tensor = torch.tensor(tuple(map(piece_to_index, plane_order)), dtype=torch.float)
100 | piece_tensor = piece_tensor.reshape((12, 1, 1))
101 | config_tensor[:12] = (ordinal_board == piece_tensor).float()
102 | if self.is_repetition(2): # Might be wrong if the full history is not available
103 | config_tensor[12] = torch.ones((8, 8), dtype=torch.float)
104 | return config_tensor if us == chess.WHITE else config_tensor.flip(1)
105 |
106 | def to_input_tensor(
107 | self,
108 | *,
109 | input_encoding: InputEncoding = InputEncoding.INPUT_CLASSICAL_112_PLANE,
110 | ):
111 | """Create the lc0 input tensor from the history of a game.
112 |
113 | Parameters
114 | ----------
115 | input_encoding : InputEncoding
116 | The input encoding method.
117 |
118 | Returns
119 | -------
120 | torch.Tensor
121 | The 112x8x8 tensor.
122 | """
123 |
124 | input_tensor = torch.zeros((112, 8, 8), dtype=torch.float)
125 | us = self.turn
126 | them = not us
127 | moves = []
128 |
129 | if (
130 | input_encoding == InputEncoding.INPUT_CLASSICAL_112_PLANE
131 | or input_encoding == InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED
132 | ):
133 | for i in range(8):
134 | config_tensor = self.to_config_tensor(us)
135 | input_tensor[i * 13 : (i + 1) * 13] = config_tensor
136 | try:
137 | moves.append(self.pop())
138 | except IndexError:
139 | if input_encoding == InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED:
140 | input_tensor[(i + 1) * 13 : 104] = config_tensor.repeat(7 - i, 1, 1)
141 | break
142 |
143 | elif input_encoding == InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED:
144 | config_tensor = self.to_config_tensor(us)
145 | input_tensor[:104] = config_tensor.repeat(8, 1, 1)
146 | elif input_encoding == InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS:
147 | input_tensor[:13] = self.to_config_tensor(us)
148 | else:
149 | raise ValueError(f"Got unexpected input encoding {input_encoding}")
150 |
151 | # Restore the moves
152 | for move in reversed(moves):
153 | self.push(move)
154 |
155 | if self.has_queenside_castling_rights(us):
156 | input_tensor[104] = torch.ones((8, 8), dtype=torch.float)
157 | if self.has_kingside_castling_rights(us):
158 | input_tensor[105] = torch.ones((8, 8), dtype=torch.float)
159 | if self.has_queenside_castling_rights(them):
160 | input_tensor[106] = torch.ones((8, 8), dtype=torch.float)
161 | if self.has_kingside_castling_rights(them):
162 | input_tensor[107] = torch.ones((8, 8), dtype=torch.float)
163 | if us == chess.BLACK:
164 | input_tensor[108] = torch.ones((8, 8), dtype=torch.float)
165 | input_tensor[109] = torch.ones((8, 8), dtype=torch.float) * self.halfmove_clock
166 | input_tensor[111] = torch.ones((8, 8), dtype=torch.float)
167 |
168 | return input_tensor
169 |
170 | @staticmethod
171 | def encode_move(
172 | move: chess.Move,
173 | us: bool,
174 | ) -> int:
175 | """
176 | Converts a chess.Move object to an index.
177 |
178 | Parameters
179 | ----------
180 | move : chess.Move
181 | The chess move to encode.
182 | us : bool
183 | The side to move (True for white, False for black).
184 |
185 | Returns
186 | -------
187 | int
188 | The encoded move index.
189 | """
190 | from_square = move.from_square
191 | to_square = move.to_square
192 |
193 | if us == chess.BLACK:
194 | from_square_row = from_square // 8
195 | from_square_col = from_square % 8
196 | from_square = 8 * (7 - from_square_row) + from_square_col
197 | to_square_row = to_square // 8
198 | to_square_col = to_square % 8
199 | to_square = 8 * (7 - to_square_row) + to_square_col
200 | us_uci_move = chess.SQUARE_NAMES[from_square] + chess.SQUARE_NAMES[to_square]
201 | if move.promotion is not None:
202 | if move.promotion == chess.BISHOP:
203 | us_uci_move += "b"
204 | elif move.promotion == chess.ROOK:
205 | us_uci_move += "r"
206 | elif move.promotion == chess.QUEEN:
207 | us_uci_move += "q"
208 | # Knight promotion is the default
209 | return INVERTED_POLICY_INDEX[us_uci_move]
210 |
211 | def decode_move(
212 | self,
213 | index: int,
214 | ) -> chess.Move:
215 | """
216 | Converts an index to a chess.Move object.
217 |
218 | Parameters
219 | ----------
220 | index : int
221 | The index to convert.
222 |
223 | Returns
224 | -------
225 | chess.Move
226 | The chess move.
227 | """
228 | us = self.turn
229 | us_uci_move = POLICY_INDEX[index]
230 | from_square = chess.SQUARE_NAMES.index(us_uci_move[:2])
231 | to_square = chess.SQUARE_NAMES.index(us_uci_move[2:4])
232 | if us == chess.BLACK:
233 | from_square_row = from_square // 8
234 | from_square_col = from_square % 8
235 | from_square = 8 * (7 - from_square_row) + from_square_col
236 | to_square_row = to_square // 8
237 | to_square_col = to_square % 8
238 | to_square = 8 * (7 - to_square_row) + to_square_col
239 |
240 | uci_move = chess.SQUARE_NAMES[from_square] + chess.SQUARE_NAMES[to_square]
241 | from_piece = self.piece_at(from_square)
242 | if from_piece == chess.PAWN and to_square >= 56: # Knight promotion is the default
243 | uci_move += "n"
244 | return chess.Move.from_uci(uci_move)
245 |
246 | def get_legal_indices(
247 | self,
248 | ) -> torch.Tensor:
249 | """
250 | Gets the legal indices.
251 |
252 | Returns
253 | -------
254 | torch.Tensor
255 | Tensor containing indices of legal moves.
256 | """
257 | us = self.turn
258 | return torch.tensor([self.encode_move(move, us) for move in self.legal_moves])
259 |
260 | def get_next_legal_boards(
261 | self,
262 | n_history: int = 7,
263 | ) -> Generator["LczeroBoard", None, None]:
264 | """
265 | Gets the next legal boards.
266 |
267 | Parameters
268 | ----------
269 | n_history : int, optional
270 | Number of previous positions to keep in the move stack, by default 7.
271 |
272 | Returns
273 | -------
274 | Generator[LczeroBoard, None, None]
275 | Generator yielding board positions after each legal move.
276 | """
277 | working_board = self.copy(stack=n_history)
278 | for move in working_board.legal_moves:
279 | working_board.push(move)
280 | yield working_board.copy(stack=n_history)
281 | working_board.pop()
282 |
283 | def render_heatmap(
284 | self,
285 | heatmap: Union[torch.Tensor, np.ndarray],
286 | square: Optional[str] = None,
287 | vmin: Optional[float] = None,
288 | vmax: Optional[float] = None,
289 | arrows: Optional[List[Tuple[str, str]]] = None,
290 | normalise: str = "none",
291 | save_to: Optional[str] = None,
292 | cmap_name: str = "RdYlBu_r",
293 | alpha: float = 1.0,
294 | relative_board_view: bool = True,
295 | heatmap_mode: str = "relative_flip",
296 | ) -> Tuple[Optional[str], Any]:
297 | """Render a heatmap on the board.
298 |
299 | Parameters
300 | ----------
301 | heatmap : torch.Tensor or numpy.ndarray
302 | The heatmap values to visualize on the board (64,) or (8, 8).
303 | square : Optional[str], default=None
304 | Chess square to highlight (e.g. 'e4').
305 | vmin : Optional[float], default=None
306 | Minimum value for the colormap normalization.
307 | vmax : Optional[float], default=None
308 | Maximum value for the colormap normalization.
309 | arrows : Optional[List[Tuple[str, str]]], default=None
310 | List of arrow tuples (from_square, to_square) to draw on board.
311 | normalise : str, default="none"
312 | Normalization method. Use "abs" for absolute value normalization.
313 | save_to : Optional[str], default=None
314 | Path to save the visualization. If None, returns the figure.
315 | cmap_name : str, default="RdYlBu_r"
316 | Name of matplotlib colormap to use.
317 | alpha : float, default=1.0
318 | Opacity of the heatmap overlay.
319 | relative_board_view : bool, default=True
320 | Whether to use the relative board view.
321 | heatmap_mode : str, default="relative_flip"
322 | Use "relative_flip" if the heatmap corresponds to a relative flip of the board,
323 | "relative_rotation" if it corresponds to a relative rotation of the board,
324 | or "absolute" if it is already in the correct orientation.
325 |
326 | Returns
327 | -------
328 | Union[Tuple[str, matplotlib.figure.Figure], None]
329 | If save_to is None, returns (SVG string, matplotlib figure).
330 | If save_to is provided, saves files and returns None.
331 |
332 | Raises
333 | ------
334 | ValueError
335 | If save_to is provided and does not end with `.svg`.
336 | """
337 | try:
338 | import matplotlib
339 | import matplotlib.pyplot as plt
340 | except ImportError as e:
341 | raise ImportError(
342 | "matplotlib is required to render heatmaps, install it with `pip install lczerolens[viz]`."
343 | ) from e
344 |
345 | if heatmap.ndim > 1:
346 | heatmap = heatmap.view(64)
347 |
348 | if heatmap_mode == "relative_flip":
349 | if not self.turn:
350 | heatmap = heatmap.view(8, 8).flip(0).view(64)
351 | elif heatmap_mode == "relative_rotation":
352 | if not self.turn:
353 | heatmap = heatmap.view(8, 8).flip(1).flip(0).view(64)
354 | elif heatmap_mode == "absolute":
355 | pass
356 | else:
357 | raise ValueError(
358 | f"Got unexpected heatmap_mode {heatmap_mode!r}. "
359 | "Valid options are ['relative_flip', 'relative_rotation', 'absolute']"
360 | )
361 |
362 | cmap = matplotlib.colormaps[cmap_name].resampled(1000)
363 |
364 | if normalise == "abs":
365 | a_max = heatmap.abs().max()
366 | if a_max != 0:
367 | heatmap = heatmap / a_max
368 | vmin = -1
369 | vmax = 1
370 | if vmin is None:
371 | vmin = heatmap.min()
372 | if vmax is None:
373 | vmax = heatmap.max()
374 | norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
375 |
376 | color_dict = {}
377 | for square_index in range(64):
378 | color = cmap(norm(heatmap[square_index]))
379 | color = (*color[:3], alpha)
380 | color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
381 | fig = plt.figure(figsize=(1, 4.1))
382 | ax = plt.gca()
383 | ax.axis("off")
384 | fig.colorbar(
385 | matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap),
386 | ax=ax,
387 | orientation="vertical",
388 | fraction=1.0,
389 | )
390 | if square is not None:
391 | try:
392 | check = chess.parse_square(square)
393 | except ValueError:
394 | check = None
395 | else:
396 | check = None
397 | if arrows is None:
398 | arrows = []
399 |
400 | svg_board = chess.svg.board(
401 | self,
402 | orientation=self.turn if relative_board_view else chess.WHITE,
403 | check=check,
404 | fill=color_dict,
405 | size=400,
406 | arrows=arrows,
407 | )
408 | buffer = io.BytesIO()
409 | fig.savefig(buffer, format="svg")
410 | svg_colorbar = buffer.getvalue().decode("utf-8")
411 | plt.close()
412 |
413 | if save_to is None:
414 | return svg_board, svg_colorbar
415 | elif not save_to.endswith(".svg"):
416 | raise ValueError("only saving to `svg` is supported")
417 |
418 | with open(save_to.replace(".svg", "_board.svg"), "w") as f:
419 | f.write(svg_board)
420 | with open(save_to.replace(".svg", "_colorbar.svg"), "w") as f:
421 | f.write(svg_colorbar)
422 |
--------------------------------------------------------------------------------
/src/lczerolens/concept.py:
--------------------------------------------------------------------------------
1 | """Class for concept-based XAI methods."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Any
5 |
6 | import torch
7 | from sklearn import metrics
8 | from datasets import Features, Value, Sequence, ClassLabel
9 |
10 | from lczerolens.board import LczeroBoard
11 |
12 |
13 | class Concept(ABC):
14 | """
15 | Class for concept-based XAI methods.
16 | """
17 |
18 | @abstractmethod
19 | def compute_label(
20 | self,
21 | board: LczeroBoard,
22 | ) -> Any:
23 | """
24 | Compute the label for a given model and input.
25 | """
26 | pass
27 |
28 | @staticmethod
29 | @abstractmethod
30 | def compute_metrics(
31 | predictions,
32 | labels,
33 | ):
34 | """
35 | Compute the metrics for a given model and input.
36 | """
37 | pass
38 |
39 | @property
40 | @abstractmethod
41 | def features(self) -> Features:
42 | """
43 | Return the features for the concept.
44 | """
45 | pass
46 |
47 |
48 | class BinaryConcept(Concept):
49 | """
50 | Class for binary concept-based XAI methods.
51 | """
52 |
53 | features = Features(
54 | {
55 | "gameid": Value("string"),
56 | "moves": Sequence(Value("string")),
57 | "fen": Value("string"),
58 | "label": ClassLabel(num_classes=2),
59 | }
60 | )
61 |
62 | @staticmethod
63 | def compute_metrics(
64 | predictions,
65 | labels,
66 | ):
67 | """
68 | Compute the metrics for a given model and input.
69 | """
70 | return {
71 | "accuracy": metrics.accuracy_score(labels, predictions),
72 | "precision": metrics.precision_score(labels, predictions),
73 | "recall": metrics.recall_score(labels, predictions),
74 | "f1": metrics.f1_score(labels, predictions),
75 | }
76 |
77 |
78 | class NullConcept(BinaryConcept):
79 | """
80 | Class for binary concept-based XAI methods.
81 | """
82 |
83 | def compute_label(
84 | self,
85 | board: LczeroBoard,
86 | ) -> Any:
87 | """
88 | Compute the label for a given model and input.
89 | """
90 | return 0
91 |
92 |
93 | class OrBinaryConcept(BinaryConcept):
94 | """
95 | Class for binary concept-based XAI methods.
96 | """
97 |
98 | def __init__(self, *concepts: BinaryConcept):
99 | for concept in concepts:
100 | if not isinstance(concept, BinaryConcept):
101 | raise ValueError(f"{concept} is not a BinaryConcept")
102 | self.concepts = concepts
103 |
104 | def compute_label(
105 | self,
106 | board: LczeroBoard,
107 | ) -> Any:
108 | """
109 | Compute the label for a given model and input.
110 | """
111 | return any(concept.compute_label(board) for concept in self.concepts)
112 |
113 |
114 | class AndBinaryConcept(BinaryConcept):
115 | """
116 | Class for binary concept-based XAI methods.
117 | """
118 |
119 | def __init__(self, *concepts: BinaryConcept):
120 | for concept in concepts:
121 | if not isinstance(concept, BinaryConcept):
122 | raise ValueError(f"{concept} is not a BinaryConcept")
123 | self.concepts = concepts
124 |
125 | def compute_label(
126 | self,
127 | board: LczeroBoard,
128 | ) -> Any:
129 | """
130 | Compute the label for a given model and input.
131 | """
132 | return all(concept.compute_label(board) for concept in self.concepts)
133 |
134 |
135 | class MulticlassConcept(Concept):
136 | """
137 | Class for multiclass concept-based XAI methods.
138 | """
139 |
140 | features = Features(
141 | {
142 | "gameid": Value("string"),
143 | "moves": Sequence(Value("string")),
144 | "fen": Value("string"),
145 | "label": Value("int32"),
146 | }
147 | )
148 |
149 | @staticmethod
150 | def compute_metrics(
151 | predictions,
152 | labels,
153 | ):
154 | """
155 | Compute the metrics for a given model and input.
156 | """
157 | return {
158 | "accuracy": metrics.accuracy_score(labels, predictions),
159 | "precision": metrics.precision_score(labels, predictions, average="weighted"),
160 | "recall": metrics.recall_score(labels, predictions, average="weighted"),
161 | "f1": metrics.f1_score(labels, predictions, average="weighted"),
162 | }
163 |
164 |
165 | class ContinuousConcept(Concept):
166 | """
167 | Class for continuous concept-based XAI methods.
168 | """
169 |
170 | features = Features(
171 | {
172 | "gameid": Value("string"),
173 | "moves": Sequence(Value("string")),
174 | "fen": Value("string"),
175 | "label": Value("float32"),
176 | }
177 | )
178 |
179 | @staticmethod
180 | def compute_metrics(
181 | predictions,
182 | labels,
183 | ):
184 | """
185 | Compute the metrics for a given model and input.
186 | """
187 | return {
188 | "rmse": metrics.root_mean_squared_error(labels, predictions),
189 | "mae": metrics.mean_absolute_error(labels, predictions),
190 | "r2": metrics.r2_score(labels, predictions),
191 | }
192 |
193 |
194 | def concept_collate_fn(batch):
195 | boards = []
196 | labels = []
197 | for element in batch:
198 | board = LczeroBoard(element["fen"])
199 | for move in element["moves"]:
200 | board.push_san(move)
201 | boards.append(board)
202 | labels.append(element["label"])
203 | return boards, labels, batch
204 |
205 |
206 | def concept_init_rel(output, infos):
207 | labels = infos[0]
208 | rel = torch.zeros_like(output)
209 | for i in range(rel.shape[0]):
210 | rel[i, labels[i]] = output[i, labels[i]]
211 | return rel
212 |
--------------------------------------------------------------------------------
/src/lczerolens/concepts/__init__.py:
--------------------------------------------------------------------------------
1 | """Concepts module."""
2 |
3 | from .material import HasMaterialAdvantage, HasPiece
4 | from .move import BestLegalMove, PieceBestLegalMove
5 | from .threat import HasMateThreat, HasThreat
6 |
7 | __all__ = [
8 | "HasPiece",
9 | "HasThreat",
10 | "HasMateThreat",
11 | "HasMaterialAdvantage",
12 | "BestLegalMove",
13 | "PieceBestLegalMove",
14 | ]
15 |
--------------------------------------------------------------------------------
/src/lczerolens/concepts/material.py:
--------------------------------------------------------------------------------
1 | """All concepts related to material."""
2 |
3 | from typing import Dict, Optional
4 |
5 | import chess
6 |
7 | from lczerolens.board import LczeroBoard
8 | from lczerolens.concept import BinaryConcept
9 |
10 |
11 | class HasPiece(BinaryConcept):
12 | """Class for material concept-based XAI methods."""
13 |
14 | def __init__(
15 | self,
16 | piece: str,
17 | relative: bool = True,
18 | ):
19 | """Initialize the class."""
20 | self.piece = chess.Piece.from_symbol(piece)
21 | self.relative = relative
22 |
23 | def compute_label(
24 | self,
25 | board: LczeroBoard,
26 | ) -> int:
27 | """Compute the label for a given model and input."""
28 | if self.relative:
29 | color = self.piece.color if board.turn else not self.piece.color
30 | else:
31 | color = self.piece.color
32 | squares = board.pieces(self.piece.piece_type, color)
33 | return 1 if len(squares) > 0 else 0
34 |
35 |
36 | class HasMaterialAdvantage(BinaryConcept):
37 | """Class for material concept-based XAI methods.
38 |
39 | Attributes
40 | ----------
41 | piece_values : Dict[int, int]
42 | The piece values.
43 | """
44 |
45 | piece_values = {
46 | chess.PAWN: 1,
47 | chess.KNIGHT: 3,
48 | chess.BISHOP: 3,
49 | chess.ROOK: 5,
50 | chess.QUEEN: 9,
51 | chess.KING: 0,
52 | }
53 |
54 | def __init__(
55 | self,
56 | relative: bool = True,
57 | ):
58 | """
59 | Initialize the class.
60 | """
61 | self.relative = relative
62 |
63 | def compute_label(
64 | self,
65 | board: LczeroBoard,
66 | piece_values: Optional[Dict[int, int]] = None,
67 | ) -> int:
68 | """
69 | Compute the label for a given model and input.
70 | """
71 | if piece_values is None:
72 | piece_values = self.piece_values
73 | if self.relative:
74 | us, them = board.turn, not board.turn
75 | else:
76 | us, them = chess.WHITE, chess.BLACK
77 | our_value = 0
78 | their_value = 0
79 | for piece in range(1, 7):
80 | our_value += len(board.pieces(piece, us)) * piece_values[piece]
81 | their_value += len(board.pieces(piece, them)) * piece_values[piece]
82 | return 1 if our_value > their_value else 0
83 |
--------------------------------------------------------------------------------
/src/lczerolens/concepts/move.py:
--------------------------------------------------------------------------------
1 | """All concepts related to move."""
2 |
3 | import chess
4 | import torch
5 |
6 | from lczerolens.board import LczeroBoard
7 | from lczerolens.model import LczeroModel, PolicyFlow
8 | from lczerolens.concept import BinaryConcept, MulticlassConcept
9 |
10 |
11 | class BestLegalMove(MulticlassConcept):
12 | """Class for move concept-based XAI methods."""
13 |
14 | def __init__(
15 | self,
16 | model: LczeroModel,
17 | ):
18 | """Initialize the class."""
19 | self.policy_flow = PolicyFlow(model)
20 |
21 | def compute_label(
22 | self,
23 | board: LczeroBoard,
24 | ) -> int:
25 | """Compute the label for a given model and input."""
26 | (policy,) = self.policy_flow(board)
27 | policy = torch.softmax(policy.squeeze(0), dim=-1)
28 |
29 | legal_move_indices = [LczeroBoard.encode_move(move, board.turn) for move in board.legal_moves]
30 | sub_index = policy[legal_move_indices].argmax().item()
31 | return legal_move_indices[sub_index]
32 |
33 |
34 | class PieceBestLegalMove(BinaryConcept):
35 | """Class for move concept-based XAI methods."""
36 |
37 | def __init__(
38 | self,
39 | model: LczeroModel,
40 | piece: str,
41 | ):
42 | """Initialize the class."""
43 | self.policy_flow = PolicyFlow(model)
44 | self.piece = chess.Piece.from_symbol(piece)
45 |
46 | def compute_label(
47 | self,
48 | board: LczeroBoard,
49 | ) -> int:
50 | """Compute the label for a given model and input."""
51 | (policy,) = self.policy_flow(board)
52 | policy = torch.softmax(policy.squeeze(0), dim=-1)
53 |
54 | legal_moves = list(board.legal_moves)
55 | legal_move_indices = [LczeroBoard.encode_move(move, board.turn) for move in legal_moves]
56 | sub_index = policy[legal_move_indices].argmax().item()
57 | best_legal_move = legal_moves[sub_index]
58 | if board.piece_at(best_legal_move.from_square) == self.piece:
59 | return 1
60 | return 0
61 |
--------------------------------------------------------------------------------
/src/lczerolens/concepts/threat.py:
--------------------------------------------------------------------------------
1 | """All concepts related to threats."""
2 |
3 | import chess
4 |
5 | from lczerolens.board import LczeroBoard
6 | from lczerolens.concept import BinaryConcept
7 |
8 |
9 | class HasThreat(BinaryConcept):
10 | """
11 | Class for material concept-based XAI methods.
12 | """
13 |
14 | def __init__(
15 | self,
16 | piece: str,
17 | relative: bool = True,
18 | ):
19 | """
20 | Initialize the class.
21 | """
22 | self.piece = chess.Piece.from_symbol(piece)
23 | self.relative = relative
24 |
25 | def compute_label(
26 | self,
27 | board: LczeroBoard,
28 | ) -> int:
29 | """
30 | Compute the label for a given model and input.
31 | """
32 | if self.relative:
33 | color = self.piece.color if board.turn else not self.piece.color
34 | else:
35 | color = self.piece.color
36 | squares = board.pieces(self.piece.piece_type, color)
37 | for square in squares:
38 | if board.is_attacked_by(not color, square):
39 | return 1
40 | return 0
41 |
42 |
43 | class HasMateThreat(BinaryConcept):
44 | """
45 | Class for material concept-based XAI methods.
46 | """
47 |
48 | def compute_label(
49 | self,
50 | board: LczeroBoard,
51 | ) -> int:
52 | """
53 | Compute the label for a given model and input.
54 | """
55 | for move in board.legal_moves:
56 | board.push(move)
57 | if board.is_checkmate():
58 | board.pop()
59 | return 1
60 | board.pop()
61 | return 0
62 |
--------------------------------------------------------------------------------
/src/lczerolens/lens.py:
--------------------------------------------------------------------------------
1 | """Generic lens class."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Dict, Iterable, Generator, Callable, Type, Union, Optional, Any
5 |
6 | import torch
7 | import re
8 |
9 | from lczerolens.model import LczeroModel
10 | from lczerolens.board import LczeroBoard
11 |
12 |
13 | class Lens(ABC):
14 | """Generic lens class for analysing model activations."""
15 |
16 | _lens_type: str
17 | _registry: Dict[str, Type["Lens"]] = {}
18 | _grad_enabled: bool = False
19 |
20 | @classmethod
21 | def register(cls, name: str) -> Callable:
22 | """Registers the lens.
23 |
24 | Parameters
25 | ----------
26 | name : str
27 | The name of the lens.
28 |
29 | Returns
30 | -------
31 | Callable
32 | The decorator to register the lens.
33 |
34 | Raises
35 | ------
36 | ValueError
37 | If the lens name is already registered.
38 | """
39 |
40 | if name in cls._registry:
41 | raise ValueError(f"Lens {name} already registered.")
42 |
43 | def decorator(subclass: Type["Lens"]):
44 | subclass._lens_type = name
45 | cls._registry[name] = subclass
46 | return subclass
47 |
48 | return decorator
49 |
50 | @classmethod
51 | def from_name(cls, name: str, *args, **kwargs) -> "Lens":
52 | """Returns the lens from its name.
53 |
54 | Parameters
55 | ----------
56 | name : str
57 | The name of the lens.
58 |
59 | Returns
60 | -------
61 | Lens
62 | The lens instance.
63 |
64 | Raises
65 | ------
66 | KeyError
67 | If the lens name is not found.
68 | """
69 | if name not in cls._registry:
70 | raise KeyError(f"Lens {name} not found.")
71 | return cls._registry[name](*args, **kwargs)
72 |
73 | def __init__(self, pattern: Optional[str] = None):
74 | """Initialise the lens.
75 |
76 | Parameters
77 | ----------
78 | pattern : Optional[str], default=None
79 | The pattern to match the modules.
80 | """
81 | if pattern is None:
82 | pattern = r"a^" # match nothing by default
83 | self._pattern = pattern
84 | self._reg_exp = re.compile(pattern)
85 |
86 | @property
87 | def pattern(self) -> str:
88 | """The pattern to match the modules."""
89 | return self._pattern
90 |
91 | @pattern.setter
92 | def pattern(self, pattern: str):
93 | self._pattern = pattern
94 | self._reg_exp = re.compile(pattern)
95 |
96 | def _get_modules(self, model: LczeroModel) -> Generator[tuple[str, Any], None, None]:
97 | """Get the modules to intervene on."""
98 | for name, module in model.named_modules():
99 | fixed_name = name.lstrip(". ") # nnsight outputs names with a dot
100 | if self._reg_exp.match(fixed_name):
101 | yield fixed_name, module
102 |
103 | def is_compatible(self, model: LczeroModel) -> bool:
104 | """Returns whether the lens is compatible with the model.
105 |
106 | Parameters
107 | ----------
108 | model : LczeroModel
109 | The NNsight model.
110 |
111 | Returns
112 | -------
113 | bool
114 | Whether the lens is compatible with the model.
115 | """
116 | return isinstance(model, LczeroModel)
117 |
118 | def _ensure_compatible(self, model: LczeroModel):
119 | """Ensure the lens is compatible with the model.
120 |
121 | Parameters
122 | ----------
123 | model : LczeroModel
124 | The NNsight model.
125 |
126 | Raises
127 | ------
128 | ValueError
129 | If the lens is not compatible with the model.
130 | """
131 | if not self.is_compatible(model):
132 | raise ValueError(f"Lens {self._lens_type} is not compatible with model of type {type(model)}.")
133 |
134 | def prepare(self, model: LczeroModel, **kwargs) -> LczeroModel:
135 | """Prepare the model for the lens.
136 |
137 | Parameters
138 | ----------
139 | model : LczeroModel
140 | The NNsight model.
141 |
142 | Returns
143 | -------
144 | LczeroModel
145 | The prepared model.
146 | """
147 | return model
148 |
149 | @abstractmethod
150 | def _intervene(self, model: LczeroModel, **kwargs) -> dict:
151 | """Intervene on the model.
152 |
153 | Parameters
154 | ----------
155 | model : LczeroModel
156 | The NNsight model.
157 |
158 | Returns
159 | -------
160 | dict
161 | The intervention results.
162 | """
163 | pass
164 |
165 | def _trace(
166 | self,
167 | model: LczeroModel,
168 | *inputs: Union[LczeroBoard, torch.Tensor],
169 | model_kwargs: dict,
170 | intervention_kwargs: dict,
171 | ):
172 | """Trace the model and intervene on it.
173 |
174 | Parameters
175 | ----------
176 | model : LczeroModel
177 | The NNsight model.
178 | inputs : Union[LczeroBoard, torch.Tensor]
179 | The inputs.
180 | model_kwargs : dict
181 | The model kwargs.
182 | intervention_kwargs : dict
183 | The intervention kwargs.
184 |
185 | Returns
186 | -------
187 | dict
188 | The intervention results.
189 | """
190 | with model.trace(*inputs, **model_kwargs):
191 | return self._intervene(model, **intervention_kwargs)
192 |
193 | def analyse(
194 | self,
195 | model: LczeroModel,
196 | *inputs: Union[LczeroBoard, torch.Tensor],
197 | **kwargs,
198 | ) -> dict:
199 | """Analyse the input.
200 |
201 | Parameters
202 | ----------
203 | model : LczeroModel
204 | The NNsight model.
205 | inputs : Union[LczeroBoard, torch.Tensor]
206 | The inputs.
207 |
208 | Returns
209 | -------
210 | dict
211 | The analysis results.
212 |
213 | Raises
214 | ------
215 | ValueError
216 | If the lens is not compatible with the model.
217 | """
218 | if not isinstance(model, LczeroModel):
219 | raise ValueError(f"Model is not a LczeroModel. Got {type(model)}.")
220 | self._ensure_compatible(model)
221 | model_kwargs = kwargs.get("model_kwargs", {})
222 | prepared_model = self.prepare(model, **kwargs)
223 | with torch.set_grad_enabled(self._grad_enabled):
224 | return self._trace(prepared_model, *inputs, model_kwargs=model_kwargs, intervention_kwargs=kwargs)
225 |
226 | def analyse_batched(
227 | self,
228 | model: LczeroModel,
229 | iter_inputs: Iterable[Union[LczeroBoard, torch.Tensor]],
230 | **kwargs,
231 | ) -> Generator[dict, None, None]:
232 | """Analyse a batches of inputs.
233 |
234 | Parameters
235 | ----------
236 | model : LczeroModel
237 | The NNsight model.
238 | iter_inputs : Iterable[Tuple[Union[LczeroBoard, torch.Tensor], dict]]
239 | The iterator over the inputs.
240 |
241 | Returns
242 | -------
243 | Generator[dict, None, None]
244 | The iterator over the statistics.
245 |
246 | Raises
247 | ------
248 | ValueError
249 | If the lens is not compatible with the model.
250 | """
251 | self._ensure_compatible(model)
252 | model_kwargs = kwargs.get("model_kwargs", {})
253 | prepared_model = self.prepare(model, **kwargs)
254 | for inputs, dynamic_intervention_kwargs in iter_inputs:
255 | kwargs.update(dynamic_intervention_kwargs)
256 | yield self._trace(prepared_model, *inputs, model_kwargs=model_kwargs, intervention_kwargs=kwargs)
257 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Lenses module.
3 | """
4 |
5 | from .activation import ActivationLens
6 | from .composite import CompositeLens
7 | from .gradient import GradientLens
8 | from .lrp import LrpLens
9 | from .patching import PatchingLens
10 | from .probing import ProbingLens
11 |
12 | __all__ = [
13 | "ActivationLens",
14 | "CompositeLens",
15 | "GradientLens",
16 | "LrpLens",
17 | "PatchingLens",
18 | "ProbingLens",
19 | ]
20 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/activation.py:
--------------------------------------------------------------------------------
1 | """Activation lens."""
2 |
3 | from lczerolens.model import LczeroModel
4 | from lczerolens.lens import Lens
5 |
6 |
7 | @Lens.register("activation")
8 | class ActivationLens(Lens):
9 | """
10 | Class for activation-based XAI methods.
11 |
12 | Examples
13 | --------
14 |
15 | .. code-block:: python
16 |
17 | model = LczeroModel.from_path(model_path)
18 | lens = ActivationLens()
19 | board = LczeroBoard()
20 | results = lens.analyse(board, model=model)
21 | """
22 |
23 | def _intervene(
24 | self,
25 | model: LczeroModel,
26 | **kwargs,
27 | ) -> dict:
28 | save_inputs = kwargs.get("save_inputs", False)
29 | results = {}
30 | for name, module in self._get_modules(model):
31 | if save_inputs:
32 | results[f"{name}_input"] = module.input.save()
33 | results[f"{name}_output"] = module.output.save()
34 | return results
35 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/composite.py:
--------------------------------------------------------------------------------
1 | """Composite lens for XAI."""
2 |
3 | from typing import List, Dict, Union, Any
4 |
5 | from lczerolens.lens import Lens
6 | from lczerolens.model import LczeroModel
7 |
8 |
9 | @Lens.register("composite")
10 | class CompositeLens(Lens):
11 | """Composite lens for XAI.
12 |
13 | Examples
14 | --------
15 |
16 | .. code-block:: python
17 |
18 | model = LczeroModel.from_path(model_path)
19 | lens = CompositeLens([ActivationLens(), GradientLens()])
20 | board = LczeroBoard()
21 | results = lens.analyse(board, model=model)
22 | """
23 |
24 | def __init__(self, lenses: Union[List[Lens], Dict[str, Lens]], merge_results: bool = True):
25 | self._lens_map = lenses if isinstance(lenses, dict) else {f"lens_{i}": lens for i, lens in enumerate(lenses)}
26 | self.merge_results = merge_results
27 |
28 | def is_compatible(self, model: LczeroModel) -> bool:
29 | return all(lens.is_compatible(model) for lens in self._lens_map.values())
30 |
31 | def prepare(self, model: LczeroModel, **kwargs) -> LczeroModel:
32 | for lens in self._lens_map.values():
33 | model = lens.prepare(model, **kwargs)
34 | return model
35 |
36 | def _intervene(self, model: LczeroModel, **kwargs) -> Dict[str, Any]:
37 | results = {name: lens._intervene(model, **kwargs) for name, lens in self._lens_map.items()}
38 | if self.merge_results:
39 | return {k: v for d in results.values() for k, v in d.items()}
40 | return results
41 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/gradient.py:
--------------------------------------------------------------------------------
1 | """Compute Gradient heatmap for a given model and input."""
2 |
3 | from lczerolens.model import LczeroModel
4 | from lczerolens.lens import Lens
5 |
6 |
7 | @Lens.register("gradient")
8 | class GradientLens(Lens):
9 | """Class for gradient-based XAI methods."""
10 |
11 | _grad_enabled: bool = True
12 |
13 | def __init__(self, *, input_requires_grad: bool = True, **kwargs):
14 | self.input_requires_grad = input_requires_grad
15 | super().__init__(**kwargs)
16 |
17 | def _intervene(self, model: LczeroModel, **kwargs) -> dict:
18 | init_target = kwargs.get("init_target", lambda model: model.output["value"])
19 | init_gradient = kwargs.get("init_gradient", lambda model: None)
20 |
21 | results = {}
22 | if self.input_requires_grad:
23 | model.input.requires_grad_(self.input_requires_grad)
24 | results["input_grad"] = model.input.grad.save()
25 | for name, module in self._get_modules(model):
26 | results[f"{name}_output_grad"] = module.output.grad.save()
27 | target = init_target(model)
28 | target.backward(gradient=init_gradient(model))
29 | return results
30 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/lrp/__init__.py:
--------------------------------------------------------------------------------
1 | from .lens import LrpLens
2 |
3 | __all__ = ["LrpLens"]
4 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/lrp/lens.py:
--------------------------------------------------------------------------------
1 | """Compute LRP heatmap for a given model and input."""
2 |
3 | from typing import Any
4 |
5 |
6 | from lczerolens.model import LczeroModel
7 | from lczerolens.lens import Lens
8 |
9 |
10 | @Lens.register("lrp")
11 | class LrpLens(Lens):
12 | """Class for wrapping the LCZero models."""
13 |
14 | def _intervene(
15 | self,
16 | model: LczeroModel,
17 | **kwargs,
18 | ) -> Any:
19 | # TODO: Refactor this logic
20 | pass
21 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/lrp/rules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/src/lczerolens/lenses/lrp/rules/__init__.py
--------------------------------------------------------------------------------
/src/lczerolens/lenses/lrp/rules/epsilon.py:
--------------------------------------------------------------------------------
1 | """Function classes to apply the LRP rules to the layers of the network.
2 |
3 | Classes
4 | -------
5 | ElementwiseMultiplyUniform
6 | Distribute the relevance 100% to the input
7 | SoftmaxEpsilon
8 | Softmax with epsilon.
9 | MatrixMultiplicationEpsilon
10 | Matrix multiplication with epsilon.
11 | """
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.autograd import Function
16 |
17 |
18 | def stabilize(tensor, epsilon=1e-6):
19 | return tensor + epsilon * ((-1) ** (tensor < 0))
20 |
21 |
22 | class AddEpsilonFunction(Function):
23 | @staticmethod
24 | def forward(ctx, input_a, input_b, epsilon=1e-6):
25 | output = input_a + input_b
26 | ctx.save_for_backward(input_a, input_b, output, torch.tensor(epsilon))
27 | return output
28 |
29 | @staticmethod
30 | def backward(ctx, *grad_output):
31 | input_a, input_b, output, epsilon = ctx.saved_tensors
32 | out_relevance = grad_output[0] / stabilize(output, epsilon)
33 | return out_relevance * input_a, out_relevance * input_b, None
34 |
35 |
36 | class AddEpsilon(torch.nn.Module):
37 | def __init__(self, epsilon=1e-6):
38 | super().__init__()
39 | self.epsilon = epsilon
40 |
41 | def forward(self, x, y):
42 | return AddEpsilonFunction.apply(x, y, self.epsilon)
43 |
44 |
45 | class MatMulEpsilonFunction(Function):
46 | @staticmethod
47 | def forward(ctx, input, param, epsilon=1e-6):
48 | output = torch.matmul(input, param)
49 | ctx.save_for_backward(input, param, output, torch.tensor(epsilon))
50 |
51 | return output
52 |
53 | @staticmethod
54 | def backward(ctx, *grad_outputs):
55 | input, param, output, epsilon = ctx.saved_tensors
56 | out_relevance = grad_outputs[0]
57 |
58 | out_relevance = out_relevance / stabilize(output, epsilon)
59 | relevance = (out_relevance @ param.T) * input
60 | return relevance, None, None
61 |
62 |
63 | class MatMulEpsilon(torch.nn.Module):
64 | def __init__(self, epsilon=1e-6):
65 | super().__init__()
66 | self.epsilon = epsilon
67 |
68 | def forward(self, x, y):
69 | return MatMulEpsilonFunction.apply(x, y, self.epsilon)
70 |
71 |
72 | class BilinearMatMulEpsilonFunction(Function):
73 | @staticmethod
74 | def forward(ctx, input_a, input_b, epsilon=1e-6):
75 | outputs = torch.matmul(input_a, input_b)
76 | ctx.save_for_backward(input_a, input_b, outputs, torch.tensor(epsilon))
77 |
78 | return outputs
79 |
80 | @staticmethod
81 | def backward(ctx, *grad_outputs):
82 | input_a, input_b, outputs, epsilon = ctx.saved_tensors
83 | out_relevance = grad_outputs[0]
84 |
85 | out_relevance = out_relevance / stabilize(2 * outputs, epsilon)
86 |
87 | relevance_a = torch.matmul(out_relevance, input_b.permute(0, 1, -1, -2)) * input_a
88 | relevance_b = torch.matmul(input_a.permute(0, 1, -1, -2), out_relevance) * input_b
89 |
90 | return relevance_a, relevance_b, None
91 |
92 |
93 | class BilinearMatMulEpsilon(torch.nn.Module):
94 | def __init__(self, epsilon=1e-6):
95 | super().__init__()
96 | self.epsilon = epsilon
97 |
98 | def forward(self, x, y):
99 | return BilinearMatMulEpsilonFunction.apply(x, y, self.epsilon)
100 |
101 |
102 | class MulUniformFunction(Function):
103 | @staticmethod
104 | def forward(ctx, input_a, input_b):
105 | return input_a * input_b
106 |
107 | @staticmethod
108 | def backward(ctx, *grad_outputs):
109 | relevance = grad_outputs[0] * 0.5
110 |
111 | return relevance, relevance
112 |
113 |
114 | class MulUniform(torch.nn.Module):
115 | def forward(self, x, y):
116 | return MulUniformFunction.apply(x, y)
117 |
118 |
119 | class SoftmaxEpsilonFunction(Function):
120 | @staticmethod
121 | def forward(ctx, inputs, dim):
122 | outputs = F.softmax(inputs, dim=dim)
123 | ctx.save_for_backward(inputs, outputs)
124 |
125 | return outputs
126 |
127 | @staticmethod
128 | def backward(ctx, *grad_outputs):
129 | inputs, output = ctx.saved_tensors
130 |
131 | relevance = (grad_outputs[0] - (output * grad_outputs[0].sum(-1, keepdim=True))) * inputs
132 |
133 | return (relevance, None)
134 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/patching.py:
--------------------------------------------------------------------------------
1 | """Patching lens."""
2 |
3 | from typing import Callable
4 |
5 | from lczerolens.model import LczeroModel
6 | from lczerolens.lens import Lens
7 |
8 |
9 | @Lens.register("patching")
10 | class PatchingLens(Lens):
11 | """
12 | Class for activation-based XAI methods.
13 |
14 | Examples
15 | --------
16 |
17 | .. code-block:: python
18 |
19 | model = LczeroModel.from_path(model_path)
20 | lens = PatchingLens()
21 | board = LczeroBoard()
22 | patch_fn = lambda n, m, *kwargs: pass
23 | results = lens.analyse(board, model=model)
24 | """
25 |
26 | def __init__(self, patch_fn: Callable, **kwargs):
27 | self._patch_fn = patch_fn
28 | super().__init__(**kwargs)
29 |
30 | def _intervene(
31 | self,
32 | model: LczeroModel,
33 | **kwargs,
34 | ) -> dict:
35 | for name, module in self._get_modules(model):
36 | self._patch_fn(name, module, **kwargs)
37 | return {}
38 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/probing/__init__.py:
--------------------------------------------------------------------------------
1 | from .lens import ProbingLens
2 |
3 | __all__ = ["ProbingLens"]
4 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/probing/lens.py:
--------------------------------------------------------------------------------
1 | """Probing lens."""
2 |
3 | from typing import Callable
4 |
5 | from lczerolens.model import LczeroModel
6 | from lczerolens.lens import Lens
7 |
8 |
9 | @Lens.register("probing")
10 | class ProbingLens(Lens):
11 | """
12 | Class for probing-based XAI methods.
13 |
14 | Examples
15 | --------
16 |
17 | .. code-block:: python
18 |
19 | model = LczeroModel.from_path(model_path)
20 | lens = ProbingLens(probe)
21 | board = LczeroBoard()
22 | results = lens.analyse(board, model=model)
23 | """
24 |
25 | def __init__(self, probe_fn: Callable, **kwargs):
26 | self._probe_fn = probe_fn
27 | super().__init__(**kwargs)
28 |
29 | def _intervene(
30 | self,
31 | model: LczeroModel,
32 | **kwargs,
33 | ) -> dict:
34 | return {name: self._probe_fn(module.output.save()) for name, module in self._get_modules(model)}
35 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/probing/probe.py:
--------------------------------------------------------------------------------
1 | """Probing lens for XAI."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Any
5 |
6 | import einops
7 | import torch
8 |
9 |
10 | EPS = 1e-6
11 |
12 |
13 | class Probe(ABC):
14 | """Abstract class for probes."""
15 |
16 | def __init__(self):
17 | self._trained = False
18 |
19 | @abstractmethod
20 | def train(
21 | self,
22 | activations: torch.Tensor,
23 | labels: Any,
24 | **kwargs,
25 | ):
26 | """Train the probe."""
27 | pass
28 |
29 | @abstractmethod
30 | def predict(self, activations: torch.Tensor, **kwargs):
31 | """Predict with the probe."""
32 | pass
33 |
34 |
35 | class SignalCav(Probe):
36 | """Signal CAV probe."""
37 |
38 | def train(
39 | self,
40 | activations: torch.Tensor,
41 | labels: torch.Tensor,
42 | **kwargs,
43 | ):
44 | if len(activations) != len(labels):
45 | raise ValueError("Number of activations and labels must match")
46 | if len(activations.shape) != 2:
47 | raise ValueError("Activations must a batch of tensors")
48 | if len(labels.shape) != 2:
49 | raise ValueError("Labels must a batch of tensors")
50 |
51 | mean_activation = activations.mean(dim=1, keepdim=True)
52 | mean_label = labels.mean(dim=1, keepdim=True)
53 | scaled_activations = activations - mean_activation
54 | scaled_labels = labels - mean_label
55 | cav = einops.einsum(scaled_activations, scaled_labels, "b a, b d -> a d")
56 | self._h = cav / (cav.norm(dim=0, keepdim=True) + EPS)
57 |
58 | def predict(self, activations: torch.Tensor, **kwargs):
59 | if not self._trained:
60 | raise ValueError("Probe not trained")
61 |
62 | if len(activations.shape) != 2:
63 | raise ValueError("Activations must a batch of tensors")
64 |
65 | dot_prod = einops.einsum(activations, self._h, "b a, a d -> b d")
66 | return dot_prod / (activations.norm(dim=1, keepdim=True) + EPS)
67 |
--------------------------------------------------------------------------------
/src/lczerolens/lenses/sae/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/src/lczerolens/lenses/sae/__init__.py
--------------------------------------------------------------------------------
/src/lczerolens/lenses/sae/buffer.py:
--------------------------------------------------------------------------------
1 | """Activation lens for XAI."""
2 |
3 | from typing import Any, Optional, Callable
4 | from dataclasses import dataclass
5 |
6 | import torch
7 | from datasets import Dataset
8 | from torch.utils.data import DataLoader, TensorDataset
9 |
10 | from lczerolens.model import LczeroModel
11 |
12 |
13 | @dataclass
14 | class ActivationBuffer:
15 | model: LczeroModel
16 | dataset: Dataset
17 | compute_fn: Callable[[Any, LczeroModel], torch.Tensor]
18 | n_batches_in_buffer: int = 10
19 | compute_batch_size: int = 64
20 | train_batch_size: int = 2048
21 | dataloader_kwargs: Optional[dict] = None
22 | logger: Optional[Callable] = None
23 |
24 | def __post_init__(self):
25 | if self.dataloader_kwargs is None:
26 | self.dataloader_kwargs = {}
27 | self._buffer = []
28 | self._remainder = None
29 | self._make_dataloader_it()
30 |
31 | def _make_dataloader_it(self):
32 | self._dataloader_it = iter(
33 | DataLoader(self.dataset, batch_size=self.compute_batch_size, **self.dataloader_kwargs)
34 | )
35 |
36 | @torch.no_grad()
37 | def _fill_buffer(self):
38 | if self.logger is not None:
39 | self.logger.info("Computing activations...")
40 | self._buffer = []
41 | while len(self._buffer) < self.n_batches_in_buffer:
42 | try:
43 | next_batch = next(self._dataloader_it)
44 | except StopIteration:
45 | break
46 | activations = self.compute_fn(next_batch, self.model)
47 | self._buffer.append(activations.to("cpu"))
48 | if not self._buffer:
49 | raise StopIteration
50 |
51 | def _make_activations_it(self):
52 | if self._remainder is not None:
53 | self._buffer.append(self._remainder)
54 | self._remainder = None
55 | activations_ds = TensorDataset(torch.cat(self._buffer, dim=0))
56 | if self.logger is not None:
57 | self.logger.info(f"Activations dataset of size {len(activations_ds)}")
58 |
59 | self._activations_it = iter(
60 | DataLoader(
61 | activations_ds,
62 | batch_size=self.train_batch_size,
63 | shuffle=True,
64 | )
65 | )
66 |
67 | def __iter__(self):
68 | self._make_dataloader_it()
69 | self._fill_buffer()
70 | self._make_activations_it()
71 | self._remainder = None
72 | return self
73 |
74 | def __next__(self):
75 | try:
76 | activations = next(self._activations_it)[0]
77 | if activations.shape[0] < self.train_batch_size:
78 | self._remainder = activations
79 | self._fill_buffer()
80 | self._make_activations_it()
81 | activations = next(self._activations_it)[0]
82 | return activations
83 | except StopIteration:
84 | try:
85 | self._fill_buffer()
86 | self._make_activations_it()
87 | self.__next__()
88 | except StopIteration as e:
89 | if self._remainder is not None:
90 | activations = self._remainder
91 | self._remainder = None
92 | return activations
93 | raise StopIteration from e
94 | raise StopIteration
95 |
--------------------------------------------------------------------------------
/src/lczerolens/model.py:
--------------------------------------------------------------------------------
1 | """Class for wrapping the LCZero models."""
2 |
3 | import os
4 | from typing import Dict, Type, Any, Tuple, Union
5 |
6 | import torch
7 | from onnx2torch import convert
8 | from onnx2torch.utils.safe_shape_inference import safe_shape_inference
9 | from tensordict import TensorDict
10 | from torch import nn
11 | from nnsight import NNsight
12 | from contextlib import contextmanager
13 |
14 | from lczerolens.board import InputEncoding, LczeroBoard
15 |
16 |
17 | class LczeroModel(NNsight):
18 | """Class for wrapping the LCZero models."""
19 |
20 | def trace(
21 | self,
22 | *inputs: Any,
23 | **kwargs: Dict[str, Any],
24 | ):
25 | kwargs["scan"] = False
26 | kwargs["validate"] = False
27 | return super().trace(*inputs, **kwargs)
28 |
29 | def _execute(self, *prepared_inputs: torch.Tensor, **kwargs) -> Any:
30 | kwargs.pop("input_encoding", None)
31 | kwargs.pop("input_requires_grad", None)
32 | with self._ensure_proper_forward():
33 | return super()._execute(*prepared_inputs, **kwargs)
34 |
35 | def _prepare_inputs(self, *inputs: Union[LczeroBoard, torch.Tensor], **kwargs) -> Tuple[Tuple[Any], int]:
36 | input_encoding = kwargs.pop("input_encoding", InputEncoding.INPUT_CLASSICAL_112_PLANE)
37 | input_requires_grad = kwargs.pop("input_requires_grad", False)
38 |
39 | if len(inputs) == 1 and isinstance(inputs[0], torch.Tensor):
40 | return inputs, len(inputs[0])
41 | for board in inputs:
42 | if not isinstance(board, LczeroBoard):
43 | raise ValueError(f"Got invalid input type {type(board)}.")
44 |
45 | tensor_list = [board.to_input_tensor(input_encoding=input_encoding).unsqueeze(0) for board in inputs]
46 | batched_tensor = torch.cat(tensor_list, dim=0)
47 | if input_requires_grad:
48 | batched_tensor.requires_grad = True
49 | batched_tensor = batched_tensor.to(self.device)
50 |
51 | return (batched_tensor,), len(inputs)
52 |
53 | def __call__(self, *inputs, **kwargs):
54 | prepared_inputs, _ = self._prepare_inputs(*inputs, **kwargs)
55 | return self._execute(*prepared_inputs, **kwargs)
56 |
57 | def __getattr__(self, key):
58 | if self._envoy._tracer is None:
59 | return getattr(self._model, key)
60 | return super().__getattr__(key)
61 |
62 | def __setattr__(self, key, value):
63 | if (
64 | (key not in ("_model", "_model_key"))
65 | and (isinstance(value, torch.nn.Module))
66 | and (self._envoy._tracer is None)
67 | ):
68 | setattr(self._model, key, value)
69 | else:
70 | super().__setattr__(key, value)
71 |
72 | @property
73 | def device(self):
74 | """Returns the device."""
75 | return next(self.parameters()).device
76 |
77 | @device.setter
78 | def device(self, device: torch.device):
79 | """Sets the device."""
80 | self.to(device)
81 |
82 | @classmethod
83 | def from_path(cls, model_path: str) -> "LczeroModel":
84 | """Creates a wrapper from a model path.
85 |
86 | Parameters
87 | ----------
88 | model_path : str
89 | Path to the model file (.onnx or .pt)
90 |
91 | Returns
92 | -------
93 | LczeroModel
94 | The wrapped model instance
95 |
96 | Raises
97 | ------
98 | NotImplementedError
99 | If the model file extension is not supported
100 | """
101 | if model_path.endswith(".onnx"):
102 | return cls.from_onnx_path(model_path)
103 | elif model_path.endswith(".pt"):
104 | return cls.from_torch_path(model_path)
105 | else:
106 | raise NotImplementedError(f"Model path {model_path} is not supported.")
107 |
108 | @classmethod
109 | def from_onnx_path(cls, onnx_model_path: str, check: bool = True) -> "LczeroModel":
110 | """Builds a model from an ONNX file path.
111 |
112 | Parameters
113 | ----------
114 | onnx_model_path : str
115 | Path to the ONNX model file
116 | check : bool, optional
117 | Whether to perform shape inference check, by default True
118 |
119 | Returns
120 | -------
121 | LczeroModel
122 | The wrapped model instance
123 |
124 | Raises
125 | ------
126 | FileNotFoundError
127 | If the model file does not exist
128 | ValueError
129 | If the model could not be loaded
130 | """
131 | if not os.path.exists(onnx_model_path):
132 | raise FileNotFoundError(f"Model path {onnx_model_path} does not exist.")
133 | try:
134 | if check:
135 | onnx_model = safe_shape_inference(onnx_model_path)
136 | onnx_torch_model = convert(onnx_model)
137 | return cls(onnx_torch_model)
138 | except Exception as e:
139 | raise ValueError(f"Could not load model at {onnx_model_path}.") from e
140 |
141 | @classmethod
142 | def from_torch_path(cls, torch_model_path: str) -> "LczeroModel":
143 | """Builds a model from a PyTorch file path.
144 |
145 | Parameters
146 | ----------
147 | torch_model_path : str
148 | Path to the PyTorch model file
149 |
150 | Returns
151 | -------
152 | LczeroModel
153 | The wrapped model instance
154 |
155 | Raises
156 | ------
157 | FileNotFoundError
158 | If the model file does not exist
159 | ValueError
160 | If the model could not be loaded or is not a valid model type
161 | """
162 | if not os.path.exists(torch_model_path):
163 | raise FileNotFoundError(f"Model path {torch_model_path} does not exist.")
164 | try:
165 | torch_model = torch.load(torch_model_path)
166 | except Exception as e:
167 | raise ValueError(f"Could not load model at {torch_model_path}.") from e
168 | if isinstance(torch_model, LczeroModel):
169 | return torch_model
170 | elif isinstance(torch_model, nn.Module):
171 | return cls(torch_model)
172 | else:
173 | raise ValueError(f"Could not load model at {torch_model_path}.")
174 |
175 | @contextmanager
176 | def _ensure_proper_forward(self):
177 | old_forward = self._model.forward
178 |
179 | output_node = list(self._model.graph.nodes)[-1]
180 | output_names = [n.name.replace("output_", "") for n in output_node.all_input_nodes]
181 |
182 | def td_forward(x):
183 | old_out = old_forward(x)
184 | return TensorDict(
185 | {name: old_out[i] for i, name in enumerate(output_names)},
186 | batch_size=x.shape[0],
187 | )
188 |
189 | self._model.forward = td_forward
190 | yield
191 | self._model.forward = old_forward
192 |
193 |
194 | class Flow(LczeroModel):
195 | """Base class for isolating a flow."""
196 |
197 | _flow_type: str
198 | _registry: Dict[str, Type["Flow"]] = {}
199 |
200 | def __init__(
201 | self,
202 | model_key,
203 | *args,
204 | **kwargs,
205 | ):
206 | if isinstance(model_key, LczeroModel):
207 | raise ValueError("Use the `from_model` classmethod to create a flow.")
208 | if not self.is_compatible(model_key):
209 | raise ValueError(f"The model does not have a {self._flow_type} head.")
210 | super().__init__(model_key, *args, **kwargs)
211 |
212 | @classmethod
213 | def register(cls, name: str):
214 | """Registers the flow.
215 |
216 | Parameters
217 | ----------
218 | name : str
219 | The name of the flow to register.
220 |
221 | Returns
222 | -------
223 | Callable
224 | Decorator function that registers the flow subclass.
225 |
226 | Raises
227 | ------
228 | ValueError
229 | If the flow name is already registered.
230 | """
231 |
232 | if name in cls._registry:
233 | raise ValueError(f"Flow {name} already registered.")
234 |
235 | def decorator(subclass):
236 | cls._registry[name] = subclass
237 | subclass._flow_type = name
238 | return subclass
239 |
240 | return decorator
241 |
242 | @classmethod
243 | def from_name(cls, name: str, *args, **kwargs) -> "Flow":
244 | """Returns the flow from its name.
245 |
246 | Parameters
247 | ----------
248 | name : str
249 | The name of the flow to instantiate.
250 | *args
251 | Positional arguments passed to flow constructor.
252 | **kwargs
253 | Keyword arguments passed to flow constructor.
254 |
255 | Returns
256 | -------
257 | Flow
258 | The instantiated flow.
259 |
260 | Raises
261 | ------
262 | KeyError
263 | If the flow name is not found.
264 | """
265 | if name not in cls._registry:
266 | raise KeyError(f"Flow {name} not found.")
267 | return cls._registry[name](*args, **kwargs)
268 |
269 | @classmethod
270 | def from_model(cls, name: str, model: LczeroModel, *args, **kwargs) -> "Flow":
271 | """Returns the flow from a model.
272 |
273 | Parameters
274 | ----------
275 | name : str
276 | The name of the flow to instantiate.
277 | model : LczeroModel
278 | The model to create the flow from.
279 | *args
280 | Positional arguments passed to flow constructor.
281 | **kwargs
282 | Keyword arguments passed to flow constructor.
283 |
284 | Returns
285 | -------
286 | Flow
287 | The instantiated flow.
288 |
289 | Raises
290 | ------
291 | KeyError
292 | If the flow name is not found.
293 | """
294 | if name not in cls._registry:
295 | raise KeyError(f"Flow {name} not found.")
296 | flow_class = cls._registry[name]
297 | return flow_class(model._model, *args, **kwargs)
298 |
299 | @classmethod
300 | def is_compatible(cls, model: nn.Module) -> bool:
301 | """Checks if the model is compatible with this flow.
302 |
303 | Parameters
304 | ----------
305 | model : nn.Module
306 | The model to check compatibility with.
307 |
308 | Returns
309 | -------
310 | bool
311 | Whether the model is compatible with this flow.
312 | """
313 | return hasattr(model, cls._flow_type) or hasattr(model, f"output/{cls._flow_type}")
314 |
315 | @contextmanager
316 | def _ensure_proper_forward(self):
317 | """Rewrites the forward function to return the flow output."""
318 | flow_type = getattr(self, "_flow_type", None)
319 | if flow_type is None:
320 | return
321 |
322 | with super()._ensure_proper_forward():
323 | old_forward = self._model.forward
324 |
325 | def flow_forward(*inputs, **kwargs):
326 | out = old_forward(*inputs, **kwargs)
327 | return out[flow_type]
328 |
329 | self._model.forward = flow_forward
330 | yield
331 | self._model.forward = old_forward
332 |
333 |
334 | @Flow.register("policy")
335 | class PolicyFlow(Flow):
336 | """Class for isolating the policy flow."""
337 |
338 |
339 | @Flow.register("value")
340 | class ValueFlow(Flow):
341 | """Class for isolating the value flow."""
342 |
343 |
344 | @Flow.register("wdl")
345 | class WdlFlow(Flow):
346 | """Class for isolating the WDL flow."""
347 |
348 |
349 | @Flow.register("mlh")
350 | class MlhFlow(Flow):
351 | """Class for isolating the MLH flow."""
352 |
353 |
354 | @Flow.register("force_value")
355 | class ForceValueFlow(Flow):
356 | """Class for forcing and isolating the value flow."""
357 |
358 | @classmethod
359 | def is_compatible(cls, model: nn.Module):
360 | return ValueFlow.is_compatible(model) or WdlFlow.is_compatible(model)
361 |
362 | @contextmanager
363 | def _ensure_proper_forward(self):
364 | flow_type = getattr(self, "_flow_type", None)
365 | if flow_type is None:
366 | return
367 |
368 | with LczeroModel._ensure_proper_forward(self):
369 | old_forward = self._model.forward
370 |
371 | def flow_forward(*inputs, **kwargs):
372 | out = old_forward(*inputs, **kwargs)
373 | if "value" in out.keys():
374 | return out["value"]
375 | return out["wdl"] @ torch.tensor([1.0, 0.0, -1.0], device=out.device)
376 |
377 | self._model.forward = flow_forward
378 | yield
379 | self._model.forward = old_forward
380 |
--------------------------------------------------------------------------------
/src/lczerolens/play/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Play module for the lczerolens package.
3 | """
4 |
5 | from .sampling import ModelSampler, Sampler, RandomSampler, PolicySampler
6 | from .game import Game
7 | from .puzzle import Puzzle
8 | from . import sampling, game, puzzle
9 |
10 | __all__ = ["ModelSampler", "Sampler", "RandomSampler", "PolicySampler", "Game", "Puzzle", "sampling", "game", "puzzle"]
11 |
--------------------------------------------------------------------------------
/src/lczerolens/play/game.py:
--------------------------------------------------------------------------------
1 | """Preproces functions for chess games."""
2 |
3 | from dataclasses import dataclass
4 | from typing import Any, Dict, List, Optional, Union
5 |
6 | from datasets import Features, Value, Sequence
7 |
8 | from lczerolens.board import LczeroBoard
9 |
10 | GAME_DATASET_FEATURES = Features(
11 | {
12 | "gameid": Value("string"),
13 | "moves": Value("string"),
14 | }
15 | )
16 |
17 | BOARD_DATASET_FEATURES = Features(
18 | {
19 | "gameid": Value("string"),
20 | "moves": Sequence(Value("string")),
21 | "fen": Value("string"),
22 | }
23 | )
24 |
25 |
26 | @dataclass
27 | class Game:
28 | gameid: str
29 | moves: List[str]
30 | book_exit: Optional[int] = None
31 |
32 | @classmethod
33 | def from_dict(cls, obj: Dict[str, str]) -> "Game":
34 | if "moves" not in obj:
35 | ValueError("The dict should contain `moves`.")
36 | if "gameid" not in obj:
37 | ValueError("The dict should contain `gameid`.")
38 | *pre, post = obj["moves"].split("{ Book exit }")
39 | if pre:
40 | if len(pre) > 1:
41 | raise ValueError("More than one book exit")
42 | (pre,) = pre
43 | parsed_pre_moves = [m for m in pre.split() if not m.endswith(".")]
44 | book_exit = len(parsed_pre_moves)
45 | else:
46 | parsed_pre_moves = []
47 | book_exit = None
48 | parsed_moves = parsed_pre_moves + [m for m in post.split() if not m.endswith(".")]
49 | return cls(
50 | gameid=obj["gameid"],
51 | moves=parsed_moves,
52 | book_exit=book_exit,
53 | )
54 |
55 | def to_boards(
56 | self,
57 | n_history: int = 0,
58 | skip_book_exit: bool = False,
59 | skip_first_n: int = 0,
60 | output_dict=True,
61 | ) -> List[Union[Dict[str, Any], LczeroBoard]]:
62 | working_board = LczeroBoard()
63 | if skip_first_n > 0 or (skip_book_exit and (self.book_exit is not None)):
64 | boards = []
65 | else:
66 | if output_dict:
67 | boards = [
68 | {
69 | "fen": working_board.fen(),
70 | "moves": [],
71 | "gameid": self.gameid,
72 | }
73 | ]
74 | else:
75 | boards = [working_board.copy(stack=n_history)]
76 |
77 | for i, move in enumerate(self.moves[:-1]): # skip the last move as it can be over
78 | working_board.push_san(move)
79 | if (i < skip_first_n) or (skip_book_exit and (self.book_exit is not None) and (i < self.book_exit)):
80 | continue
81 | if output_dict:
82 | save_board = working_board.copy(stack=n_history)
83 | boards.append(
84 | {
85 | "fen": save_board.root().fen(),
86 | "moves": [move.uci() for move in save_board.move_stack],
87 | "gameid": self.gameid,
88 | }
89 | )
90 | else:
91 | boards.append(working_board.copy(stack=n_history))
92 | return boards
93 |
94 | @staticmethod
95 | def board_collate_fn(batch):
96 | boards = []
97 | for element in batch:
98 | board = LczeroBoard(element["fen"])
99 | for move in element["moves"]:
100 | board.push_san(move)
101 | boards.append(board)
102 | return boards, {}
103 |
--------------------------------------------------------------------------------
/src/lczerolens/play/puzzle.py:
--------------------------------------------------------------------------------
1 | """Preproces functions for chess puzzles."""
2 |
3 | from dataclasses import dataclass
4 | from typing import Dict, List, Union, Tuple, Optional, Iterable
5 |
6 | import chess
7 | import torch
8 | from datasets import Features, Value
9 | from itertools import tee, chain
10 |
11 | from lczerolens.board import LczeroBoard
12 | from .sampling import Sampler
13 |
14 |
15 | PUZZLE_DATASET_FEATURES = Features(
16 | {
17 | "PuzzleId": Value("string"),
18 | "FEN": Value("string"),
19 | "Moves": Value("string"),
20 | "Rating": Value("int64"),
21 | "RatingDeviation": Value("int64"),
22 | "Popularity": Value("int64"),
23 | "NbPlays": Value("int64"),
24 | "Themes": Value("string"),
25 | "GameUrl": Value("string"),
26 | "OpeningTags": Value("string"),
27 | }
28 | )
29 |
30 |
31 | @dataclass
32 | class Puzzle:
33 | puzzle_id: str
34 | fen: str
35 | initial_move: chess.Move
36 | moves: List[chess.Move]
37 | rating: int
38 | rating_deviation: int
39 | popularity: int
40 | nb_plays: int
41 | themes: List[str]
42 | game_url: str
43 | opening_tags: List[str]
44 |
45 | @classmethod
46 | def from_dict(cls, obj: Dict[str, Union[str, int, None]]) -> "Puzzle":
47 | uci_moves = obj["Moves"].split()
48 | moves = [chess.Move.from_uci(uci_move) for uci_move in uci_moves]
49 | return cls(
50 | puzzle_id=obj["PuzzleId"],
51 | fen=obj["FEN"],
52 | initial_move=moves[0],
53 | moves=moves[1:],
54 | rating=obj["Rating"],
55 | rating_deviation=obj["RatingDeviation"],
56 | popularity=obj["Popularity"],
57 | nb_plays=obj["NbPlays"],
58 | themes=obj["Themes"].split() if obj["Themes"] is not None else [],
59 | game_url=obj["GameUrl"],
60 | opening_tags=obj["OpeningTags"].split() if obj["OpeningTags"] is not None else [],
61 | )
62 |
63 | def __len__(self) -> int:
64 | return len(self.moves)
65 |
66 | @property
67 | def initial_board(self) -> LczeroBoard:
68 | board = LczeroBoard(self.fen)
69 | board.push(self.initial_move)
70 | return board
71 |
72 | def board_move_generator(self, all_moves: bool = False) -> Iterable[Tuple[LczeroBoard, chess.Move]]:
73 | board = self.initial_board
74 | initial_turn = board.turn
75 | for move in self.moves:
76 | if not all_moves and board.turn != initial_turn:
77 | board.push(move)
78 | continue
79 | yield board.copy(), move
80 | board.push(move)
81 |
82 | @classmethod
83 | def evaluate_multiple(
84 | cls,
85 | puzzles: Iterable["Puzzle"],
86 | sampler: Sampler,
87 | all_moves: bool = False,
88 | compute_metrics: bool = True,
89 | **kwargs,
90 | ) -> Union[Iterable[Dict[str, float]], Iterable[Tuple[torch.Tensor, torch.Tensor, chess.Move]]]:
91 | metric_puzzles, board_move_puzzles = tee(puzzles)
92 | board_move_generator = chain.from_iterable(
93 | puzzle.board_move_generator(all_moves) for puzzle in board_move_puzzles
94 | )
95 |
96 | def board_generator():
97 | for board, _ in board_move_generator:
98 | yield board
99 |
100 | util_boards, move_boards = tee(board_generator())
101 |
102 | def metric_inputs_generator():
103 | util_gen = sampler.get_utilities(util_boards, **kwargs)
104 | for board, (utility, legal_indices, _) in zip(move_boards, util_gen):
105 | predicted_move = sampler.choose_move(board, utility, legal_indices)
106 | yield utility, legal_indices, predicted_move
107 |
108 | if compute_metrics:
109 | return cls.compute_metrics(metric_puzzles, metric_inputs_generator(), all_moves=all_moves)
110 | else:
111 | return metric_inputs_generator()
112 |
113 | def evaluate(self, sampler: Sampler, all_moves: bool = False, **kwargs) -> Tuple[float, Optional[float]]:
114 | return next(iter(self.evaluate_multiple([self], sampler, all_moves, **kwargs)))
115 |
116 | @staticmethod
117 | def compute_metrics(
118 | puzzles: Iterable["Puzzle"],
119 | inputs: Iterable[Tuple[torch.Tensor, torch.Tensor, chess.Move]],
120 | all_moves: bool = False,
121 | ) -> Iterable[Dict[str, float]]:
122 | iter_inputs = iter(inputs)
123 | for puzzle in puzzles:
124 | total = len(puzzle) if all_moves else (len(puzzle) + 1) // 2
125 | metrics = {"score": 0.0, "perplexity": 1.0, "normalized_perplexity": 1.0}
126 | for board, move in puzzle.board_move_generator(all_moves=all_moves):
127 | utility, legal_indices, predicted_move = next(iter_inputs)
128 | index = LczeroBoard.encode_move(move, board.turn)
129 | probs = torch.softmax(utility, dim=0)
130 | move_prob = probs[legal_indices == index].item()
131 | metrics["perplexity"] *= move_prob ** (-1 / total)
132 | metrics["normalized_perplexity"] *= (len(legal_indices) * move_prob) ** (-1 / total)
133 | if predicted_move == move:
134 | metrics["score"] += 1
135 | metrics["score"] /= total
136 | yield metrics
137 |
138 | def _repr_svg_(self) -> str:
139 | return self.initial_board._repr_svg_()
140 |
--------------------------------------------------------------------------------
/src/lczerolens/play/sampling.py:
--------------------------------------------------------------------------------
1 | """Classes for playing."""
2 |
3 | from abc import ABC, abstractmethod
4 | from dataclasses import dataclass
5 | from typing import Optional, Callable, Tuple, Dict, Iterable
6 |
7 | import chess
8 | import torch
9 | from torch.distributions import Categorical
10 | from itertools import tee
11 |
12 | from lczerolens.model import LczeroModel
13 | from lczerolens.board import LczeroBoard
14 |
15 |
16 | class Sampler(ABC):
17 | @abstractmethod
18 | def get_utilities(
19 | self, boards: Iterable[LczeroBoard], **kwargs
20 | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]]:
21 | """Get the utility of the board.
22 |
23 | Parameters
24 | ----------
25 | boards : Iterable[LczeroBoard]
26 | The boards to evaluate.
27 |
28 | Returns
29 | -------
30 | Iterable[Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]]
31 | The iterable over utilities, legal indices, and log dictionaries.
32 | """
33 | pass
34 |
35 | def choose_move(self, board: LczeroBoard, utility: torch.Tensor, legal_indices: torch.Tensor) -> chess.Move:
36 | """Choose the next moves.
37 |
38 | Parameters
39 | ----------
40 | board : LczeroBoard
41 | The board.
42 | utility : torch.Tensor
43 | The utility of the board.
44 | legal_indices : torch.Tensor
45 | The legal indices.
46 |
47 | Returns
48 | -------
49 | Iterable[chess.Move]
50 | The iterable over the moves.
51 | """
52 | m = Categorical(logits=utility)
53 | idx = m.sample()
54 | return board.decode_move(legal_indices[idx])
55 |
56 | def get_next_moves(self, boards: Iterable[LczeroBoard], **kwargs) -> Iterable[Tuple[chess.Move, Dict[str, float]]]:
57 | """Get the next move.
58 |
59 | Parameters
60 | ----------
61 | boards : Iterable[LczeroBoard]
62 | The boards to evaluate.
63 |
64 | Returns
65 | -------
66 | Iterable[Tuple[chess.Move, Dict[str, float]]]
67 | The iterable over the moves and log dictionaries.
68 | """
69 | util_boards, move_boards = tee(boards)
70 | for board, (utility, legal_indices, to_log) in zip(move_boards, self.get_utilities(util_boards, **kwargs)):
71 | predicted_move = self.choose_move(board, utility, legal_indices)
72 | yield predicted_move, to_log
73 |
74 |
75 | class RandomSampler(Sampler):
76 | def get_utilities(
77 | self, boards: Iterable[LczeroBoard], **kwargs
78 | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]]:
79 | for board in boards:
80 | legal_indices = board.get_legal_indices()
81 | utilities = torch.ones_like(legal_indices, dtype=torch.float32)
82 | yield utilities, legal_indices, {}
83 |
84 |
85 | @dataclass
86 | class ModelSampler(Sampler):
87 | model: LczeroModel
88 | use_argmax: bool = True
89 | alpha: float = 1.0
90 | beta: float = 1.0
91 | gamma: float = 1.0
92 | draw_score: float = 0.0
93 | m_max: float = 0.0345
94 | m_slope: float = 0.0027
95 | k_0: float = 0.0
96 | k_1: float = 1.6521
97 | k_2: float = -0.6521
98 | q_threshold: float = 0.8
99 |
100 | def choose_move(self, board: LczeroBoard, utility: torch.Tensor, legal_indices: torch.Tensor) -> chess.Move:
101 | if self.use_argmax:
102 | idx = utility.argmax()
103 | return board.decode_move(legal_indices[idx])
104 | return super().choose_move(board, utility, legal_indices)
105 |
106 | @torch.no_grad()
107 | def get_utilities(
108 | self, boards: Iterable[LczeroBoard], **kwargs
109 | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]]:
110 | batch_size = kwargs.pop("batch_size", -1)
111 | callback = kwargs.pop("callback", None)
112 |
113 | for legal_indices, batch_stats in self._get_batched_stats(boards, batch_size, **kwargs):
114 | to_log = {}
115 | utility = 0
116 | q_values = self._get_q_values(batch_stats, to_log)
117 | utility += self.alpha * q_values
118 | utility += self.beta * self._get_m_values(batch_stats, q_values, to_log)
119 | utility += self.gamma * self._get_p_values(batch_stats, legal_indices, to_log)
120 | to_log["max_utility"] = utility.max().item()
121 |
122 | self._use_callback(callback, batch_stats, to_log)
123 |
124 | yield utility, legal_indices, to_log
125 |
126 | def _get_batched_stats(self, boards, batch_size, use_next_boards=True, **kwargs):
127 | next_batch = []
128 | next_legal_indices = []
129 |
130 | def generator(next_batch, next_legal_indices):
131 | all_stats = self.model(*next_batch, **kwargs)
132 | offset = 0
133 | for legal_indices in next_legal_indices:
134 | n_boards = legal_indices.shape[0] + 1 if use_next_boards else 1
135 | batch_stats = all_stats[offset : offset + n_boards]
136 | offset += n_boards
137 | yield legal_indices, batch_stats
138 |
139 | for board in boards:
140 | legal_indices = board.get_legal_indices()
141 | next_boards = list(board.get_next_legal_boards()) if use_next_boards else []
142 | if len(next_batch) + len(next_boards) + 1 > batch_size and batch_size != -1:
143 | yield from generator(next_batch, next_legal_indices)
144 | next_batch = []
145 | next_legal_indices = []
146 | next_batch.extend([board] + next_boards)
147 | next_legal_indices.append(legal_indices)
148 | if next_batch:
149 | yield from generator(next_batch, next_legal_indices)
150 |
151 | def _get_q_values(self, batch_stats, to_log):
152 | if "value" in batch_stats.keys():
153 | to_log["value"] = batch_stats["value"][0].item()
154 | return batch_stats["value"][1:, 0]
155 | elif "wdl" in batch_stats.keys():
156 | to_log["wdl_w"] = batch_stats["wdl"][0][0].item()
157 | to_log["wdl_d"] = batch_stats["wdl"][0][1].item()
158 | to_log["wdl_l"] = batch_stats["wdl"][0][2].item()
159 | scores = torch.tensor([1, self.draw_score, -1])
160 | return batch_stats["wdl"][1:] @ scores
161 | return torch.zeros(batch_stats.batch_size[0] - 1)
162 |
163 | def _get_m_values(self, batch_stats, q_values, to_log):
164 | if "mlh" in batch_stats.keys():
165 | to_log["mlh"] = batch_stats["mlh"][0].item()
166 | delta_m_values = self.m_slope * (batch_stats["mlh"][1:, 0] - batch_stats["mlh"][0, 0])
167 | delta_m_values.clamp_(-self.m_max, self.m_max)
168 | scaled_q_values = torch.relu(q_values.abs() - self.q_threshold) / (1 - self.q_threshold)
169 | poly_q_values = self.k_0 + self.k_1 * scaled_q_values + self.k_2 * scaled_q_values**2
170 | return -q_values.sign() * delta_m_values * poly_q_values
171 | return torch.zeros(batch_stats.batch_size[0] - 1)
172 |
173 | def _get_p_values(
174 | self,
175 | batch_stats,
176 | legal_indices,
177 | to_log,
178 | ):
179 | if "policy" in batch_stats.keys():
180 | legal_policy = batch_stats["policy"][0].gather(0, legal_indices)
181 | to_log["max_legal_policy"] = legal_policy.max().item()
182 | return legal_policy
183 | return torch.zeros_like(legal_indices)
184 |
185 | def _use_callback(self, callback, batch_stats, to_log):
186 | if callback is not None:
187 | to_log_update = callback(batch_stats, to_log)
188 | if not isinstance(to_log_update, dict):
189 | raise ValueError("Callback must return a dictionary.")
190 | to_log |= to_log_update
191 |
192 |
193 | @dataclass
194 | class PolicySampler(ModelSampler):
195 | use_suboptimal: bool = False
196 |
197 | @torch.no_grad()
198 | def get_utilities(
199 | self, boards: Iterable[LczeroBoard], **kwargs
200 | ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]]:
201 | batch_size = kwargs.pop("batch_size", -1)
202 | callback = kwargs.pop("callback", None)
203 |
204 | to_log = {}
205 | for legal_indices, batch_stats in self._get_batched_stats(boards, batch_size, use_next_boards=False, **kwargs):
206 | legal_policy = batch_stats["policy"][0].gather(0, legal_indices.to(batch_stats["policy"].device))
207 | if self.use_suboptimal:
208 | idx = legal_policy.argmax()
209 | legal_policy[idx] = torch.tensor(-1e3)
210 |
211 | self._use_callback(callback, batch_stats, to_log)
212 |
213 | yield legal_policy, legal_indices, to_log
214 |
215 |
216 | @dataclass
217 | class SelfPlay:
218 | """A class for generating games."""
219 |
220 | white: Sampler
221 | black: Sampler
222 |
223 | def play(
224 | self,
225 | board: Optional[LczeroBoard] = None,
226 | max_moves: int = 100,
227 | to_play: chess.Color = chess.WHITE,
228 | report_fn: Optional[Callable[[dict, chess.Color], None]] = None,
229 | white_kwargs: Optional[Dict] = None,
230 | black_kwargs: Optional[Dict] = None,
231 | ):
232 | """
233 | Plays a game.
234 | """
235 | if board is None:
236 | board = LczeroBoard()
237 | if white_kwargs is None:
238 | white_kwargs = {}
239 | if black_kwargs is None:
240 | black_kwargs = {}
241 | game = []
242 | if to_play == chess.BLACK:
243 | move, _ = next(iter(self.black.get_next_moves([board], **black_kwargs)))
244 | board.push(move)
245 | game.append(move)
246 | for _ in range(max_moves):
247 | if board.is_game_over() or len(game) >= max_moves:
248 | break
249 | move, to_log = next(iter(self.white.get_next_moves([board], **white_kwargs)))
250 | if report_fn is not None:
251 | report_fn(to_log, board.turn)
252 | board.push(move)
253 | game.append(move)
254 |
255 | if board.is_game_over() or len(game) >= max_moves:
256 | break
257 | move, to_log = next(iter(self.black.get_next_moves([board], **black_kwargs)))
258 | if report_fn is not None:
259 | report_fn(to_log, board.turn)
260 | board.push(move)
261 | game.append(move)
262 | if board.is_game_over() or len(game) >= max_moves:
263 | break
264 | return game, board
265 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/__init__.py
--------------------------------------------------------------------------------
/tests/assets/error.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "assert False"
10 | ]
11 | }
12 | ],
13 | "metadata": {
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "nbformat": 4,
19 | "nbformat_minor": 2
20 | }
21 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """
2 | File to test the encodings for the Leela Chess Zero engine.
3 | """
4 |
5 | import onnxruntime as ort
6 | import pytest
7 | from lczero.backends import Backend, Weights
8 |
9 | from lczerolens import LczeroModel
10 | from lczerolens import backends as lczero_utils
11 |
12 |
13 | @pytest.fixture(scope="session")
14 | def tiny_lczero_backend():
15 | lczero_weights = Weights("assets/tinygyal-8.pb.gz")
16 | yield Backend(weights=lczero_weights)
17 |
18 |
19 | @pytest.fixture(scope="session")
20 | def tiny_ensure_network():
21 | lczero_utils.convert_to_onnx("assets/tinygyal-8.pb.gz", "assets/tinygyal-8.onnx")
22 | yield
23 |
24 |
25 | @pytest.fixture(scope="session")
26 | def tiny_model(tiny_ensure_network):
27 | yield LczeroModel.from_path("assets/tinygyal-8.onnx")
28 |
29 |
30 | @pytest.fixture(scope="session")
31 | def tiny_senet_ort(tiny_ensure_network):
32 | senet_ort = ort.InferenceSession("assets/tinygyal-8.onnx")
33 | yield senet_ort
34 |
35 |
36 | @pytest.fixture(scope="class")
37 | def maia_ensure_network():
38 | lczero_utils.convert_to_onnx("assets/maia-1100.pb.gz", "assets/maia-1100.onnx")
39 | yield
40 |
41 |
42 | @pytest.fixture(scope="class")
43 | def maia_model(maia_ensure_network):
44 | yield LczeroModel.from_path("assets/maia-1100.onnx")
45 |
46 |
47 | @pytest.fixture(scope="class")
48 | def maia_senet_ort(maia_ensure_network):
49 | senet_ort = ort.InferenceSession("assets/maia-1100.onnx")
50 | yield senet_ort
51 |
52 |
53 | @pytest.fixture(scope="class")
54 | def winner_ensure_network():
55 | lczero_utils.convert_to_onnx(
56 | "assets/384x30-2022_0108_1903_17_608.pb.gz",
57 | "assets/384x30-2022_0108_1903_17_608.onnx",
58 | )
59 | yield
60 |
61 |
62 | @pytest.fixture(scope="class")
63 | def winner_model(winner_ensure_network):
64 | yield LczeroModel.from_path("assets/384x30-2022_0108_1903_17_608.onnx")
65 |
66 |
67 | @pytest.fixture(scope="class")
68 | def winner_senet_ort(winner_ensure_network):
69 | yield ort.InferenceSession("assets/384x30-2022_0108_1903_17_608.onnx")
70 |
71 |
72 | def pytest_addoption(parser):
73 | parser.addoption("--run-slow", action="store_true", default=False, help="run slow tests")
74 | parser.addoption("--run-fast", action="store_true", default=False, help="run fast tests")
75 | parser.addoption("--run-backends", action="store_true", default=False, help="run backends tests")
76 |
77 |
78 | def pytest_configure(config):
79 | config.addinivalue_line("markers", "slow: mark test as slow to run")
80 | config.addinivalue_line("markers", "backends: mark test as backends test")
81 |
82 |
83 | def pytest_collection_modifyitems(config, items):
84 | run_slow = config.getoption("--run-slow")
85 | run_fast = config.getoption("--run-fast")
86 | run_backends = config.getoption("--run-backends")
87 |
88 | skip_slow = pytest.mark.skip(reason="--run-slow not given in cli: skipping slow tests")
89 | skip_fast = pytest.mark.skip(reason="--run-fast not given in cli: skipping fast tests")
90 | skip_backends = pytest.mark.skip(reason="--run-backends not given in cli: skipping backends tests")
91 |
92 | for item in items:
93 | if "slow" in item.keywords and not run_slow:
94 | item.add_marker(skip_slow)
95 | if "fast" in item.keywords and not run_fast:
96 | item.add_marker(skip_fast)
97 | if "backends" in item.keywords and not run_backends:
98 | item.add_marker(skip_backends)
99 |
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/integration/__init__.py
--------------------------------------------------------------------------------
/tests/integration/test_notebooks.py:
--------------------------------------------------------------------------------
1 | """
2 | Integration tests using the notebooks.
3 | """
4 |
5 | import subprocess
6 | import pytest
7 |
8 | NOTEBOOKS = [
9 | "docs/source/notebooks/features/visualise-heatmaps.ipynb",
10 | "docs/source/notebooks/features/probe-concepts.ipynb",
11 | "docs/source/notebooks/features/convert-official-weights.ipynb",
12 | "docs/source/notebooks/features/move-prediction.ipynb",
13 | "docs/source/notebooks/tutorials/piece-value-estimation-using-lrp.ipynb",
14 | "docs/source/notebooks/walkthrough.ipynb",
15 | ]
16 |
17 |
18 | def run_notebook(notebook):
19 | result = subprocess.run(
20 | ["uv", "run", "jupyter", "nbconvert", "--to", "notebook", "--execute", notebook],
21 | stderr=subprocess.PIPE,
22 | )
23 | if result.returncode != 0:
24 | raise subprocess.CalledProcessError(result.returncode, result.args, result.stderr)
25 |
26 |
27 | class TestNotebooks:
28 | def test_error_notebook(self):
29 | with pytest.raises(subprocess.CalledProcessError):
30 | run_notebook("tests/assets/error.ipynb")
31 |
32 | @pytest.mark.slow
33 | @pytest.mark.parametrize("notebook", NOTEBOOKS)
34 | def test_notebook(self, notebook):
35 | run_notebook(notebook)
36 |
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/unit/__init__.py
--------------------------------------------------------------------------------
/tests/unit/concepts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/unit/concepts/__init__.py
--------------------------------------------------------------------------------
/tests/unit/concepts/test_concepts.py:
--------------------------------------------------------------------------------
1 | """
2 | Test cases for the concept module.
3 | """
4 |
5 | from lczerolens.concept import BinaryConcept, AndBinaryConcept
6 | from lczerolens.concepts import (
7 | HasPiece,
8 | HasThreat,
9 | )
10 | from lczerolens import LczeroBoard
11 |
12 |
13 | class TestBinaryConcept:
14 | """
15 | Test cases for the BinaryConcept class.
16 | """
17 |
18 | def test_compute_metrics(self):
19 | """
20 | Test the compute_metrics method.
21 | """
22 | predictions = [0, 1, 0, 1]
23 | labels = [0, 1, 1, 1]
24 | metrics = BinaryConcept.compute_metrics(predictions, labels)
25 | assert metrics["accuracy"] == 0.75
26 | assert metrics["precision"] == 1.0
27 | assert metrics["recall"] == 0.6666666666666666
28 |
29 | def test_compute_label(self):
30 | """
31 | Test the compute_label method.
32 | """
33 | concept = AndBinaryConcept(HasPiece("p"), HasPiece("n"))
34 | assert concept.compute_label(LczeroBoard("8/8/8/8/8/8/8/8 w - - 0 1")) == 0
35 | assert concept.compute_label(LczeroBoard("8/p7/8/8/8/8/8/8 w - - 0 1")) == 0
36 | assert concept.compute_label(LczeroBoard("8/pn6/8/8/8/8/8/8 w - - 0 1")) == 1
37 |
38 | def test_relative_threat(self):
39 | """
40 | Test the relative threat concept.
41 | """
42 | concept = HasThreat("p", relative=True) # Is an enemy pawn threatened?
43 | assert concept.compute_label(LczeroBoard("8/8/8/8/8/8/8/8 w - - 0 1")) == 0
44 | assert concept.compute_label(LczeroBoard("R7/8/8/8/8/8/p7/8 w - - 0 1")) == 1
45 | assert concept.compute_label(LczeroBoard("R7/8/8/8/8/8/p7/8 b - - 0 1")) == 0
46 |
--------------------------------------------------------------------------------
/tests/unit/conftest.py:
--------------------------------------------------------------------------------
1 | """
2 | File to test the encodings for the Leela Chess Zero engine.
3 | """
4 |
5 | import random
6 | import chess
7 | import pytest
8 |
9 | from lczerolens import LczeroBoard
10 |
11 |
12 | @pytest.fixture(scope="module")
13 | def random_move_board_list():
14 | board = LczeroBoard()
15 | seed = 42
16 | random.seed(seed)
17 | move_list = []
18 | board_list = [board.copy()]
19 | for _ in range(20):
20 | move = random.choice(list(board.legal_moves))
21 | move_list.append(move)
22 | board.push(move)
23 | board_list.append(board.copy(stack=8))
24 | return move_list, board_list
25 |
26 |
27 | @pytest.fixture(scope="module")
28 | def repetition_move_board_list():
29 | board = LczeroBoard()
30 | move_list = []
31 | board_list = [board.copy()]
32 | for uci_move in ("b1a3", "b8c6", "a3b1", "c6b8") * 4:
33 | move = chess.Move.from_uci(uci_move)
34 | move_list.append(move)
35 | board.push(move)
36 | board_list.append(board.copy(stack=True)) # Full stack is needed for repetition detection
37 | return move_list, board_list
38 |
39 |
40 | @pytest.fixture(scope="module")
41 | def long_move_board_list():
42 | board = LczeroBoard()
43 | seed = 6
44 | random.seed(seed)
45 | move_list = []
46 | board_list = [board.copy()]
47 | for _ in range(80):
48 | move = random.choice(list(board.legal_moves))
49 | move_list.append(move)
50 | board.push(move)
51 | board_list.append(board.copy(stack=8))
52 | return move_list, board_list
53 |
--------------------------------------------------------------------------------
/tests/unit/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/unit/core/__init__.py
--------------------------------------------------------------------------------
/tests/unit/core/test_board.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the board utils.
3 | """
4 |
5 | from typing import List, Tuple
6 | import pytest
7 | import chess
8 | from lczero.backends import GameState
9 |
10 | from lczerolens import backends as lczero_utils
11 | from lczerolens import LczeroBoard
12 |
13 |
14 | @pytest.mark.backends
15 | class TestWithBackend:
16 | def test_board_to_config_tensor(
17 | self, random_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend
18 | ):
19 | """
20 | Test that the board to tensor function works.
21 | """
22 | move_list, board_list = random_move_board_list
23 | for i, board in enumerate(board_list):
24 | board_tensor = board.to_config_tensor()
25 | uci_moves = [move.uci() for move in move_list[:i]]
26 | lczero_game = GameState(moves=uci_moves)
27 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game, planes=13)
28 | assert (board_tensor == lczero_input_tensor[:13]).all()
29 |
30 | def test_board_to_input_tensor(
31 | self, random_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend
32 | ):
33 | """
34 | Test that the board to tensor function works.
35 | """
36 | move_list, board_list = random_move_board_list
37 | for i, board in enumerate(board_list):
38 | board_tensor = board.to_input_tensor()
39 | uci_moves = [move.uci() for move in move_list[:i]]
40 | lczero_game = GameState(moves=uci_moves)
41 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game)
42 | # assert (board_tensor == lczero_input_tensor).all()
43 | for plane in range(112):
44 | assert (board_tensor[plane] == lczero_input_tensor[plane]).all()
45 |
46 |
47 | @pytest.mark.backends
48 | class TestRepetition:
49 | def test_board_to_config_tensor(
50 | self, repetition_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend
51 | ):
52 | """
53 | Test that the board to tensor function works.
54 | """
55 | move_list, board_list = repetition_move_board_list
56 | for i, board in enumerate(board_list):
57 | uci_moves = [move.uci() for move in move_list[:i]]
58 | board_tensor = board.to_config_tensor()
59 | lczero_game = GameState(moves=uci_moves)
60 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game, planes=13)
61 | assert (board_tensor == lczero_input_tensor[:13]).all()
62 |
63 | def test_board_to_input_tensor(
64 | self, repetition_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend
65 | ):
66 | """
67 | Test that the board to tensor function works.
68 | """
69 | move_list, board_list = repetition_move_board_list
70 | for i, board in enumerate(board_list):
71 | uci_moves = [move.uci() for move in move_list[:i]]
72 | board_tensor = board.to_input_tensor()
73 | lczero_game = GameState(moves=uci_moves)
74 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game)
75 | assert (board_tensor == lczero_input_tensor).all()
76 |
77 |
78 | @pytest.mark.backends
79 | class TestLong:
80 | def test_board_to_config_tensor(
81 | self, long_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend
82 | ):
83 | """
84 | Test that the board to tensor function works.
85 | """
86 | move_list, board_list = long_move_board_list
87 | for i, board in enumerate(board_list):
88 | uci_moves = [move.uci() for move in move_list[:i]]
89 | board_tensor = board.to_config_tensor()
90 | lczero_game = GameState(moves=uci_moves)
91 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game, planes=13)
92 | assert (board_tensor == lczero_input_tensor[:13]).all()
93 |
94 | def test_board_to_input_tensor(
95 | self, long_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]], tiny_lczero_backend
96 | ):
97 | """
98 | Test that the board to tensor function works.
99 | """
100 | move_list, board_list = long_move_board_list
101 | for i, board in enumerate(board_list):
102 | uci_moves = [move.uci() for move in move_list[:i]]
103 | board_tensor = board.to_input_tensor()
104 | lczero_game = GameState(moves=uci_moves)
105 | lczero_input_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game)
106 | assert (board_tensor == lczero_input_tensor).all()
107 |
108 |
109 | class TestStability:
110 | def test_encode_decode(self, random_move_board_list: Tuple[List[chess.Move], List[LczeroBoard]]):
111 | """
112 | Test that encoding and decoding a move is the identity.
113 | """
114 | us, them = chess.WHITE, chess.BLACK
115 | for move, board in zip(*random_move_board_list):
116 | encoded_move = LczeroBoard.encode_move(move, us)
117 | decoded_move = board.decode_move(encoded_move)
118 | assert move == decoded_move
119 | us, them = them, us
120 |
121 |
122 | @pytest.mark.backends
123 | class TestBackend:
124 | def test_encode_decode_random(self, random_move_board_list):
125 | """
126 | Test that encoding and decoding a move corresponds to the backend.
127 | """
128 | move_list, board_list = random_move_board_list
129 | for i, board in enumerate(board_list):
130 | lczero_game = GameState(moves=[move.uci() for move in move_list[:i]])
131 | legal_moves = [move.uci() for move in board.legal_moves]
132 | (
133 | lczero_legal_moves,
134 | lczero_policy_indices,
135 | ) = lczero_utils.moves_with_castling_swap(lczero_game, board)
136 | assert len(legal_moves) == len(lczero_legal_moves)
137 | assert set(legal_moves) == set(lczero_legal_moves)
138 | policy_indices = [LczeroBoard.encode_move(move, board.turn) for move in board.legal_moves]
139 | assert len(lczero_policy_indices) == len(policy_indices)
140 | assert set(lczero_policy_indices) == set(policy_indices)
141 |
142 | def test_encode_decode_long(self, long_move_board_list):
143 | """
144 | Test that encoding and decoding a move corresponds to the backend.
145 | """
146 | move_list, board_list = long_move_board_list
147 | for i, board in enumerate(board_list):
148 | lczero_game = GameState(moves=[move.uci() for move in move_list[:i]])
149 | legal_moves = [move.uci() for move in board.legal_moves]
150 | (
151 | lczero_legal_moves,
152 | lczero_policy_indices,
153 | ) = lczero_utils.moves_with_castling_swap(lczero_game, board)
154 | assert len(legal_moves) == len(lczero_legal_moves)
155 | assert set(legal_moves) == set(lczero_legal_moves)
156 | policy_indices = [LczeroBoard.encode_move(move, board.turn) for move in board.legal_moves]
157 | assert len(lczero_policy_indices) == len(policy_indices)
158 | assert set(lczero_policy_indices) == set(policy_indices)
159 |
--------------------------------------------------------------------------------
/tests/unit/core/test_input_encoding.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the input encoding.
3 | """
4 |
5 | import pytest
6 | import chess
7 |
8 | from lczerolens import LczeroBoard, InputEncoding
9 |
10 |
11 | class TestInputEncoding:
12 | @pytest.mark.parametrize(
13 | "input_encoding_expected",
14 | [
15 | (InputEncoding.INPUT_CLASSICAL_112_PLANE, 32),
16 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED, 32 * 8),
17 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED, 32 * 8),
18 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS, 32),
19 | ],
20 | )
21 | def test_sum_initial_config_planes(self, input_encoding_expected):
22 | """
23 | Test the sum of the config planes for the initial board.
24 | """
25 | input_encoding, expected = input_encoding_expected
26 | board = LczeroBoard()
27 | board_tensor = board.to_input_tensor(input_encoding=input_encoding)
28 | assert board_tensor[:104].sum() == expected
29 |
30 | @pytest.mark.parametrize(
31 | "input_encoding_expected",
32 | [
33 | (InputEncoding.INPUT_CLASSICAL_112_PLANE, 31 + 32 * 4),
34 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_REPEATED, 31 + 32 * 7),
35 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_REPEATED, 31 * 8),
36 | (InputEncoding.INPUT_CLASSICAL_112_PLANE_NO_HISTORY_ZEROS, 31),
37 | ],
38 | )
39 | def test_sum_qga_config_planes(self, input_encoding_expected):
40 | """
41 | Test the sum of the config planes for the queen's gambit accepted.
42 | """
43 | input_encoding, expected = input_encoding_expected
44 | board = LczeroBoard()
45 | moves = [
46 | chess.Move.from_uci("d2d4"),
47 | chess.Move.from_uci("d7d5"),
48 | chess.Move.from_uci("c2c4"),
49 | chess.Move.from_uci("d5c4"),
50 | ]
51 | for move in moves:
52 | board.push(move)
53 | board_tensor = board.to_input_tensor(input_encoding=input_encoding)
54 | assert board_tensor[:104].sum() == expected
55 |
--------------------------------------------------------------------------------
/tests/unit/core/test_lczero.py:
--------------------------------------------------------------------------------
1 | """
2 | LCZero utils tests.
3 | """
4 |
5 | import torch
6 | import pytest
7 | from lczero.backends import GameState
8 |
9 | from lczerolens import backends as lczero_utils
10 |
11 |
12 | class TestExecution:
13 | def test_describenet(self):
14 | """
15 | Test that the describenet function works.
16 | """
17 | description = lczero_utils.describenet("assets/tinygyal-8.pb.gz")
18 | assert isinstance(description, str)
19 | assert "Minimal Lc0 version:" in description
20 |
21 | def test_convertnet(self):
22 | """
23 | Test that the convert_to_onnx function works.
24 | """
25 | conversion = lczero_utils.convert_to_onnx("assets/tinygyal-8.pb.gz", "assets/tinygyal-8.onnx")
26 | assert isinstance(conversion, str)
27 | assert "INPUT_CLASSICAL_112_PLANE" in conversion
28 |
29 | def test_generic_command(self):
30 | """
31 | Test that the generic command function works.
32 | """
33 | generic_command = lczero_utils.generic_command(["--help"])
34 | assert isinstance(generic_command, str)
35 | assert "Usage: lc0" in generic_command
36 |
37 | @pytest.mark.backends
38 | def test_board_from_backend(self, tiny_lczero_backend):
39 | """
40 | Test that the board from backend function works.
41 | """
42 | lczero_game = GameState()
43 | lczero_board_tensor = lczero_utils.board_from_backend(tiny_lczero_backend, lczero_game)
44 | assert lczero_board_tensor.shape == (112, 8, 8)
45 |
46 | @pytest.mark.backends
47 | def test_prediction_from_backend(self, tiny_lczero_backend):
48 | """
49 | Test that the prediction from backend function works.
50 | """
51 | lczero_game = GameState()
52 | lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game)
53 | assert lczero_policy.shape == (1858,)
54 | assert (lczero_value >= -1) and (lczero_value <= 1)
55 | lczero_policy_softmax, _ = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game, softmax=True)
56 | assert lczero_policy_softmax.shape == (1858,)
57 | assert (lczero_policy_softmax >= 0).all() and (lczero_policy_softmax <= 1).all()
58 | assert torch.softmax(lczero_policy, dim=0).allclose(lczero_policy_softmax, atol=1e-4)
59 |
--------------------------------------------------------------------------------
/tests/unit/core/test_lens.py:
--------------------------------------------------------------------------------
1 | """Lens tests."""
2 |
3 | from typing import Any
4 | import pytest
5 |
6 | from lczerolens import Lens
7 | from lczerolens.model import LczeroModel
8 |
9 |
10 | @Lens.register("test_lens")
11 | class TestLens(Lens):
12 | """Test lens."""
13 |
14 | def is_compatible(self, model: LczeroModel) -> bool:
15 | return True
16 |
17 | def _intervene(self, model: LczeroModel, **kwargs) -> Any:
18 | pass
19 |
20 |
21 | class TestLensRegistry:
22 | def test_lens_registry_duplicate(self):
23 | """Test that registering a lens with an existing name raises an error."""
24 | with pytest.raises(ValueError, match="Lens .* already registered"):
25 |
26 | @Lens.register("test_lens")
27 | class DuplicateLens(Lens):
28 | """Duplicate lens."""
29 |
30 | def is_compatible(self, model: LczeroModel) -> bool:
31 | return True
32 |
33 | def analyse(self, *inputs, **kwargs) -> Any:
34 | pass
35 |
36 | def test_lens_registry_missing(self):
37 | """Test that instantiating a non-registered lens raises an error."""
38 | with pytest.raises(KeyError, match="Lens .* not found"):
39 | Lens.from_name("non_existent_lens")
40 |
41 | def test_lens_type(self):
42 | """Test that the lens type is correct."""
43 | assert TestLens._lens_type == "test_lens"
44 |
--------------------------------------------------------------------------------
/tests/unit/core/test_model.py:
--------------------------------------------------------------------------------
1 | """Model tests."""
2 |
3 | import pytest
4 | import torch
5 | from lczero.backends import GameState
6 |
7 | from lczerolens import Flow, LczeroBoard
8 | from lczerolens import backends as lczero_utils
9 |
10 |
11 | @pytest.mark.backends
12 | class TestModel:
13 | def test_model_prediction(self, tiny_lczero_backend, tiny_model):
14 | """Test that the model prediction works."""
15 | board = LczeroBoard()
16 | (out,) = tiny_model(board)
17 | policy = out["policy"]
18 | value = out["value"]
19 | lczero_game = GameState()
20 | lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game)
21 | assert torch.allclose(policy, lczero_policy, atol=1e-4)
22 | assert torch.allclose(value, lczero_value, atol=1e-4)
23 |
24 | def test_model_prediction_random(self, tiny_lczero_backend, tiny_model, random_move_board_list):
25 | """Test that the model prediction works."""
26 | move_list, board_list = random_move_board_list
27 | for i, board in enumerate(board_list):
28 | (out,) = tiny_model(board)
29 | policy = out["policy"]
30 | value = out["value"]
31 | lczero_game = GameState(moves=[move.uci() for move in move_list[:i]])
32 | lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game)
33 | assert torch.allclose(policy, lczero_policy, atol=1e-4)
34 | assert torch.allclose(value, lczero_value, atol=1e-4)
35 |
36 | def test_model_prediction_repetition(self, tiny_lczero_backend, tiny_model, repetition_move_board_list):
37 | """Test that the model prediction works."""
38 | move_list, board_list = repetition_move_board_list
39 | for i, board in enumerate(board_list):
40 | (out,) = tiny_model(board)
41 | policy = out["policy"]
42 | value = out["value"]
43 | lczero_game = GameState(moves=[move.uci() for move in move_list[:i]])
44 | lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game)
45 | assert torch.allclose(policy, lczero_policy, atol=1e-4)
46 | assert torch.allclose(value, lczero_value, atol=1e-4)
47 |
48 | def test_model_prediction_long(self, tiny_lczero_backend, tiny_model, long_move_board_list):
49 | """Test that the model prediction works."""
50 | move_list, board_list = long_move_board_list
51 | for i, board in enumerate(board_list):
52 | (out,) = tiny_model(board)
53 | policy = out["policy"]
54 | value = out["value"]
55 | lczero_game = GameState(moves=[move.uci() for move in move_list[:i]])
56 | lczero_policy, lczero_value = lczero_utils.prediction_from_backend(tiny_lczero_backend, lczero_game)
57 | assert torch.allclose(policy, lczero_policy, atol=1e-4)
58 | assert torch.allclose(value, lczero_value, atol=1e-4)
59 |
60 |
61 | class TestFlows:
62 | def test_policy_flow(self, tiny_model):
63 | """Test that the policy flow works."""
64 | policy_flow = Flow.from_model("policy", tiny_model)
65 | board = LczeroBoard()
66 | (policy,) = policy_flow(board)
67 | model_policy = tiny_model(board)["policy"][0]
68 | assert torch.allclose(policy, model_policy)
69 |
70 | def test_value_flow(self, tiny_model):
71 | """Test that the value flow works."""
72 | value_flow = Flow.from_model("value", tiny_model)
73 | board = LczeroBoard()
74 | (value,) = value_flow(board)
75 | model_value = tiny_model(board)["value"][0]
76 | assert torch.allclose(value, model_value)
77 |
78 | def test_wdl_flow(self, winner_model):
79 | """Test that the wdl flow works."""
80 | wdl_flow = Flow.from_model("wdl", winner_model)
81 | board = LczeroBoard()
82 | (wdl,) = wdl_flow(board)
83 | model_wdl = winner_model(board)["wdl"][0]
84 | assert torch.allclose(wdl, model_wdl)
85 |
86 | def test_mlh_flow(self, winner_model):
87 | """Test that the mlh flow works."""
88 | mlh_flow = Flow.from_model("mlh", winner_model)
89 | board = LczeroBoard()
90 | (mlh,) = mlh_flow(board)
91 | model_mlh = winner_model(board)["mlh"][0]
92 | assert torch.allclose(mlh, model_mlh)
93 |
94 | def test_force_value_flow_value(self, tiny_model):
95 | """Test that the force value flow works."""
96 | force_value_flow = Flow.from_model("force_value", tiny_model)
97 | board = LczeroBoard()
98 | (value,) = force_value_flow(board)
99 | model_value = tiny_model(board)["value"][0]
100 | assert torch.allclose(value, model_value)
101 |
102 | def test_force_value_flow_wdl(self, winner_model):
103 | """Test that the force value flow works."""
104 | force_value_flow = Flow.from_model("force_value", winner_model)
105 | board = LczeroBoard()
106 | (wdl,) = force_value_flow(board)
107 | model_wdl = winner_model(board)["wdl"][0]
108 | model_value = model_wdl @ torch.tensor([1.0, 0.0, -1.0], device=model_wdl.device)
109 | assert torch.allclose(wdl, model_value)
110 |
111 | def test_incompatible_flows(self, tiny_model, winner_model):
112 | """Test that the flows raise an error *
113 | when the model is incompatible.
114 | """
115 | with pytest.raises(ValueError):
116 | Flow.from_model("value", winner_model)
117 | with pytest.raises(ValueError):
118 | Flow.from_model("wdl", tiny_model)
119 | with pytest.raises(ValueError):
120 | Flow.from_model("mlh", tiny_model)
121 |
122 | with pytest.raises(ValueError):
123 | Flow._registry["value"](tiny_model)
124 |
125 |
126 | @Flow.register("test_flow")
127 | class TestFlow(Flow):
128 | """Test flow."""
129 |
130 |
131 | class TestFlowRegistry:
132 | def test_flow_registry_duplicate(self):
133 | """Test that registering a flow with an existing name raises an error."""
134 | with pytest.raises(ValueError, match="Flow .* already registered"):
135 |
136 | @Flow.register("test_flow")
137 | class DuplicateFlow(Flow):
138 | """Duplicate flow."""
139 |
140 | def test_flow_registry_missing(self, tiny_model):
141 | """Test that instantiating a non-registered flow raises an error."""
142 | with pytest.raises(KeyError, match="Flow .* not found"):
143 | Flow.from_model("non_existent_flow", tiny_model)
144 |
145 | def test_flow_type(self):
146 | """Test that the flow type is correct."""
147 | assert TestFlow._flow_type == "test_flow"
148 |
--------------------------------------------------------------------------------
/tests/unit/lenses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/unit/lenses/__init__.py
--------------------------------------------------------------------------------
/tests/unit/lenses/test_activation.py:
--------------------------------------------------------------------------------
1 | """Activation lens tests."""
2 |
3 | from lczerolens import Lens
4 | from lczerolens.lenses import ActivationLens
5 | from lczerolens.board import LczeroBoard
6 |
7 |
8 | class TestLens:
9 | def test_is_compatible(self, tiny_model):
10 | lens = Lens.from_name("activation")
11 | assert isinstance(lens, ActivationLens)
12 | assert lens.is_compatible(tiny_model)
13 |
14 | def test_analyse_board(self, tiny_model):
15 | lens = ActivationLens(pattern=r".*")
16 | board = LczeroBoard()
17 | results = lens.analyse(tiny_model, board)
18 |
19 | assert len(results) > 0
20 | for key in results:
21 | assert key.endswith("_output")
22 |
23 | def test_analyse_with_inputs(self, tiny_model):
24 | lens = ActivationLens(pattern=r".*")
25 | board = LczeroBoard()
26 | results = lens.analyse(tiny_model, board, save_inputs=True)
27 |
28 | input_keys = [k for k in results if k.endswith("_input")]
29 | output_keys = [k for k in results if k.endswith("_output")]
30 |
31 | assert len(input_keys) > 0
32 | assert len(output_keys) > 0
33 |
34 | def test_analyse_specific_modules(self, tiny_model):
35 | lens = ActivationLens(pattern=r".*conv.*")
36 | board = LczeroBoard()
37 | results = lens.analyse(tiny_model, board)
38 |
39 | assert len(results) > 0
40 | for key in results:
41 | module_name = key.replace("_output", "")
42 | assert "conv" in module_name
43 |
--------------------------------------------------------------------------------
/tests/unit/lenses/test_gradient.py:
--------------------------------------------------------------------------------
1 | """Gradient lens tests."""
2 |
3 | from lczerolens import Lens
4 | from lczerolens.lenses import GradientLens
5 | from lczerolens.board import LczeroBoard
6 |
7 |
8 | class TestLens:
9 | def test_is_compatible(self, tiny_model):
10 | lens = Lens.from_name("gradient")
11 | assert isinstance(lens, GradientLens)
12 | assert lens.is_compatible(tiny_model)
13 |
14 | def test_analyse_board(self, tiny_model):
15 | lens = GradientLens()
16 | board = LczeroBoard()
17 | results = lens.analyse(tiny_model, board)
18 |
19 | assert "input_grad" in results
20 |
21 | def test_analyse_without_input_grad(self, tiny_model):
22 | lens = GradientLens(input_requires_grad=False)
23 | board = LczeroBoard()
24 | results = lens.analyse(tiny_model, board)
25 |
26 | assert "input_grad" not in results
27 |
28 | def test_analyse_specific_modules(self, tiny_model):
29 | lens = GradientLens(pattern=r".*conv.*relu", input_requires_grad=False)
30 | board = LczeroBoard()
31 | results = lens.analyse(tiny_model, board)
32 |
33 | assert len(results) > 0
34 | for key in results:
35 | module_name = key.replace("_output_grad", "")
36 | assert "conv" in module_name
37 |
--------------------------------------------------------------------------------
/tests/unit/lenses/test_lrp.py:
--------------------------------------------------------------------------------
1 | """LRP lens tests."""
2 |
3 | from lczerolens import Lens
4 | from lczerolens.lenses import LrpLens
5 |
6 |
7 | class TestLens:
8 | def test_is_compatible(self, tiny_model):
9 | lens = Lens.from_name("lrp")
10 | assert isinstance(lens, LrpLens)
11 | assert lens.is_compatible(tiny_model)
12 |
--------------------------------------------------------------------------------
/tests/unit/play/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xmaster6y/lczerolens/fcef7f668b01781420003badba5627023b8f8700/tests/unit/play/__init__.py
--------------------------------------------------------------------------------
/tests/unit/play/test_puzzle.py:
--------------------------------------------------------------------------------
1 | """Puzzle tests."""
2 |
3 | import pytest
4 |
5 |
6 | from lczerolens.play.puzzle import Puzzle
7 | from lczerolens.play.sampling import RandomSampler, PolicySampler
8 |
9 |
10 | @pytest.fixture
11 | def opening_puzzle():
12 | return {
13 | "PuzzleId": "1",
14 | "FEN": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
15 | "Moves": "e2e4 e7e5 d2d4 d7d5",
16 | "Rating": 1000,
17 | "RatingDeviation": 100,
18 | "Popularity": 1000,
19 | "NbPlays": 1000,
20 | "Themes": "Opening",
21 | "GameUrl": "https://lichess.org/training/1",
22 | "OpeningTags": "Ruy Lopez",
23 | }
24 |
25 |
26 | @pytest.fixture
27 | def easy_puzzle():
28 | return {
29 | "PuzzleId": "00008",
30 | "FEN": "r6k/pp2r2p/4Rp1Q/3p4/8/1N1P2R1/PqP2bPP/7K b - - 0 24",
31 | "Moves": "f2g3 e6e7 b2b1 b3c1 b1c1 h6c1",
32 | "Rating": 1913,
33 | "RatingDeviation": 75,
34 | "Popularity": 94,
35 | "NbPlays": 6230,
36 | "Themes": "crushing hangingPiece long middlegame",
37 | "GameUrl": "https://lichess.org/787zsVup/black#47",
38 | "OpeningTags": None,
39 | }
40 |
41 |
42 | class TestPuzzle:
43 | def test_puzzle_creation(self, opening_puzzle):
44 | """Test puzzle creation."""
45 | puzzle = Puzzle.from_dict(opening_puzzle)
46 | assert len(puzzle) == 3
47 | assert puzzle.rating == 1000
48 | assert puzzle.rating_deviation == 100
49 | assert puzzle.popularity == 1000
50 | assert puzzle.nb_plays == 1000
51 | assert puzzle.themes == ["Opening"]
52 | assert puzzle.game_url == "https://lichess.org/training/1"
53 | assert puzzle.opening_tags == ["Ruy", "Lopez"]
54 |
55 | def test_puzzle_use(self, opening_puzzle):
56 | """Test puzzle use."""
57 | puzzle = Puzzle.from_dict(opening_puzzle)
58 | assert len(list(puzzle.board_move_generator())) == 2
59 | assert len(list(puzzle.board_move_generator(all_moves=True))) == 3
60 |
61 |
62 | class TestRandomSampler:
63 | def test_puzzle_evaluation(self, opening_puzzle):
64 | """Test puzzle evaluation."""
65 | puzzle = Puzzle.from_dict(opening_puzzle)
66 | sampler = RandomSampler()
67 | metrics = puzzle.evaluate(sampler)
68 | assert metrics["score"] != 1.0
69 | assert abs(metrics["perplexity"] - (20.0 * 30) ** 0.5) < 1e-3
70 |
71 | def test_puzzle_multiple_evaluation_len(self, easy_puzzle):
72 | """Test puzzle evaluation."""
73 | puzzles = [Puzzle.from_dict(easy_puzzle) for _ in range(10)]
74 | sampler = RandomSampler()
75 | all_results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=True, compute_metrics=False)
76 | assert len(list(all_results)) == 10 * 5
77 | results = Puzzle.evaluate_multiple(puzzles, sampler, compute_metrics=False)
78 | assert len(list(results)) == 10 * 3
79 |
80 | def test_puzzle_multiple_evaluation(self, easy_puzzle):
81 | """Test puzzle evaluation."""
82 | puzzles = [Puzzle.from_dict(easy_puzzle) for _ in range(10)]
83 | sampler = RandomSampler()
84 | all_results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=True)
85 | assert len(list(all_results)) == 10
86 | results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=False)
87 | assert len(list(results)) == 10
88 |
89 | def test_puzzle_multiple_evaluation_batch_size(self, easy_puzzle):
90 | """Test puzzle evaluation."""
91 | puzzles = [Puzzle.from_dict(easy_puzzle) for _ in range(10)]
92 | sampler = RandomSampler()
93 | all_results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=True, batch_size=5)
94 | assert len(list(all_results)) == 10
95 | results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=False, batch_size=5)
96 | assert len(list(results)) == 10
97 |
98 |
99 | class TestPolicySampler:
100 | def test_puzzle_evaluation(self, easy_puzzle, winner_model):
101 | """Test puzzle evaluation."""
102 | puzzle = Puzzle.from_dict(easy_puzzle)
103 | sampler = PolicySampler(model=winner_model, use_argmax=True)
104 | metrics = puzzle.evaluate(sampler, all_moves=True)
105 | assert metrics["score"] > 0.0
106 | assert metrics["perplexity"] < 15.0
107 |
108 | def test_puzzle_multiple_evaluation(self, easy_puzzle, tiny_model):
109 | """Test puzzle evaluation."""
110 | puzzles = [Puzzle.from_dict(easy_puzzle) for _ in range(10)]
111 | sampler = PolicySampler(model=tiny_model, use_argmax=False)
112 | all_results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=True)
113 | assert len(list(all_results)) == 10
114 | results = Puzzle.evaluate_multiple(puzzles, sampler, all_moves=False)
115 | assert len(list(results)) == 10
116 |
--------------------------------------------------------------------------------
/tests/unit/play/test_sampling.py:
--------------------------------------------------------------------------------
1 | """Sampling tests."""
2 |
3 | from lczerolens.play.sampling import ModelSampler, SelfPlay, PolicySampler, RandomSampler
4 | from lczerolens.board import LczeroBoard
5 |
6 |
7 | class TestRandomSampler:
8 | def test_get_utilities(self):
9 | """Test get_utilities method."""
10 | board = LczeroBoard()
11 | sampler = RandomSampler()
12 | utility, _, _ = next(iter(sampler.get_utilities([board, board])))
13 | assert utility.shape[0] == 20
14 |
15 |
16 | class TestModelSampler:
17 | def test_get_utilities_tiny(self, tiny_model):
18 | """Test get_utilities method."""
19 | board = LczeroBoard()
20 | sampler = ModelSampler(tiny_model, use_argmax=False)
21 | utility, _, _ = next(iter(sampler.get_utilities([board, board])))
22 | assert utility.shape[0] == 20
23 |
24 | def test_get_utilities_winner(self, winner_model):
25 | """Test get_utilities method."""
26 | board = LczeroBoard()
27 | sampler = ModelSampler(winner_model, use_argmax=False)
28 | utility, _, _ = next(iter(sampler.get_utilities([board, board])))
29 | assert utility.shape[0] == 20
30 |
31 | def test_policy_sampler_tiny(self, tiny_model):
32 | """Test policy_sampler method."""
33 | board = LczeroBoard()
34 | sampler = PolicySampler(tiny_model, use_argmax=False)
35 | utility, _, _ = next(iter(sampler.get_utilities([board, board])))
36 | assert utility.shape[0] == 20
37 |
38 |
39 | class TestSelfPlay:
40 | def test_play(self, tiny_model, winner_model):
41 | """Test play method."""
42 | board = LczeroBoard()
43 | white = ModelSampler(tiny_model, use_argmax=False)
44 | black = ModelSampler(winner_model, use_argmax=False)
45 | self_play = SelfPlay(white=white, black=black)
46 | logs = []
47 |
48 | def report_fn(log, to_play):
49 | logs.append((log, to_play))
50 |
51 | game, board = self_play.play(board=board, max_moves=10, report_fn=report_fn)
52 |
53 | assert len(game) == len(logs) == 10
54 |
55 |
56 | class TestBatchedPolicySampler:
57 | def test_batched_policy_sampler_ag(self, tiny_model):
58 | """Test batched_policy_sampler method."""
59 | boards = [LczeroBoard() for _ in range(10)]
60 |
61 | sampler_ag = PolicySampler(tiny_model, use_argmax=True)
62 | moves = sampler_ag.get_next_moves(boards)
63 | assert len(list(moves)) == 10
64 | assert all([move == moves[0] for move in moves])
65 |
66 | def test_batched_policy_sampler_no_ag(self, tiny_model):
67 | """Test batched_policy_sampler method."""
68 | boards = [LczeroBoard() for _ in range(10)]
69 |
70 | sampler_no_ag = PolicySampler(tiny_model, use_argmax=False)
71 | moves = sampler_no_ag.get_next_moves(boards)
72 | assert len(list(moves)) == 10
73 |
74 | def test_batched_policy_sampler_no_ag_sub(self, tiny_model):
75 | """Test batched_policy_sampler method."""
76 | boards = [LczeroBoard() for _ in range(10)]
77 |
78 | sampler_no_ag = PolicySampler(tiny_model, use_argmax=False, use_suboptimal=True)
79 | moves = sampler_no_ag.get_next_moves(boards)
80 | assert len(list(moves)) == 10
81 |
--------------------------------------------------------------------------------