├── .github └── workflows │ └── unittests.yml ├── .gitignore ├── .pylintrc ├── .readthedocs.yaml ├── .vscode ├── extensions.json └── settings.json ├── AUTHORS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── .gitignore ├── Makefile ├── _autogen_root.rst ├── _include │ └── _glue_figures.ipynb ├── _static │ ├── custom.css │ ├── custom.js │ └── readme_teaser.png ├── _templates │ ├── pzbase.rst │ ├── pzclass.rst │ ├── pzdata.rst │ ├── pzmodule.rst │ └── pzmodule_full.rst ├── api │ ├── penzai.deprecated.v1.rst │ ├── pz.nn.rst │ ├── pz.rst │ ├── pz.ts.rst │ └── treescope.rst ├── conf.py ├── ext │ ├── nb_output_cell_to_iframe.py │ └── pz_alias_rewrite.py ├── guides │ ├── howto_reference.md │ └── v2_differences.md ├── index.rst └── scripts │ └── readthedocs_fetch_notebook_outputs.sh ├── notebooks ├── how_to_think_in_penzai.ipynb ├── induction_heads.ipynb ├── induction_heads_2B.ipynb ├── jitting_and_sharding.ipynb ├── lora_from_scratch.ipynb ├── named_axes.ipynb └── selectors.ipynb ├── penzai ├── __init__.py ├── core │ ├── __init__.py │ ├── _treescope_handlers │ │ ├── __init__.py │ │ ├── named_axes_handlers.py │ │ ├── selection_rendering.py │ │ ├── shapecheck_handlers.py │ │ └── struct_handler.py │ ├── auto_order_types.py │ ├── named_axes.py │ ├── partitioning.py │ ├── random_stream.py │ ├── selectors.py │ ├── shapecheck.py │ ├── struct.py │ ├── syntactic_sugar.py │ ├── tree_util.py │ └── variables.py ├── deprecated │ ├── __init__.py │ └── v1 │ │ ├── __init__.py │ │ ├── core │ │ ├── __init__.py │ │ ├── _treescope_handlers │ │ │ ├── __init__.py │ │ │ └── layer_handler.py │ │ ├── layer.py │ │ └── random_stream.py │ │ ├── data_effects │ │ ├── __init__.py │ │ ├── _treescope_handlers.py │ │ ├── effect_base.py │ │ ├── local_state.py │ │ ├── random.py │ │ ├── side_input.py │ │ └── side_output.py │ │ ├── example_models │ │ ├── __init__.py │ │ ├── gemma │ │ │ ├── __init__.py │ │ │ ├── model_core.py │ │ │ ├── sampling_mode.py │ │ │ └── simple_decoding_loop.py │ │ └── simple_mlp.py │ │ ├── nn │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── basic_ops.py │ │ ├── combinators.py │ │ ├── dropout.py │ │ ├── embeddings.py │ │ ├── grouping.py │ │ ├── linear_and_affine.py │ │ ├── parameters.py │ │ └── standardization.py │ │ ├── pz │ │ ├── __init__.py │ │ ├── de.py │ │ └── nn.py │ │ └── toolshed │ │ ├── annotate_shapes.py │ │ ├── basic_training.py │ │ ├── check_layers_by_tracing.py │ │ ├── interleave_intermediates.py │ │ ├── isolate_submodel.py │ │ ├── jit_wrapper.py │ │ ├── lora.py │ │ ├── model_rewiring.py │ │ ├── sharding_util.py │ │ └── unflaxify.py ├── experimental │ ├── __init__.py │ └── v2 │ │ └── __init__.py ├── models │ ├── __init__.py │ ├── simple_mlp.py │ └── transformer │ │ ├── __init__.py │ │ ├── model_parts.py │ │ ├── sampling_mode.py │ │ ├── simple_decoding_loop.py │ │ └── variants │ │ ├── __init__.py │ │ ├── gemma.py │ │ ├── gpt_neox.py │ │ ├── llama.py │ │ ├── llamalike_common.py │ │ └── mistral.py ├── nn │ ├── __init__.py │ ├── _treescope_handlers │ │ ├── __init__.py │ │ └── layer_handler.py │ ├── attention.py │ ├── basic_ops.py │ ├── combinators.py │ ├── dropout.py │ ├── embeddings.py │ ├── grouping.py │ ├── layer.py │ ├── layer_stack.py │ ├── linear_and_affine.py │ ├── parameters.py │ └── standardization.py ├── pz │ ├── __init__.py │ ├── nn.py │ └── ts.py ├── toolshed │ ├── __init__.py │ ├── auto_nmap.py │ ├── basic_training.py │ ├── gradient_checkpointing.py │ ├── isolate_submodel.py │ ├── jit_wrapper.py │ ├── lora.py │ ├── model_rewiring.py │ ├── patch_ipdb.py │ ├── save_intermediates.py │ ├── sharding_util.py │ ├── token_visualization.py │ └── unflaxify.py └── treescope │ ├── __init__.py │ ├── _compatibility_setup.py │ ├── formatting_util.py │ └── repr_lib.py ├── pyproject.toml ├── run_tests.py ├── tests ├── __init__.py ├── core │ ├── auto_order_types_test.py │ ├── misc_util_test.py │ ├── named_axes_test.py │ ├── partitioning_test.py │ ├── selectors_test.py │ ├── shapecheck_test.py │ ├── struct_pytree_dataclass_test.py │ └── variables_test.py ├── deprecated │ ├── __init__.py │ └── v1 │ │ ├── __init__.py │ │ ├── data_effects │ │ ├── __init__.py │ │ ├── local_state_test.py │ │ ├── random_test.py │ │ ├── side_input_test.py │ │ └── side_output_test.py │ │ ├── example_models │ │ ├── __init__.py │ │ ├── gemma_test.py │ │ └── simple_mlp_test.py │ │ ├── misc_util_test.py │ │ ├── nn │ │ ├── __init__.py │ │ ├── basic_ops_test.py │ │ ├── embedding_test.py │ │ ├── grouping_test.py │ │ ├── linear_and_affine_test.py │ │ ├── parameters_test.py │ │ └── standardization_test.py │ │ ├── shapecheck_layer_test.py │ │ └── toolshed │ │ ├── __init__.py │ │ ├── isolate_submodel_test.py │ │ ├── lora_test.py │ │ ├── model_rewiring_test.py │ │ ├── sharding_util_test.py │ │ └── unflaxify_test.py ├── models │ ├── __init__.py │ ├── simple_mlp_test.py │ ├── transformer_consistency_test.py │ └── transformer_llamalike_test.py ├── nn │ ├── __init__.py │ ├── basic_ops_test.py │ ├── embedding_test.py │ ├── grouping_test.py │ ├── layer_stack_test.py │ ├── layer_test.py │ ├── linear_and_affine_test.py │ ├── parameters_test.py │ └── standardization_test.py ├── toolshed │ ├── __init__.py │ ├── auto_nmap_test.py │ ├── gradient_checkpointing_test.py │ ├── isolate_submodel_test.py │ ├── jit_wrapper_test.py │ ├── lora_test.py │ ├── model_rewiring_test.py │ ├── save_intermediates_test.py │ └── unflaxify_test.py └── treescope │ ├── fixtures │ ├── __init__.py │ └── treescope_examples_fixture.py │ ├── ndarray_adapters_test.py │ └── renderer_test.py └── uv.lock /.github/workflows/unittests.yml: -------------------------------------------------------------------------------- 1 | name: Unittests 2 | 3 | on: 4 | # Allow to trigger the workflow manually (e.g. when deps changes) 5 | workflow_dispatch: 6 | # Run on pushes to main 7 | push: 8 | branches: 9 | - main 10 | # Run on pull requests to main (including test branches) 11 | pull_request: 12 | branches: 13 | - main 14 | 15 | jobs: 16 | unittest-job: 17 | runs-on: ubuntu-latest 18 | timeout-minutes: 30 19 | 20 | concurrency: 21 | group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }}-${{ github.head_ref || 'none' }} 22 | cancel-in-progress: true 23 | 24 | steps: 25 | - uses: actions/checkout@v3 26 | 27 | # Install deps 28 | - uses: actions/setup-python@v4 29 | with: 30 | python-version: 3.10.14 31 | # Uncomment to cache of pip dependencies (if tests too slow) 32 | # cache: pip 33 | # cache-dependency-path: '**/pyproject.toml' 34 | 35 | - uses: astral-sh/setup-uv@v3 36 | with: 37 | version: "0.4.17" 38 | 39 | - name: Install dependencies 40 | run: | 41 | uv sync --locked --extra extras --extra dev 42 | 43 | # Check formatting 44 | - name: Check pyink formatting 45 | run: uv run pyink penzai tests --check 46 | 47 | - name: Run pylint 48 | run: uv run pylint penzai 49 | 50 | # Run tests 51 | - name: Run tests 52 | run: uv run python run_tests.py 53 | 54 | # Run typechecker 55 | - name: Run pytype 56 | run: uv run pytype --jobs auto penzai 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | 8 | # Poetry, setuptools, PyPI distribution artifacts. 9 | /*.egg-info 10 | .eggs/ 11 | build/ 12 | dist/ 13 | poetry.lock 14 | 15 | # Tests 16 | .pytest_cache/ 17 | 18 | # Type checking 19 | .pytype/ 20 | 21 | # Virtual env 22 | .venv/ 23 | 24 | # Other 25 | *.DS_Store 26 | 27 | # PyCharm 28 | .idea 29 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.11" 10 | commands: 11 | # Fetch precomputed notebook outputs (if any) 12 | - bash docs/scripts/readthedocs_fetch_notebook_outputs.sh 13 | # Install and build using uv 14 | - asdf plugin add uv 15 | - asdf install uv latest 16 | - asdf global uv latest 17 | - uv sync --extra docs --frozen 18 | - uv pip install readthedocs-sphinx-ext 19 | - cd docs && uv run python -m sphinx -T -b html -d docs/_build/doctrees -D language=en . $READTHEDOCS_OUTPUT/html 20 | 21 | sphinx: 22 | builder: html 23 | configuration: docs/conf.py 24 | fail_on_warning: false 25 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.black-formatter", 4 | "ms-python.pylint" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.insertFinalNewline": true, 3 | "files.trimFinalNewlines": true, 4 | "files.trimTrailingWhitespace": true, 5 | "files.associations": { 6 | ".pylintrc": "ini" 7 | }, 8 | "files.watcherExclude": { 9 | "**/.git/**": true 10 | }, 11 | "files.exclude": { 12 | "**/__pycache__": true, 13 | "**/.pytest_cache": true, 14 | "**/*.egg-info": true 15 | }, 16 | "python.testing.unittestEnabled": false, 17 | "python.testing.nosetestsEnabled": false, 18 | "python.testing.pytestEnabled": true, 19 | "python.linting.pylintUseMinimalCheckers": false, 20 | "[python]": { 21 | "editor.rulers": [80], 22 | "editor.tabSize": 2, 23 | "editor.formatOnSave": true, 24 | "editor.detectIndentation": false, 25 | "editor.defaultFormatter": "ms-python.black-formatter", 26 | }, 27 | "black-formatter.path": ["uvx", "--python=>=3.9,!=3.12.5", "pyink"], 28 | "pylint.enabled": true, 29 | } 30 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the list of Penzai's significant contributors. 2 | # 3 | # This does not necessarily list everyone who has contributed code, 4 | # especially since many employees of one corporation may be contributing. 5 | # To see the full list of contributors, see the revision history in 6 | # source control. 7 | DeepMind Technologies Limited 8 | Daniel D. Johnson 9 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | **/_autosummary 2 | _build 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | # You can set these variables from the command line. 4 | SPHINXOPTS = 5 | SPHINXBUILD = sphinx-build 6 | SOURCEDIR = . 7 | BUILDDIR = _build 8 | 9 | # Put it first so that "make" without argument is like "make help". 10 | help: 11 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 12 | 13 | .PHONY: help Makefile 14 | 15 | # Catch-all target: route all unknown targets to Sphinx using the new 16 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 17 | %: Makefile 18 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 19 | 20 | clean: 21 | rm -rf $(BUILDDIR)/* _collections _autosummary modules 22 | -------------------------------------------------------------------------------- /docs/_autogen_root.rst: -------------------------------------------------------------------------------- 1 | .. 2 | This file is not actually referenced in the docs, but it is the entry point 3 | that generates automatic summaries for Penzai's modules. `index.rst` points 4 | directly at the autosummary files generated while processing this one. 5 | We also reference other orphaned modules here so that Sphinx doesn't give 6 | warnings about them. 7 | 8 | :orphan: 9 | 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | :template: pzmodule_full.rst 13 | :recursive: 14 | 15 | penzai.core 16 | penzai.nn 17 | penzai.models 18 | penzai.toolshed 19 | 20 | penzai.deprecated.v1.core 21 | penzai.deprecated.v1.nn 22 | penzai.deprecated.v1.data_effects 23 | penzai.deprecated.v1.example_models 24 | penzai.deprecated.v1.toolshed 25 | 26 | .. toctree:: 27 | :hidden: 28 | 29 | notebooks/induction_heads_2B 30 | _include/_glue_figures 31 | -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | .cell_output .output.text_html, .glued-cell-output .output.text_html { 2 | background-color: white; 3 | line-height: 1.24; 4 | } 5 | 6 | iframe.cell_output_frame { 7 | width: 100%; 8 | border: 1px solid #7f7f7f2e; 9 | width: 100%; 10 | min-width: 100%; 11 | max-height: 50em; 12 | overflow: scroll; 13 | background-color: white; 14 | position: relative; 15 | } 16 | 17 | html { 18 | scroll-behavior: auto !important; 19 | } 20 | 21 | .title.logo__title { 22 | font-family:monospace; 23 | } 24 | .title.logo__title::before { 25 | color: #cccccc; 26 | position: relative; 27 | left: -0.75ch; 28 | content: '▼'; 29 | } 30 | 31 | html[data-theme=light], html[data-theme=dark] { 32 | --pst-color-inline-code: var(--pst-color-text-base) !important; 33 | } 34 | 35 | @media (min-width: 992px) { 36 | .bd-page-width { 37 | max-width: 100% !important; 38 | } 39 | .bd-sidebar-primary { 40 | flex-basis: calc(min(20%, 25em)) !important; 41 | } 42 | } 43 | 44 | html[data-theme="light"] .highlight .gd { 45 | color: #A00000 !important; 46 | } 47 | html[data-theme="light"] .highlight .gi { 48 | color: #008400 !important; 49 | } 50 | html[data-theme="dark"] .highlight .gd { 51 | color: #fe7b6e !important; 52 | } 53 | html[data-theme="dark"] .highlight .gi { 54 | color: #56d364 !important; 55 | } 56 | -------------------------------------------------------------------------------- /docs/_static/custom.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2024 The Penzai Authors. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | document.addEventListener('DOMContentLoaded', () => { 18 | // Move cell outputs into iframes to avoid JS/CSS conflicts and improve 19 | // responsiveness of the main page. 20 | const frameTpls = document.querySelectorAll('template.cell_output_frame_src'); 21 | for (let frameTpl of frameTpls) { 22 | let frame = document.createElement('iframe'); 23 | frame.classList.add('cell_output_frame'); 24 | frame.sandbox = [ 25 | 'allow-downloads', 'allow-forms', 'allow-pointer-lock', 'allow-popups', 26 | 'allow-same-origin', 'allow-scripts', 27 | 'allow-storage-access-by-user-activation', 28 | 'allow-popups-to-escape-sandbox' 29 | ].join(' '); 30 | frame.addEventListener("load", () => { 31 | frame.contentDocument.body.appendChild(frameTpl.content.cloneNode(true)); 32 | frame.contentDocument.body.style.height = 'fit-content'; 33 | frame.contentDocument.body.style.margin = '0'; 34 | frame.contentDocument.body.style.padding = '0.5em 1ch 0.5em 1ch'; 35 | const observer = new ResizeObserver(() => { 36 | const frameBounds = frame.getBoundingClientRect(); 37 | const bounds = frame.contentDocument.body.getBoundingClientRect(); 38 | if (frame.contentDocument.body.scrollWidth > frameBounds.width) { 39 | // Make room for the scrollbar. 40 | frame.style.height = `calc(1em + ${bounds.height}px)`; 41 | } else { 42 | frame.style.height = `${bounds.height}px`; 43 | } 44 | }); 45 | observer.observe(frame.contentDocument.body); 46 | }); 47 | frame.src = "about:blank"; 48 | frameTpl.parentNode.replaceChild(frame, frameTpl); 49 | } 50 | 51 | // Add zero-width spaces to sidebar identifiers to improve word breaking. 52 | const sidebarNodes = document.querySelectorAll( 53 | '.bd-docs-nav .reference, .bd-docs-nav .reference *'); 54 | for (let parent of sidebarNodes) { 55 | for (let elt of parent.childNodes) { 56 | if (elt instanceof Text) { 57 | elt.textContent = 58 | elt.textContent 59 | .split(/(?=_)|(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])/g) 60 | .join('\u200b'); 61 | } 62 | } 63 | } 64 | }); 65 | -------------------------------------------------------------------------------- /docs/_static/readme_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/penzai/5533d8656506ebc09bf9d328a94cdc460e6c50e5/docs/_static/readme_teaser.png -------------------------------------------------------------------------------- /docs/_templates/pzbase.rst: -------------------------------------------------------------------------------- 1 | {{ name | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. auto{{ objtype }}:: {{ objname }} 6 | -------------------------------------------------------------------------------- /docs/_templates/pzclass.rst: -------------------------------------------------------------------------------- 1 | {{ name | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :special-members: 8 | :show-inheritance: 9 | 10 | 11 | {% set attr_ns = namespace(inherited=false) %} 12 | {% for item in attributes %} 13 | {% if item in inherited_members %} 14 | {% set attr_ns.inherited = true %} 15 | {% endif %} 16 | {%- endfor %} 17 | 18 | {% if attr_ns.inherited %} 19 | .. rubric:: {{ _('Inherited Attributes') }} 20 | .. autosummary:: 21 | {% for item in attributes %} 22 | {% if item in inherited_members %} 23 | ~{{ name }}.{{ item }} 24 | {% endif %} 25 | {%- endfor %} 26 | {% endif %} 27 | 28 | {% set method_ns = namespace(methods_ext=methods, own=false, inherited=false) %} 29 | {% for special in ["__call__", "__enter__", "__exit__"] %} 30 | {% if special in members %} 31 | {% set method_ns.methods_ext = method_ns.methods_ext + [special] %} 32 | {% endif %} 33 | {%- endfor %} 34 | {% for item in method_ns.methods_ext %} 35 | {% if item in inherited_members %} 36 | {% set method_ns.inherited = true %} 37 | {% else %} 38 | {% set method_ns.own = true %} 39 | {% endif %} 40 | {%- endfor %} 41 | 42 | {% if method_ns.own %} 43 | .. rubric:: {{ _('Methods') }} 44 | .. autosummary:: 45 | {% for item in method_ns.methods_ext %} 46 | {% if item not in inherited_members %} 47 | ~{{ name }}.{{ item }} 48 | {% endif %} 49 | {%- endfor %} 50 | {% endif %} 51 | 52 | {% if attributes %} 53 | .. rubric:: {{ _('Attributes') }} 54 | .. autosummary:: 55 | {% for item in attributes %} 56 | ~{{ name }}.{{ item }} 57 | {%- endfor %} 58 | {% endif %} 59 | 60 | {% if method_ns.inherited %} 61 | 62 | .. rubric:: {{ _('Inherited Methods') }} 63 | .. raw:: html 64 | 65 |
66 | (expand to view inherited methods) 67 | 68 | .. autosummary:: 69 | {% for item in method_ns.methods_ext %} 70 | {% if item in inherited_members %} 71 | ~{{ name }}.{{ item }} 72 | {% endif %} 73 | {%- endfor %} 74 | 75 | .. raw:: html 76 | 77 |
78 | 79 | {% endif %} 80 | -------------------------------------------------------------------------------- /docs/_templates/pzdata.rst: -------------------------------------------------------------------------------- 1 | {{ name | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autodata:: {{ objname }} 6 | :no-value: 7 | -------------------------------------------------------------------------------- /docs/_templates/pzmodule.rst: -------------------------------------------------------------------------------- 1 | {{ name | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block classes %} 6 | {% if classes %} 7 | .. rubric:: {{ _('Classes') }} 8 | 9 | .. autosummary:: 10 | :toctree: leaf 11 | :template: pzclass.rst 12 | {% for item in classes %} 13 | {{ item }} 14 | {%- endfor %} 15 | {% endif %} 16 | {% endblock %} 17 | 18 | {% block functions %} 19 | {% if functions %} 20 | .. rubric:: {{ _('Functions') }} 21 | 22 | .. autosummary:: 23 | :toctree: leaf 24 | :template: pzbase.rst 25 | {% for item in functions %} 26 | {{ item }} 27 | {%- endfor %} 28 | {% endif %} 29 | {% endblock %} 30 | 31 | {% block attributes %} 32 | {% if attributes %} 33 | .. rubric:: {{ _('Module Attributes') }} 34 | 35 | .. autosummary:: 36 | :toctree: leaf 37 | :template: pzdata.rst 38 | {% for item in attributes %} 39 | {{ item }} 40 | {%- endfor %} 41 | {% endif %} 42 | {% endblock %} 43 | 44 | {% block exceptions %} 45 | {% if exceptions %} 46 | .. rubric:: {{ _('Exceptions') }} 47 | 48 | .. autosummary:: 49 | :toctree: leaf 50 | :template: pzbase.rst 51 | {% for item in exceptions %} 52 | {{ item }} 53 | {%- endfor %} 54 | {% endif %} 55 | {% endblock %} 56 | 57 | {% block modules %} 58 | {% if modules %} 59 | .. rubric:: Submodules 60 | 61 | .. autosummary:: 62 | :toctree: 63 | :template: pzmodule.rst 64 | :recursive: 65 | {% for item in modules %} 66 | {{ item }} 67 | {%- endfor %} 68 | {% endif %} 69 | {% endblock %} 70 | -------------------------------------------------------------------------------- /docs/_templates/pzmodule_full.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block classes %} 6 | {% if classes %} 7 | .. rubric:: {{ _('Classes') }} 8 | 9 | .. autosummary:: 10 | :toctree: leaf 11 | :template: pzclass.rst 12 | {% for item in classes %} 13 | {{ item }} 14 | {%- endfor %} 15 | {% endif %} 16 | {% endblock %} 17 | 18 | {% block functions %} 19 | {% if functions %} 20 | .. rubric:: {{ _('Functions') }} 21 | 22 | .. autosummary:: 23 | :toctree: leaf 24 | :template: pzbase.rst 25 | {% for item in functions %} 26 | {{ item }} 27 | {%- endfor %} 28 | {% endif %} 29 | {% endblock %} 30 | 31 | {% block attributes %} 32 | {% if attributes %} 33 | .. rubric:: {{ _('Module Attributes') }} 34 | 35 | .. autosummary:: 36 | :toctree: leaf 37 | :template: pzdata.rst 38 | {% for item in attributes %} 39 | {{ item }} 40 | {%- endfor %} 41 | {% endif %} 42 | {% endblock %} 43 | 44 | {% block exceptions %} 45 | {% if exceptions %} 46 | .. rubric:: {{ _('Exceptions') }} 47 | 48 | .. autosummary:: 49 | :toctree: leaf 50 | :template: pzbase.rst 51 | {% for item in exceptions %} 52 | {{ item }} 53 | {%- endfor %} 54 | {% endif %} 55 | {% endblock %} 56 | 57 | {% block modules %} 58 | {% if modules %} 59 | .. rubric:: Submodules 60 | 61 | .. autosummary:: 62 | :toctree: 63 | :template: pzmodule.rst 64 | :recursive: 65 | {% for item in modules %} 66 | {{ item }} 67 | {%- endfor %} 68 | {% endif %} 69 | {% endblock %} 70 | -------------------------------------------------------------------------------- /docs/api/pz.nn.rst: -------------------------------------------------------------------------------- 1 | ``pz.nn``: Neural network alias namespace 2 | ========================================= 3 | 4 | .. module:: penzai.pz.nn 5 | .. currentmodule:: penzai 6 | 7 | 8 | Layers and Parameter Utilities 9 | ------------------------------ 10 | 11 | .. autosummary:: 12 | pz.nn.Layer 13 | pz.nn.ParameterLike 14 | pz.nn.derive_param_key 15 | pz.nn.make_parameter 16 | pz.nn.assert_no_parameter_slots 17 | 18 | 19 | Basic Combinators 20 | ----------------- 21 | 22 | .. autosummary:: 23 | 24 | pz.nn.Sequential 25 | pz.nn.NamedGroup 26 | pz.nn.CheckedSequential 27 | pz.nn.Residual 28 | pz.nn.BranchAndAddTogether 29 | pz.nn.BranchAndMultiplyTogether 30 | pz.nn.inline_anonymous_sequentials 31 | pz.nn.inline_groups 32 | pz.nn.is_anonymous_sequential 33 | pz.nn.is_sequential_or_named 34 | 35 | Basic Operations 36 | ---------------- 37 | 38 | .. autosummary:: 39 | 40 | pz.nn.Elementwise 41 | pz.nn.Softmax 42 | pz.nn.CheckStructure 43 | pz.nn.Identity 44 | pz.nn.CastToDType 45 | pz.nn.TanhSoftCap 46 | 47 | Linear and Affine Layers 48 | ------------------------ 49 | 50 | .. autosummary:: 51 | 52 | pz.nn.Linear 53 | pz.nn.RenameAxes 54 | pz.nn.AddBias 55 | pz.nn.Affine 56 | pz.nn.ConstantRescale 57 | pz.nn.NamedEinsum 58 | pz.nn.LinearInPlace 59 | pz.nn.LinearOperatorWeightInitializer 60 | pz.nn.contract 61 | pz.nn.variance_scaling_initializer 62 | pz.nn.xavier_normal_initializer 63 | pz.nn.xavier_uniform_initializer 64 | pz.nn.constant_initializer 65 | pz.nn.zero_initializer 66 | 67 | 68 | Standardization 69 | --------------- 70 | 71 | .. autosummary:: 72 | 73 | pz.nn.LayerNorm 74 | pz.nn.Standardize 75 | pz.nn.RMSLayerNorm 76 | pz.nn.RMSStandardize 77 | 78 | 79 | Dropout 80 | ------- 81 | 82 | .. autosummary:: 83 | 84 | pz.nn.StochasticDropout 85 | pz.nn.DisabledDropout 86 | pz.nn.maybe_dropout 87 | 88 | 89 | Language Modeling 90 | ----------------- 91 | 92 | .. autosummary:: 93 | pz.nn.Attention 94 | pz.nn.KVCachingAttention 95 | pz.nn.ApplyExplicitAttentionMask 96 | pz.nn.ApplyCausalAttentionMask 97 | pz.nn.ApplyCausalSlidingWindowAttentionMask 98 | pz.nn.EmbeddingTable 99 | pz.nn.EmbeddingLookup 100 | pz.nn.EmbeddingDecode 101 | pz.nn.ApplyRoPE 102 | 103 | 104 | Layer Stacks 105 | ------------ 106 | 107 | .. autosummary:: 108 | pz.nn.LayerStack 109 | pz.nn.LayerStackVarBehavior 110 | pz.nn.layerstack_axes_from_keypath 111 | pz.nn.LayerStackGetAttrKey 112 | -------------------------------------------------------------------------------- /docs/api/pz.rst: -------------------------------------------------------------------------------- 1 | ``pz``: Penzai's alias namespace 2 | ================================ 3 | 4 | .. module:: penzai.pz 5 | .. currentmodule:: penzai 6 | 7 | 8 | .. toctree:: 9 | :hidden: 10 | 11 | pz.nn 12 | pz.ts 13 | 14 | 15 | Structs 16 | ------- 17 | 18 | Most objects in Penzai models are subclasses of ``pz.Struct`` and 19 | decorated with ``pz.pytree_dataclass``, which makes them into frozen Python 20 | dataclasses that are also JAX PyTrees. 21 | 22 | .. autosummary:: 23 | 24 | pz.pytree_dataclass 25 | pz.Struct 26 | 27 | 28 | PyTree Manipulation 29 | ------------------- 30 | 31 | Penzai provides a number of utilities to make targeted modifications to PyTrees. 32 | Since Penzai models are PyTrees, you can use them to insert new layers into 33 | models, or modify the configuration of existing layers. 34 | 35 | .. autosummary:: 36 | pz.select 37 | pz.Selection 38 | pz.combine 39 | pz.NotInThisPartition 40 | pz.pretty_keystr 41 | 42 | 43 | Named Axes 44 | ---------- 45 | 46 | ``pz.nx`` is an alias for :obj:`penzai.core.named_axes`, which contains 47 | Penzai's named axis system. Some commonly-used attributes on ``pz.nx``: 48 | 49 | .. autosummary:: 50 | pz.nx.NamedArray 51 | pz.nx.nmap 52 | pz.nx.wrap 53 | 54 | See :obj:`penzai.core.named_axes` for documentation of all of the methods and 55 | classes accessible through the ``pz.nx`` alias. 56 | 57 | To simplify slicing named axes, Penzai also provides a helper object: 58 | 59 | .. autosummary:: 60 | pz.slice 61 | 62 | 63 | Parameters and State Variables 64 | ------------------------------ 65 | 66 | Penzai handles mutable state by embedding stateful parameters and variables into 67 | JAX pytrees. It provides a number of utilities to manipulate these stateful 68 | components and support passing them across JAX transformation boundaries. 69 | 70 | .. autosummary:: 71 | 72 | pz.Parameter 73 | pz.ParameterValue 74 | pz.ParameterSlot 75 | pz.StateVariable 76 | pz.StateVariableValue 77 | pz.StateVariableSlot 78 | pz.unbind_variables 79 | pz.bind_variables 80 | pz.freeze_variables 81 | pz.variable_jit 82 | pz.unbind_params 83 | pz.freeze_params 84 | pz.unbind_state_vars 85 | pz.freeze_state_vars 86 | 87 | 88 | .. autosummary:: 89 | 90 | pz.VariableConflictError 91 | pz.UnboundVariableError 92 | pz.VariableLabel 93 | pz.AbstractVariable 94 | pz.AbstractVariableValue 95 | pz.AbstractVariableSlot 96 | pz.AutoStateVarLabel 97 | pz.ScopedStateVarLabel 98 | pz.scoped_auto_state_var_labels 99 | pz.RandomStream 100 | 101 | 102 | Neural Networks 103 | --------------- 104 | 105 | :obj:`pz.nn` is an alias namespace for Penzai's declarative neural network 106 | system, which uses a combinator-based design to expose all of your model's 107 | operations as nodes in your model PyTree. :obj:`pz.nn` re-exports layers from 108 | submodules of :obj:`penzai.nn` in a single convenient namespace. 109 | 110 | See the documentation for :obj:`pz.nn` to view all of the 111 | methods and classes accessible through this alias namespace. 112 | 113 | 114 | Shape-Checking 115 | -------------- 116 | 117 | ``pz.chk`` is an alias for :obj:`penzai.core.shapecheck`, which contains 118 | utilities for checking the shapes of PyTrees of positional and named arrays. 119 | Some commonly-used attributes on ``pz.chk``: 120 | 121 | .. autosummary:: 122 | pz.chk.ArraySpec 123 | pz.chk.var 124 | pz.chk.vars_for_axes 125 | 126 | See :obj:`penzai.core.shapecheck` for documentation of all of the methods and 127 | classes accessible through the `pz.chk` alias. 128 | 129 | 130 | Dataclass and Struct Utilities 131 | ------------------------------ 132 | 133 | .. autosummary:: 134 | pz.is_pytree_dataclass_type 135 | pz.is_pytree_node_field 136 | pz.StructStaticMetadata 137 | pz.PyTreeDataclassSafetyError 138 | 139 | 140 | Rendering and Global Configuration Management 141 | --------------------------------------------- 142 | 143 | These utilities are available in the ``pz`` namespace for backwards 144 | compatibility. However, they have been moved to the separate Treescope 145 | pretty-printing package. See the 146 | `Treescope documentation `_ 147 | for more information. 148 | 149 | .. autosummary:: 150 | 151 | pz.ts 152 | pz.show 153 | pz.ContextualValue 154 | pz.oklch_color 155 | pz.color_from_string 156 | pz.dataclass_from_attributes 157 | pz.init_takes_fields 158 | -------------------------------------------------------------------------------- /docs/api/pz.ts.rst: -------------------------------------------------------------------------------- 1 | ``pz.ts``: Treescope alias namespace 2 | ==================================== 3 | 4 | .. module:: penzai.pz.ts 5 | .. currentmodule:: penzai 6 | 7 | Treescope, Penzai's interactive pretty-printer, has moved to an independent 8 | package `treescope`. See the 9 | `Treescope documentation `_ 10 | for more information. 11 | 12 | The alias namespace ``pz.ts`` contains shorthand aliases for commonly-used 13 | Treescope functions. However, we recommend that new users use `treescope` 14 | directly. 15 | 16 | Using Treescope in IPython Notebooks 17 | ------------------------------------ 18 | 19 | .. autosummary:: 20 | 21 | pz.ts.basic_interactive_setup 22 | pz.ts.register_as_default 23 | pz.ts.register_autovisualize_magic 24 | pz.ts.register_context_manager_magic 25 | 26 | 27 | Showing Objects Explicitly 28 | -------------------------- 29 | 30 | .. autosummary:: 31 | pz.ts.render_array 32 | pz.ts.render_array_sharding 33 | pz.ts.integer_digitbox 34 | pz.ts.text_on_color 35 | pz.ts.display 36 | pz.show 37 | 38 | 39 | Styling Displayed Objects 40 | ------------------------- 41 | 42 | .. autosummary:: 43 | 44 | pz.ts.inline 45 | pz.ts.indented 46 | pz.ts.with_font_size 47 | pz.ts.with_color 48 | pz.ts.bolded 49 | pz.ts.styled 50 | 51 | 52 | Configuring Treescope 53 | --------------------- 54 | 55 | .. autosummary:: 56 | 57 | pz.ts.active_renderer 58 | pz.ts.active_expansion_strategy 59 | pz.ts.using_expansion_strategy 60 | pz.ts.active_autovisualizer 61 | pz.ts.default_diverging_colormap 62 | pz.ts.default_sequential_colormap 63 | 64 | Building Autovisualizers 65 | ------------------------ 66 | 67 | .. autosummary:: 68 | 69 | pz.ts.ArrayAutovisualizer 70 | pz.ts.Autovisualizer 71 | pz.ts.ChildAutovisualizer 72 | pz.ts.IPythonVisualization 73 | pz.ts.vocab_autovisualizer 74 | pz.ts.default_magic_autovisualizer 75 | 76 | 77 | Rendering to Strings 78 | -------------------- 79 | 80 | .. autosummary:: 81 | 82 | pz.ts.render_to_text 83 | pz.ts.render_to_html 84 | 85 | -------------------------------------------------------------------------------- /docs/api/treescope.rst: -------------------------------------------------------------------------------- 1 | penzai.treescope (Moved!) 2 | ========================================================= 3 | 4 | .. module:: penzai.treescope 5 | 6 | Treescope, Penzai's interactive pretty-printer, has moved to an independent 7 | package `treescope`. See the 8 | `Treescope documentation `_ 9 | for more information. 10 | 11 | The subpackage `penzai.treescope` contains compatibility stubs to ensure that 12 | libraries that use `penzai.treescope` to render custom types continue to work. 13 | We recommend that new users use `treescope` directly. 14 | 15 | One exception is if you want to use Treescope to render custom types that were 16 | configured using the old ``__penzai_repr__`` extension method and which have 17 | not yet been migrated to use the ``__treescope_repr__`` method. In this case, 18 | you can enable Treescope using the legacy Penzai API, using :: 19 | 20 | pz.ts.basic_interactive_setup() 21 | 22 | or, for more control: :: 23 | 24 | pz.ts.register_as_default() 25 | pz.ts.register_autovisualize_magic() 26 | pz.ts.active_autovisualizer.set_globally(pz.ts.ArrayAutovisualizer()) 27 | 28 | You can also pretty-print individual values using `pz.show` or `pz.ts.display`. 29 | -------------------------------------------------------------------------------- /docs/ext/nb_output_cell_to_iframe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Penzai Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Wraps MyST-nb output cells in HTML templates. 15 | 16 | This can be used to sandbox the environment of output cell renderings, which is 17 | useful for Treescope renderings since they make heavy use of Javascript and CSS, 18 | and can also improve responsiveness of the main page. 19 | 20 | This transformation replaces output cells with ``