├── .black ├── .flake8 ├── .github └── PULL_REQUEST_TEMPLATE │ └── release_self_review.md ├── .gitignore ├── .isort.cfg ├── .python-version ├── .readthedocs.yaml ├── .taplo.toml ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── demos ├── README.md ├── ROME │ ├── __init__.py │ ├── gpt2_ROME.ipynb │ ├── llama_ROME.ipynb │ └── rome.py └── transformers │ └── patch_attention_head.py ├── dev ├── README.md ├── build_base.sh ├── build_dev.sh ├── containers │ ├── base.Dockerfile │ ├── dev.Dockerfile │ └── init │ │ ├── jupyter.sh │ │ └── tailscale_init.sh ├── dev_local.sh ├── dev_shell.sh ├── external_requirements.txt ├── graphpatch.sh └── toxin.sh ├── docs ├── api.rst ├── compiled_graph_module.rst ├── conf.py ├── data_structures.rst ├── extraction_options.rst ├── extras_versions.rst ├── index.rst ├── multiply_invoked_module.rst ├── node_path.rst ├── notes_on_compilation.rst ├── opaque_graph_module.rst ├── patch.rst ├── patchable_graph.rst ├── requirements.txt ├── tweaks │ └── __init__.py ├── what_is_activation_patching.rst └── working_with_graphpatch.rst ├── graphpatch ├── __init__.py ├── exceptions.py ├── extraction │ ├── __init__.py │ ├── compiled_graph_module.py │ ├── extraction_context.py │ ├── extraction_options.py │ ├── graph_extraction.py │ ├── graphpatch_module.py │ ├── multiply_invoked_module.py │ ├── opaque_graph_module.py │ ├── wrapped_8_bit_linear.py │ └── wrapped_layer_norm.py ├── hacks.py ├── meta │ ├── __init__.py │ ├── graph_meta.py │ ├── node_data.py │ └── node_path.py ├── optional │ ├── __init__.py │ ├── accelerate.py │ ├── bitsandbytes.py │ ├── dataclasses.py │ ├── transformer_lens.py │ ├── transformers.py │ └── typing_extensions.py ├── patch.py └── patchable_graph.py ├── mypy.ini ├── pyproject.toml ├── pytest.ini ├── scripts └── export_extras_versions.py ├── tests ├── __init__.py ├── __snapshots__ │ ├── test_node_path.ambr │ └── test_node_path │ │ ├── test_patchable_graph_graph_repr[compiled] │ │ ├── patchable_accelerate_pretrained_module_2_0.ambr │ │ ├── patchable_accelerate_pretrained_module_2_1.ambr │ │ ├── patchable_accelerate_pretrained_module_2_2-2_3.ambr │ │ ├── patchable_accelerate_pretrained_module_2_4.ambr │ │ ├── patchable_accelerate_pretrained_module_2_5.ambr │ │ ├── patchable_attribute_module_2_0.ambr │ │ ├── patchable_attribute_module_2_1.ambr │ │ ├── patchable_attribute_module_2_2-2_3.ambr │ │ ├── patchable_attribute_module_2_4.ambr │ │ ├── patchable_attribute_module_2_5.ambr │ │ ├── patchable_buffer_module_2_0.ambr │ │ ├── patchable_buffer_module_2_1.ambr │ │ ├── patchable_buffer_module_2_2-2_3.ambr │ │ ├── patchable_buffer_module_2_4.ambr │ │ ├── patchable_buffer_module_2_5.ambr │ │ ├── patchable_container_module_2_0.ambr │ │ ├── patchable_container_module_2_1.ambr │ │ ├── patchable_container_module_2_2-2_3.ambr │ │ ├── patchable_container_module_2_4.ambr │ │ ├── patchable_container_module_2_5.ambr │ │ ├── patchable_deeply_nested_output_module_2_0.ambr │ │ ├── patchable_deeply_nested_output_module_2_1.ambr │ │ ├── patchable_deeply_nested_output_module_2_2-2_3.ambr │ │ ├── patchable_deeply_nested_output_module_2_4.ambr │ │ ├── patchable_deeply_nested_output_module_2_5.ambr │ │ ├── patchable_disk_offload_pretrained_module_2_0.ambr │ │ ├── patchable_disk_offload_pretrained_module_2_1.ambr │ │ ├── patchable_disk_offload_pretrained_module_2_2-2_3.ambr │ │ ├── patchable_disk_offload_pretrained_module_2_4.ambr │ │ ├── patchable_disk_offload_pretrained_module_2_5.ambr │ │ ├── patchable_graph_break_module_2_0.ambr │ │ ├── patchable_graph_break_module_2_1.ambr │ │ ├── patchable_graph_break_module_2_2-2_3.ambr │ │ ├── patchable_graph_break_module_2_4.ambr │ │ ├── patchable_graph_break_module_2_5.ambr │ │ ├── patchable_layer_norm_module_2_0.ambr │ │ ├── patchable_layer_norm_module_2_1.ambr │ │ ├── patchable_layer_norm_module_2_2-2_3.ambr │ │ ├── patchable_layer_norm_module_2_4.ambr │ │ ├── patchable_layer_norm_module_2_5.ambr │ │ ├── patchable_minimal_module_2_0.ambr │ │ ├── patchable_minimal_module_2_1.ambr │ │ ├── patchable_minimal_module_2_2-2_3.ambr │ │ ├── patchable_minimal_module_2_4.ambr │ │ ├── patchable_minimal_module_2_5.ambr │ │ ├── patchable_mixed_cpu_pretrained_module_2_0.ambr │ │ ├── patchable_mixed_cpu_pretrained_module_2_1.ambr │ │ ├── patchable_mixed_cpu_pretrained_module_2_2-2_3.ambr │ │ ├── patchable_mixed_cpu_pretrained_module_2_4.ambr │ │ ├── patchable_mixed_cpu_pretrained_module_2_5.ambr │ │ ├── patchable_nested_module_2_0.ambr │ │ ├── patchable_nested_module_2_1.ambr │ │ ├── patchable_nested_module_2_2-2_3.ambr │ │ ├── patchable_nested_module_2_4.ambr │ │ ├── patchable_nested_module_2_5.ambr │ │ ├── patchable_pretrained_module_2_0.ambr │ │ ├── patchable_pretrained_module_2_1.ambr │ │ ├── patchable_pretrained_module_2_2-2_3.ambr │ │ ├── patchable_pretrained_module_2_4.ambr │ │ ├── patchable_pretrained_module_2_5.ambr │ │ ├── patchable_protected_name_module_2_0.ambr │ │ ├── patchable_protected_name_module_2_1.ambr │ │ ├── patchable_protected_name_module_2_2-2_3.ambr │ │ ├── patchable_protected_name_module_2_4.ambr │ │ ├── patchable_protected_name_module_2_5.ambr │ │ ├── patchable_quantized_module_2_0.ambr │ │ ├── patchable_quantized_module_2_1.ambr │ │ ├── patchable_quantized_module_2_2-2_3.ambr │ │ ├── patchable_quantized_module_2_4.ambr │ │ ├── patchable_quantized_module_2_5.ambr │ │ ├── patchable_quantized_pretrained_module_2_0.ambr │ │ ├── patchable_quantized_pretrained_module_2_1.ambr │ │ ├── patchable_quantized_pretrained_module_2_2-2_3.ambr │ │ ├── patchable_quantized_pretrained_module_2_4.ambr │ │ ├── patchable_quantized_pretrained_module_2_5.ambr │ │ ├── patchable_tuple_output_module_2_0.ambr │ │ ├── patchable_tuple_output_module_2_1.ambr │ │ ├── patchable_tuple_output_module_2_2-2_3.ambr │ │ ├── patchable_tuple_output_module_2_4.ambr │ │ ├── patchable_tuple_output_module_2_5.ambr │ │ ├── patchable_unused_submodule_module_2_0.ambr │ │ ├── patchable_unused_submodule_module_2_1.ambr │ │ ├── patchable_unused_submodule_module_2_2-2_3.ambr │ │ ├── patchable_unused_submodule_module_2_4.ambr │ │ ├── patchable_unused_submodule_module_2_5.ambr │ │ ├── patchable_varargs_module_2_0.ambr │ │ ├── patchable_varargs_module_2_1.ambr │ │ ├── patchable_varargs_module_2_2-2_3.ambr │ │ ├── patchable_varargs_module_2_4.ambr │ │ └── patchable_varargs_module_2_5.ambr │ │ └── test_patchable_graph_graph_repr[opaque] │ │ ├── patchable_accelerate_pretrained_module.ambr │ │ ├── patchable_attribute_module.ambr │ │ ├── patchable_buffer_module.ambr │ │ ├── patchable_container_module.ambr │ │ ├── patchable_deeply_nested_output_module.ambr │ │ ├── patchable_disk_offload_pretrained_module.ambr │ │ ├── patchable_graph_break_module.ambr │ │ ├── patchable_layer_norm_module.ambr │ │ ├── patchable_minimal_module.ambr │ │ ├── patchable_mixed_cpu_pretrained_module.ambr │ │ ├── patchable_nested_module.ambr │ │ ├── patchable_pretrained_module.ambr │ │ ├── patchable_protected_name_module.ambr │ │ ├── patchable_quantized_module.ambr │ │ ├── patchable_quantized_pretrained_module.ambr │ │ ├── patchable_tuple_output_module.ambr │ │ ├── patchable_unused_submodule_module.ambr │ │ └── patchable_varargs_module.ambr ├── conftest.py ├── fixtures │ ├── __init__.py │ ├── attribute_module.py │ ├── buffer_module.py │ ├── container_module.py │ ├── deeply_nested_output_module.py │ ├── fixture_collections.py │ ├── gpt2_merges.txt │ ├── gpt2_tokenizer.json │ ├── gpt2_vocab.json │ ├── graph_break_module.py │ ├── layer_norm_module.py │ ├── llama_tokenizer.model │ ├── minimal_module.py │ ├── nested_module.py │ ├── pretrained │ │ ├── __init__.py │ │ ├── test_model.py │ │ ├── test_model_config.py │ │ └── test_model_tokenizer.py │ ├── pretrained_module.py │ ├── protected_name_module.py │ ├── quantized_module.py │ ├── tiny_gpt2.py │ ├── tiny_llama.py │ ├── tuple_output_module.py │ ├── unused_submodule_module.py │ └── varargs_module.py ├── test_extraction.py ├── test_meta.py ├── test_node_path.py ├── test_patch.py ├── test_quantization.py ├── test_real_models.py ├── test_rome.py ├── test_serialization.py ├── test_transformer_lens.py ├── test_validate_env.py └── util.py ├── tox.ini └── uv.lock /.black: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max_line_length = 120 3 | ignore = 4 | # Whitespace before colon in slice, incompatible with black 5 | E203 6 | # Linebreak before binary operator 7 | W503 8 | exclude = 9 | snapshots 10 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/release_self_review.md: -------------------------------------------------------------------------------- 1 | ### Overview 2 | 3 | ### Changes 4 | 5 | ### Pre-release checklist: 6 | * [ ] short tests across full tox matrix 7 | * [ ] long tests across full tox matrix 8 | * [ ] notebooks 9 | * [ ] lint, format, typecheck 10 | * [ ] regenerate README.md 11 | 12 | ### Post-release checklist: 13 | * [ ] Tagged release on GitHub 14 | * [ ] Verify readthedocs updated 15 | * [ ] Publish to PyPI 16 | * [ ] Docker images 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # MacOS 156 | .DS_Store 157 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile=black 3 | combine_as_imports=True 4 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12.5 2 | 3.11.5 3 | 3.10.11 4 | 3.9.18 5 | 3.8.18 6 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.10" 13 | commands: 14 | - python -m pip install --upgrade --no-cache-dir pip setuptools 15 | - python -m pip install --upgrade --no-cache-dir sphinx readthedocs-sphinx-ext 16 | - python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt 17 | - python -m sphinx -T -E -b html -d _build/doctrees -D language=en docs $READTHEDOCS_OUTPUT/html 18 | 19 | # Build documentation in the "docs/" directory with Sphinx 20 | sphinx: 21 | configuration: docs/conf.py 22 | 23 | # Optionally build your docs in additional formats such as PDF and ePub 24 | # formats: 25 | # - pdf 26 | # - epub 27 | 28 | # Optional but recommended, declare the Python requirements required 29 | # to build your documentation 30 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 31 | python: 32 | install: 33 | - requirements: docs/requirements.txt 34 | -------------------------------------------------------------------------------- /.taplo.toml: -------------------------------------------------------------------------------- 1 | include = [".poetry-lockfiles/*.in", "**/*.toml"] 2 | exclude = [] 3 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "black-formatter.args": ["--config=.black"], 3 | "isort.args": ["--settings-file", ".isort.cfg"], 4 | "[python]": { 5 | "editor.defaultFormatter": "ms-python.black-formatter" 6 | }, 7 | "editor.rulers": [ 8 | 100, 9 | 120 10 | ], 11 | "python.analysis.typeCheckingMode": "off" 12 | } 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-2024 Evan Lloyd 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /demos/README.md: -------------------------------------------------------------------------------- 1 | # Demo index 2 | 3 | ## ROME 4 | Partial implementation of the Rank-One Model Editing (ROME) method from the paper 5 | [Locating and Editing Factual Associations in GPT](https://rome.baulab.info/)[^1]. Uses `graphpatch` 6 | to perform activation patching. Also see the GPT2 demo (modified slightly to work on free-tier hardware) [on Collab](https://colab.research.google.com/drive/1JlSp7Ckikb_r1bHvzHzphvR7nHPq_kbJ?usp=sharing). 7 | 8 | Files for the demo: 9 | 10 | ### `ROME/rome.py` 11 | Implementation of the ROME algorithm. The main API consists of [generate_key_vector](https://github.com/evan-lloyd/graphpatch/blob/5ebc57a12f8b23c869eb22581695f7e03688f941/demos/ROME/rome.py#L214) and [generate_value_vector](https://github.com/evan-lloyd/graphpatch/blob/5ebc57a12f8b23c869eb22581695f7e03688f941/demos/ROME/rome.py#L104C22-L104C22), which compute the vectors needed for the model editing, and [RomePatch](https://github.com/evan-lloyd/graphpatch/blob/5ebc57a12f8b23c869eb22581695f7e03688f941/demos/ROME/rome.py#L17) which applies the edit when used in the 12 | PatchableGraph [patch() context manager](https://graphpatch.readthedocs.io/en/latest/patchable_graph.html#graphpatch.PatchableGraph.patch). 13 | 14 | ### `ROME/gpt2_ROME.ipynb` 15 | Notebook demonstrating applying ROME to [GPT2-XL](https://huggingface.co/gpt2-xl). Assumes that the 16 | model weights have been saved to `/models/gpt2-xl`; change the `model_path` variable as appropriate. 17 | 18 | ### `ROME/llama_ROME.ipynb` 19 | Same example as above, but applied to [Llama-7B](https://huggingface.co/luodian/llama-7b-hf). 20 | 21 | [^1]: Kevin Meng, David Bau, Alex Andonian, and Yonatan Belinkov. "Locating and Editing Factual Associations in GPT." Advances in Neural Information Processing Systems 36 (2022). 22 | -------------------------------------------------------------------------------- /demos/ROME/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evan-lloyd/graphpatch/d1ecec2949ea622eb04a4a364ef08942d6a8025f/demos/ROME/__init__.py -------------------------------------------------------------------------------- /demos/transformers/patch_attention_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | from graphpatch import PatchableGraph, ZeroPatch 5 | 6 | model = AutoModelForCausalLM.from_pretrained("gpt2") 7 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 8 | example_inputs = tokenizer("when in Paris, do as the", return_tensors="pt") 9 | pg = PatchableGraph(model, example_inputs.input_ids) 10 | print( 11 | "Unpatched top 5 logits:", 12 | [ 13 | (t[0], tokenizer.decode([int(t[1])])) 14 | for t in zip( 15 | *map( 16 | lambda x: x.tolist(), 17 | torch.topk(pg(**example_inputs).logits[0, 6, :], k=5, sorted=True), 18 | ) 19 | ) 20 | ], 21 | ) 22 | # Zeros out the attention pattern of the 11th attention head in the first layer. 23 | # NB: slice=(slice(None), slice(None), 10, slice(None)) corresponds to the tensor 24 | # indexing expression [:, :, 10, :]. 25 | with pg.patch( 26 | { 27 | "transformer.h_0.attn.attn_output_1": ZeroPatch( 28 | slice=(slice(None), slice(None), 10, slice(None)) 29 | ) 30 | } 31 | ): 32 | print( 33 | "Patched top 5 logits:", 34 | [ 35 | (t[0], tokenizer.decode([int(t[1])])) 36 | for t in zip( 37 | *map( 38 | lambda x: x.tolist(), 39 | torch.topk(pg(**example_inputs).logits[0, 6, :], k=5, sorted=True), 40 | ) 41 | ) 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /dev/README.md: -------------------------------------------------------------------------------- 1 | ## Overiew 2 | 3 | Scripts and Docker images for developing on or with `graphpatch`. All scripts assume they are being run from the root of the repository. 4 | 5 | ## Docker images 6 | ### base 7 | `dev/containers/base.Dockerfile` 8 | 9 | Base container for running `graphpatch`; installs only the required dependencies. Uses [nvidia/cuda](https://hub.docker.com/r/nvidia/cuda/) as the base image. 10 | 11 | #### Building 12 | ``` 13 | ./dev/build_base.sh 14 | ``` 15 | Builds the base Docker image and tags it `graphpatch-base`. 16 | 17 | #### Basic usage 18 | ``` 19 | docker run -it --gpus all --rm graphpatch-base bash 20 | ``` 21 | Starts a shell within the image, granting access to host GPU's. If the host system lacks CUDA-supported GPU's, omit `--gpus all`. If you are developing experiments or an application using `graphpatch`, consider instead creating your own Docker image with this as a base using the [FROM](https://docs.docker.com/engine/reference/builder/#from) instruction. 22 | 23 | ### dev 24 | `dev/containers/dev.Dockerfile` 25 | 26 | Container which I used during development on [RunPod](https://www.runpod.io/) instances. No affiliation, nor is this an endorsement, but I did find this service, along with [Tailscale](https://tailscale.com/) to join to a private VPN, convenient for testing on multi-GPU setups. This container's setup is not RunPod-specific, but it does assume the use of Tailscale; you will need to set the `TAILSCALE_AUTH_KEY` and `TAILSCALE_HOST_NAME` environment variables in the container launch configuration. 27 | 28 | Installs all dev and optional dependencies, as well as Tailscale for joining to a private VPN. `TAILSCALE_AUTH_KEY` should be set to a [Tailscale auth key](https://tailscale.com/kb/1085/auth-keys) for your Tailnet. I found using ephemeral, reusable keys useful for development. The value of `TAILSCALE_HOST_NAME` will be used to set the name of the host when it joins your Tailnet. For example, during development I used the name `graphpatch`, which allowed me to do things like 29 | ``` 30 | ssh graphpatch 31 | ``` 32 | from any other device on my Tailnet. 33 | 34 | #### Building 35 | ``` 36 | ./dev/build_dev.sh 37 | ``` 38 | 39 | Builds the dev Docker image and tags it `graphpatch-dev`. 40 | 41 | #### Usage 42 | ``` 43 | ./dev/dev_shell.sh 44 | ``` 45 | 46 | Runs a bash shell inside the dev image, locally. Mounts the repository directory as a volume mapped to `/graphpatch` inside the container. 47 | 48 | ``` 49 | ./dev/dev_local.sh 50 | ``` 51 | 52 | Runs the container locally, using the same configuration as when running it on a hosting service. Connects to your Tailnet using the `TAILSCALE_AUTH_KEY` and `TAILSCALE_HOST_NAME` environment variables and launches the Tailscale SSH server. You can then connect to the container over SSH from any other device on your Tailnet. Also mounts the repository directory as a volume mapped to `/graphpatch` inside the container. 53 | 54 | #### Other recommendations 55 | A workflow I found useful was launching this image on RunPod, then setting up file synchronization between my local machine and the running container via [Mutagen](https://mutagen.io/). I created a Docker image to streamline this process, available [here](https://github.com/evan-lloyd/codesync/tree/main). 56 | 57 | ## Utility Scripts 58 | `dev/toxin.sh` — run a command inside the designated `tox` environment; by [Daniel Pryden](https://github.com/dpryden), script taken from https://gist.github.com/dpryden/92a9a94ed21207bba549bbe7ac41ca9f 59 | 60 | Example usage: 61 | ``` 62 | ./dev/toxin.sh -e test-py38-torch21-extranone bash 63 | ``` 64 | Starts a shell within the Python 3.8, PyTorch 2.1, no-`transformers` test environment. Note that the environment must have been previously initialized. I found this extremely useful for debugging the dependency setup within the Tox testing matrix. 65 | -------------------------------------------------------------------------------- /dev/build_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | docker build . -f dev/containers/base.Dockerfile --tag graphpatch-base --platform linux/amd64 3 | -------------------------------------------------------------------------------- /dev/build_dev.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | docker build . -f dev/containers/dev.Dockerfile --tag graphpatch-dev --platform linux/amd64 3 | -------------------------------------------------------------------------------- /dev/containers/base.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04 2 | 3 | WORKDIR /graphpatch 4 | COPY .python-version ./ 5 | 6 | RUN apt update -y && apt upgrade -y && \ 7 | DEBIAN_FRONTEND=noninteractive apt install -y wget openssh-server curl git tmux 8 | 9 | ADD https://astral.sh/uv/0.5.23/install.sh /uv-installer.sh 10 | RUN sh /uv-installer.sh && rm /uv-installer.sh 11 | ENV PATH=/root/.local/bin:/graphpatch/.venv/bin:$PATH \ 12 | GP_MODEL_DIR=/models \ 13 | UV_CACHE_DIR=/root/.cache/uv \ 14 | UV_LINK_MODE=symlink \ 15 | TERMINFO_DIRS=/etc/terminfo:/lib/terminfo:/usr/share/terminfo 16 | RUN uv python install `head -n 1 .python-version` 17 | 18 | # Bake in env vars so they'll be present when we SSH into a remote container 19 | RUN env > /etc/environment && mkdir /models && mkdir /graphpatch/.pytest_cache 20 | RUN echo "cd /graphpatch" >> "/root/.bashrc" 21 | 22 | COPY pyproject.toml uv.lock README.md ./ 23 | RUN uv sync --frozen --no-install-project --no-dev 24 | 25 | COPY graphpatch ./graphpatch 26 | RUN uv sync --frozen --group base --group torch25 --all-extras 27 | -------------------------------------------------------------------------------- /dev/containers/dev.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM graphpatch-base 2 | WORKDIR /graphpatch 3 | 4 | RUN wget -O - https://tailscale.com/install.sh | sh 5 | COPY --chmod=777 dev/containers/init/ /init/ 6 | 7 | RUN uv python install 8 | 9 | # Add dev packages for remote development/testing 10 | RUN --mount=type=bind,source=dev/external_requirements.txt,target=external_requirements.txt \ 11 | uv tool install tox --overrides external_requirements.txt --with tox-uv 12 | COPY pytest.ini mypy.ini tox.ini .black .flake8 .isort.cfg .taplo.toml ./ 13 | RUN uv sync --frozen --all-extras --group testenv-lint --group testenv-format \ 14 | --group testenv-typecheck --group testenv-test --group torch25 --group base --group dev 15 | 16 | COPY tests/ tests/ 17 | COPY demos/ demos/ 18 | 19 | ENTRYPOINT [ "/init/tailscale_init.sh" ] 20 | -------------------------------------------------------------------------------- /dev/containers/init/jupyter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PYDEVD_DISABLE_FILE_VALIDATION=1 jupyter lab --ip 0.0.0.0 --allow-root --no-browser 3 | -------------------------------------------------------------------------------- /dev/containers/init/tailscale_init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | tailscaled --tun=userspace-networking & 3 | tailscale up --authkey=$TAILSCALE_AUTH_KEY --hostname=$TAILSCALE_HOST_NAME --ssh 4 | sleep infinity 5 | -------------------------------------------------------------------------------- /dev/dev_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | . ./dev/graphpatch.sh 3 | docker_run -it --rm $(tailscale_env) graphpatch-dev 4 | -------------------------------------------------------------------------------- /dev/dev_shell.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | . ./dev/graphpatch.sh 3 | docker_run -it --rm --entrypoint /bin/bash graphpatch-dev 4 | -------------------------------------------------------------------------------- /dev/external_requirements.txt: -------------------------------------------------------------------------------- 1 | uv==0.5.23 2 | tox-uv==1.20.1 3 | tox==4.24.1 4 | -------------------------------------------------------------------------------- /dev/graphpatch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function docker_binds() { 4 | echo "--mount type=bind,src=$(pwd)/graphpatch,dst=/graphpatch/graphpatch \ 5 | --mount type=bind,src=$(pwd)/tests,dst=/graphpatch/tests \ 6 | --mount type=bind,src=$(pwd)/demos,dst=/graphpatch/demos \ 7 | --mount type=bind,src=$(pwd)/dev,dst=/graphpatch/dev \ 8 | --mount type=bind,src=$(pwd)/pyproject.toml,dst=/graphpatch/pyproject.toml \ 9 | --mount type=bind,src=$(pwd)/uv.lock,dst=/graphpatch/uv.lock \ 10 | --mount type=bind,src=$(pwd)/README.md,dst=/graphpatch/README.md \ 11 | --mount type=bind,src=$(pwd)/LICENSE,dst=/graphpatch/LICENSE \ 12 | --mount type=bind,src=$(pwd)/.black,dst=/graphpatch/.black \ 13 | --mount type=bind,src=$(pwd)/.flake8,dst=/graphpatch/.flake8 \ 14 | --mount type=bind,src=$(pwd)/.isort.cfg,dst=/graphpatch/.isort.cfg \ 15 | --mount type=bind,src=$(pwd)/.taplo.toml,dst=/graphpatch/.taplo.toml \ 16 | --mount type=bind,src=$(pwd)/tox.ini,dst=/graphpatch/tox.ini \ 17 | --mount type=bind,src=$(pwd)/pytest.ini,dst=/graphpatch/pytest.ini \ 18 | --mount type=bind,src=$(pwd)/.python-version,dst=/graphpatch/.python-version" 19 | } 20 | 21 | function docker_run() { 22 | docker run --gpus all $(docker_binds) "$@" || docker run $(docker_binds) "$@" 23 | } 24 | 25 | function tailscale_env() { 26 | echo "-e TAILSCALE_AUTH_KEY=$TAILSCALE_AUTH_KEY -e TAILSCALE_HOST_NAME=graphpatch" 27 | } 28 | -------------------------------------------------------------------------------- /dev/toxin.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Run a command inside a Tox environment. 3 | # Original: https://gist.github.com/dpryden/92a9a94ed21207bba549bbe7ac41ca9f 4 | # Extended to select an environment based on labels. 5 | 6 | if ! [[ -d .tox ]]; then 7 | echo 'Cannot find .tox in this directory!' 8 | exit 1 9 | fi 10 | if [[ "$1" == '-r' ]]; then 11 | maybe_refresh=-r 12 | shift 13 | fi 14 | if [[ "$1" == '-e' ]]; then 15 | toxenv="$2" 16 | elif [[ "$1" == '-m' ]]; then 17 | toxenv=$(tox -a -m $2 | head -n 1) 18 | else 19 | toxenv="$(cd .tox && echo py* | awk '{print $NF}')" 20 | fi 21 | toxdir="$(cd .tox && pwd)" 22 | bindir="$toxdir/$toxenv/bin" 23 | activate_script="$bindir/activate" 24 | 25 | if [[ ! -f $activate_script || "$maybe_refresh" == '-r' ]]; then 26 | tox $maybe_refresh $1 $2 --notest 27 | fi 28 | 29 | if ! [[ -f $activate_script ]]; then 30 | printf 'Cannot find tox env "%s" in current directory!\n' "$toxenv" 31 | exit 1 32 | fi 33 | shift 2 34 | if [[ "$1" == "" ]]; then 35 | PATH="$bindir:$PATH" /bin/bash --rcfile "$activate_script" 36 | else 37 | PATH="$bindir:$PATH" /bin/bash --rcfile "$activate_script" -ci $* 38 | fi 39 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. py:currentmodule:: graphpatch 2 | 3 | API 4 | === 5 | The public API to ``graphpatch`` consists of :class:`PatchableGraph`, which is a wrapper around :class:`Module`, 6 | and :class:`Patch `, which is a compact representation of an intervention 7 | to perform on an intermediate Tensor value. 8 | 9 | Reference 10 | ********* 11 | 12 | .. toctree:: 13 | :titlesonly: 14 | 15 | extraction_options 16 | patch 17 | patchable_graph 18 | -------------------------------------------------------------------------------- /docs/compiled_graph_module.rst: -------------------------------------------------------------------------------- 1 | CompiledGraphModule 2 | ################### 3 | .. py:currentmodule:: graphpatch 4 | 5 | .. autoclass:: CompiledGraphModule() 6 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = "graphpatch" 10 | copyright = "2023–2024, Evan Lloyd" 11 | author = "Evan Lloyd" 12 | release = "0.2.3" 13 | 14 | # -- General configuration --------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 16 | 17 | extensions = [ 18 | "sphinx.ext.autodoc", 19 | "sphinx.ext.intersphinx", 20 | "sphinx.ext.napoleon", 21 | "IPython.sphinxext.ipython_console_highlighting", 22 | "IPython.sphinxext.ipython_directive", 23 | "docs.tweaks", 24 | "sphinx_markdown_builder", 25 | ] 26 | intersphinx_mapping = { 27 | "bitsandbytes": ("https://huggingface.co/docs/bitsandbytes/main/en/", None), 28 | "torch": ("https://pytorch.org/docs/stable/", None), 29 | "torchcpp": ("https://pytorch.org/cppdocs/", None), 30 | "transformers": ("https://huggingface.co/docs/transformers/main/en/", None), 31 | "accelerate": ("https://huggingface.co/docs/accelerate/main/en/", None), 32 | "transformer_lens": ("https://transformerlensorg.github.io/TransformerLens/", None), 33 | "python": ("https://docs.python.org/3.11", None), 34 | } 35 | napoleon_google_docstring = True 36 | napoleon_numpy_docstring = False 37 | napoleon_include_init_with_doc = True 38 | 39 | templates_path = ["_templates"] 40 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 41 | 42 | 43 | # -- Options for HTML output ------------------------------------------------- 44 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 45 | 46 | html_theme = "furo" 47 | html_static_path = ["_static"] 48 | 49 | add_module_names = False 50 | 51 | autodoc_inherit_docstrings = False 52 | autodoc_type_aliases = { 53 | "TensorSlice": "TensorSlice", 54 | "TensorSliceElement": "TensorSliceElement", 55 | "PatchTarget": "PatchTarget", 56 | } 57 | autodoc_typehints_format = "short" 58 | 59 | markdown_http_base = "https://graphpatch.readthedocs.io/en/stable" 60 | markdown_uri_doc_suffix = ".html" 61 | 62 | rst_epilog = f""" 63 | .. |graphpatch_version| replace:: ``graphpatch {release}`` 64 | """ 65 | -------------------------------------------------------------------------------- /docs/data_structures.rst: -------------------------------------------------------------------------------- 1 | .. py:currentmodule:: graphpatch 2 | 3 | Data structures 4 | =============== 5 | :class:`PatchableGraph` returns some types that are not meant to be directly constructed by users. 6 | :class:`CompiledGraphModule` and :class:`OpaqueGraphModule` are the transformed versions of the 7 | submodules of a module made patchable by ``graphpatch``. :class:`NodePath ` is a 8 | REPL-oriented structure for easier navigation of the generated graphs. 9 | 10 | Reference 11 | ********* 12 | 13 | .. toctree:: 14 | :titlesonly: 15 | 16 | compiled_graph_module 17 | multiply_invoked_module 18 | node_path 19 | opaque_graph_module 20 | -------------------------------------------------------------------------------- /docs/extraction_options.rst: -------------------------------------------------------------------------------- 1 | .. _extraction_options: 2 | 3 | ExtractionOptions 4 | ################# 5 | .. py:currentmodule:: graphpatch 6 | 7 | .. autoclass:: ExtractionOptions() 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/extras_versions.rst: -------------------------------------------------------------------------------- 1 | .. code:: 2 | 3 | accelerate==1.0.0 4 | bitsandbytes==0.44.1 5 | numpy==1.24.4 (Python 3.8) 6 | numpy==2.0.2 (Python 3.9) 7 | numpy==2.1.1 (later Python versions) 8 | sentencepiece==0.2.0 9 | transformer-lens==2.4.1 10 | transformers==4.45.2 11 | -------------------------------------------------------------------------------- /docs/multiply_invoked_module.rst: -------------------------------------------------------------------------------- 1 | MultiplyInvokedModule 2 | ##################### 3 | .. py:currentmodule:: graphpatch 4 | 5 | .. autoclass:: MultiplyInvokedModule() 6 | -------------------------------------------------------------------------------- /docs/node_path.rst: -------------------------------------------------------------------------------- 1 | NodePath 2 | ############## 3 | .. py:currentmodule:: graphpatch.meta 4 | 5 | .. autoclass:: NodePath() 6 | -------------------------------------------------------------------------------- /docs/opaque_graph_module.rst: -------------------------------------------------------------------------------- 1 | OpaqueGraphModule 2 | ################# 3 | .. py:currentmodule:: graphpatch 4 | 5 | .. autoclass:: OpaqueGraphModule() 6 | -------------------------------------------------------------------------------- /docs/patch.rst: -------------------------------------------------------------------------------- 1 | .. py:currentmodule:: graphpatch.patch 2 | 3 | .. _patch: 4 | 5 | Patch 6 | ##### 7 | 8 | .. automodule:: graphpatch.patch 9 | :members: 10 | :member-order: bysource 11 | 12 | Types 13 | ----- 14 | 15 | .. data:: TensorSlice 16 | :type: TensorSliceElement | List[TensorSlice] | Tuple[TensorSlice, ...] 17 | 18 | This is a datatype representing the indexing operation done when you slice a :class:`Tensor `, 19 | as happens in code like 20 | 21 | .. code:: 22 | 23 | x[:, 5:8, 2] = 3 24 | 25 | This is not a ``graphpatch``-specific type (we have merely aliased it for convenience), but interacts 26 | with :class:`Python internals ` which may be unfamiliar. 27 | 28 | Briefly, you will almost always want to pass a sequence (tuple or list) with as many elements as the dimensionality 29 | of your tensor. Within this sequence, elements can be either integers, subsequences, :class:`slices `, or Tensors. 30 | Each element of the sequence will select a subset of the Tensor along the dimension with the corresponding index. 31 | An integer will select a single "row" along that dimension. A subsequence will select multiple "rows". 32 | A slice will select a range of "rows". (``slice(None)`` selects all rows for that dimension, equivalent 33 | to writing a ":" within the bracket expression.) A Tensor will perform a complex operation 34 | that is out of the scope of this brief note. 35 | 36 | For a concrete example, we can accomplish the above operation with the following :class:`ReplacePatch`: 37 | 38 | .. code:: 39 | 40 | ReplacePatch(value=3, slice=((slice(None), slice(5, 8), 2))) 41 | 42 | See also: :std:doc:`torchcpp:notes/tensor_indexing`. 43 | 44 | .. data:: TensorSliceElement 45 | :type: int | slice | torch.Tensor 46 | 47 | One component of a :data:`TensorSlice`. 48 | 49 | .. data:: PatchTarget 50 | :type: TypeVar 51 | 52 | Generic type argument which will be specialized for patches expecting different data types. Almost always 53 | specialized to :class:`Tensor `. 54 | -------------------------------------------------------------------------------- /docs/patchable_graph.rst: -------------------------------------------------------------------------------- 1 | PatchableGraph 2 | ############## 3 | .. py:currentmodule:: graphpatch 4 | 5 | .. autoclass:: PatchableGraph 6 | :members: 7 | :exclude-members: forward 8 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | filelock==3.12.4 2 | fsspec==2023.10.0 3 | jinja2==3.1.4 4 | markupsafe==2.1.3 5 | mpmath==1.3.0 6 | networkx==3.1 7 | setuptools==68.2.2 8 | sympy==1.12 9 | torch==2.1.1 10 | triton==2.1.0 11 | typing-extensions==4.8.0 12 | furo==2023.9.10 13 | ipython==8.12.3 14 | sphinx-markdown-builder==0.6.6 15 | -------------------------------------------------------------------------------- /docs/tweaks/__init__.py: -------------------------------------------------------------------------------- 1 | import furo 2 | from bs4 import BeautifulSoup 3 | 4 | 5 | # h/t workaround in https://github.com/sphinx-doc/sphinx/issues/10785 6 | def resolve_type_aliases(app, env, node, contnode): 7 | aliases = app.config.autodoc_type_aliases 8 | if node["refdomain"] == "py" and node["reftype"] == "class" and node["reftarget"] in aliases: 9 | return app.env.get_domain("py").resolve_xref( 10 | env, node["refdoc"], app.builder, "data", node["reftarget"], node, contnode 11 | ) 12 | 13 | 14 | PREFIXES_TO_STRIP = ("torch.nn", "torch.fx", "bitsandbytes.nn") 15 | 16 | 17 | def dequalify_intersphinx(app, doctree, docname): 18 | from docutils.nodes import NodeVisitor, Text 19 | 20 | class Visitor(NodeVisitor): 21 | def dispatch_visit(self, node): 22 | # Bit of a hammer, but this seems like the simplest way to get intersphinx to not 23 | # qualify cross-references on a case-by-case basis. We have no easy hook to modify 24 | # the text it generates because resolve_type_aliases is all-or-nothing--our own hook 25 | # would always be either too late or too early. 26 | if any(str(node).startswith(p) for p in PREFIXES_TO_STRIP): 27 | node.parent.children = [Text(node.split(".")[-1])] 28 | 29 | doctree.walk(Visitor(doctree.document)) 30 | pass 31 | 32 | 33 | def setup(app): 34 | orig_get_navigation_tree = furo.get_navigation_tree 35 | 36 | def get_navigation_tree(toctree_html): 37 | furo_html = orig_get_navigation_tree(toctree_html) 38 | soup = BeautifulSoup(furo_html, "html.parser") 39 | # Expand all TOC sections by default 40 | for checkbox in soup.find_all("input", class_="toctree-checkbox", recursive=True): 41 | checkbox.attrs["checked"] = "" 42 | # Don't show section collapse button 43 | for label in soup.find_all("label"): 44 | label.decompose() 45 | 46 | return str(soup) 47 | 48 | furo.get_navigation_tree = get_navigation_tree 49 | app.connect("missing-reference", resolve_type_aliases) 50 | app.connect("doctree-resolved", dequalify_intersphinx) 51 | -------------------------------------------------------------------------------- /docs/what_is_activation_patching.rst: -------------------------------------------------------------------------------- 1 | .. py:currentmodule:: graphpatch 2 | 3 | .. _what_is_activation_patching: 4 | 5 | What is activation patching? 6 | ============================ 7 | **Activation patching** is a technique in mechanistic interpretability that involves *modifying* (patching) 8 | a subset of intermediate values (activations) and evaluating a model's behavior under this intervention. 9 | The idea is that by making a local change and keeping everything else constant, we can validate 10 | causal hypotheses about how the model works. Or, in the other direction, by trying a bunch of local 11 | changes in a loop, we can *discover* what parts of a model are important for achieving a given behavior. 12 | Example: 13 | 14 | .. code:: 15 | 16 | pg = PatchableGraph(my_model, **clean_inputs) 17 | # Record clean activations 18 | with pg.patch({target_node: (clean := ProbePatch())}): 19 | clean_output = pg(**clean_inputs) 20 | # Evaluate output with corrupted inputs, patching in the "clean" value at the target node 21 | with pg.patch({target_node: ReplacePatch(value=clean.activation)}): 22 | corrupted_output = pg(**corrupted_inputs) 23 | evaluate_intervention(clean_output, corrupted_output) 24 | 25 | **Ablation** is a nearly identical concept; it involves running a model while intervening on 26 | specific intermediate outputs. The distinction is that I typically see "activation patching" used as 27 | a term for substituting activations observed in runs of the model under different conditions 28 | (for example, with a different input), whereas "ablation" generally refers to substitutions of more 29 | "constant" values (such as zeros or a sample mean of several different runs). Example: 30 | 31 | .. code:: 32 | 33 | pg = PatchableGraph(my_module, **inputs) 34 | # Zero ablation 35 | with pg.patch({target_node: ZeroPatch()}): 36 | ablated_output = pg(**inputs) 37 | 38 | In ``graphpatch`` I use these terms interchangeably; the point is to make the substitution of 39 | intermediate values as easy as possible. -------------------------------------------------------------------------------- /graphpatch/__init__.py: -------------------------------------------------------------------------------- 1 | from .extraction import ( 2 | CompiledGraphModule, 3 | ExtractionOptions, 4 | MultiplyInvokedModule, 5 | OpaqueGraphModule, 6 | ) 7 | 8 | # Load bitsandbytes early so we can suppress its rather chatty startup process. 9 | from .optional import bitsandbytes # noqa: F401 10 | from .patch import ( 11 | AddPatch, 12 | CustomPatch, 13 | Patch, 14 | ProbePatch, 15 | RecordPatch, 16 | ReplacePatch, 17 | ZeroPatch, 18 | ) 19 | from .patchable_graph import PatchableGraph 20 | 21 | __all__ = [ 22 | "AddPatch", 23 | "CompiledGraphModule", 24 | "CustomPatch", 25 | "MultiplyInvokedModule", 26 | "OpaqueGraphModule", 27 | "Patch", 28 | "ProbePatch", 29 | "RecordPatch", 30 | "ReplacePatch", 31 | "ZeroPatch", 32 | "ExtractionOptions", 33 | "PatchableGraph", 34 | ] 35 | -------------------------------------------------------------------------------- /graphpatch/exceptions.py: -------------------------------------------------------------------------------- 1 | class GraphPatchWarning(Warning): 2 | pass 3 | 4 | 5 | class GraphPatchException(Exception): 6 | pass 7 | -------------------------------------------------------------------------------- /graphpatch/extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from .compiled_graph_module import CompiledGraphModule 2 | from .extraction_options import ExtractionOptions 3 | from .graph_extraction import UnusedModule, extract 4 | from .graphpatch_module import GraphPatchModule 5 | from .multiply_invoked_module import MultiplyInvokedModule 6 | from .opaque_graph_module import OpaqueGraphModule 7 | 8 | __all__ = [ 9 | "CompiledGraphModule", 10 | "ExtractionOptions", 11 | "extract", 12 | "GraphPatchModule", 13 | "MultiplyInvokedModule", 14 | "OpaqueGraphModule", 15 | "UnusedModule", 16 | ] 17 | -------------------------------------------------------------------------------- /graphpatch/extraction/compiled_graph_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from torch import compile 4 | from torch.fx import Graph, GraphModule 5 | from torch.nn import Module 6 | 7 | from .. import hacks 8 | from .graphpatch_module import GraphPatchModule 9 | 10 | 11 | class CompiledGraphModule(GraphPatchModule): 12 | """CompiledGraphModule is a subclass of :class:`torch.fx.GraphModule`. It is essentially the 13 | output of a successful run of :func:`torch.compile` with some minor modifications made by 14 | ``graphpatch``. 15 | """ 16 | 17 | pass 18 | 19 | 20 | def compile_module(module: Module, *args: Any, **kwargs: Any) -> CompiledGraphModule: 21 | try: 22 | hacks._CURRENTLY_COMPILING = True 23 | graph_module = GraphModule({}, Graph()) 24 | 25 | def callback(gm: GraphModule, *args: Any, **kwargs: Any) -> GraphModule: 26 | nonlocal graph_module 27 | graph_module = gm 28 | # There is no hook to choose a subclass of GraphModule to create during compilation, so 29 | # dynamically make it a subclass of CompiledGraphModule. GraphModules are always created 30 | # by torch as the sole instance of a dynamically generated class, so this is safe. 31 | assert type(gm) is not GraphModule 32 | 33 | # We don't want to get back a LazyGraphModule, which now happens in 2.3. 34 | if hacks.TORCH_VERSION >= (2, 3): 35 | from torch.fx._lazy_graph_module import _LazyGraphModule 36 | 37 | if _LazyGraphModule in type(gm).__bases__: 38 | # Force an actual compilation of the GraphModule, which we need downstream. 39 | gm.real_recompile() 40 | type(gm).__bases__ = (CompiledGraphModule,) + tuple( 41 | GraphModule if c is _LazyGraphModule else c for c in type(gm).__bases__ 42 | ) 43 | else: 44 | type(gm).__bases__ = (CompiledGraphModule,) + type(gm).__bases__ 45 | type(gm).__name__ = CompiledGraphModule.__name__ 46 | gm._init(module) 47 | hacks._CURRENTLY_COMPILING = False 48 | return gm 49 | 50 | # We need to actually run inference to generate a GraphModule, which gets passed to 51 | # our callback above. 52 | compile(backend=callback, dynamic=True, fullgraph=True)(module)(*args, **kwargs) 53 | 54 | if not isinstance(graph_module, CompiledGraphModule): 55 | raise ValueError("Compilation callback was never called.") 56 | 57 | # In torch >= 2.1.0, FakeTensors get attached in each FXNode's meta, but they are 58 | # unpicklable. Make sure we don't keep them around. 59 | for node in graph_module.graph.nodes: 60 | node.meta.pop("example_value", None) 61 | 62 | return graph_module 63 | finally: 64 | hacks._CURRENTLY_COMPILING = False 65 | -------------------------------------------------------------------------------- /graphpatch/extraction/extraction_options.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Optional, Set, Type 2 | 3 | from torch.fx import GraphModule 4 | from torch.fx.graph import Graph 5 | from torch.nn import Module 6 | 7 | from ..optional.dataclasses import dataclass, field 8 | 9 | 10 | @dataclass(kw_only=True) 11 | class ExtractionOptions: 12 | """Options to control the behavior of ``graphpatch`` during graph extraction. This is a 13 | keyword-only dataclass; to construct one, pass any number of options from the below. 14 | 15 | Attributes: 16 | allow_unused_submodules: Whether to treat a submodule not being called under the given 17 | example inputs as a normal condition. Default: ``True`` if ``error_on_compilation_failure`` 18 | is ``False``, otherwise ("strict mode"), ``False``. 19 | classes_to_skip_compiling: Set of Module classes to leave uncompiled. These modules will 20 | only be patchable at their inputs, outputs, parameters, and buffers. May be useful for 21 | working around compilation issues. Default: ``set()``. 22 | copy_transformers_generation_config: If the wrapped Module is a huggingface transformers 23 | implementation, should graphpatch attempt to copy its generation config so generation 24 | convenience functions like ``generate()`` can be used? Default: ``True``. 25 | custom_extraction_functions: Optional map from Module classes to callables generating 26 | :class:`torch.fx.Graph` to be used in place of graphpatch's normal extraction mechanism 27 | when encountering that class. Advanced feature; should not be necessary for ordinary 28 | use. See :ref:`custom_extraction_functions`. Default: ``dict()``. 29 | error_on_compilation_failure: Treat failure to compile a submodule as an error, rather than 30 | falling back to module-level patching via :class:`OpaqueGraphModule`. Default: ``False``. 31 | postprocessing_function: Optional function to call which will modify the generated 32 | :class:`torch.fx.GraphModule`. This function can modify the underlying 33 | :class:`torch.fx.Graph` in-place. The original module is passed for reference in case, 34 | for example, the needed modifications depend on its configuration. Advanced feature; 35 | should not be necessary for ordinary use. Default: ``None``. 36 | skip_compilation: Skip compilation on all modules. Only module inputs and outputs will be 37 | patchable. May be useful for faster iteration times if patching intermediate values 38 | isn't needed. Default: ``False``. 39 | warn_on_compilation_failure: Issue a warning when compilation fails, but then fall back 40 | to module-level patching for the failed module(s). Default: ``False``. 41 | 42 | Example: 43 | .. code:: 44 | 45 | options = ExtractionOptions( 46 | classes_to_skip_compiling={MyUncompilableModule}, 47 | error_on_compilation_failure=True, 48 | ) 49 | pg = PatchableGraph(my_model, options, **example_inputs) 50 | """ 51 | 52 | allow_unused_submodules: Optional[bool] = None 53 | classes_to_skip_compiling: Set[Type[Module]] = field(default_factory=set) 54 | copy_transformers_generation_config: bool = True 55 | custom_extraction_functions: Dict[Type[Module], Callable[[Module], Graph]] = field( 56 | default_factory=dict 57 | ) 58 | error_on_compilation_failure: bool = False 59 | postprocessing_function: Optional[Callable[[GraphModule, Module], None]] = None 60 | skip_compilation: bool = False 61 | warn_on_compilation_failure: bool = False 62 | 63 | def __post_init__(self) -> None: 64 | if self.allow_unused_submodules is None: 65 | self.allow_unused_submodules = not self.error_on_compilation_failure 66 | -------------------------------------------------------------------------------- /graphpatch/extraction/multiply_invoked_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from torch.nn import ModuleList 4 | 5 | 6 | class MultiplyInvokedModule(ModuleList): 7 | """Wrapper around a module that was invoked multiple times by its parent when ``graphpatch`` 8 | converted it into a GraphModule. This allows you to patch distinct invocations independently. 9 | 10 | Example: 11 | .. code:: 12 | 13 | class Foo(Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.bar = Linear(3, 3) 17 | 18 | def forward(self, x, y): 19 | return self.bar(x) + self.bar(y) 20 | 21 | .. ipython:: 22 | :verbatim: 23 | 24 | In [1]: pg = PatchableGraph(Foo(), **inputs) 25 | In [2]: print(pg._graph_module) 26 | Out [2]: 27 | CompiledGraphModule( 28 | (bar): MultiplyInvokedModule( 29 | (0-1): 2 x CompiledGraphModule() 30 | ) 31 | ) 32 | In [3]: pg.graph 33 | Out[3]: 34 | : CompiledGraphModule 35 | ├─x: Tensor(3, 3) 36 | ├─y: Tensor(3, 3) 37 | ├─bar_0: CompiledGraphModule 38 | │ ├─input: Tensor(3, 3) 39 | │ ├─weight: Tensor(3, 3) 40 | │ ├─bias: Tensor(3) 41 | │ ├─linear: Tensor(3, 3) 42 | │ └─output: Tensor(3, 3) 43 | ├─bar_1: CompiledGraphModule 44 | │ ├─input: Tensor(3, 3) 45 | │ ├─weight: Tensor(3, 3) 46 | │ ├─bias: Tensor(3) 47 | │ ├─linear: Tensor(3, 3) 48 | │ └─output: Tensor(3, 3) 49 | ├─add: Tensor(3, 3) 50 | └─output: Tensor(3, 3) 51 | 52 | You can patch the two calls to the submodule "bar" independently: 53 | 54 | .. code:: 55 | 56 | >>> with pg.patch({"bar_0": ZeroPatch(), "bar_1": AddPatch(value=1)}): 57 | ... 58 | 59 | See also :ref:`multiple_invocations`. 60 | 61 | """ 62 | 63 | _graphpatch_invocation_index: int 64 | 65 | def __init__(self, *args: Any, **kwargs: Any): 66 | super().__init__(*args, **kwargs) 67 | self._graphpatch_invocation_index = 0 68 | 69 | def forward(self, *args: Any, **kwargs: Any) -> Any: 70 | # TODO: surely any sane module will never vary how many times it calls its submodules 71 | # and the modulo doesn't matter? But we may want to make this configurable between 72 | # round-robin or throwing an exception, possibly a global "strict" mode? 73 | index = self._graphpatch_invocation_index % len(self._modules) 74 | self._graphpatch_invocation_index = index + 1 75 | return self[index](*args, **kwargs) 76 | -------------------------------------------------------------------------------- /graphpatch/extraction/wrapped_8_bit_linear.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any 3 | 4 | import torch 5 | from torch import Tensor, float16 6 | from torch.nn import Module, Parameter 7 | 8 | from .. import hacks 9 | from ..optional.bitsandbytes import ( 10 | Linear8bitLt, 11 | MatmulLtState, 12 | get_tile_inds, 13 | matmul, 14 | undo_layout, 15 | ) 16 | 17 | 18 | @hacks.allow_in_graph # type: ignore 19 | def matmul_8bit(x: Tensor, weight: Tensor, bias: Tensor, threshold: float) -> Any: 20 | # bitsandbytes matmul doesn't work with FakeTensors, so just return a tensor of the right shape. 21 | if hacks.in_fake_mode(): 22 | return torch.zeros(*x.shape[:-1], weight.shape[0], device=x.device, dtype=float16) 23 | state = MatmulLtState() 24 | state.has_fp16_weights = True 25 | state.threshold = threshold 26 | return matmul(x, weight, bias=bias, state=state).to(float16) 27 | 28 | 29 | class Wrapped8BitLinear(Module): 30 | def __init__(self, original: Linear8bitLt): 31 | super().__init__() 32 | # CB and SCB get deleted when running inference, so we may have to recompute them. 33 | if original.weight.CB is not None: 34 | CB = original.weight.CB 35 | else: 36 | CB = undo_layout( 37 | original.weight, get_tile_inds(original.state.formatB, original.weight.device) 38 | )[: original.out_features, : original.in_features] 39 | if original.weight.SCB is not None: 40 | SCB = original.weight.SCB 41 | else: 42 | SCB = original.state.SCB 43 | self.threshold = original.state.threshold 44 | # It doesn't make sense with the current logic to compute gradients for the quantization 45 | # parameters. The user is expected instead to patch the "weight" node within the forward 46 | # computation. 47 | self.CB = Parameter(CB, requires_grad=False) 48 | self.SCB = Parameter(SCB.to(float16).unsqueeze(1), requires_grad=False) 49 | if original.bias is not None: 50 | self.bias = Parameter(original.bias.to(float16)) 51 | else: 52 | self.register_parameter("bias", None) 53 | 54 | def forward(self, x: Tensor) -> Any: 55 | weight = (self.CB * self.SCB) / 127 56 | return matmul_8bit(x, weight, self.bias, self.threshold) 57 | 58 | def __deepcopy__(self, memo: Any) -> "Wrapped8BitLinear": 59 | """Prevents an error when torch attempts to fakify our 8-bit parameters, which fails because 60 | they are a Tensor subclass.""" 61 | if hacks.in_fake_mode(): 62 | return self 63 | new_instance = type(self).__new__(type(self)) 64 | Module.__init__(new_instance) 65 | new_instance.CB = deepcopy(self.CB, memo) 66 | new_instance.SCB = deepcopy(self.SCB, memo) 67 | new_instance.threshold = deepcopy(self.threshold, memo) 68 | if self.bias is not None: 69 | new_instance.bias = Parameter(deepcopy(self.bias, memo)) 70 | else: 71 | new_instance.register_parameter("bias", None) 72 | return new_instance 73 | -------------------------------------------------------------------------------- /graphpatch/extraction/wrapped_layer_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import LayerNorm, Module, Parameter 6 | from torch.nn.functional import layer_norm as torch_layer_norm 7 | 8 | from .. import hacks 9 | 10 | 11 | @hacks.allow_in_graph # type: ignore 12 | def layer_norm( 13 | input: Tensor, normalized_shape: Tuple[int, ...], weight: Parameter, bias: Parameter, eps: float 14 | ) -> Tensor: 15 | if hacks.in_fake_mode(): 16 | return torch.zeros(input.shape, device=input.device, dtype=input.dtype) 17 | return torch_layer_norm(input, normalized_shape, weight, bias, eps) 18 | 19 | 20 | class WrappedLayerNorm(Module): 21 | def __init__(self, original: LayerNorm): 22 | super().__init__() 23 | self.normalized_shape = original.normalized_shape 24 | self.weight = original.weight 25 | self.bias = original.bias 26 | self.eps = original.eps 27 | 28 | def forward(self, input: Tensor) -> Any: 29 | return layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) 30 | -------------------------------------------------------------------------------- /graphpatch/meta/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph_meta import ( 2 | GraphMeta, 3 | NodeMeta, 4 | NodeShape, 5 | OutputArgumentIndex, 6 | wrap_graph_module, 7 | wrap_node_shape, 8 | wrap_output_argument_index, 9 | ) 10 | from .node_data import ( 11 | NodeData, 12 | NodeDataWrapper, 13 | PrettyPrintedNodeData, 14 | make_pretty_printed, 15 | wrap_node_data, 16 | ) 17 | from .node_path import NodePath, wrap_node_path 18 | 19 | __all__ = [ 20 | "GraphMeta", 21 | "NodeData", 22 | "NodeDataWrapper", 23 | "NodeMeta", 24 | "NodePath", 25 | "NodeShape", 26 | "OutputArgumentIndex", 27 | "PrettyPrintedNodeData", 28 | "make_pretty_printed", 29 | "wrap_graph_module", 30 | "wrap_node_data", 31 | "wrap_node_path", 32 | "wrap_node_shape", 33 | "wrap_output_argument_index", 34 | ] 35 | -------------------------------------------------------------------------------- /graphpatch/optional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evan-lloyd/graphpatch/d1ecec2949ea622eb04a4a364ef08942d6a8025f/graphpatch/optional/__init__.py -------------------------------------------------------------------------------- /graphpatch/optional/accelerate.py: -------------------------------------------------------------------------------- 1 | try: 2 | from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module 3 | 4 | AVAILABLE = True 5 | 6 | import torch 7 | from torch import mps 8 | 9 | # Fixes an annoying error with some combinations of accelerate/torch 10 | if not hasattr(torch, "mps"): 11 | torch.mps = mps 12 | 13 | except ImportError: 14 | 15 | class ModelHook: 16 | pass 17 | 18 | def add_hook_to_module(*args, **kwargs): 19 | pass 20 | 21 | def remove_hook_from_module(*args, **kwargs): 22 | pass 23 | 24 | AVAILABLE = False 25 | 26 | __all__ = ["ModelHook", "add_hook_to_module", "remove_hook_from_module"] 27 | -------------------------------------------------------------------------------- /graphpatch/optional/bitsandbytes.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | try: 5 | with contextlib.redirect_stdout(None), warnings.catch_warnings(): 6 | warnings.simplefilter("ignore") 7 | from bitsandbytes import MatmulLtState, matmul 8 | from bitsandbytes.autograd._functions import get_tile_inds, undo_layout 9 | from bitsandbytes.nn import Linear8bitLt 10 | AVAILABLE = True 11 | except ImportError: 12 | 13 | class Linear8bitLt: 14 | pass 15 | 16 | class MatmulLtState: 17 | pass 18 | 19 | def _undef(*args, **kwargs): 20 | raise NotImplementedError( 21 | "You seem to have encountered an error with graphpatch's optional dependency logic." 22 | " To work around, you can try making sure that the module bitsandbytes is available." 23 | ) 24 | 25 | get_tile_inds = _undef 26 | undo_layout = _undef 27 | matmul = _undef 28 | 29 | AVAILABLE = False 30 | 31 | __all__ = ["Linear8bitLt", "get_tile_inds", "undo_layout", "MatmulLtState", "matmul"] 32 | -------------------------------------------------------------------------------- /graphpatch/optional/dataclasses.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import field 3 | 4 | if sys.version_info >= (3, 10): 5 | # kw_only added in 3.10 6 | from dataclasses import dataclass 7 | else: 8 | # kw_only implementation adapted from 9 | # https://stackoverflow.com/questions/49908182/how-to-make-keyword-only-fields-with-dataclasses#comment129098154_49911616 10 | import dataclasses 11 | from contextlib import contextmanager 12 | from dataclasses import MISSING, dataclass as _orig_dataclass 13 | 14 | # Re-order fields so non-defaults come first. Would ordinarily be much more complicated than this, 15 | # but all our dataclasses are kw_only anyway, so we don't care about changing the order. 16 | @contextmanager 17 | def monkeypatch_init_fn(): 18 | orig_init_fn = dataclasses._init_fn 19 | 20 | def patched_init_fn(fields, *args, **kwargs): 21 | fields = sorted( 22 | fields, 23 | key=lambda f: f.default is MISSING and f.default_factory is MISSING, 24 | reverse=True, 25 | ) 26 | return orig_init_fn(fields, *args, **kwargs) 27 | 28 | dataclasses._init_fn = patched_init_fn 29 | try: 30 | yield 31 | finally: 32 | dataclasses._init_fn = orig_init_fn 33 | 34 | # Backport good-enough kw_only behavior 35 | def dataclass(cls=None, *, kw_only=False, **dc_kwargs): 36 | def make_dataclass(cls): 37 | with monkeypatch_init_fn(): 38 | dc = _orig_dataclass(cls, **dc_kwargs) 39 | if not kw_only: 40 | return dc 41 | _orig_init = dc.__init__ 42 | 43 | def kw_only_init(self, **kwargs): 44 | _orig_init(self, **kwargs) 45 | 46 | dc.__init__ = kw_only_init 47 | return dc 48 | 49 | if cls is None: 50 | return make_dataclass 51 | return make_dataclass(cls) 52 | 53 | 54 | __all__ = ["dataclass", "field"] 55 | -------------------------------------------------------------------------------- /graphpatch/optional/transformer_lens.py: -------------------------------------------------------------------------------- 1 | try: 2 | from transformer_lens import HookedTransformer, loading_from_pretrained 3 | 4 | AVAILABLE = True 5 | except ImportError: 6 | 7 | class HookedTransformer: 8 | pass 9 | 10 | loading_from_pretrained = None 11 | 12 | AVAILABLE = False 13 | 14 | __all__ = ["HookedTransformer", "loading_from_pretrained"] 15 | -------------------------------------------------------------------------------- /graphpatch/optional/transformers.py: -------------------------------------------------------------------------------- 1 | try: 2 | from transformers import ( 3 | AutoConfig, 4 | AutoModel, 5 | AutoTokenizer, 6 | BitsAndBytesConfig, 7 | GenerationConfig, 8 | GenerationMixin, 9 | GPT2LMHeadModel, 10 | LlamaForCausalLM, 11 | LlamaModel, 12 | LlamaTokenizer, 13 | PretrainedConfig, 14 | PreTrainedModel, 15 | PreTrainedTokenizer, 16 | ) 17 | from transformers.modeling_outputs import CausalLMOutput 18 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention 19 | from transformers.utils.generic import ModelOutput 20 | 21 | AVAILABLE = True 22 | 23 | import torch 24 | 25 | # Fixes an annoying error with some combinations of transformers/torch 26 | if not hasattr(torch, "compiler"): 27 | torch.compiler = torch._dynamo 28 | if not hasattr(torch.compiler, "is_compiling"): 29 | torch.compiler.is_compiling = torch._dynamo.is_compiling 30 | 31 | except ImportError: 32 | 33 | class AutoConfig: 34 | pass 35 | 36 | class AutoModel: 37 | pass 38 | 39 | class AutoTokenizer: 40 | pass 41 | 42 | class GPT2LMHeadModel: 43 | pass 44 | 45 | class LlamaForCausalLM: 46 | pass 47 | 48 | class PretrainedConfig: 49 | pass 50 | 51 | class PreTrainedModel: 52 | pass 53 | 54 | class PreTrainedTokenizer: 55 | pass 56 | 57 | class GPT2Attention: 58 | pass 59 | 60 | class LlamaModel: 61 | pass 62 | 63 | class LlamaTokenizer: 64 | pass 65 | 66 | class BitsAndBytesConfig: 67 | pass 68 | 69 | class GenerationMixin: 70 | pass 71 | 72 | class ModelOutput: 73 | pass 74 | 75 | class GenerationConfig: 76 | pass 77 | 78 | class CausalLMOutput: 79 | pass 80 | 81 | AVAILABLE = False 82 | 83 | __all__ = [ 84 | "AutoConfig", 85 | "AutoModel", 86 | "AutoTokenizer", 87 | "BitsAndBytesConfig", 88 | "CausalLMOutput", 89 | "GenerationConfig", 90 | "GenerationMixin", 91 | "GPT2Attention", 92 | "GPT2LMHeadModel", 93 | "LlamaForCausalLM", 94 | "LlamaModel", 95 | "LlamaTokenizer", 96 | "ModelOutput", 97 | "PretrainedConfig", 98 | "PreTrainedModel", 99 | "PreTrainedTokenizer", 100 | ] 101 | -------------------------------------------------------------------------------- /graphpatch/optional/typing_extensions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info >= (3, 10): 4 | # Version 3.10: https://peps.python.org/pep-0647/ 5 | from typing import TypeAlias, TypeGuard 6 | else: 7 | try: 8 | from typing_extensions import TypeAlias, TypeGuard 9 | except ImportError: 10 | TypeGuard = None 11 | TypeAlias = None 12 | 13 | __all__ = ["TypeAlias", "TypeGuard"] 14 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | strict = True 3 | disable_error_code = import-untyped, no-untyped-call 4 | exclude = optional 5 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = 3 | tests 4 | filterwarnings = 5 | ignore::UserWarning:bitsandbytes 6 | ignore:\nCan not relate snapshot:UserWarning:syrupy 7 | ignore:Overriding torch_dtype=None 8 | ignore:invalid escape sequence:DeprecationWarning 9 | ignore:`_is_quantized_training_enabled` is going to be deprecated:FutureWarning 10 | env = 11 | TOKENIZERS_PARALLELISM=false 12 | addopts = --basetemp=.pytest_cache/tmp 13 | -------------------------------------------------------------------------------- /scripts/export_extras_versions.py: -------------------------------------------------------------------------------- 1 | import re 2 | import subprocess 3 | import sys 4 | 5 | try: 6 | result = subprocess.run( 7 | "uv export --all-extras", 8 | check=True, 9 | shell=True, 10 | capture_output=True, 11 | ) 12 | except subprocess.CalledProcessError as exc: 13 | print(exc.stdout, exc.stderr) 14 | raise 15 | 16 | PACKAGE_NAMES = ( 17 | "accelerate", 18 | "bitsandbytes", 19 | "transformers", 20 | "numpy", 21 | "sentencepiece", 22 | "transformer-lens", 23 | ) 24 | 25 | wanted_lines = [ 26 | (line, f"\t{match.group(0)}") 27 | for line in result.stdout.decode().split("\n") 28 | if (match := re.match(f"({'|'.join(PACKAGE_NAMES)}).+?(?= ;?)", line)) 29 | ] 30 | out_lines = [] 31 | for line in wanted_lines: 32 | version_requirement_match = re.search(r"python_full_version ([!|<|>|=]+) ('.+?')", line[0]) 33 | sys_platform_match = re.search(r"sys_platform ([!|<|>|=]+) '(.+?)'", line[0]) 34 | out_line = line[1] 35 | if version_requirement_match: 36 | out_line += f" ; {version_requirement_match.group(0)}" 37 | if sys_platform_match: 38 | out_line += f" {'and' if version_requirement_match else ';'} {sys_platform_match.group(0)}" 39 | out_lines.append(out_line) 40 | 41 | 42 | result = ".. code::\n\n" + "\n".join(out_lines) + "\n\n" 43 | if len(sys.argv) > 1: 44 | open(sys.argv[1], "w").write(result) 45 | else: 46 | print(result) 47 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evan-lloyd/graphpatch/d1ecec2949ea622eb04a4a364ef08942d6a8025f/tests/__init__.py -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test_node_path_code[2_0] 3 | def forward(self, x : torch.Tensor): 4 | child_a = self.child_a(x); x = None 5 | getitem = child_a[0] 6 | getitem_1 = getitem[0]; getitem = None 7 | getitem_2 = child_a[1]; child_a = None 8 | getitem_3 = getitem_2[0] 9 | getitem_4 = getitem_3[0] 10 | getitem_5 = getitem_4[0]; getitem_4 = None 11 | getitem_6 = getitem_3[1]; getitem_3 = None 12 | getitem_7 = getitem_2[1] 13 | getitem_8 = getitem_7[0] 14 | getitem_9 = getitem_8[0]; getitem_8 = None 15 | getitem_10 = getitem_7[1]; getitem_7 = None 16 | getitem_11 = getitem_2[2]; getitem_2 = None 17 | getitem_12 = getitem_11[0] 18 | getitem_13 = getitem_12[0]; getitem_12 = None 19 | getitem_14 = getitem_11[1]; getitem_11 = None 20 | linear_0 = getattr(self.linear, "0")(getitem_1); getitem_1 = None 21 | add = linear_0 + 2 22 | linear_1 = getattr(self.linear, "1")(add); add = None 23 | return ((linear_0,), [([getitem_5], getitem_6), ([getitem_9], getitem_10), ([getitem_13], getitem_14)], {'nested_dict': [(linear_1,)]}) 24 | # --- 25 | # name: test_node_path_code[2_1] 26 | def forward(self, x : torch.Tensor): 27 | child_a = self.child_a(x); x = None 28 | getitem = child_a[0] 29 | getitem_1 = getitem[0]; getitem = None 30 | getitem_2 = child_a[1]; child_a = None 31 | getitem_3 = getitem_2[0] 32 | getitem_4 = getitem_3[0] 33 | getitem_5 = getitem_4[0]; getitem_4 = None 34 | getitem_6 = getitem_3[1]; getitem_3 = None 35 | getitem_7 = getitem_2[1] 36 | getitem_8 = getitem_7[0] 37 | getitem_9 = getitem_8[0]; getitem_8 = None 38 | getitem_10 = getitem_7[1]; getitem_7 = None 39 | getitem_11 = getitem_2[2]; getitem_2 = None 40 | getitem_12 = getitem_11[0] 41 | getitem_13 = getitem_12[0]; getitem_12 = None 42 | getitem_14 = getitem_11[1]; getitem_11 = None 43 | linear_0 = getattr(self.linear, "0")(getitem_1); getitem_1 = None 44 | add = linear_0 + 2 45 | linear_1 = getattr(self.linear, "1")(add); add = None 46 | return ((linear_0,), [([getitem_5], getitem_6), ([getitem_9], getitem_10), ([getitem_13], getitem_14)], {'nested_dict': [(linear_1,)]}) 47 | # --- 48 | -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_accelerate_pretrained_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─input_ids: Tensor(1, 100) 4 | ├─kwargs 5 | ├─getitem: Tensor(1, 100) 6 | ├─size: Size(2) 7 | │ ├─sub_0: int 8 | │ └─sub_1: int 9 | ├─getitem_1: int 10 | ├─view: Tensor(1, 1, 100) 11 | ├─repeat: Tensor(1, 100, 100) 12 | ├─to: Tensor(1, 100, 100) 13 | ├─model: CompiledGraphModule 14 | │ ├─root_inputs: Tensor(1, 100, 100) 15 | │ ├─child_a: CompiledGraphModule 16 | │ │ ├─a_inputs: Tensor(1, 100, 100) 17 | │ │ ├─grandchildren_b_0: CompiledGraphModule 18 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 19 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 20 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 21 | │ │ │ ├─mul: Tensor(1, 100, 100) 22 | │ │ │ ├─c: CompiledGraphModule 23 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 24 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 25 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 26 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 27 | │ │ │ │ ├─add: Tensor(1, 100, 100) 28 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 29 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 30 | │ │ │ │ ├─c_linear: CompiledGraphModule 31 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 32 | │ │ │ │ │ ├─weight: Tensor(100, 100) 33 | │ │ │ │ │ ├─bias: Tensor(100) 34 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 35 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 36 | │ │ │ │ └─output: Tensor(1, 100, 100) 37 | │ │ │ ├─b_linear: CompiledGraphModule 38 | │ │ │ │ ├─input: Tensor(1, 100, 100) 39 | │ │ │ │ ├─weight: Tensor(100, 100) 40 | │ │ │ │ ├─bias: Tensor(100) 41 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 42 | │ │ │ │ └─output: Tensor(1, 100, 100) 43 | │ │ │ └─output: Tensor(1, 100, 100) 44 | │ │ ├─grandchildren_b_1: CompiledGraphModule 45 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 46 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 47 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 48 | │ │ │ ├─mul: Tensor(1, 100, 100) 49 | │ │ │ ├─c: CompiledGraphModule 50 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 51 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 52 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 53 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 54 | │ │ │ │ ├─add: Tensor(1, 100, 100) 55 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 56 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 57 | │ │ │ │ ├─c_linear: CompiledGraphModule 58 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 59 | │ │ │ │ │ ├─weight: Tensor(100, 100) 60 | │ │ │ │ │ ├─bias: Tensor(100) 61 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 62 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 63 | │ │ │ │ └─output: Tensor(1, 100, 100) 64 | │ │ │ ├─b_linear: CompiledGraphModule 65 | │ │ │ │ ├─input: Tensor(1, 100, 100) 66 | │ │ │ │ ├─weight: Tensor(100, 100) 67 | │ │ │ │ ├─bias: Tensor(100) 68 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 69 | │ │ │ │ └─output: Tensor(1, 100, 100) 70 | │ │ │ └─output: Tensor(1, 100, 100) 71 | │ │ ├─grandchildren_b_2: CompiledGraphModule 72 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 73 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 74 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 75 | │ │ │ ├─mul: Tensor(1, 100, 100) 76 | │ │ │ ├─c: CompiledGraphModule 77 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 78 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 79 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 80 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 81 | │ │ │ │ ├─add: Tensor(1, 100, 100) 82 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 83 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 84 | │ │ │ │ ├─c_linear: CompiledGraphModule 85 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 86 | │ │ │ │ │ ├─weight: Tensor(100, 100) 87 | │ │ │ │ │ ├─bias: Tensor(100) 88 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 89 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 90 | │ │ │ │ └─output: Tensor(1, 100, 100) 91 | │ │ │ ├─b_linear: CompiledGraphModule 92 | │ │ │ │ ├─input: Tensor(1, 100, 100) 93 | │ │ │ │ ├─weight: Tensor(100, 100) 94 | │ │ │ │ ├─bias: Tensor(100) 95 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 96 | │ │ │ │ └─output: Tensor(1, 100, 100) 97 | │ │ │ └─output: Tensor(1, 100, 100) 98 | │ │ ├─a_linear: CompiledGraphModule 99 | │ │ │ ├─input: Tensor(1, 100, 100) 100 | │ │ │ ├─weight: Tensor(100, 100) 101 | │ │ │ ├─bias: Tensor(100) 102 | │ │ │ ├─linear: Tensor(1, 100, 100) 103 | │ │ │ └─output: Tensor(1, 100, 100) 104 | │ │ └─output: Tensor(1, 100, 100) 105 | │ ├─root_linear: CompiledGraphModule 106 | │ │ ├─input: Tensor(1, 100, 100) 107 | │ │ ├─weight: Tensor(100, 100) 108 | │ │ ├─bias: Tensor(100) 109 | │ │ ├─linear: Tensor(1, 100, 100) 110 | │ │ └─output: Tensor(1, 100, 100) 111 | │ └─output: Tensor(1, 100, 100) 112 | └─output: CausalLMOutput(1) 113 | └─logits: Tensor(1, 100, 100) 114 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_accelerate_pretrained_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─input_ids: Tensor(1, 100) 4 | ├─kwargs 5 | ├─getitem: Tensor(1, 100) 6 | ├─size: Size(2) 7 | │ ├─sub_0: int 8 | │ └─sub_1: int 9 | ├─getitem_1: int 10 | ├─getitem_2: int 11 | ├─view: Tensor(1, 1, 100) 12 | ├─repeat: Tensor(1, 100, 100) 13 | ├─to: Tensor(1, 100, 100) 14 | ├─model: CompiledGraphModule 15 | │ ├─root_inputs: Tensor(1, 100, 100) 16 | │ ├─child_a: CompiledGraphModule 17 | │ │ ├─a_inputs: Tensor(1, 100, 100) 18 | │ │ ├─grandchildren_b_0: CompiledGraphModule 19 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 20 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 21 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 22 | │ │ │ ├─mul: Tensor(1, 100, 100) 23 | │ │ │ ├─c: CompiledGraphModule 24 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 25 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 26 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 27 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 28 | │ │ │ │ ├─add: Tensor(1, 100, 100) 29 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 30 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 31 | │ │ │ │ ├─c_linear: CompiledGraphModule 32 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 33 | │ │ │ │ │ ├─weight: Tensor(100, 100) 34 | │ │ │ │ │ ├─bias: Tensor(100) 35 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 36 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 37 | │ │ │ │ └─output: Tensor(1, 100, 100) 38 | │ │ │ ├─b_linear: CompiledGraphModule 39 | │ │ │ │ ├─input: Tensor(1, 100, 100) 40 | │ │ │ │ ├─weight: Tensor(100, 100) 41 | │ │ │ │ ├─bias: Tensor(100) 42 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 43 | │ │ │ │ └─output: Tensor(1, 100, 100) 44 | │ │ │ └─output: Tensor(1, 100, 100) 45 | │ │ ├─grandchildren_b_1: CompiledGraphModule 46 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 47 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 48 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 49 | │ │ │ ├─mul: Tensor(1, 100, 100) 50 | │ │ │ ├─c: CompiledGraphModule 51 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 52 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 53 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 54 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 55 | │ │ │ │ ├─add: Tensor(1, 100, 100) 56 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 57 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 58 | │ │ │ │ ├─c_linear: CompiledGraphModule 59 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 60 | │ │ │ │ │ ├─weight: Tensor(100, 100) 61 | │ │ │ │ │ ├─bias: Tensor(100) 62 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 63 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 64 | │ │ │ │ └─output: Tensor(1, 100, 100) 65 | │ │ │ ├─b_linear: CompiledGraphModule 66 | │ │ │ │ ├─input: Tensor(1, 100, 100) 67 | │ │ │ │ ├─weight: Tensor(100, 100) 68 | │ │ │ │ ├─bias: Tensor(100) 69 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 70 | │ │ │ │ └─output: Tensor(1, 100, 100) 71 | │ │ │ └─output: Tensor(1, 100, 100) 72 | │ │ ├─grandchildren_b_2: CompiledGraphModule 73 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 74 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 75 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 76 | │ │ │ ├─mul: Tensor(1, 100, 100) 77 | │ │ │ ├─c: CompiledGraphModule 78 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 79 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 80 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 81 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 82 | │ │ │ │ ├─add: Tensor(1, 100, 100) 83 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 84 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 85 | │ │ │ │ ├─c_linear: CompiledGraphModule 86 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 87 | │ │ │ │ │ ├─weight: Tensor(100, 100) 88 | │ │ │ │ │ ├─bias: Tensor(100) 89 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 90 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 91 | │ │ │ │ └─output: Tensor(1, 100, 100) 92 | │ │ │ ├─b_linear: CompiledGraphModule 93 | │ │ │ │ ├─input: Tensor(1, 100, 100) 94 | │ │ │ │ ├─weight: Tensor(100, 100) 95 | │ │ │ │ ├─bias: Tensor(100) 96 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 97 | │ │ │ │ └─output: Tensor(1, 100, 100) 98 | │ │ │ └─output: Tensor(1, 100, 100) 99 | │ │ ├─a_linear: CompiledGraphModule 100 | │ │ │ ├─input: Tensor(1, 100, 100) 101 | │ │ │ ├─weight: Tensor(100, 100) 102 | │ │ │ ├─bias: Tensor(100) 103 | │ │ │ ├─linear: Tensor(1, 100, 100) 104 | │ │ │ └─output: Tensor(1, 100, 100) 105 | │ │ └─output: Tensor(1, 100, 100) 106 | │ ├─root_linear: CompiledGraphModule 107 | │ │ ├─input: Tensor(1, 100, 100) 108 | │ │ ├─weight: Tensor(100, 100) 109 | │ │ ├─bias: Tensor(100) 110 | │ │ ├─linear: Tensor(1, 100, 100) 111 | │ │ └─output: Tensor(1, 100, 100) 112 | │ └─output: Tensor(1, 100, 100) 113 | └─output: CausalLMOutput(1) 114 | └─logits: Tensor(1, 100, 100) 115 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_attribute_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─attribute_to_serialize: float 5 | ├─add: Tensor(3, 2) 6 | ├─linear: CompiledGraphModule 7 | │ ├─input: Tensor(3, 2) 8 | │ ├─weight: Tensor(3, 2) 9 | │ ├─bias: Tensor(3) 10 | │ ├─linear: Tensor(3, 3) 11 | │ └─output: Tensor(3, 3) 12 | └─output: Tensor(3, 3) 13 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_attribute_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─add: Tensor(3, 2) 5 | ├─linear: CompiledGraphModule 6 | │ ├─input: Tensor(3, 2) 7 | │ ├─weight: Tensor(3, 2) 8 | │ ├─bias: Tensor(3) 9 | │ ├─linear: Tensor(3, 3) 10 | │ └─output: Tensor(3, 3) 11 | └─output: Tensor(3, 3) 12 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_attribute_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─add: Tensor(3, 2) 5 | ├─linear: CompiledGraphModule 6 | │ ├─input: Tensor(3, 2) 7 | │ ├─weight: Tensor(3, 2) 8 | │ ├─bias: Tensor(3) 9 | │ ├─linear: Tensor(3, 3) 10 | │ └─output: Tensor(3, 3) 11 | └─output: Tensor(3, 3) 12 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_attribute_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─add: Tensor(3, 2) 5 | ├─linear: CompiledGraphModule 6 | │ ├─input: Tensor(3, 2) 7 | │ ├─weight: Tensor(3, 2) 8 | │ ├─bias: Tensor(3) 9 | │ ├─linear: Tensor(3, 3) 10 | │ └─output: Tensor(3, 3) 11 | └─output: Tensor(3, 3) 12 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_attribute_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─add: Tensor(3, 2) 5 | ├─linear: CompiledGraphModule 6 | │ ├─input: Tensor(3, 2) 7 | │ ├─weight: Tensor(3, 2) 8 | │ ├─bias: Tensor(3) 9 | │ ├─linear: Tensor(3, 3) 10 | │ └─output: Tensor(3, 3) 11 | └─output: Tensor(3, 3) 12 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_buffer_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | ├─buffer: Tensor(3, 3) 11 | ├─add: Tensor(3, 3) 12 | └─output: Tensor(3, 3) 13 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_buffer_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | ├─buffer: Tensor(3, 3) 11 | ├─add: Tensor(3, 3) 12 | └─output: Tensor(3, 3) 13 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_buffer_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | ├─buffer: Tensor(3, 3) 11 | ├─add: Tensor(3, 3) 12 | └─output: Tensor(3, 3) 13 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_buffer_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | ├─buffer: Tensor(3, 3) 11 | ├─add: Tensor(3, 3) 12 | └─output: Tensor(3, 3) 13 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_buffer_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | ├─buffer: Tensor(3, 3) 11 | ├─add: Tensor(3, 3) 12 | └─output: Tensor(3, 3) 13 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_disk_offload_pretrained_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─input_ids: Tensor(1, 100) 4 | ├─kwargs 5 | ├─getitem: Tensor(1, 100) 6 | ├─size: Size(2) 7 | │ ├─sub_0: int 8 | │ └─sub_1: int 9 | ├─getitem_1: int 10 | ├─view: Tensor(1, 1, 100) 11 | ├─repeat: Tensor(1, 100, 100) 12 | ├─to: Tensor(1, 100, 100) 13 | ├─model: CompiledGraphModule 14 | │ ├─root_inputs: Tensor(1, 100, 100) 15 | │ ├─child_a: CompiledGraphModule 16 | │ │ ├─a_inputs: Tensor(1, 100, 100) 17 | │ │ ├─grandchildren_b_0: CompiledGraphModule 18 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 19 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 20 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 21 | │ │ │ ├─mul: Tensor(1, 100, 100) 22 | │ │ │ ├─c: CompiledGraphModule 23 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 24 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 25 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 26 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 27 | │ │ │ │ ├─add: Tensor(1, 100, 100) 28 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 29 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 30 | │ │ │ │ ├─c_linear: CompiledGraphModule 31 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 32 | │ │ │ │ │ ├─weight: Tensor(100, 100) 33 | │ │ │ │ │ ├─bias: Tensor(100) 34 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 35 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 36 | │ │ │ │ └─output: Tensor(1, 100, 100) 37 | │ │ │ ├─b_linear: CompiledGraphModule 38 | │ │ │ │ ├─input: Tensor(1, 100, 100) 39 | │ │ │ │ ├─weight: Tensor(100, 100) 40 | │ │ │ │ ├─bias: Tensor(100) 41 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 42 | │ │ │ │ └─output: Tensor(1, 100, 100) 43 | │ │ │ └─output: Tensor(1, 100, 100) 44 | │ │ ├─grandchildren_b_1: CompiledGraphModule 45 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 46 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 47 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 48 | │ │ │ ├─mul: Tensor(1, 100, 100) 49 | │ │ │ ├─c: CompiledGraphModule 50 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 51 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 52 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 53 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 54 | │ │ │ │ ├─add: Tensor(1, 100, 100) 55 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 56 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 57 | │ │ │ │ ├─c_linear: CompiledGraphModule 58 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 59 | │ │ │ │ │ ├─weight: Tensor(100, 100) 60 | │ │ │ │ │ ├─bias: Tensor(100) 61 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 62 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 63 | │ │ │ │ └─output: Tensor(1, 100, 100) 64 | │ │ │ ├─b_linear: CompiledGraphModule 65 | │ │ │ │ ├─input: Tensor(1, 100, 100) 66 | │ │ │ │ ├─weight: Tensor(100, 100) 67 | │ │ │ │ ├─bias: Tensor(100) 68 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 69 | │ │ │ │ └─output: Tensor(1, 100, 100) 70 | │ │ │ └─output: Tensor(1, 100, 100) 71 | │ │ ├─grandchildren_b_2: CompiledGraphModule 72 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 73 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 74 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 75 | │ │ │ ├─mul: Tensor(1, 100, 100) 76 | │ │ │ ├─c: CompiledGraphModule 77 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 78 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 79 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 80 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 81 | │ │ │ │ ├─add: Tensor(1, 100, 100) 82 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 83 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 84 | │ │ │ │ ├─c_linear: CompiledGraphModule 85 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 86 | │ │ │ │ │ ├─weight: Tensor(100, 100) 87 | │ │ │ │ │ ├─bias: Tensor(100) 88 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 89 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 90 | │ │ │ │ └─output: Tensor(1, 100, 100) 91 | │ │ │ ├─b_linear: CompiledGraphModule 92 | │ │ │ │ ├─input: Tensor(1, 100, 100) 93 | │ │ │ │ ├─weight: Tensor(100, 100) 94 | │ │ │ │ ├─bias: Tensor(100) 95 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 96 | │ │ │ │ └─output: Tensor(1, 100, 100) 97 | │ │ │ └─output: Tensor(1, 100, 100) 98 | │ │ ├─a_linear: CompiledGraphModule 99 | │ │ │ ├─input: Tensor(1, 100, 100) 100 | │ │ │ ├─weight: Tensor(100, 100) 101 | │ │ │ ├─bias: Tensor(100) 102 | │ │ │ ├─linear: Tensor(1, 100, 100) 103 | │ │ │ └─output: Tensor(1, 100, 100) 104 | │ │ └─output: Tensor(1, 100, 100) 105 | │ ├─root_linear: CompiledGraphModule 106 | │ │ ├─input: Tensor(1, 100, 100) 107 | │ │ ├─weight: Tensor(100, 100) 108 | │ │ ├─bias: Tensor(100) 109 | │ │ ├─linear: Tensor(1, 100, 100) 110 | │ │ └─output: Tensor(1, 100, 100) 111 | │ └─output: Tensor(1, 100, 100) 112 | └─output: CausalLMOutput(1) 113 | └─logits: Tensor(1, 100, 100) 114 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_graph_break_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foo: int 5 | ├─sub_shape: tuple(2) 6 | │ ├─sub_0: int 7 | │ └─sub_1: int 8 | ├─bar: int 9 | ├─instance_value: int 10 | ├─shadowed_class_var: int 11 | ├─linear_0: CompiledGraphModule 12 | │ ├─input: Tensor(3, 3) 13 | │ ├─weight: Tensor(3, 3) 14 | │ ├─bias: Tensor(3) 15 | │ ├─linear: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | ├─linear_1: CompiledGraphModule 18 | │ ├─input: Tensor(3, 3) 19 | │ ├─weight: Tensor(3, 3) 20 | │ ├─bias: Tensor(3) 21 | │ ├─linear: Tensor(3, 3) 22 | │ └─output: Tensor(3, 3) 23 | ├─linear_2: CompiledGraphModule 24 | │ ├─input: Tensor(3, 3) 25 | │ ├─weight: Tensor(3, 3) 26 | │ ├─bias: Tensor(3) 27 | │ ├─linear: Tensor(3, 3) 28 | │ └─output: Tensor(3, 3) 29 | ├─unused_submodule: OpaqueGraphModule 30 | │ ├─input 31 | │ ├─constants 32 | │ ├─bias 33 | │ ├─in_features 34 | │ ├─out_features 35 | │ ├─weight 36 | │ └─output 37 | └─output: Tensor(3, 3) 38 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_graph_break_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foo: int 5 | ├─sub_shape: tuple(2) 6 | │ ├─sub_0: int 7 | │ └─sub_1: int 8 | ├─bar: int 9 | ├─instance_value: int 10 | ├─shadowed_class_var: int 11 | ├─linear_0: CompiledGraphModule 12 | │ ├─input: Tensor(3, 3) 13 | │ ├─weight: Tensor(3, 3) 14 | │ ├─bias: Tensor(3) 15 | │ ├─linear: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | ├─linear_1: CompiledGraphModule 18 | │ ├─input: Tensor(3, 3) 19 | │ ├─weight: Tensor(3, 3) 20 | │ ├─bias: Tensor(3) 21 | │ ├─linear: Tensor(3, 3) 22 | │ └─output: Tensor(3, 3) 23 | ├─linear_2: CompiledGraphModule 24 | │ ├─input: Tensor(3, 3) 25 | │ ├─weight: Tensor(3, 3) 26 | │ ├─bias: Tensor(3) 27 | │ ├─linear: Tensor(3, 3) 28 | │ └─output: Tensor(3, 3) 29 | ├─unused_submodule: OpaqueGraphModule 30 | │ ├─input 31 | │ ├─constants 32 | │ ├─bias 33 | │ ├─in_features 34 | │ ├─out_features 35 | │ ├─weight 36 | │ └─output 37 | └─output: Tensor(3, 3) 38 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_graph_break_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foo: int 5 | ├─sub_shape: tuple(2) 6 | │ ├─sub_0: int 7 | │ └─sub_1: int 8 | ├─bar: int 9 | ├─instance_value: int 10 | ├─shadowed_class_var: int 11 | ├─linear_0: CompiledGraphModule 12 | │ ├─input: Tensor(3, 3) 13 | │ ├─weight: Tensor(3, 3) 14 | │ ├─bias: Tensor(3) 15 | │ ├─linear: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | ├─linear_1: CompiledGraphModule 18 | │ ├─input: Tensor(3, 3) 19 | │ ├─weight: Tensor(3, 3) 20 | │ ├─bias: Tensor(3) 21 | │ ├─linear: Tensor(3, 3) 22 | │ └─output: Tensor(3, 3) 23 | ├─linear_2: CompiledGraphModule 24 | │ ├─input: Tensor(3, 3) 25 | │ ├─weight: Tensor(3, 3) 26 | │ ├─bias: Tensor(3) 27 | │ ├─linear: Tensor(3, 3) 28 | │ └─output: Tensor(3, 3) 29 | ├─unused_submodule: OpaqueGraphModule 30 | │ ├─input 31 | │ ├─constants 32 | │ ├─bias 33 | │ ├─in_features 34 | │ ├─out_features 35 | │ ├─weight 36 | │ └─output 37 | └─output: Tensor(3, 3) 38 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_graph_break_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foo: int 5 | ├─sub_shape: tuple(2) 6 | │ ├─sub_0: int 7 | │ └─sub_1: int 8 | ├─bar: int 9 | ├─instance_value: int 10 | ├─shadowed_class_var: int 11 | ├─linear_0: CompiledGraphModule 12 | │ ├─input: Tensor(3, 3) 13 | │ ├─weight: Tensor(3, 3) 14 | │ ├─bias: Tensor(3) 15 | │ ├─linear: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | ├─linear_1: CompiledGraphModule 18 | │ ├─input: Tensor(3, 3) 19 | │ ├─weight: Tensor(3, 3) 20 | │ ├─bias: Tensor(3) 21 | │ ├─linear: Tensor(3, 3) 22 | │ └─output: Tensor(3, 3) 23 | ├─linear_2: CompiledGraphModule 24 | │ ├─input: Tensor(3, 3) 25 | │ ├─weight: Tensor(3, 3) 26 | │ ├─bias: Tensor(3) 27 | │ ├─linear: Tensor(3, 3) 28 | │ └─output: Tensor(3, 3) 29 | ├─unused_submodule: OpaqueGraphModule 30 | │ ├─input 31 | │ ├─constants 32 | │ ├─bias 33 | │ ├─in_features 34 | │ ├─out_features 35 | │ ├─weight 36 | │ └─output 37 | └─output: Tensor(3, 3) 38 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_graph_break_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foo: int 5 | ├─sub_shape: tuple(2) 6 | │ ├─sub_0: int 7 | │ └─sub_1: int 8 | ├─bar: int 9 | ├─instance_value: int 10 | ├─shadowed_class_var: int 11 | ├─linear_0: CompiledGraphModule 12 | │ ├─input: Tensor(3, 3) 13 | │ ├─weight: Tensor(3, 3) 14 | │ ├─bias: Tensor(3) 15 | │ ├─linear: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | ├─linear_1: CompiledGraphModule 18 | │ ├─input: Tensor(3, 3) 19 | │ ├─weight: Tensor(3, 3) 20 | │ ├─bias: Tensor(3) 21 | │ ├─linear: Tensor(3, 3) 22 | │ └─output: Tensor(3, 3) 23 | ├─linear_2: CompiledGraphModule 24 | │ ├─input: Tensor(3, 3) 25 | │ ├─weight: Tensor(3, 3) 26 | │ ├─bias: Tensor(3) 27 | │ ├─linear: Tensor(3, 3) 28 | │ └─output: Tensor(3, 3) 29 | ├─unused_submodule: OpaqueGraphModule 30 | │ ├─input 31 | │ ├─constants 32 | │ ├─bias 33 | │ ├─in_features 34 | │ ├─out_features 35 | │ ├─weight 36 | │ └─output 37 | └─output: Tensor(3, 3) 38 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_layer_norm_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─ln: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─eps: float 7 | │ ├─weight: Tensor(2) 8 | │ ├─bias: Tensor(2) 9 | │ ├─layer_norm: Tensor(3, 2) 10 | │ └─output: Tensor(3, 2) 11 | └─output: Tensor(3, 2) 12 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_layer_norm_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─ln: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(2) 7 | │ ├─bias: Tensor(2) 8 | │ ├─layer_norm: Tensor(3, 2) 9 | │ └─output: Tensor(3, 2) 10 | └─output: Tensor(3, 2) 11 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_layer_norm_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─ln: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(2) 7 | │ ├─bias: Tensor(2) 8 | │ ├─layer_norm: Tensor(3, 2) 9 | │ └─output: Tensor(3, 2) 10 | └─output: Tensor(3, 2) 11 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_layer_norm_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─ln: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─normalized_shape: tuple(1) 7 | │ │ └─sub_0: int 8 | │ ├─getitem: int 9 | │ ├─weight: Tensor(2) 10 | │ ├─bias: Tensor(2) 11 | │ ├─layer_norm: Tensor(3, 2) 12 | │ └─output: Tensor(3, 2) 13 | └─output: Tensor(3, 2) 14 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_layer_norm_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─ln: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(2) 7 | │ ├─bias: Tensor(2) 8 | │ ├─layer_norm: Tensor(3, 2) 9 | │ └─output: Tensor(3, 2) 10 | └─output: Tensor(3, 2) 11 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_minimal_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | └─output: Tensor(3, 3) 11 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_minimal_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | └─output: Tensor(3, 3) 11 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_minimal_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | └─output: Tensor(3, 3) 11 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_minimal_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | └─output: Tensor(3, 3) 11 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_minimal_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | └─output: Tensor(3, 3) 11 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_mixed_cpu_pretrained_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─input_ids: Tensor(1, 100) 4 | ├─kwargs 5 | ├─getitem: Tensor(1, 100) 6 | ├─size: Size(2) 7 | │ ├─sub_0: int 8 | │ └─sub_1: int 9 | ├─getitem_1: int 10 | ├─view: Tensor(1, 1, 100) 11 | ├─repeat: Tensor(1, 100, 100) 12 | ├─to: Tensor(1, 100, 100) 13 | ├─model: CompiledGraphModule 14 | │ ├─root_inputs: Tensor(1, 100, 100) 15 | │ ├─child_a: CompiledGraphModule 16 | │ │ ├─a_inputs: Tensor(1, 100, 100) 17 | │ │ ├─grandchildren_b_0: CompiledGraphModule 18 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 19 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 20 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 21 | │ │ │ ├─mul: Tensor(1, 100, 100) 22 | │ │ │ ├─c: CompiledGraphModule 23 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 24 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 25 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 26 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 27 | │ │ │ │ ├─add: Tensor(1, 100, 100) 28 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 29 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 30 | │ │ │ │ ├─c_linear: CompiledGraphModule 31 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 32 | │ │ │ │ │ ├─weight: Tensor(100, 100) 33 | │ │ │ │ │ ├─bias: Tensor(100) 34 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 35 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 36 | │ │ │ │ └─output: Tensor(1, 100, 100) 37 | │ │ │ ├─b_linear: CompiledGraphModule 38 | │ │ │ │ ├─input: Tensor(1, 100, 100) 39 | │ │ │ │ ├─weight: Tensor(100, 100) 40 | │ │ │ │ ├─bias: Tensor(100) 41 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 42 | │ │ │ │ └─output: Tensor(1, 100, 100) 43 | │ │ │ └─output: Tensor(1, 100, 100) 44 | │ │ ├─grandchildren_b_1: CompiledGraphModule 45 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 46 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 47 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 48 | │ │ │ ├─mul: Tensor(1, 100, 100) 49 | │ │ │ ├─c: CompiledGraphModule 50 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 51 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 52 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 53 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 54 | │ │ │ │ ├─add: Tensor(1, 100, 100) 55 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 56 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 57 | │ │ │ │ ├─c_linear: CompiledGraphModule 58 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 59 | │ │ │ │ │ ├─weight: Tensor(100, 100) 60 | │ │ │ │ │ ├─bias: Tensor(100) 61 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 62 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 63 | │ │ │ │ └─output: Tensor(1, 100, 100) 64 | │ │ │ ├─b_linear: CompiledGraphModule 65 | │ │ │ │ ├─input: Tensor(1, 100, 100) 66 | │ │ │ │ ├─weight: Tensor(100, 100) 67 | │ │ │ │ ├─bias: Tensor(100) 68 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 69 | │ │ │ │ └─output: Tensor(1, 100, 100) 70 | │ │ │ └─output: Tensor(1, 100, 100) 71 | │ │ ├─grandchildren_b_2: CompiledGraphModule 72 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 73 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 74 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 75 | │ │ │ ├─mul: Tensor(1, 100, 100) 76 | │ │ │ ├─c: CompiledGraphModule 77 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 78 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 79 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 80 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 81 | │ │ │ │ ├─add: Tensor(1, 100, 100) 82 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 83 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 84 | │ │ │ │ ├─c_linear: CompiledGraphModule 85 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 86 | │ │ │ │ │ ├─weight: Tensor(100, 100) 87 | │ │ │ │ │ ├─bias: Tensor(100) 88 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 89 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 90 | │ │ │ │ └─output: Tensor(1, 100, 100) 91 | │ │ │ ├─b_linear: CompiledGraphModule 92 | │ │ │ │ ├─input: Tensor(1, 100, 100) 93 | │ │ │ │ ├─weight: Tensor(100, 100) 94 | │ │ │ │ ├─bias: Tensor(100) 95 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 96 | │ │ │ │ └─output: Tensor(1, 100, 100) 97 | │ │ │ └─output: Tensor(1, 100, 100) 98 | │ │ ├─a_linear: CompiledGraphModule 99 | │ │ │ ├─input: Tensor(1, 100, 100) 100 | │ │ │ ├─weight: Tensor(100, 100) 101 | │ │ │ ├─bias: Tensor(100) 102 | │ │ │ ├─linear: Tensor(1, 100, 100) 103 | │ │ │ └─output: Tensor(1, 100, 100) 104 | │ │ └─output: Tensor(1, 100, 100) 105 | │ ├─root_linear: CompiledGraphModule 106 | │ │ ├─input: Tensor(1, 100, 100) 107 | │ │ ├─weight: Tensor(100, 100) 108 | │ │ ├─bias: Tensor(100) 109 | │ │ ├─linear: Tensor(1, 100, 100) 110 | │ │ └─output: Tensor(1, 100, 100) 111 | │ └─output: Tensor(1, 100, 100) 112 | └─output: CausalLMOutput(1) 113 | └─logits: Tensor(1, 100, 100) 114 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_nested_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─child_a: CompiledGraphModule 5 | │ ├─a_inputs: Tensor(1, 100) 6 | │ ├─grandchildren_b_0: CompiledGraphModule 7 | │ │ ├─b_inputs: Tensor(1, 100) 8 | │ │ ├─a_inputs: Tensor(1, 100) 9 | │ │ ├─ones_like: Tensor(1, 100) 10 | │ │ ├─mul: Tensor(1, 100) 11 | │ │ ├─c: CompiledGraphModule 12 | │ │ │ ├─c_inputs: Tensor(1, 100) 13 | │ │ │ ├─inputs_2: Tensor(1, 100) 14 | │ │ │ ├─inputs_3: Tensor(1, 100) 15 | │ │ │ ├─inputs_4: Tensor(1, 100) 16 | │ │ │ ├─add: Tensor(1, 100) 17 | │ │ │ ├─add_1: Tensor(1, 100) 18 | │ │ │ ├─add_2: Tensor(1, 100) 19 | │ │ │ ├─c_linear: CompiledGraphModule 20 | │ │ │ │ ├─input: Tensor(1, 100) 21 | │ │ │ │ ├─weight: Tensor(100, 100) 22 | │ │ │ │ ├─bias: Tensor(100) 23 | │ │ │ │ ├─linear: Tensor(1, 100) 24 | │ │ │ │ └─output: Tensor(1, 100) 25 | │ │ │ └─output: Tensor(1, 100) 26 | │ │ ├─b_linear: CompiledGraphModule 27 | │ │ │ ├─input: Tensor(1, 100) 28 | │ │ │ ├─weight: Tensor(100, 100) 29 | │ │ │ ├─bias: Tensor(100) 30 | │ │ │ ├─linear: Tensor(1, 100) 31 | │ │ │ └─output: Tensor(1, 100) 32 | │ │ └─output: Tensor(1, 100) 33 | │ ├─grandchildren_b_1: CompiledGraphModule 34 | │ │ ├─b_inputs: Tensor(1, 100) 35 | │ │ ├─a_inputs: Tensor(1, 100) 36 | │ │ ├─ones_like: Tensor(1, 100) 37 | │ │ ├─mul: Tensor(1, 100) 38 | │ │ ├─c: CompiledGraphModule 39 | │ │ │ ├─c_inputs: Tensor(1, 100) 40 | │ │ │ ├─inputs_2: Tensor(1, 100) 41 | │ │ │ ├─inputs_3: Tensor(1, 100) 42 | │ │ │ ├─inputs_4: Tensor(1, 100) 43 | │ │ │ ├─add: Tensor(1, 100) 44 | │ │ │ ├─add_1: Tensor(1, 100) 45 | │ │ │ ├─add_2: Tensor(1, 100) 46 | │ │ │ ├─c_linear: CompiledGraphModule 47 | │ │ │ │ ├─input: Tensor(1, 100) 48 | │ │ │ │ ├─weight: Tensor(100, 100) 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─linear: Tensor(1, 100) 51 | │ │ │ │ └─output: Tensor(1, 100) 52 | │ │ │ └─output: Tensor(1, 100) 53 | │ │ ├─b_linear: CompiledGraphModule 54 | │ │ │ ├─input: Tensor(1, 100) 55 | │ │ │ ├─weight: Tensor(100, 100) 56 | │ │ │ ├─bias: Tensor(100) 57 | │ │ │ ├─linear: Tensor(1, 100) 58 | │ │ │ └─output: Tensor(1, 100) 59 | │ │ └─output: Tensor(1, 100) 60 | │ ├─grandchildren_b_2: CompiledGraphModule 61 | │ │ ├─b_inputs: Tensor(1, 100) 62 | │ │ ├─a_inputs: Tensor(1, 100) 63 | │ │ ├─ones_like: Tensor(1, 100) 64 | │ │ ├─mul: Tensor(1, 100) 65 | │ │ ├─c: CompiledGraphModule 66 | │ │ │ ├─c_inputs: Tensor(1, 100) 67 | │ │ │ ├─inputs_2: Tensor(1, 100) 68 | │ │ │ ├─inputs_3: Tensor(1, 100) 69 | │ │ │ ├─inputs_4: Tensor(1, 100) 70 | │ │ │ ├─add: Tensor(1, 100) 71 | │ │ │ ├─add_1: Tensor(1, 100) 72 | │ │ │ ├─add_2: Tensor(1, 100) 73 | │ │ │ ├─c_linear: CompiledGraphModule 74 | │ │ │ │ ├─input: Tensor(1, 100) 75 | │ │ │ │ ├─weight: Tensor(100, 100) 76 | │ │ │ │ ├─bias: Tensor(100) 77 | │ │ │ │ ├─linear: Tensor(1, 100) 78 | │ │ │ │ └─output: Tensor(1, 100) 79 | │ │ │ └─output: Tensor(1, 100) 80 | │ │ ├─b_linear: CompiledGraphModule 81 | │ │ │ ├─input: Tensor(1, 100) 82 | │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ ├─bias: Tensor(100) 84 | │ │ │ ├─linear: Tensor(1, 100) 85 | │ │ │ └─output: Tensor(1, 100) 86 | │ │ └─output: Tensor(1, 100) 87 | │ ├─a_linear: CompiledGraphModule 88 | │ │ ├─input: Tensor(1, 100) 89 | │ │ ├─weight: Tensor(100, 100) 90 | │ │ ├─bias: Tensor(100) 91 | │ │ ├─linear: Tensor(1, 100) 92 | │ │ └─output: Tensor(1, 100) 93 | │ └─output: Tensor(1, 100) 94 | ├─root_linear: CompiledGraphModule 95 | │ ├─input: Tensor(1, 100) 96 | │ ├─weight: Tensor(100, 100) 97 | │ ├─bias: Tensor(100) 98 | │ ├─linear: Tensor(1, 100) 99 | │ └─output: Tensor(1, 100) 100 | └─output: Tensor(1, 100) 101 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_nested_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─child_a: CompiledGraphModule 5 | │ ├─a_inputs: Tensor(1, 100) 6 | │ ├─grandchildren_b_0: CompiledGraphModule 7 | │ │ ├─b_inputs: Tensor(1, 100) 8 | │ │ ├─a_inputs: Tensor(1, 100) 9 | │ │ ├─ones_like: Tensor(1, 100) 10 | │ │ ├─mul: Tensor(1, 100) 11 | │ │ ├─c: CompiledGraphModule 12 | │ │ │ ├─c_inputs: Tensor(1, 100) 13 | │ │ │ ├─inputs_2: Tensor(1, 100) 14 | │ │ │ ├─inputs_3: Tensor(1, 100) 15 | │ │ │ ├─inputs_4: Tensor(1, 100) 16 | │ │ │ ├─add: Tensor(1, 100) 17 | │ │ │ ├─add_1: Tensor(1, 100) 18 | │ │ │ ├─add_2: Tensor(1, 100) 19 | │ │ │ ├─c_linear: CompiledGraphModule 20 | │ │ │ │ ├─input: Tensor(1, 100) 21 | │ │ │ │ ├─weight: Tensor(100, 100) 22 | │ │ │ │ ├─bias: Tensor(100) 23 | │ │ │ │ ├─linear: Tensor(1, 100) 24 | │ │ │ │ └─output: Tensor(1, 100) 25 | │ │ │ └─output: Tensor(1, 100) 26 | │ │ ├─b_linear: CompiledGraphModule 27 | │ │ │ ├─input: Tensor(1, 100) 28 | │ │ │ ├─weight: Tensor(100, 100) 29 | │ │ │ ├─bias: Tensor(100) 30 | │ │ │ ├─linear: Tensor(1, 100) 31 | │ │ │ └─output: Tensor(1, 100) 32 | │ │ └─output: Tensor(1, 100) 33 | │ ├─grandchildren_b_1: CompiledGraphModule 34 | │ │ ├─b_inputs: Tensor(1, 100) 35 | │ │ ├─a_inputs: Tensor(1, 100) 36 | │ │ ├─ones_like: Tensor(1, 100) 37 | │ │ ├─mul: Tensor(1, 100) 38 | │ │ ├─c: CompiledGraphModule 39 | │ │ │ ├─c_inputs: Tensor(1, 100) 40 | │ │ │ ├─inputs_2: Tensor(1, 100) 41 | │ │ │ ├─inputs_3: Tensor(1, 100) 42 | │ │ │ ├─inputs_4: Tensor(1, 100) 43 | │ │ │ ├─add: Tensor(1, 100) 44 | │ │ │ ├─add_1: Tensor(1, 100) 45 | │ │ │ ├─add_2: Tensor(1, 100) 46 | │ │ │ ├─c_linear: CompiledGraphModule 47 | │ │ │ │ ├─input: Tensor(1, 100) 48 | │ │ │ │ ├─weight: Tensor(100, 100) 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─linear: Tensor(1, 100) 51 | │ │ │ │ └─output: Tensor(1, 100) 52 | │ │ │ └─output: Tensor(1, 100) 53 | │ │ ├─b_linear: CompiledGraphModule 54 | │ │ │ ├─input: Tensor(1, 100) 55 | │ │ │ ├─weight: Tensor(100, 100) 56 | │ │ │ ├─bias: Tensor(100) 57 | │ │ │ ├─linear: Tensor(1, 100) 58 | │ │ │ └─output: Tensor(1, 100) 59 | │ │ └─output: Tensor(1, 100) 60 | │ ├─grandchildren_b_2: CompiledGraphModule 61 | │ │ ├─b_inputs: Tensor(1, 100) 62 | │ │ ├─a_inputs: Tensor(1, 100) 63 | │ │ ├─ones_like: Tensor(1, 100) 64 | │ │ ├─mul: Tensor(1, 100) 65 | │ │ ├─c: CompiledGraphModule 66 | │ │ │ ├─c_inputs: Tensor(1, 100) 67 | │ │ │ ├─inputs_2: Tensor(1, 100) 68 | │ │ │ ├─inputs_3: Tensor(1, 100) 69 | │ │ │ ├─inputs_4: Tensor(1, 100) 70 | │ │ │ ├─add: Tensor(1, 100) 71 | │ │ │ ├─add_1: Tensor(1, 100) 72 | │ │ │ ├─add_2: Tensor(1, 100) 73 | │ │ │ ├─c_linear: CompiledGraphModule 74 | │ │ │ │ ├─input: Tensor(1, 100) 75 | │ │ │ │ ├─weight: Tensor(100, 100) 76 | │ │ │ │ ├─bias: Tensor(100) 77 | │ │ │ │ ├─linear: Tensor(1, 100) 78 | │ │ │ │ └─output: Tensor(1, 100) 79 | │ │ │ └─output: Tensor(1, 100) 80 | │ │ ├─b_linear: CompiledGraphModule 81 | │ │ │ ├─input: Tensor(1, 100) 82 | │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ ├─bias: Tensor(100) 84 | │ │ │ ├─linear: Tensor(1, 100) 85 | │ │ │ └─output: Tensor(1, 100) 86 | │ │ └─output: Tensor(1, 100) 87 | │ ├─a_linear: CompiledGraphModule 88 | │ │ ├─input: Tensor(1, 100) 89 | │ │ ├─weight: Tensor(100, 100) 90 | │ │ ├─bias: Tensor(100) 91 | │ │ ├─linear: Tensor(1, 100) 92 | │ │ └─output: Tensor(1, 100) 93 | │ └─output: Tensor(1, 100) 94 | ├─root_linear: CompiledGraphModule 95 | │ ├─input: Tensor(1, 100) 96 | │ ├─weight: Tensor(100, 100) 97 | │ ├─bias: Tensor(100) 98 | │ ├─linear: Tensor(1, 100) 99 | │ └─output: Tensor(1, 100) 100 | └─output: Tensor(1, 100) 101 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_nested_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─child_a: CompiledGraphModule 5 | │ ├─a_inputs: Tensor(1, 100) 6 | │ ├─grandchildren_b_0: CompiledGraphModule 7 | │ │ ├─b_inputs: Tensor(1, 100) 8 | │ │ ├─a_inputs: Tensor(1, 100) 9 | │ │ ├─ones_like: Tensor(1, 100) 10 | │ │ ├─mul: Tensor(1, 100) 11 | │ │ ├─c: CompiledGraphModule 12 | │ │ │ ├─c_inputs: Tensor(1, 100) 13 | │ │ │ ├─inputs_2: Tensor(1, 100) 14 | │ │ │ ├─inputs_3: Tensor(1, 100) 15 | │ │ │ ├─inputs_4: Tensor(1, 100) 16 | │ │ │ ├─add: Tensor(1, 100) 17 | │ │ │ ├─add_1: Tensor(1, 100) 18 | │ │ │ ├─add_2: Tensor(1, 100) 19 | │ │ │ ├─c_linear: CompiledGraphModule 20 | │ │ │ │ ├─input: Tensor(1, 100) 21 | │ │ │ │ ├─weight: Tensor(100, 100) 22 | │ │ │ │ ├─bias: Tensor(100) 23 | │ │ │ │ ├─linear: Tensor(1, 100) 24 | │ │ │ │ └─output: Tensor(1, 100) 25 | │ │ │ └─output: Tensor(1, 100) 26 | │ │ ├─b_linear: CompiledGraphModule 27 | │ │ │ ├─input: Tensor(1, 100) 28 | │ │ │ ├─weight: Tensor(100, 100) 29 | │ │ │ ├─bias: Tensor(100) 30 | │ │ │ ├─linear: Tensor(1, 100) 31 | │ │ │ └─output: Tensor(1, 100) 32 | │ │ └─output: Tensor(1, 100) 33 | │ ├─grandchildren_b_1: CompiledGraphModule 34 | │ │ ├─b_inputs: Tensor(1, 100) 35 | │ │ ├─a_inputs: Tensor(1, 100) 36 | │ │ ├─ones_like: Tensor(1, 100) 37 | │ │ ├─mul: Tensor(1, 100) 38 | │ │ ├─c: CompiledGraphModule 39 | │ │ │ ├─c_inputs: Tensor(1, 100) 40 | │ │ │ ├─inputs_2: Tensor(1, 100) 41 | │ │ │ ├─inputs_3: Tensor(1, 100) 42 | │ │ │ ├─inputs_4: Tensor(1, 100) 43 | │ │ │ ├─add: Tensor(1, 100) 44 | │ │ │ ├─add_1: Tensor(1, 100) 45 | │ │ │ ├─add_2: Tensor(1, 100) 46 | │ │ │ ├─c_linear: CompiledGraphModule 47 | │ │ │ │ ├─input: Tensor(1, 100) 48 | │ │ │ │ ├─weight: Tensor(100, 100) 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─linear: Tensor(1, 100) 51 | │ │ │ │ └─output: Tensor(1, 100) 52 | │ │ │ └─output: Tensor(1, 100) 53 | │ │ ├─b_linear: CompiledGraphModule 54 | │ │ │ ├─input: Tensor(1, 100) 55 | │ │ │ ├─weight: Tensor(100, 100) 56 | │ │ │ ├─bias: Tensor(100) 57 | │ │ │ ├─linear: Tensor(1, 100) 58 | │ │ │ └─output: Tensor(1, 100) 59 | │ │ └─output: Tensor(1, 100) 60 | │ ├─grandchildren_b_2: CompiledGraphModule 61 | │ │ ├─b_inputs: Tensor(1, 100) 62 | │ │ ├─a_inputs: Tensor(1, 100) 63 | │ │ ├─ones_like: Tensor(1, 100) 64 | │ │ ├─mul: Tensor(1, 100) 65 | │ │ ├─c: CompiledGraphModule 66 | │ │ │ ├─c_inputs: Tensor(1, 100) 67 | │ │ │ ├─inputs_2: Tensor(1, 100) 68 | │ │ │ ├─inputs_3: Tensor(1, 100) 69 | │ │ │ ├─inputs_4: Tensor(1, 100) 70 | │ │ │ ├─add: Tensor(1, 100) 71 | │ │ │ ├─add_1: Tensor(1, 100) 72 | │ │ │ ├─add_2: Tensor(1, 100) 73 | │ │ │ ├─c_linear: CompiledGraphModule 74 | │ │ │ │ ├─input: Tensor(1, 100) 75 | │ │ │ │ ├─weight: Tensor(100, 100) 76 | │ │ │ │ ├─bias: Tensor(100) 77 | │ │ │ │ ├─linear: Tensor(1, 100) 78 | │ │ │ │ └─output: Tensor(1, 100) 79 | │ │ │ └─output: Tensor(1, 100) 80 | │ │ ├─b_linear: CompiledGraphModule 81 | │ │ │ ├─input: Tensor(1, 100) 82 | │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ ├─bias: Tensor(100) 84 | │ │ │ ├─linear: Tensor(1, 100) 85 | │ │ │ └─output: Tensor(1, 100) 86 | │ │ └─output: Tensor(1, 100) 87 | │ ├─a_linear: CompiledGraphModule 88 | │ │ ├─input: Tensor(1, 100) 89 | │ │ ├─weight: Tensor(100, 100) 90 | │ │ ├─bias: Tensor(100) 91 | │ │ ├─linear: Tensor(1, 100) 92 | │ │ └─output: Tensor(1, 100) 93 | │ └─output: Tensor(1, 100) 94 | ├─root_linear: CompiledGraphModule 95 | │ ├─input: Tensor(1, 100) 96 | │ ├─weight: Tensor(100, 100) 97 | │ ├─bias: Tensor(100) 98 | │ ├─linear: Tensor(1, 100) 99 | │ └─output: Tensor(1, 100) 100 | └─output: Tensor(1, 100) 101 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_nested_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─child_a: CompiledGraphModule 5 | │ ├─a_inputs: Tensor(1, 100) 6 | │ ├─grandchildren_b_0: CompiledGraphModule 7 | │ │ ├─b_inputs: Tensor(1, 100) 8 | │ │ ├─a_inputs: Tensor(1, 100) 9 | │ │ ├─ones_like: Tensor(1, 100) 10 | │ │ ├─mul: Tensor(1, 100) 11 | │ │ ├─c: CompiledGraphModule 12 | │ │ │ ├─c_inputs: Tensor(1, 100) 13 | │ │ │ ├─inputs_2: Tensor(1, 100) 14 | │ │ │ ├─inputs_3: Tensor(1, 100) 15 | │ │ │ ├─inputs_4: Tensor(1, 100) 16 | │ │ │ ├─add: Tensor(1, 100) 17 | │ │ │ ├─add_1: Tensor(1, 100) 18 | │ │ │ ├─add_2: Tensor(1, 100) 19 | │ │ │ ├─c_linear: CompiledGraphModule 20 | │ │ │ │ ├─input: Tensor(1, 100) 21 | │ │ │ │ ├─weight: Tensor(100, 100) 22 | │ │ │ │ ├─bias: Tensor(100) 23 | │ │ │ │ ├─linear: Tensor(1, 100) 24 | │ │ │ │ └─output: Tensor(1, 100) 25 | │ │ │ └─output: Tensor(1, 100) 26 | │ │ ├─b_linear: CompiledGraphModule 27 | │ │ │ ├─input: Tensor(1, 100) 28 | │ │ │ ├─weight: Tensor(100, 100) 29 | │ │ │ ├─bias: Tensor(100) 30 | │ │ │ ├─linear: Tensor(1, 100) 31 | │ │ │ └─output: Tensor(1, 100) 32 | │ │ └─output: Tensor(1, 100) 33 | │ ├─grandchildren_b_1: CompiledGraphModule 34 | │ │ ├─b_inputs: Tensor(1, 100) 35 | │ │ ├─a_inputs: Tensor(1, 100) 36 | │ │ ├─ones_like: Tensor(1, 100) 37 | │ │ ├─mul: Tensor(1, 100) 38 | │ │ ├─c: CompiledGraphModule 39 | │ │ │ ├─c_inputs: Tensor(1, 100) 40 | │ │ │ ├─inputs_2: Tensor(1, 100) 41 | │ │ │ ├─inputs_3: Tensor(1, 100) 42 | │ │ │ ├─inputs_4: Tensor(1, 100) 43 | │ │ │ ├─add: Tensor(1, 100) 44 | │ │ │ ├─add_1: Tensor(1, 100) 45 | │ │ │ ├─add_2: Tensor(1, 100) 46 | │ │ │ ├─c_linear: CompiledGraphModule 47 | │ │ │ │ ├─input: Tensor(1, 100) 48 | │ │ │ │ ├─weight: Tensor(100, 100) 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─linear: Tensor(1, 100) 51 | │ │ │ │ └─output: Tensor(1, 100) 52 | │ │ │ └─output: Tensor(1, 100) 53 | │ │ ├─b_linear: CompiledGraphModule 54 | │ │ │ ├─input: Tensor(1, 100) 55 | │ │ │ ├─weight: Tensor(100, 100) 56 | │ │ │ ├─bias: Tensor(100) 57 | │ │ │ ├─linear: Tensor(1, 100) 58 | │ │ │ └─output: Tensor(1, 100) 59 | │ │ └─output: Tensor(1, 100) 60 | │ ├─grandchildren_b_2: CompiledGraphModule 61 | │ │ ├─b_inputs: Tensor(1, 100) 62 | │ │ ├─a_inputs: Tensor(1, 100) 63 | │ │ ├─ones_like: Tensor(1, 100) 64 | │ │ ├─mul: Tensor(1, 100) 65 | │ │ ├─c: CompiledGraphModule 66 | │ │ │ ├─c_inputs: Tensor(1, 100) 67 | │ │ │ ├─inputs_2: Tensor(1, 100) 68 | │ │ │ ├─inputs_3: Tensor(1, 100) 69 | │ │ │ ├─inputs_4: Tensor(1, 100) 70 | │ │ │ ├─add: Tensor(1, 100) 71 | │ │ │ ├─add_1: Tensor(1, 100) 72 | │ │ │ ├─add_2: Tensor(1, 100) 73 | │ │ │ ├─c_linear: CompiledGraphModule 74 | │ │ │ │ ├─input: Tensor(1, 100) 75 | │ │ │ │ ├─weight: Tensor(100, 100) 76 | │ │ │ │ ├─bias: Tensor(100) 77 | │ │ │ │ ├─linear: Tensor(1, 100) 78 | │ │ │ │ └─output: Tensor(1, 100) 79 | │ │ │ └─output: Tensor(1, 100) 80 | │ │ ├─b_linear: CompiledGraphModule 81 | │ │ │ ├─input: Tensor(1, 100) 82 | │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ ├─bias: Tensor(100) 84 | │ │ │ ├─linear: Tensor(1, 100) 85 | │ │ │ └─output: Tensor(1, 100) 86 | │ │ └─output: Tensor(1, 100) 87 | │ ├─a_linear: CompiledGraphModule 88 | │ │ ├─input: Tensor(1, 100) 89 | │ │ ├─weight: Tensor(100, 100) 90 | │ │ ├─bias: Tensor(100) 91 | │ │ ├─linear: Tensor(1, 100) 92 | │ │ └─output: Tensor(1, 100) 93 | │ └─output: Tensor(1, 100) 94 | ├─root_linear: CompiledGraphModule 95 | │ ├─input: Tensor(1, 100) 96 | │ ├─weight: Tensor(100, 100) 97 | │ ├─bias: Tensor(100) 98 | │ ├─linear: Tensor(1, 100) 99 | │ └─output: Tensor(1, 100) 100 | └─output: Tensor(1, 100) 101 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_nested_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─child_a: CompiledGraphModule 5 | │ ├─a_inputs: Tensor(1, 100) 6 | │ ├─grandchildren_b_0: CompiledGraphModule 7 | │ │ ├─b_inputs: Tensor(1, 100) 8 | │ │ ├─a_inputs: Tensor(1, 100) 9 | │ │ ├─ones_like: Tensor(1, 100) 10 | │ │ ├─mul: Tensor(1, 100) 11 | │ │ ├─c: CompiledGraphModule 12 | │ │ │ ├─c_inputs: Tensor(1, 100) 13 | │ │ │ ├─inputs_2: Tensor(1, 100) 14 | │ │ │ ├─inputs_3: Tensor(1, 100) 15 | │ │ │ ├─inputs_4: Tensor(1, 100) 16 | │ │ │ ├─add: Tensor(1, 100) 17 | │ │ │ ├─add_1: Tensor(1, 100) 18 | │ │ │ ├─add_2: Tensor(1, 100) 19 | │ │ │ ├─c_linear: CompiledGraphModule 20 | │ │ │ │ ├─input: Tensor(1, 100) 21 | │ │ │ │ ├─weight: Tensor(100, 100) 22 | │ │ │ │ ├─bias: Tensor(100) 23 | │ │ │ │ ├─linear: Tensor(1, 100) 24 | │ │ │ │ └─output: Tensor(1, 100) 25 | │ │ │ └─output: Tensor(1, 100) 26 | │ │ ├─b_linear: CompiledGraphModule 27 | │ │ │ ├─input: Tensor(1, 100) 28 | │ │ │ ├─weight: Tensor(100, 100) 29 | │ │ │ ├─bias: Tensor(100) 30 | │ │ │ ├─linear: Tensor(1, 100) 31 | │ │ │ └─output: Tensor(1, 100) 32 | │ │ └─output: Tensor(1, 100) 33 | │ ├─grandchildren_b_1: CompiledGraphModule 34 | │ │ ├─b_inputs: Tensor(1, 100) 35 | │ │ ├─a_inputs: Tensor(1, 100) 36 | │ │ ├─ones_like: Tensor(1, 100) 37 | │ │ ├─mul: Tensor(1, 100) 38 | │ │ ├─c: CompiledGraphModule 39 | │ │ │ ├─c_inputs: Tensor(1, 100) 40 | │ │ │ ├─inputs_2: Tensor(1, 100) 41 | │ │ │ ├─inputs_3: Tensor(1, 100) 42 | │ │ │ ├─inputs_4: Tensor(1, 100) 43 | │ │ │ ├─add: Tensor(1, 100) 44 | │ │ │ ├─add_1: Tensor(1, 100) 45 | │ │ │ ├─add_2: Tensor(1, 100) 46 | │ │ │ ├─c_linear: CompiledGraphModule 47 | │ │ │ │ ├─input: Tensor(1, 100) 48 | │ │ │ │ ├─weight: Tensor(100, 100) 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─linear: Tensor(1, 100) 51 | │ │ │ │ └─output: Tensor(1, 100) 52 | │ │ │ └─output: Tensor(1, 100) 53 | │ │ ├─b_linear: CompiledGraphModule 54 | │ │ │ ├─input: Tensor(1, 100) 55 | │ │ │ ├─weight: Tensor(100, 100) 56 | │ │ │ ├─bias: Tensor(100) 57 | │ │ │ ├─linear: Tensor(1, 100) 58 | │ │ │ └─output: Tensor(1, 100) 59 | │ │ └─output: Tensor(1, 100) 60 | │ ├─grandchildren_b_2: CompiledGraphModule 61 | │ │ ├─b_inputs: Tensor(1, 100) 62 | │ │ ├─a_inputs: Tensor(1, 100) 63 | │ │ ├─ones_like: Tensor(1, 100) 64 | │ │ ├─mul: Tensor(1, 100) 65 | │ │ ├─c: CompiledGraphModule 66 | │ │ │ ├─c_inputs: Tensor(1, 100) 67 | │ │ │ ├─inputs_2: Tensor(1, 100) 68 | │ │ │ ├─inputs_3: Tensor(1, 100) 69 | │ │ │ ├─inputs_4: Tensor(1, 100) 70 | │ │ │ ├─add: Tensor(1, 100) 71 | │ │ │ ├─add_1: Tensor(1, 100) 72 | │ │ │ ├─add_2: Tensor(1, 100) 73 | │ │ │ ├─c_linear: CompiledGraphModule 74 | │ │ │ │ ├─input: Tensor(1, 100) 75 | │ │ │ │ ├─weight: Tensor(100, 100) 76 | │ │ │ │ ├─bias: Tensor(100) 77 | │ │ │ │ ├─linear: Tensor(1, 100) 78 | │ │ │ │ └─output: Tensor(1, 100) 79 | │ │ │ └─output: Tensor(1, 100) 80 | │ │ ├─b_linear: CompiledGraphModule 81 | │ │ │ ├─input: Tensor(1, 100) 82 | │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ ├─bias: Tensor(100) 84 | │ │ │ ├─linear: Tensor(1, 100) 85 | │ │ │ └─output: Tensor(1, 100) 86 | │ │ └─output: Tensor(1, 100) 87 | │ ├─a_linear: CompiledGraphModule 88 | │ │ ├─input: Tensor(1, 100) 89 | │ │ ├─weight: Tensor(100, 100) 90 | │ │ ├─bias: Tensor(100) 91 | │ │ ├─linear: Tensor(1, 100) 92 | │ │ └─output: Tensor(1, 100) 93 | │ └─output: Tensor(1, 100) 94 | ├─root_linear: CompiledGraphModule 95 | │ ├─input: Tensor(1, 100) 96 | │ ├─weight: Tensor(100, 100) 97 | │ ├─bias: Tensor(100) 98 | │ ├─linear: Tensor(1, 100) 99 | │ └─output: Tensor(1, 100) 100 | └─output: Tensor(1, 100) 101 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_pretrained_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─input_ids: Tensor(1, 100) 4 | ├─kwargs 5 | ├─getitem: Tensor(1, 100) 6 | ├─size: Size(2) 7 | │ ├─sub_0: int 8 | │ └─sub_1: int 9 | ├─getitem_1: int 10 | ├─view: Tensor(1, 1, 100) 11 | ├─repeat: Tensor(1, 100, 100) 12 | ├─to: Tensor(1, 100, 100) 13 | ├─model: CompiledGraphModule 14 | │ ├─root_inputs: Tensor(1, 100, 100) 15 | │ ├─child_a: CompiledGraphModule 16 | │ │ ├─a_inputs: Tensor(1, 100, 100) 17 | │ │ ├─grandchildren_b_0: CompiledGraphModule 18 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 19 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 20 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 21 | │ │ │ ├─mul: Tensor(1, 100, 100) 22 | │ │ │ ├─c: CompiledGraphModule 23 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 24 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 25 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 26 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 27 | │ │ │ │ ├─add: Tensor(1, 100, 100) 28 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 29 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 30 | │ │ │ │ ├─c_linear: CompiledGraphModule 31 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 32 | │ │ │ │ │ ├─weight: Tensor(100, 100) 33 | │ │ │ │ │ ├─bias: Tensor(100) 34 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 35 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 36 | │ │ │ │ └─output: Tensor(1, 100, 100) 37 | │ │ │ ├─b_linear: CompiledGraphModule 38 | │ │ │ │ ├─input: Tensor(1, 100, 100) 39 | │ │ │ │ ├─weight: Tensor(100, 100) 40 | │ │ │ │ ├─bias: Tensor(100) 41 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 42 | │ │ │ │ └─output: Tensor(1, 100, 100) 43 | │ │ │ └─output: Tensor(1, 100, 100) 44 | │ │ ├─grandchildren_b_1: CompiledGraphModule 45 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 46 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 47 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 48 | │ │ │ ├─mul: Tensor(1, 100, 100) 49 | │ │ │ ├─c: CompiledGraphModule 50 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 51 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 52 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 53 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 54 | │ │ │ │ ├─add: Tensor(1, 100, 100) 55 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 56 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 57 | │ │ │ │ ├─c_linear: CompiledGraphModule 58 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 59 | │ │ │ │ │ ├─weight: Tensor(100, 100) 60 | │ │ │ │ │ ├─bias: Tensor(100) 61 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 62 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 63 | │ │ │ │ └─output: Tensor(1, 100, 100) 64 | │ │ │ ├─b_linear: CompiledGraphModule 65 | │ │ │ │ ├─input: Tensor(1, 100, 100) 66 | │ │ │ │ ├─weight: Tensor(100, 100) 67 | │ │ │ │ ├─bias: Tensor(100) 68 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 69 | │ │ │ │ └─output: Tensor(1, 100, 100) 70 | │ │ │ └─output: Tensor(1, 100, 100) 71 | │ │ ├─grandchildren_b_2: CompiledGraphModule 72 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 73 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 74 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 75 | │ │ │ ├─mul: Tensor(1, 100, 100) 76 | │ │ │ ├─c: CompiledGraphModule 77 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 78 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 79 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 80 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 81 | │ │ │ │ ├─add: Tensor(1, 100, 100) 82 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 83 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 84 | │ │ │ │ ├─c_linear: CompiledGraphModule 85 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 86 | │ │ │ │ │ ├─weight: Tensor(100, 100) 87 | │ │ │ │ │ ├─bias: Tensor(100) 88 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 89 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 90 | │ │ │ │ └─output: Tensor(1, 100, 100) 91 | │ │ │ ├─b_linear: CompiledGraphModule 92 | │ │ │ │ ├─input: Tensor(1, 100, 100) 93 | │ │ │ │ ├─weight: Tensor(100, 100) 94 | │ │ │ │ ├─bias: Tensor(100) 95 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 96 | │ │ │ │ └─output: Tensor(1, 100, 100) 97 | │ │ │ └─output: Tensor(1, 100, 100) 98 | │ │ ├─a_linear: CompiledGraphModule 99 | │ │ │ ├─input: Tensor(1, 100, 100) 100 | │ │ │ ├─weight: Tensor(100, 100) 101 | │ │ │ ├─bias: Tensor(100) 102 | │ │ │ ├─linear: Tensor(1, 100, 100) 103 | │ │ │ └─output: Tensor(1, 100, 100) 104 | │ │ └─output: Tensor(1, 100, 100) 105 | │ ├─root_linear: CompiledGraphModule 106 | │ │ ├─input: Tensor(1, 100, 100) 107 | │ │ ├─weight: Tensor(100, 100) 108 | │ │ ├─bias: Tensor(100) 109 | │ │ ├─linear: Tensor(1, 100, 100) 110 | │ │ └─output: Tensor(1, 100, 100) 111 | │ └─output: Tensor(1, 100, 100) 112 | └─output: CausalLMOutput(1) 113 | └─logits: Tensor(1, 100, 100) 114 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_pretrained_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─input_ids: Tensor(1, 100) 4 | ├─kwargs 5 | ├─getitem: Tensor(1, 100) 6 | ├─size: Size(2) 7 | │ ├─sub_0: int 8 | │ └─sub_1: int 9 | ├─getitem_1: int 10 | ├─getitem_2: int 11 | ├─view: Tensor(1, 1, 100) 12 | ├─repeat: Tensor(1, 100, 100) 13 | ├─to: Tensor(1, 100, 100) 14 | ├─model: CompiledGraphModule 15 | │ ├─root_inputs: Tensor(1, 100, 100) 16 | │ ├─child_a: CompiledGraphModule 17 | │ │ ├─a_inputs: Tensor(1, 100, 100) 18 | │ │ ├─grandchildren_b_0: CompiledGraphModule 19 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 20 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 21 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 22 | │ │ │ ├─mul: Tensor(1, 100, 100) 23 | │ │ │ ├─c: CompiledGraphModule 24 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 25 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 26 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 27 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 28 | │ │ │ │ ├─add: Tensor(1, 100, 100) 29 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 30 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 31 | │ │ │ │ ├─c_linear: CompiledGraphModule 32 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 33 | │ │ │ │ │ ├─weight: Tensor(100, 100) 34 | │ │ │ │ │ ├─bias: Tensor(100) 35 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 36 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 37 | │ │ │ │ └─output: Tensor(1, 100, 100) 38 | │ │ │ ├─b_linear: CompiledGraphModule 39 | │ │ │ │ ├─input: Tensor(1, 100, 100) 40 | │ │ │ │ ├─weight: Tensor(100, 100) 41 | │ │ │ │ ├─bias: Tensor(100) 42 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 43 | │ │ │ │ └─output: Tensor(1, 100, 100) 44 | │ │ │ └─output: Tensor(1, 100, 100) 45 | │ │ ├─grandchildren_b_1: CompiledGraphModule 46 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 47 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 48 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 49 | │ │ │ ├─mul: Tensor(1, 100, 100) 50 | │ │ │ ├─c: CompiledGraphModule 51 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 52 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 53 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 54 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 55 | │ │ │ │ ├─add: Tensor(1, 100, 100) 56 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 57 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 58 | │ │ │ │ ├─c_linear: CompiledGraphModule 59 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 60 | │ │ │ │ │ ├─weight: Tensor(100, 100) 61 | │ │ │ │ │ ├─bias: Tensor(100) 62 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 63 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 64 | │ │ │ │ └─output: Tensor(1, 100, 100) 65 | │ │ │ ├─b_linear: CompiledGraphModule 66 | │ │ │ │ ├─input: Tensor(1, 100, 100) 67 | │ │ │ │ ├─weight: Tensor(100, 100) 68 | │ │ │ │ ├─bias: Tensor(100) 69 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 70 | │ │ │ │ └─output: Tensor(1, 100, 100) 71 | │ │ │ └─output: Tensor(1, 100, 100) 72 | │ │ ├─grandchildren_b_2: CompiledGraphModule 73 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 74 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 75 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 76 | │ │ │ ├─mul: Tensor(1, 100, 100) 77 | │ │ │ ├─c: CompiledGraphModule 78 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 79 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 80 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 81 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 82 | │ │ │ │ ├─add: Tensor(1, 100, 100) 83 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 84 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 85 | │ │ │ │ ├─c_linear: CompiledGraphModule 86 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 87 | │ │ │ │ │ ├─weight: Tensor(100, 100) 88 | │ │ │ │ │ ├─bias: Tensor(100) 89 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 90 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 91 | │ │ │ │ └─output: Tensor(1, 100, 100) 92 | │ │ │ ├─b_linear: CompiledGraphModule 93 | │ │ │ │ ├─input: Tensor(1, 100, 100) 94 | │ │ │ │ ├─weight: Tensor(100, 100) 95 | │ │ │ │ ├─bias: Tensor(100) 96 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 97 | │ │ │ │ └─output: Tensor(1, 100, 100) 98 | │ │ │ └─output: Tensor(1, 100, 100) 99 | │ │ ├─a_linear: CompiledGraphModule 100 | │ │ │ ├─input: Tensor(1, 100, 100) 101 | │ │ │ ├─weight: Tensor(100, 100) 102 | │ │ │ ├─bias: Tensor(100) 103 | │ │ │ ├─linear: Tensor(1, 100, 100) 104 | │ │ │ └─output: Tensor(1, 100, 100) 105 | │ │ └─output: Tensor(1, 100, 100) 106 | │ ├─root_linear: CompiledGraphModule 107 | │ │ ├─input: Tensor(1, 100, 100) 108 | │ │ ├─weight: Tensor(100, 100) 109 | │ │ ├─bias: Tensor(100) 110 | │ │ ├─linear: Tensor(1, 100, 100) 111 | │ │ └─output: Tensor(1, 100, 100) 112 | │ └─output: Tensor(1, 100, 100) 113 | └─output: CausalLMOutput(1) 114 | └─logits: Tensor(1, 100, 100) 115 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_pretrained_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─input_ids: Tensor(1, 100) 4 | ├─kwargs 5 | ├─getitem: Tensor(1, 100) 6 | ├─size: Size(2) 7 | │ ├─sub_0: int 8 | │ └─sub_1: int 9 | ├─getitem_1: int 10 | ├─getitem_2: int 11 | ├─view: Tensor(1, 1, 100) 12 | ├─repeat: Tensor(1, 100, 100) 13 | ├─embedding: Tensor(1, 100, 100) 14 | ├─model: CompiledGraphModule 15 | │ ├─root_inputs: Tensor(1, 100, 100) 16 | │ ├─child_a: CompiledGraphModule 17 | │ │ ├─a_inputs: Tensor(1, 100, 100) 18 | │ │ ├─grandchildren_b_0: CompiledGraphModule 19 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 20 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 21 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 22 | │ │ │ ├─mul: Tensor(1, 100, 100) 23 | │ │ │ ├─c: CompiledGraphModule 24 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 25 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 26 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 27 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 28 | │ │ │ │ ├─add: Tensor(1, 100, 100) 29 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 30 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 31 | │ │ │ │ ├─c_linear: CompiledGraphModule 32 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 33 | │ │ │ │ │ ├─weight: Tensor(100, 100) 34 | │ │ │ │ │ ├─bias: Tensor(100) 35 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 36 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 37 | │ │ │ │ └─output: Tensor(1, 100, 100) 38 | │ │ │ ├─b_linear: CompiledGraphModule 39 | │ │ │ │ ├─input: Tensor(1, 100, 100) 40 | │ │ │ │ ├─weight: Tensor(100, 100) 41 | │ │ │ │ ├─bias: Tensor(100) 42 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 43 | │ │ │ │ └─output: Tensor(1, 100, 100) 44 | │ │ │ └─output: Tensor(1, 100, 100) 45 | │ │ ├─grandchildren_b_1: CompiledGraphModule 46 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 47 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 48 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 49 | │ │ │ ├─mul: Tensor(1, 100, 100) 50 | │ │ │ ├─c: CompiledGraphModule 51 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 52 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 53 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 54 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 55 | │ │ │ │ ├─add: Tensor(1, 100, 100) 56 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 57 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 58 | │ │ │ │ ├─c_linear: CompiledGraphModule 59 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 60 | │ │ │ │ │ ├─weight: Tensor(100, 100) 61 | │ │ │ │ │ ├─bias: Tensor(100) 62 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 63 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 64 | │ │ │ │ └─output: Tensor(1, 100, 100) 65 | │ │ │ ├─b_linear: CompiledGraphModule 66 | │ │ │ │ ├─input: Tensor(1, 100, 100) 67 | │ │ │ │ ├─weight: Tensor(100, 100) 68 | │ │ │ │ ├─bias: Tensor(100) 69 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 70 | │ │ │ │ └─output: Tensor(1, 100, 100) 71 | │ │ │ └─output: Tensor(1, 100, 100) 72 | │ │ ├─grandchildren_b_2: CompiledGraphModule 73 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 74 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 75 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 76 | │ │ │ ├─mul: Tensor(1, 100, 100) 77 | │ │ │ ├─c: CompiledGraphModule 78 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 79 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 80 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 81 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 82 | │ │ │ │ ├─add: Tensor(1, 100, 100) 83 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 84 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 85 | │ │ │ │ ├─c_linear: CompiledGraphModule 86 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 87 | │ │ │ │ │ ├─weight: Tensor(100, 100) 88 | │ │ │ │ │ ├─bias: Tensor(100) 89 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 90 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 91 | │ │ │ │ └─output: Tensor(1, 100, 100) 92 | │ │ │ ├─b_linear: CompiledGraphModule 93 | │ │ │ │ ├─input: Tensor(1, 100, 100) 94 | │ │ │ │ ├─weight: Tensor(100, 100) 95 | │ │ │ │ ├─bias: Tensor(100) 96 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 97 | │ │ │ │ └─output: Tensor(1, 100, 100) 98 | │ │ │ └─output: Tensor(1, 100, 100) 99 | │ │ ├─a_linear: CompiledGraphModule 100 | │ │ │ ├─input: Tensor(1, 100, 100) 101 | │ │ │ ├─weight: Tensor(100, 100) 102 | │ │ │ ├─bias: Tensor(100) 103 | │ │ │ ├─linear: Tensor(1, 100, 100) 104 | │ │ │ └─output: Tensor(1, 100, 100) 105 | │ │ └─output: Tensor(1, 100, 100) 106 | │ ├─root_linear: CompiledGraphModule 107 | │ │ ├─input: Tensor(1, 100, 100) 108 | │ │ ├─weight: Tensor(100, 100) 109 | │ │ ├─bias: Tensor(100) 110 | │ │ ├─linear: Tensor(1, 100, 100) 111 | │ │ └─output: Tensor(1, 100, 100) 112 | │ └─output: Tensor(1, 100, 100) 113 | └─output: CausalLMOutput(1) 114 | └─logits: Tensor(1, 100, 100) 115 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_pretrained_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─input_ids: Tensor(1, 100) 4 | ├─kwargs 5 | ├─getitem: Tensor(1, 100) 6 | ├─size: Size(2) 7 | │ ├─sub_0: int 8 | │ └─sub_1: int 9 | ├─getitem_1: int 10 | ├─getitem_2: int 11 | ├─view: Tensor(1, 1, 100) 12 | ├─repeat: Tensor(1, 100, 100) 13 | ├─embedding: Tensor(1, 100, 100) 14 | ├─model: CompiledGraphModule 15 | │ ├─root_inputs: Tensor(1, 100, 100) 16 | │ ├─child_a: CompiledGraphModule 17 | │ │ ├─a_inputs: Tensor(1, 100, 100) 18 | │ │ ├─grandchildren_b_0: CompiledGraphModule 19 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 20 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 21 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 22 | │ │ │ ├─mul: Tensor(1, 100, 100) 23 | │ │ │ ├─c: CompiledGraphModule 24 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 25 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 26 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 27 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 28 | │ │ │ │ ├─add: Tensor(1, 100, 100) 29 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 30 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 31 | │ │ │ │ ├─c_linear: CompiledGraphModule 32 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 33 | │ │ │ │ │ ├─weight: Tensor(100, 100) 34 | │ │ │ │ │ ├─bias: Tensor(100) 35 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 36 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 37 | │ │ │ │ └─output: Tensor(1, 100, 100) 38 | │ │ │ ├─b_linear: CompiledGraphModule 39 | │ │ │ │ ├─input: Tensor(1, 100, 100) 40 | │ │ │ │ ├─weight: Tensor(100, 100) 41 | │ │ │ │ ├─bias: Tensor(100) 42 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 43 | │ │ │ │ └─output: Tensor(1, 100, 100) 44 | │ │ │ └─output: Tensor(1, 100, 100) 45 | │ │ ├─grandchildren_b_1: CompiledGraphModule 46 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 47 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 48 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 49 | │ │ │ ├─mul: Tensor(1, 100, 100) 50 | │ │ │ ├─c: CompiledGraphModule 51 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 52 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 53 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 54 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 55 | │ │ │ │ ├─add: Tensor(1, 100, 100) 56 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 57 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 58 | │ │ │ │ ├─c_linear: CompiledGraphModule 59 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 60 | │ │ │ │ │ ├─weight: Tensor(100, 100) 61 | │ │ │ │ │ ├─bias: Tensor(100) 62 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 63 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 64 | │ │ │ │ └─output: Tensor(1, 100, 100) 65 | │ │ │ ├─b_linear: CompiledGraphModule 66 | │ │ │ │ ├─input: Tensor(1, 100, 100) 67 | │ │ │ │ ├─weight: Tensor(100, 100) 68 | │ │ │ │ ├─bias: Tensor(100) 69 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 70 | │ │ │ │ └─output: Tensor(1, 100, 100) 71 | │ │ │ └─output: Tensor(1, 100, 100) 72 | │ │ ├─grandchildren_b_2: CompiledGraphModule 73 | │ │ │ ├─b_inputs: Tensor(1, 100, 100) 74 | │ │ │ ├─a_inputs: Tensor(1, 100, 100) 75 | │ │ │ ├─ones_like: Tensor(1, 100, 100) 76 | │ │ │ ├─mul: Tensor(1, 100, 100) 77 | │ │ │ ├─c: CompiledGraphModule 78 | │ │ │ │ ├─c_inputs: Tensor(1, 100, 100) 79 | │ │ │ │ ├─inputs_2: Tensor(1, 100, 100) 80 | │ │ │ │ ├─inputs_3: Tensor(1, 100, 100) 81 | │ │ │ │ ├─inputs_4: Tensor(1, 100, 100) 82 | │ │ │ │ ├─add: Tensor(1, 100, 100) 83 | │ │ │ │ ├─add_1: Tensor(1, 100, 100) 84 | │ │ │ │ ├─add_2: Tensor(1, 100, 100) 85 | │ │ │ │ ├─c_linear: CompiledGraphModule 86 | │ │ │ │ │ ├─input: Tensor(1, 100, 100) 87 | │ │ │ │ │ ├─weight: Tensor(100, 100) 88 | │ │ │ │ │ ├─bias: Tensor(100) 89 | │ │ │ │ │ ├─linear: Tensor(1, 100, 100) 90 | │ │ │ │ │ └─output: Tensor(1, 100, 100) 91 | │ │ │ │ └─output: Tensor(1, 100, 100) 92 | │ │ │ ├─b_linear: CompiledGraphModule 93 | │ │ │ │ ├─input: Tensor(1, 100, 100) 94 | │ │ │ │ ├─weight: Tensor(100, 100) 95 | │ │ │ │ ├─bias: Tensor(100) 96 | │ │ │ │ ├─linear: Tensor(1, 100, 100) 97 | │ │ │ │ └─output: Tensor(1, 100, 100) 98 | │ │ │ └─output: Tensor(1, 100, 100) 99 | │ │ ├─a_linear: CompiledGraphModule 100 | │ │ │ ├─input: Tensor(1, 100, 100) 101 | │ │ │ ├─weight: Tensor(100, 100) 102 | │ │ │ ├─bias: Tensor(100) 103 | │ │ │ ├─linear: Tensor(1, 100, 100) 104 | │ │ │ └─output: Tensor(1, 100, 100) 105 | │ │ └─output: Tensor(1, 100, 100) 106 | │ ├─root_linear: CompiledGraphModule 107 | │ │ ├─input: Tensor(1, 100, 100) 108 | │ │ ├─weight: Tensor(100, 100) 109 | │ │ ├─bias: Tensor(100) 110 | │ │ ├─linear: Tensor(1, 100, 100) 111 | │ │ └─output: Tensor(1, 100, 100) 112 | │ └─output: Tensor(1, 100, 100) 113 | └─output: CausalLMOutput(1) 114 | └─logits: Tensor(1, 100, 100) 115 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_protected_name_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─sub_shape: int 4 | ├─sub_code_1: int 5 | ├─add: Tensor(3, 2) 6 | ├─add_1: Tensor(3, 2) 7 | ├─add_2: Tensor(3, 2) 8 | ├─add_3: Tensor(3, 2) 9 | ├─_code_: CompiledGraphModule 10 | │ ├─input: Tensor(3, 2) 11 | │ ├─weight: Tensor(3, 2) 12 | │ ├─bias: Tensor(3) 13 | │ ├─linear: Tensor(3, 3) 14 | │ └─output: Tensor(3, 3) 15 | └─output: Tensor(3, 3) 16 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_protected_name_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─sub_shape_1: Tensor(3, 2) 4 | ├─sub_code_1: int 5 | ├─sub_shape: int 6 | ├─add: Tensor(3, 2) 7 | ├─add_1: Tensor(3, 2) 8 | ├─add_2: Tensor(3, 2) 9 | ├─add_3: Tensor(3, 2) 10 | ├─_code_: CompiledGraphModule 11 | │ ├─input: Tensor(3, 2) 12 | │ ├─weight: Tensor(3, 2) 13 | │ ├─bias: Tensor(3) 14 | │ ├─linear: Tensor(3, 3) 15 | │ └─output: Tensor(3, 3) 16 | └─output: Tensor(3, 3) 17 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_protected_name_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─sub_shape_1: Tensor(3, 2) 4 | ├─sub_code_1: int 5 | ├─sub_shape: int 6 | ├─add: Tensor(3, 2) 7 | ├─add_1: Tensor(3, 2) 8 | ├─add_2: Tensor(3, 2) 9 | ├─add_3: Tensor(3, 2) 10 | ├─_code_: CompiledGraphModule 11 | │ ├─input: Tensor(3, 2) 12 | │ ├─weight: Tensor(3, 2) 13 | │ ├─bias: Tensor(3) 14 | │ ├─linear: Tensor(3, 3) 15 | │ └─output: Tensor(3, 3) 16 | └─output: Tensor(3, 3) 17 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_protected_name_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─sub_shape_1: Tensor(3, 2) 4 | ├─sub_code_1: int 5 | ├─sub_shape: int 6 | ├─sub_shape_2: tuple(2) 7 | │ ├─sub_0: int 8 | │ └─sub_1: int 9 | ├─getitem: int 10 | ├─add: Tensor(3, 2) 11 | ├─add_1: Tensor(3, 2) 12 | ├─add_2: Tensor(3, 2) 13 | ├─add_3: Tensor(3, 2) 14 | ├─_code_: CompiledGraphModule 15 | │ ├─input: Tensor(3, 2) 16 | │ ├─weight: Tensor(3, 2) 17 | │ ├─bias: Tensor(3) 18 | │ ├─linear: Tensor(3, 3) 19 | │ └─output: Tensor(3, 3) 20 | └─output: Tensor(3, 3) 21 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_protected_name_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─sub_shape_1: Tensor(3, 2) 4 | ├─sub_code_1: int 5 | ├─sub_shape: int 6 | ├─add: Tensor(3, 2) 7 | ├─add_1: Tensor(3, 2) 8 | ├─add_2: Tensor(3, 2) 9 | ├─add_3: Tensor(3, 2) 10 | ├─_code_: CompiledGraphModule 11 | │ ├─input: Tensor(3, 2) 12 | │ ├─weight: Tensor(3, 2) 13 | │ ├─bias: Tensor(3) 14 | │ ├─linear: Tensor(3, 3) 15 | │ └─output: Tensor(3, 3) 16 | └─output: Tensor(3, 3) 17 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_quantized_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─x: Tensor(3, 2) 6 | │ ├─cb: Tensor(3, 2) 7 | │ ├─scb: Tensor(3, 1) 8 | │ ├─bias: Tensor(3) 9 | │ ├─threshold: float 10 | │ ├─mul: Tensor(3, 2) 11 | │ ├─weight: Tensor(3, 2) 12 | │ ├─matmul_8bit: Tensor(3, 3) 13 | │ └─output: Tensor(3, 3) 14 | └─output: Tensor(3, 3) 15 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_quantized_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─x: Tensor(3, 2) 6 | │ ├─cb: Tensor(3, 2) 7 | │ ├─scb: Tensor(3, 1) 8 | │ ├─bias: Tensor(3) 9 | │ ├─threshold: float 10 | │ ├─mul: Tensor(3, 2) 11 | │ ├─weight: Tensor(3, 2) 12 | │ ├─matmul_8bit: Tensor(3, 3) 13 | │ └─output: Tensor(3, 3) 14 | └─output: Tensor(3, 3) 15 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_quantized_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─x: Tensor(3, 2) 6 | │ ├─cb: Tensor(3, 2) 7 | │ ├─scb: Tensor(3, 1) 8 | │ ├─bias: Tensor(3) 9 | │ ├─threshold: float 10 | │ ├─mul: Tensor(3, 2) 11 | │ ├─weight: Tensor(3, 2) 12 | │ ├─matmul_8bit: Tensor(3, 3) 13 | │ └─output: Tensor(3, 3) 14 | └─output: Tensor(3, 3) 15 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_quantized_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─x: Tensor(3, 2) 6 | │ ├─cb: Tensor(3, 2) 7 | │ ├─scb: Tensor(3, 1) 8 | │ ├─bias: Tensor(3) 9 | │ ├─threshold: float 10 | │ ├─mul: Tensor(3, 2) 11 | │ ├─weight: Tensor(3, 2) 12 | │ ├─matmul_8bit: Tensor(3, 3) 13 | │ └─output: Tensor(3, 3) 14 | └─output: Tensor(3, 3) 15 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_quantized_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear: CompiledGraphModule 5 | │ ├─x: Tensor(3, 2) 6 | │ ├─cb: Tensor(3, 2) 7 | │ ├─scb: Tensor(3, 1) 8 | │ ├─bias: Tensor(3) 9 | │ ├─threshold: float 10 | │ ├─mul: Tensor(3, 2) 11 | │ ├─weight: Tensor(3, 2) 12 | │ ├─matmul_8bit: Tensor(3, 3) 13 | │ └─output: Tensor(3, 3) 14 | └─output: Tensor(3, 3) 15 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_tuple_output_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear_0: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | ├─add: Tensor(3, 2) 11 | ├─linear_1: CompiledGraphModule 12 | │ ├─input: Tensor(3, 2) 13 | │ ├─weight: Tensor(3, 2) 14 | │ ├─bias: Tensor(3) 15 | │ ├─linear: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | └─output: tuple(2) 18 | ├─sub_0: Tensor(3, 3) 19 | └─sub_1: Tensor(3, 3) 20 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_tuple_output_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear_0: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | ├─add: Tensor(3, 2) 11 | ├─linear_1: CompiledGraphModule 12 | │ ├─input: Tensor(3, 2) 13 | │ ├─weight: Tensor(3, 2) 14 | │ ├─bias: Tensor(3) 15 | │ ├─linear: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | └─output: tuple(2) 18 | ├─sub_0: Tensor(3, 3) 19 | └─sub_1: Tensor(3, 3) 20 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_tuple_output_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear_0: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | ├─add: Tensor(3, 2) 11 | ├─linear_1: CompiledGraphModule 12 | │ ├─input: Tensor(3, 2) 13 | │ ├─weight: Tensor(3, 2) 14 | │ ├─bias: Tensor(3) 15 | │ ├─linear: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | └─output: tuple(2) 18 | ├─sub_0: Tensor(3, 3) 19 | └─sub_1: Tensor(3, 3) 20 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_tuple_output_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear_0: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | ├─add: Tensor(3, 2) 11 | ├─linear_1: CompiledGraphModule 12 | │ ├─input: Tensor(3, 2) 13 | │ ├─weight: Tensor(3, 2) 14 | │ ├─bias: Tensor(3) 15 | │ ├─linear: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | └─output: tuple(2) 18 | ├─sub_0: Tensor(3, 3) 19 | └─sub_1: Tensor(3, 3) 20 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_tuple_output_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─linear_0: CompiledGraphModule 5 | │ ├─input: Tensor(3, 2) 6 | │ ├─weight: Tensor(3, 2) 7 | │ ├─bias: Tensor(3) 8 | │ ├─linear: Tensor(3, 3) 9 | │ └─output: Tensor(3, 3) 10 | ├─add: Tensor(3, 2) 11 | ├─linear_1: CompiledGraphModule 12 | │ ├─input: Tensor(3, 2) 13 | │ ├─weight: Tensor(3, 2) 14 | │ ├─bias: Tensor(3) 15 | │ ├─linear: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | └─output: tuple(2) 18 | ├─sub_0: Tensor(3, 3) 19 | └─sub_1: Tensor(3, 3) 20 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_unused_submodule_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─child_a: CompiledGraphModule 5 | │ ├─a_inputs: Tensor(1, 100) 6 | │ ├─grandchildren_b_0: CompiledGraphModule 7 | │ │ ├─b_inputs: Tensor(1, 100) 8 | │ │ ├─a_inputs: Tensor(1, 100) 9 | │ │ ├─ones_like: Tensor(1, 100) 10 | │ │ ├─mul: Tensor(1, 100) 11 | │ │ ├─c: CompiledGraphModule 12 | │ │ │ ├─c_inputs: Tensor(1, 100) 13 | │ │ │ ├─inputs_2: Tensor(1, 100) 14 | │ │ │ ├─inputs_3: Tensor(1, 100) 15 | │ │ │ ├─inputs_4: Tensor(1, 100) 16 | │ │ │ ├─add: Tensor(1, 100) 17 | │ │ │ ├─add_1: Tensor(1, 100) 18 | │ │ │ ├─add_2: Tensor(1, 100) 19 | │ │ │ ├─c_linear: CompiledGraphModule 20 | │ │ │ │ ├─input: Tensor(1, 100) 21 | │ │ │ │ ├─weight: Tensor(100, 100) 22 | │ │ │ │ ├─bias: Tensor(100) 23 | │ │ │ │ ├─linear: Tensor(1, 100) 24 | │ │ │ │ └─output: Tensor(1, 100) 25 | │ │ │ └─output: Tensor(1, 100) 26 | │ │ ├─b_linear: CompiledGraphModule 27 | │ │ │ ├─input: Tensor(1, 100) 28 | │ │ │ ├─weight: Tensor(100, 100) 29 | │ │ │ ├─bias: Tensor(100) 30 | │ │ │ ├─linear: Tensor(1, 100) 31 | │ │ │ └─output: Tensor(1, 100) 32 | │ │ └─output: Tensor(1, 100) 33 | │ ├─grandchildren_b_1: CompiledGraphModule 34 | │ │ ├─b_inputs: Tensor(1, 100) 35 | │ │ ├─a_inputs: Tensor(1, 100) 36 | │ │ ├─ones_like: Tensor(1, 100) 37 | │ │ ├─mul: Tensor(1, 100) 38 | │ │ ├─c: CompiledGraphModule 39 | │ │ │ ├─c_inputs: Tensor(1, 100) 40 | │ │ │ ├─inputs_2: Tensor(1, 100) 41 | │ │ │ ├─inputs_3: Tensor(1, 100) 42 | │ │ │ ├─inputs_4: Tensor(1, 100) 43 | │ │ │ ├─add: Tensor(1, 100) 44 | │ │ │ ├─add_1: Tensor(1, 100) 45 | │ │ │ ├─add_2: Tensor(1, 100) 46 | │ │ │ ├─c_linear: CompiledGraphModule 47 | │ │ │ │ ├─input: Tensor(1, 100) 48 | │ │ │ │ ├─weight: Tensor(100, 100) 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─linear: Tensor(1, 100) 51 | │ │ │ │ └─output: Tensor(1, 100) 52 | │ │ │ └─output: Tensor(1, 100) 53 | │ │ ├─b_linear: CompiledGraphModule 54 | │ │ │ ├─input: Tensor(1, 100) 55 | │ │ │ ├─weight: Tensor(100, 100) 56 | │ │ │ ├─bias: Tensor(100) 57 | │ │ │ ├─linear: Tensor(1, 100) 58 | │ │ │ └─output: Tensor(1, 100) 59 | │ │ └─output: Tensor(1, 100) 60 | │ ├─grandchildren_b_3: CompiledGraphModule 61 | │ │ ├─b_inputs: Tensor(1, 100) 62 | │ │ ├─a_inputs: Tensor(1, 100) 63 | │ │ ├─ones_like: Tensor(1, 100) 64 | │ │ ├─mul: Tensor(1, 100) 65 | │ │ ├─c: CompiledGraphModule 66 | │ │ │ ├─c_inputs: Tensor(1, 100) 67 | │ │ │ ├─inputs_2: Tensor(1, 100) 68 | │ │ │ ├─inputs_3: Tensor(1, 100) 69 | │ │ │ ├─inputs_4: Tensor(1, 100) 70 | │ │ │ ├─add: Tensor(1, 100) 71 | │ │ │ ├─add_1: Tensor(1, 100) 72 | │ │ │ ├─add_2: Tensor(1, 100) 73 | │ │ │ ├─c_linear: CompiledGraphModule 74 | │ │ │ │ ├─input: Tensor(1, 100) 75 | │ │ │ │ ├─weight: Tensor(100, 100) 76 | │ │ │ │ ├─bias: Tensor(100) 77 | │ │ │ │ ├─linear: Tensor(1, 100) 78 | │ │ │ │ └─output: Tensor(1, 100) 79 | │ │ │ └─output: Tensor(1, 100) 80 | │ │ ├─b_linear: CompiledGraphModule 81 | │ │ │ ├─input: Tensor(1, 100) 82 | │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ ├─bias: Tensor(100) 84 | │ │ │ ├─linear: Tensor(1, 100) 85 | │ │ │ └─output: Tensor(1, 100) 86 | │ │ └─output: Tensor(1, 100) 87 | │ ├─a_linear: CompiledGraphModule 88 | │ │ ├─input: Tensor(1, 100) 89 | │ │ ├─weight: Tensor(100, 100) 90 | │ │ ├─bias: Tensor(100) 91 | │ │ ├─linear: Tensor(1, 100) 92 | │ │ └─output: Tensor(1, 100) 93 | │ └─output: Tensor(1, 100) 94 | ├─root_linear: CompiledGraphModule 95 | │ ├─input: Tensor(1, 100) 96 | │ ├─weight: Tensor(100, 100) 97 | │ ├─bias: Tensor(100) 98 | │ ├─linear: Tensor(1, 100) 99 | │ └─output: Tensor(1, 100) 100 | └─output: Tensor(1, 100) 101 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_unused_submodule_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─child_a: CompiledGraphModule 5 | │ ├─a_inputs: Tensor(1, 100) 6 | │ ├─grandchildren_b_0: CompiledGraphModule 7 | │ │ ├─b_inputs: Tensor(1, 100) 8 | │ │ ├─a_inputs: Tensor(1, 100) 9 | │ │ ├─ones_like: Tensor(1, 100) 10 | │ │ ├─mul: Tensor(1, 100) 11 | │ │ ├─c: CompiledGraphModule 12 | │ │ │ ├─c_inputs: Tensor(1, 100) 13 | │ │ │ ├─inputs_2: Tensor(1, 100) 14 | │ │ │ ├─inputs_3: Tensor(1, 100) 15 | │ │ │ ├─inputs_4: Tensor(1, 100) 16 | │ │ │ ├─add: Tensor(1, 100) 17 | │ │ │ ├─add_1: Tensor(1, 100) 18 | │ │ │ ├─add_2: Tensor(1, 100) 19 | │ │ │ ├─c_linear: CompiledGraphModule 20 | │ │ │ │ ├─input: Tensor(1, 100) 21 | │ │ │ │ ├─weight: Tensor(100, 100) 22 | │ │ │ │ ├─bias: Tensor(100) 23 | │ │ │ │ ├─linear: Tensor(1, 100) 24 | │ │ │ │ └─output: Tensor(1, 100) 25 | │ │ │ └─output: Tensor(1, 100) 26 | │ │ ├─b_linear: CompiledGraphModule 27 | │ │ │ ├─input: Tensor(1, 100) 28 | │ │ │ ├─weight: Tensor(100, 100) 29 | │ │ │ ├─bias: Tensor(100) 30 | │ │ │ ├─linear: Tensor(1, 100) 31 | │ │ │ └─output: Tensor(1, 100) 32 | │ │ └─output: Tensor(1, 100) 33 | │ ├─grandchildren_b_1: CompiledGraphModule 34 | │ │ ├─b_inputs: Tensor(1, 100) 35 | │ │ ├─a_inputs: Tensor(1, 100) 36 | │ │ ├─ones_like: Tensor(1, 100) 37 | │ │ ├─mul: Tensor(1, 100) 38 | │ │ ├─c: CompiledGraphModule 39 | │ │ │ ├─c_inputs: Tensor(1, 100) 40 | │ │ │ ├─inputs_2: Tensor(1, 100) 41 | │ │ │ ├─inputs_3: Tensor(1, 100) 42 | │ │ │ ├─inputs_4: Tensor(1, 100) 43 | │ │ │ ├─add: Tensor(1, 100) 44 | │ │ │ ├─add_1: Tensor(1, 100) 45 | │ │ │ ├─add_2: Tensor(1, 100) 46 | │ │ │ ├─c_linear: CompiledGraphModule 47 | │ │ │ │ ├─input: Tensor(1, 100) 48 | │ │ │ │ ├─weight: Tensor(100, 100) 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─linear: Tensor(1, 100) 51 | │ │ │ │ └─output: Tensor(1, 100) 52 | │ │ │ └─output: Tensor(1, 100) 53 | │ │ ├─b_linear: CompiledGraphModule 54 | │ │ │ ├─input: Tensor(1, 100) 55 | │ │ │ ├─weight: Tensor(100, 100) 56 | │ │ │ ├─bias: Tensor(100) 57 | │ │ │ ├─linear: Tensor(1, 100) 58 | │ │ │ └─output: Tensor(1, 100) 59 | │ │ └─output: Tensor(1, 100) 60 | │ ├─grandchildren_b_3: CompiledGraphModule 61 | │ │ ├─b_inputs: Tensor(1, 100) 62 | │ │ ├─a_inputs: Tensor(1, 100) 63 | │ │ ├─ones_like: Tensor(1, 100) 64 | │ │ ├─mul: Tensor(1, 100) 65 | │ │ ├─c: CompiledGraphModule 66 | │ │ │ ├─c_inputs: Tensor(1, 100) 67 | │ │ │ ├─inputs_2: Tensor(1, 100) 68 | │ │ │ ├─inputs_3: Tensor(1, 100) 69 | │ │ │ ├─inputs_4: Tensor(1, 100) 70 | │ │ │ ├─add: Tensor(1, 100) 71 | │ │ │ ├─add_1: Tensor(1, 100) 72 | │ │ │ ├─add_2: Tensor(1, 100) 73 | │ │ │ ├─c_linear: CompiledGraphModule 74 | │ │ │ │ ├─input: Tensor(1, 100) 75 | │ │ │ │ ├─weight: Tensor(100, 100) 76 | │ │ │ │ ├─bias: Tensor(100) 77 | │ │ │ │ ├─linear: Tensor(1, 100) 78 | │ │ │ │ └─output: Tensor(1, 100) 79 | │ │ │ └─output: Tensor(1, 100) 80 | │ │ ├─b_linear: CompiledGraphModule 81 | │ │ │ ├─input: Tensor(1, 100) 82 | │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ ├─bias: Tensor(100) 84 | │ │ │ ├─linear: Tensor(1, 100) 85 | │ │ │ └─output: Tensor(1, 100) 86 | │ │ └─output: Tensor(1, 100) 87 | │ ├─a_linear: CompiledGraphModule 88 | │ │ ├─input: Tensor(1, 100) 89 | │ │ ├─weight: Tensor(100, 100) 90 | │ │ ├─bias: Tensor(100) 91 | │ │ ├─linear: Tensor(1, 100) 92 | │ │ └─output: Tensor(1, 100) 93 | │ └─output: Tensor(1, 100) 94 | ├─root_linear: CompiledGraphModule 95 | │ ├─input: Tensor(1, 100) 96 | │ ├─weight: Tensor(100, 100) 97 | │ ├─bias: Tensor(100) 98 | │ ├─linear: Tensor(1, 100) 99 | │ └─output: Tensor(1, 100) 100 | └─output: Tensor(1, 100) 101 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_unused_submodule_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─child_a: CompiledGraphModule 5 | │ ├─a_inputs: Tensor(1, 100) 6 | │ ├─grandchildren_b_0: CompiledGraphModule 7 | │ │ ├─b_inputs: Tensor(1, 100) 8 | │ │ ├─a_inputs: Tensor(1, 100) 9 | │ │ ├─ones_like: Tensor(1, 100) 10 | │ │ ├─mul: Tensor(1, 100) 11 | │ │ ├─c: CompiledGraphModule 12 | │ │ │ ├─c_inputs: Tensor(1, 100) 13 | │ │ │ ├─inputs_2: Tensor(1, 100) 14 | │ │ │ ├─inputs_3: Tensor(1, 100) 15 | │ │ │ ├─inputs_4: Tensor(1, 100) 16 | │ │ │ ├─add: Tensor(1, 100) 17 | │ │ │ ├─add_1: Tensor(1, 100) 18 | │ │ │ ├─add_2: Tensor(1, 100) 19 | │ │ │ ├─c_linear: CompiledGraphModule 20 | │ │ │ │ ├─input: Tensor(1, 100) 21 | │ │ │ │ ├─weight: Tensor(100, 100) 22 | │ │ │ │ ├─bias: Tensor(100) 23 | │ │ │ │ ├─linear: Tensor(1, 100) 24 | │ │ │ │ └─output: Tensor(1, 100) 25 | │ │ │ └─output: Tensor(1, 100) 26 | │ │ ├─b_linear: CompiledGraphModule 27 | │ │ │ ├─input: Tensor(1, 100) 28 | │ │ │ ├─weight: Tensor(100, 100) 29 | │ │ │ ├─bias: Tensor(100) 30 | │ │ │ ├─linear: Tensor(1, 100) 31 | │ │ │ └─output: Tensor(1, 100) 32 | │ │ └─output: Tensor(1, 100) 33 | │ ├─grandchildren_b_1: CompiledGraphModule 34 | │ │ ├─b_inputs: Tensor(1, 100) 35 | │ │ ├─a_inputs: Tensor(1, 100) 36 | │ │ ├─ones_like: Tensor(1, 100) 37 | │ │ ├─mul: Tensor(1, 100) 38 | │ │ ├─c: CompiledGraphModule 39 | │ │ │ ├─c_inputs: Tensor(1, 100) 40 | │ │ │ ├─inputs_2: Tensor(1, 100) 41 | │ │ │ ├─inputs_3: Tensor(1, 100) 42 | │ │ │ ├─inputs_4: Tensor(1, 100) 43 | │ │ │ ├─add: Tensor(1, 100) 44 | │ │ │ ├─add_1: Tensor(1, 100) 45 | │ │ │ ├─add_2: Tensor(1, 100) 46 | │ │ │ ├─c_linear: CompiledGraphModule 47 | │ │ │ │ ├─input: Tensor(1, 100) 48 | │ │ │ │ ├─weight: Tensor(100, 100) 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─linear: Tensor(1, 100) 51 | │ │ │ │ └─output: Tensor(1, 100) 52 | │ │ │ └─output: Tensor(1, 100) 53 | │ │ ├─b_linear: CompiledGraphModule 54 | │ │ │ ├─input: Tensor(1, 100) 55 | │ │ │ ├─weight: Tensor(100, 100) 56 | │ │ │ ├─bias: Tensor(100) 57 | │ │ │ ├─linear: Tensor(1, 100) 58 | │ │ │ └─output: Tensor(1, 100) 59 | │ │ └─output: Tensor(1, 100) 60 | │ ├─grandchildren_b_3: CompiledGraphModule 61 | │ │ ├─b_inputs: Tensor(1, 100) 62 | │ │ ├─a_inputs: Tensor(1, 100) 63 | │ │ ├─ones_like: Tensor(1, 100) 64 | │ │ ├─mul: Tensor(1, 100) 65 | │ │ ├─c: CompiledGraphModule 66 | │ │ │ ├─c_inputs: Tensor(1, 100) 67 | │ │ │ ├─inputs_2: Tensor(1, 100) 68 | │ │ │ ├─inputs_3: Tensor(1, 100) 69 | │ │ │ ├─inputs_4: Tensor(1, 100) 70 | │ │ │ ├─add: Tensor(1, 100) 71 | │ │ │ ├─add_1: Tensor(1, 100) 72 | │ │ │ ├─add_2: Tensor(1, 100) 73 | │ │ │ ├─c_linear: CompiledGraphModule 74 | │ │ │ │ ├─input: Tensor(1, 100) 75 | │ │ │ │ ├─weight: Tensor(100, 100) 76 | │ │ │ │ ├─bias: Tensor(100) 77 | │ │ │ │ ├─linear: Tensor(1, 100) 78 | │ │ │ │ └─output: Tensor(1, 100) 79 | │ │ │ └─output: Tensor(1, 100) 80 | │ │ ├─b_linear: CompiledGraphModule 81 | │ │ │ ├─input: Tensor(1, 100) 82 | │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ ├─bias: Tensor(100) 84 | │ │ │ ├─linear: Tensor(1, 100) 85 | │ │ │ └─output: Tensor(1, 100) 86 | │ │ └─output: Tensor(1, 100) 87 | │ ├─a_linear: CompiledGraphModule 88 | │ │ ├─input: Tensor(1, 100) 89 | │ │ ├─weight: Tensor(100, 100) 90 | │ │ ├─bias: Tensor(100) 91 | │ │ ├─linear: Tensor(1, 100) 92 | │ │ └─output: Tensor(1, 100) 93 | │ └─output: Tensor(1, 100) 94 | ├─root_linear: CompiledGraphModule 95 | │ ├─input: Tensor(1, 100) 96 | │ ├─weight: Tensor(100, 100) 97 | │ ├─bias: Tensor(100) 98 | │ ├─linear: Tensor(1, 100) 99 | │ └─output: Tensor(1, 100) 100 | └─output: Tensor(1, 100) 101 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_unused_submodule_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─child_a: CompiledGraphModule 5 | │ ├─a_inputs: Tensor(1, 100) 6 | │ ├─grandchildren_b_0: CompiledGraphModule 7 | │ │ ├─b_inputs: Tensor(1, 100) 8 | │ │ ├─a_inputs: Tensor(1, 100) 9 | │ │ ├─ones_like: Tensor(1, 100) 10 | │ │ ├─mul: Tensor(1, 100) 11 | │ │ ├─c: CompiledGraphModule 12 | │ │ │ ├─c_inputs: Tensor(1, 100) 13 | │ │ │ ├─inputs_2: Tensor(1, 100) 14 | │ │ │ ├─inputs_3: Tensor(1, 100) 15 | │ │ │ ├─inputs_4: Tensor(1, 100) 16 | │ │ │ ├─add: Tensor(1, 100) 17 | │ │ │ ├─add_1: Tensor(1, 100) 18 | │ │ │ ├─add_2: Tensor(1, 100) 19 | │ │ │ ├─c_linear: CompiledGraphModule 20 | │ │ │ │ ├─input: Tensor(1, 100) 21 | │ │ │ │ ├─weight: Tensor(100, 100) 22 | │ │ │ │ ├─bias: Tensor(100) 23 | │ │ │ │ ├─linear: Tensor(1, 100) 24 | │ │ │ │ └─output: Tensor(1, 100) 25 | │ │ │ └─output: Tensor(1, 100) 26 | │ │ ├─b_linear: CompiledGraphModule 27 | │ │ │ ├─input: Tensor(1, 100) 28 | │ │ │ ├─weight: Tensor(100, 100) 29 | │ │ │ ├─bias: Tensor(100) 30 | │ │ │ ├─linear: Tensor(1, 100) 31 | │ │ │ └─output: Tensor(1, 100) 32 | │ │ └─output: Tensor(1, 100) 33 | │ ├─grandchildren_b_1: CompiledGraphModule 34 | │ │ ├─b_inputs: Tensor(1, 100) 35 | │ │ ├─a_inputs: Tensor(1, 100) 36 | │ │ ├─ones_like: Tensor(1, 100) 37 | │ │ ├─mul: Tensor(1, 100) 38 | │ │ ├─c: CompiledGraphModule 39 | │ │ │ ├─c_inputs: Tensor(1, 100) 40 | │ │ │ ├─inputs_2: Tensor(1, 100) 41 | │ │ │ ├─inputs_3: Tensor(1, 100) 42 | │ │ │ ├─inputs_4: Tensor(1, 100) 43 | │ │ │ ├─add: Tensor(1, 100) 44 | │ │ │ ├─add_1: Tensor(1, 100) 45 | │ │ │ ├─add_2: Tensor(1, 100) 46 | │ │ │ ├─c_linear: CompiledGraphModule 47 | │ │ │ │ ├─input: Tensor(1, 100) 48 | │ │ │ │ ├─weight: Tensor(100, 100) 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─linear: Tensor(1, 100) 51 | │ │ │ │ └─output: Tensor(1, 100) 52 | │ │ │ └─output: Tensor(1, 100) 53 | │ │ ├─b_linear: CompiledGraphModule 54 | │ │ │ ├─input: Tensor(1, 100) 55 | │ │ │ ├─weight: Tensor(100, 100) 56 | │ │ │ ├─bias: Tensor(100) 57 | │ │ │ ├─linear: Tensor(1, 100) 58 | │ │ │ └─output: Tensor(1, 100) 59 | │ │ └─output: Tensor(1, 100) 60 | │ ├─grandchildren_b_3: CompiledGraphModule 61 | │ │ ├─b_inputs: Tensor(1, 100) 62 | │ │ ├─a_inputs: Tensor(1, 100) 63 | │ │ ├─ones_like: Tensor(1, 100) 64 | │ │ ├─mul: Tensor(1, 100) 65 | │ │ ├─c: CompiledGraphModule 66 | │ │ │ ├─c_inputs: Tensor(1, 100) 67 | │ │ │ ├─inputs_2: Tensor(1, 100) 68 | │ │ │ ├─inputs_3: Tensor(1, 100) 69 | │ │ │ ├─inputs_4: Tensor(1, 100) 70 | │ │ │ ├─add: Tensor(1, 100) 71 | │ │ │ ├─add_1: Tensor(1, 100) 72 | │ │ │ ├─add_2: Tensor(1, 100) 73 | │ │ │ ├─c_linear: CompiledGraphModule 74 | │ │ │ │ ├─input: Tensor(1, 100) 75 | │ │ │ │ ├─weight: Tensor(100, 100) 76 | │ │ │ │ ├─bias: Tensor(100) 77 | │ │ │ │ ├─linear: Tensor(1, 100) 78 | │ │ │ │ └─output: Tensor(1, 100) 79 | │ │ │ └─output: Tensor(1, 100) 80 | │ │ ├─b_linear: CompiledGraphModule 81 | │ │ │ ├─input: Tensor(1, 100) 82 | │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ ├─bias: Tensor(100) 84 | │ │ │ ├─linear: Tensor(1, 100) 85 | │ │ │ └─output: Tensor(1, 100) 86 | │ │ └─output: Tensor(1, 100) 87 | │ ├─a_linear: CompiledGraphModule 88 | │ │ ├─input: Tensor(1, 100) 89 | │ │ ├─weight: Tensor(100, 100) 90 | │ │ ├─bias: Tensor(100) 91 | │ │ ├─linear: Tensor(1, 100) 92 | │ │ └─output: Tensor(1, 100) 93 | │ └─output: Tensor(1, 100) 94 | ├─root_linear: CompiledGraphModule 95 | │ ├─input: Tensor(1, 100) 96 | │ ├─weight: Tensor(100, 100) 97 | │ ├─bias: Tensor(100) 98 | │ ├─linear: Tensor(1, 100) 99 | │ └─output: Tensor(1, 100) 100 | └─output: Tensor(1, 100) 101 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_unused_submodule_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─child_a: CompiledGraphModule 5 | │ ├─a_inputs: Tensor(1, 100) 6 | │ ├─grandchildren_b_0: CompiledGraphModule 7 | │ │ ├─b_inputs: Tensor(1, 100) 8 | │ │ ├─a_inputs: Tensor(1, 100) 9 | │ │ ├─ones_like: Tensor(1, 100) 10 | │ │ ├─mul: Tensor(1, 100) 11 | │ │ ├─c: CompiledGraphModule 12 | │ │ │ ├─c_inputs: Tensor(1, 100) 13 | │ │ │ ├─inputs_2: Tensor(1, 100) 14 | │ │ │ ├─inputs_3: Tensor(1, 100) 15 | │ │ │ ├─inputs_4: Tensor(1, 100) 16 | │ │ │ ├─add: Tensor(1, 100) 17 | │ │ │ ├─add_1: Tensor(1, 100) 18 | │ │ │ ├─add_2: Tensor(1, 100) 19 | │ │ │ ├─c_linear: CompiledGraphModule 20 | │ │ │ │ ├─input: Tensor(1, 100) 21 | │ │ │ │ ├─weight: Tensor(100, 100) 22 | │ │ │ │ ├─bias: Tensor(100) 23 | │ │ │ │ ├─linear: Tensor(1, 100) 24 | │ │ │ │ └─output: Tensor(1, 100) 25 | │ │ │ └─output: Tensor(1, 100) 26 | │ │ ├─b_linear: CompiledGraphModule 27 | │ │ │ ├─input: Tensor(1, 100) 28 | │ │ │ ├─weight: Tensor(100, 100) 29 | │ │ │ ├─bias: Tensor(100) 30 | │ │ │ ├─linear: Tensor(1, 100) 31 | │ │ │ └─output: Tensor(1, 100) 32 | │ │ └─output: Tensor(1, 100) 33 | │ ├─grandchildren_b_1: CompiledGraphModule 34 | │ │ ├─b_inputs: Tensor(1, 100) 35 | │ │ ├─a_inputs: Tensor(1, 100) 36 | │ │ ├─ones_like: Tensor(1, 100) 37 | │ │ ├─mul: Tensor(1, 100) 38 | │ │ ├─c: CompiledGraphModule 39 | │ │ │ ├─c_inputs: Tensor(1, 100) 40 | │ │ │ ├─inputs_2: Tensor(1, 100) 41 | │ │ │ ├─inputs_3: Tensor(1, 100) 42 | │ │ │ ├─inputs_4: Tensor(1, 100) 43 | │ │ │ ├─add: Tensor(1, 100) 44 | │ │ │ ├─add_1: Tensor(1, 100) 45 | │ │ │ ├─add_2: Tensor(1, 100) 46 | │ │ │ ├─c_linear: CompiledGraphModule 47 | │ │ │ │ ├─input: Tensor(1, 100) 48 | │ │ │ │ ├─weight: Tensor(100, 100) 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─linear: Tensor(1, 100) 51 | │ │ │ │ └─output: Tensor(1, 100) 52 | │ │ │ └─output: Tensor(1, 100) 53 | │ │ ├─b_linear: CompiledGraphModule 54 | │ │ │ ├─input: Tensor(1, 100) 55 | │ │ │ ├─weight: Tensor(100, 100) 56 | │ │ │ ├─bias: Tensor(100) 57 | │ │ │ ├─linear: Tensor(1, 100) 58 | │ │ │ └─output: Tensor(1, 100) 59 | │ │ └─output: Tensor(1, 100) 60 | │ ├─grandchildren_b_3: CompiledGraphModule 61 | │ │ ├─b_inputs: Tensor(1, 100) 62 | │ │ ├─a_inputs: Tensor(1, 100) 63 | │ │ ├─ones_like: Tensor(1, 100) 64 | │ │ ├─mul: Tensor(1, 100) 65 | │ │ ├─c: CompiledGraphModule 66 | │ │ │ ├─c_inputs: Tensor(1, 100) 67 | │ │ │ ├─inputs_2: Tensor(1, 100) 68 | │ │ │ ├─inputs_3: Tensor(1, 100) 69 | │ │ │ ├─inputs_4: Tensor(1, 100) 70 | │ │ │ ├─add: Tensor(1, 100) 71 | │ │ │ ├─add_1: Tensor(1, 100) 72 | │ │ │ ├─add_2: Tensor(1, 100) 73 | │ │ │ ├─c_linear: CompiledGraphModule 74 | │ │ │ │ ├─input: Tensor(1, 100) 75 | │ │ │ │ ├─weight: Tensor(100, 100) 76 | │ │ │ │ ├─bias: Tensor(100) 77 | │ │ │ │ ├─linear: Tensor(1, 100) 78 | │ │ │ │ └─output: Tensor(1, 100) 79 | │ │ │ └─output: Tensor(1, 100) 80 | │ │ ├─b_linear: CompiledGraphModule 81 | │ │ │ ├─input: Tensor(1, 100) 82 | │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ ├─bias: Tensor(100) 84 | │ │ │ ├─linear: Tensor(1, 100) 85 | │ │ │ └─output: Tensor(1, 100) 86 | │ │ └─output: Tensor(1, 100) 87 | │ ├─a_linear: CompiledGraphModule 88 | │ │ ├─input: Tensor(1, 100) 89 | │ │ ├─weight: Tensor(100, 100) 90 | │ │ ├─bias: Tensor(100) 91 | │ │ ├─linear: Tensor(1, 100) 92 | │ │ └─output: Tensor(1, 100) 93 | │ └─output: Tensor(1, 100) 94 | ├─root_linear: CompiledGraphModule 95 | │ ├─input: Tensor(1, 100) 96 | │ ├─weight: Tensor(100, 100) 97 | │ ├─bias: Tensor(100) 98 | │ ├─linear: Tensor(1, 100) 99 | │ └─output: Tensor(1, 100) 100 | └─output: Tensor(1, 100) 101 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_varargs_module_2_0.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foos: tuple(3) 5 | │ ├─sub_0: Tensor(3, 3) 6 | │ ├─sub_1: Tensor(3, 3) 7 | │ └─sub_2: Tensor(3, 3) 8 | ├─blah: int 9 | ├─bars: dict(2) 10 | │ ├─a: Tensor(3, 3) 11 | │ └─b: Tensor(3, 3) 12 | ├─getitem: Tensor(3, 3) 13 | ├─linear_0: CompiledGraphModule 14 | │ ├─input: Tensor(3, 3) 15 | │ ├─weight: Tensor(3, 3) 16 | │ ├─bias: Tensor(3) 17 | │ ├─linear: Tensor(3, 3) 18 | │ └─output: Tensor(3, 3) 19 | ├─iadd: Tensor(3, 3) 20 | ├─iadd_1: Tensor(3, 3) 21 | ├─iadd_2: Tensor(3, 3) 22 | ├─add: Tensor(3, 3) 23 | ├─linear_1: CompiledGraphModule 24 | │ ├─input: Tensor(3, 3) 25 | │ ├─weight: Tensor(3, 3) 26 | │ ├─bias: Tensor(3) 27 | │ ├─linear: Tensor(3, 3) 28 | │ └─output: Tensor(3, 3) 29 | ├─iadd_3: Tensor(3, 3) 30 | ├─iadd_4: Tensor(3, 3) 31 | └─output: Tensor(3, 3) 32 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_varargs_module_2_1.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foos: tuple(3) 5 | │ ├─sub_0: Tensor(3, 3) 6 | │ ├─sub_1: Tensor(3, 3) 7 | │ └─sub_2: Tensor(3, 3) 8 | ├─blah: int 9 | ├─bars: dict(2) 10 | │ ├─a: Tensor(3, 3) 11 | │ └─b: Tensor(3, 3) 12 | ├─getitem: Tensor(3, 3) 13 | ├─linear_0: CompiledGraphModule 14 | │ ├─input: Tensor(3, 3) 15 | │ ├─weight: Tensor(3, 3) 16 | │ ├─bias: Tensor(3) 17 | │ ├─linear: Tensor(3, 3) 18 | │ └─output: Tensor(3, 3) 19 | ├─iadd: Tensor(3, 3) 20 | ├─iadd_1: Tensor(3, 3) 21 | ├─iadd_2: Tensor(3, 3) 22 | ├─add: Tensor(3, 3) 23 | ├─linear_1: CompiledGraphModule 24 | │ ├─input: Tensor(3, 3) 25 | │ ├─weight: Tensor(3, 3) 26 | │ ├─bias: Tensor(3) 27 | │ ├─linear: Tensor(3, 3) 28 | │ └─output: Tensor(3, 3) 29 | ├─iadd_3: Tensor(3, 3) 30 | ├─iadd_4: Tensor(3, 3) 31 | └─output: Tensor(3, 3) 32 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_varargs_module_2_2-2_3.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foos: tuple(3) 5 | │ ├─sub_0: Tensor(3, 3) 6 | │ ├─sub_1: Tensor(3, 3) 7 | │ └─sub_2: Tensor(3, 3) 8 | ├─blah: int 9 | ├─bars: dict(2) 10 | │ ├─a: Tensor(3, 3) 11 | │ └─b: Tensor(3, 3) 12 | ├─getitem: Tensor(3, 3) 13 | ├─linear_0: CompiledGraphModule 14 | │ ├─input: Tensor(3, 3) 15 | │ ├─weight: Tensor(3, 3) 16 | │ ├─bias: Tensor(3) 17 | │ ├─linear: Tensor(3, 3) 18 | │ └─output: Tensor(3, 3) 19 | ├─result_1: Tensor(3, 3) 20 | ├─result_2: Tensor(3, 3) 21 | ├─result_3: Tensor(3, 3) 22 | ├─add: Tensor(3, 3) 23 | ├─linear_1: CompiledGraphModule 24 | │ ├─input: Tensor(3, 3) 25 | │ ├─weight: Tensor(3, 3) 26 | │ ├─bias: Tensor(3) 27 | │ ├─linear: Tensor(3, 3) 28 | │ └─output: Tensor(3, 3) 29 | ├─result_5: Tensor(3, 3) 30 | ├─result_6: Tensor(3, 3) 31 | └─output: Tensor(3, 3) 32 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_varargs_module_2_4.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foos: tuple(3) 5 | │ ├─sub_0: Tensor(3, 3) 6 | │ ├─sub_1: Tensor(3, 3) 7 | │ └─sub_2: Tensor(3, 3) 8 | ├─blah: int 9 | ├─bars: dict(2) 10 | │ ├─a: Tensor(3, 3) 11 | │ └─b: Tensor(3, 3) 12 | ├─getitem: Tensor(3, 3) 13 | ├─linear_0: CompiledGraphModule 14 | │ ├─input: Tensor(3, 3) 15 | │ ├─weight: Tensor(3, 3) 16 | │ ├─bias: Tensor(3) 17 | │ ├─linear: Tensor(3, 3) 18 | │ └─output: Tensor(3, 3) 19 | ├─result_1: Tensor(3, 3) 20 | ├─result_2: Tensor(3, 3) 21 | ├─result_3: Tensor(3, 3) 22 | ├─add: Tensor(3, 3) 23 | ├─linear_1: CompiledGraphModule 24 | │ ├─input: Tensor(3, 3) 25 | │ ├─weight: Tensor(3, 3) 26 | │ ├─bias: Tensor(3) 27 | │ ├─linear: Tensor(3, 3) 28 | │ └─output: Tensor(3, 3) 29 | ├─result_5: Tensor(3, 3) 30 | ├─result_6: Tensor(3, 3) 31 | └─output: Tensor(3, 3) 32 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[compiled]/patchable_varargs_module_2_5.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : CompiledGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foos: tuple(3) 5 | │ ├─sub_0: Tensor(3, 3) 6 | │ ├─sub_1: Tensor(3, 3) 7 | │ └─sub_2: Tensor(3, 3) 8 | ├─blah: int 9 | ├─bars: dict(2) 10 | │ ├─a: Tensor(3, 3) 11 | │ └─b: Tensor(3, 3) 12 | ├─getitem: Tensor(3, 3) 13 | ├─linear_0: CompiledGraphModule 14 | │ ├─input: Tensor(3, 3) 15 | │ ├─weight: Tensor(3, 3) 16 | │ ├─bias: Tensor(3) 17 | │ ├─linear: Tensor(3, 3) 18 | │ └─output: Tensor(3, 3) 19 | ├─result_1: Tensor(3, 3) 20 | ├─result_2: Tensor(3, 3) 21 | ├─result_3: Tensor(3, 3) 22 | ├─add: Tensor(3, 3) 23 | ├─linear_1: CompiledGraphModule 24 | │ ├─input: Tensor(3, 3) 25 | │ ├─weight: Tensor(3, 3) 26 | │ ├─bias: Tensor(3) 27 | │ ├─linear: Tensor(3, 3) 28 | │ └─output: Tensor(3, 3) 29 | ├─result_5: Tensor(3, 3) 30 | ├─result_6: Tensor(3, 3) 31 | └─output: Tensor(3, 3) 32 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[opaque]/patchable_attribute_module.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─sub_shape: tuple(2) 5 | │ ├─sub_0: int 6 | │ └─sub_1: int 7 | ├─attribute_to_serialize: float 8 | ├─linear: OpaqueGraphModule 9 | │ ├─input: Tensor(3, 2) 10 | │ ├─constants: list(2) 11 | │ │ ├─sub_0: str 12 | │ │ └─sub_1: str 13 | │ ├─bias: Tensor(3) 14 | │ ├─in_features: int 15 | │ ├─out_features: int 16 | │ ├─weight: Tensor(3, 2) 17 | │ └─output: Tensor(3, 3) 18 | └─output: Tensor(3, 3) 19 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[opaque]/patchable_buffer_module.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─sub_shape: tuple(2) 5 | │ ├─sub_0: int 6 | │ └─sub_1: int 7 | ├─buffer: Tensor(3, 3) 8 | ├─linear: OpaqueGraphModule 9 | │ ├─input: Tensor(3, 2) 10 | │ ├─constants: list(2) 11 | │ │ ├─sub_0: str 12 | │ │ └─sub_1: str 13 | │ ├─bias: Tensor(3) 14 | │ ├─in_features: int 15 | │ ├─out_features: int 16 | │ ├─weight: Tensor(3, 2) 17 | │ └─output: Tensor(3, 3) 18 | └─output: Tensor(3, 3) 19 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[opaque]/patchable_graph_break_module.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foo: int 5 | ├─sub_shape: tuple(2) 6 | │ ├─sub_0: int 7 | │ └─sub_1: int 8 | ├─bar: int 9 | ├─instance_value: int 10 | ├─shadowed_class_var: int 11 | ├─linear_0: OpaqueGraphModule 12 | │ ├─input: Tensor(3, 3) 13 | │ ├─constants: list(2) 14 | │ │ ├─sub_0: str 15 | │ │ └─sub_1: str 16 | │ ├─bias: Tensor(3) 17 | │ ├─in_features: int 18 | │ ├─out_features: int 19 | │ ├─weight: Tensor(3, 3) 20 | │ └─output: Tensor(3, 3) 21 | ├─linear_1: OpaqueGraphModule 22 | │ ├─input: Tensor(3, 3) 23 | │ ├─constants: list(2) 24 | │ │ ├─sub_0: str 25 | │ │ └─sub_1: str 26 | │ ├─bias: Tensor(3) 27 | │ ├─in_features: int 28 | │ ├─out_features: int 29 | │ ├─weight: Tensor(3, 3) 30 | │ └─output: Tensor(3, 3) 31 | ├─linear_2: OpaqueGraphModule 32 | │ ├─input: Tensor(3, 3) 33 | │ ├─constants: list(2) 34 | │ │ ├─sub_0: str 35 | │ │ └─sub_1: str 36 | │ ├─bias: Tensor(3) 37 | │ ├─in_features: int 38 | │ ├─out_features: int 39 | │ ├─weight: Tensor(3, 3) 40 | │ └─output: Tensor(3, 3) 41 | ├─unused_submodule: OpaqueGraphModule 42 | │ ├─input 43 | │ ├─constants 44 | │ ├─bias 45 | │ ├─in_features 46 | │ ├─out_features 47 | │ ├─weight 48 | │ └─output 49 | └─output: Tensor(3, 3) 50 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[opaque]/patchable_layer_norm_module.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─sub_shape: tuple(2) 5 | │ ├─sub_0: int 6 | │ └─sub_1: int 7 | ├─ln: OpaqueGraphModule 8 | │ ├─input: Tensor(3, 2) 9 | │ ├─bias: Tensor(2) 10 | │ ├─eps: float 11 | │ ├─normalized_shape: tuple(1) 12 | │ │ └─sub_0: int 13 | │ ├─weight: Tensor(2) 14 | │ └─output: Tensor(3, 2) 15 | └─output: Tensor(3, 2) 16 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[opaque]/patchable_minimal_module.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─sub_shape: tuple(2) 5 | │ ├─sub_0: int 6 | │ └─sub_1: int 7 | ├─linear: OpaqueGraphModule 8 | │ ├─input: Tensor(3, 2) 9 | │ ├─constants: list(2) 10 | │ │ ├─sub_0: str 11 | │ │ └─sub_1: str 12 | │ ├─bias: Tensor(3) 13 | │ ├─in_features: int 14 | │ ├─out_features: int 15 | │ ├─weight: Tensor(3, 2) 16 | │ └─output: Tensor(3, 3) 17 | └─output: Tensor(3, 3) 18 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[opaque]/patchable_nested_module.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─root_inputs: Tensor(1, 100) 4 | ├─root_linear: OpaqueGraphModule 5 | │ ├─input: Tensor(1, 100) 6 | │ ├─constants: list(2) 7 | │ │ ├─sub_0: str 8 | │ │ └─sub_1: str 9 | │ ├─bias: Tensor(100) 10 | │ ├─in_features: int 11 | │ ├─out_features: int 12 | │ ├─weight: Tensor(100, 100) 13 | │ └─output: Tensor(1, 100) 14 | ├─child_a: OpaqueGraphModule 15 | │ ├─a_inputs: Tensor(1, 100) 16 | │ ├─a_linear: OpaqueGraphModule 17 | │ │ ├─input: Tensor(1, 100) 18 | │ │ ├─constants: list(2) 19 | │ │ │ ├─sub_0: str 20 | │ │ │ └─sub_1: str 21 | │ │ ├─bias: Tensor(100) 22 | │ │ ├─in_features: int 23 | │ │ ├─out_features: int 24 | │ │ ├─weight: Tensor(100, 100) 25 | │ │ └─output: Tensor(1, 100) 26 | │ ├─grandchildren_b_0: OpaqueGraphModule 27 | │ │ ├─b_inputs: Tensor(1, 100) 28 | │ │ ├─a_inputs: Tensor(1, 100) 29 | │ │ ├─b_linear: OpaqueGraphModule 30 | │ │ │ ├─input: Tensor(1, 100) 31 | │ │ │ ├─constants: list(2) 32 | │ │ │ │ ├─sub_0: str 33 | │ │ │ │ └─sub_1: str 34 | │ │ │ ├─bias: Tensor(100) 35 | │ │ │ ├─in_features: int 36 | │ │ │ ├─out_features: int 37 | │ │ │ ├─weight: Tensor(100, 100) 38 | │ │ │ └─output: Tensor(1, 100) 39 | │ │ ├─c: OpaqueGraphModule 40 | │ │ │ ├─c_inputs: Tensor(1, 100) 41 | │ │ │ ├─inputs_2: Tensor(1, 100) 42 | │ │ │ ├─inputs_3: Tensor(1, 100) 43 | │ │ │ ├─inputs_4: Tensor(1, 100) 44 | │ │ │ ├─c_linear: OpaqueGraphModule 45 | │ │ │ │ ├─input: Tensor(1, 100) 46 | │ │ │ │ ├─constants: list(2) 47 | │ │ │ │ │ ├─sub_0: str 48 | │ │ │ │ │ └─sub_1: str 49 | │ │ │ │ ├─bias: Tensor(100) 50 | │ │ │ │ ├─in_features: int 51 | │ │ │ │ ├─out_features: int 52 | │ │ │ │ ├─weight: Tensor(100, 100) 53 | │ │ │ │ └─output: Tensor(1, 100) 54 | │ │ │ └─output: Tensor(1, 100) 55 | │ │ └─output: Tensor(1, 100) 56 | │ ├─grandchildren_b_1: OpaqueGraphModule 57 | │ │ ├─b_inputs: Tensor(1, 100) 58 | │ │ ├─a_inputs: Tensor(1, 100) 59 | │ │ ├─b_linear: OpaqueGraphModule 60 | │ │ │ ├─input: Tensor(1, 100) 61 | │ │ │ ├─constants: list(2) 62 | │ │ │ │ ├─sub_0: str 63 | │ │ │ │ └─sub_1: str 64 | │ │ │ ├─bias: Tensor(100) 65 | │ │ │ ├─in_features: int 66 | │ │ │ ├─out_features: int 67 | │ │ │ ├─weight: Tensor(100, 100) 68 | │ │ │ └─output: Tensor(1, 100) 69 | │ │ ├─c: OpaqueGraphModule 70 | │ │ │ ├─c_inputs: Tensor(1, 100) 71 | │ │ │ ├─inputs_2: Tensor(1, 100) 72 | │ │ │ ├─inputs_3: Tensor(1, 100) 73 | │ │ │ ├─inputs_4: Tensor(1, 100) 74 | │ │ │ ├─c_linear: OpaqueGraphModule 75 | │ │ │ │ ├─input: Tensor(1, 100) 76 | │ │ │ │ ├─constants: list(2) 77 | │ │ │ │ │ ├─sub_0: str 78 | │ │ │ │ │ └─sub_1: str 79 | │ │ │ │ ├─bias: Tensor(100) 80 | │ │ │ │ ├─in_features: int 81 | │ │ │ │ ├─out_features: int 82 | │ │ │ │ ├─weight: Tensor(100, 100) 83 | │ │ │ │ └─output: Tensor(1, 100) 84 | │ │ │ └─output: Tensor(1, 100) 85 | │ │ └─output: Tensor(1, 100) 86 | │ ├─grandchildren_b_2: OpaqueGraphModule 87 | │ │ ├─b_inputs: Tensor(1, 100) 88 | │ │ ├─a_inputs: Tensor(1, 100) 89 | │ │ ├─b_linear: OpaqueGraphModule 90 | │ │ │ ├─input: Tensor(1, 100) 91 | │ │ │ ├─constants: list(2) 92 | │ │ │ │ ├─sub_0: str 93 | │ │ │ │ └─sub_1: str 94 | │ │ │ ├─bias: Tensor(100) 95 | │ │ │ ├─in_features: int 96 | │ │ │ ├─out_features: int 97 | │ │ │ ├─weight: Tensor(100, 100) 98 | │ │ │ └─output: Tensor(1, 100) 99 | │ │ ├─c: OpaqueGraphModule 100 | │ │ │ ├─c_inputs: Tensor(1, 100) 101 | │ │ │ ├─inputs_2: Tensor(1, 100) 102 | │ │ │ ├─inputs_3: Tensor(1, 100) 103 | │ │ │ ├─inputs_4: Tensor(1, 100) 104 | │ │ │ ├─c_linear: OpaqueGraphModule 105 | │ │ │ │ ├─input: Tensor(1, 100) 106 | │ │ │ │ ├─constants: list(2) 107 | │ │ │ │ │ ├─sub_0: str 108 | │ │ │ │ │ └─sub_1: str 109 | │ │ │ │ ├─bias: Tensor(100) 110 | │ │ │ │ ├─in_features: int 111 | │ │ │ │ ├─out_features: int 112 | │ │ │ │ ├─weight: Tensor(100, 100) 113 | │ │ │ │ └─output: Tensor(1, 100) 114 | │ │ │ └─output: Tensor(1, 100) 115 | │ │ └─output: Tensor(1, 100) 116 | │ └─output: Tensor(1, 100) 117 | └─output: Tensor(1, 100) 118 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[opaque]/patchable_protected_name_module.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─sub_shape_1: Tensor(3, 2) 4 | ├─sub_code: int 5 | ├─sub_shape: int 6 | ├─_shape_1: tuple(2) 7 | │ ├─sub_0: int 8 | │ └─sub_1: int 9 | ├─bar: tuple(2) 10 | │ ├─sub_0: float 11 | │ └─sub_1: tuple(2) 12 | │ ├─sub_0: float 13 | │ └─sub_1: float 14 | ├─_code_: OpaqueGraphModule 15 | │ ├─input: Tensor(3, 2) 16 | │ ├─constants: list(2) 17 | │ │ ├─sub_0: str 18 | │ │ └─sub_1: str 19 | │ ├─bias: Tensor(3) 20 | │ ├─in_features: int 21 | │ ├─out_features: int 22 | │ ├─weight: Tensor(3, 2) 23 | │ └─output: Tensor(3, 3) 24 | └─output: Tensor(3, 3) 25 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[opaque]/patchable_quantized_module.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─sub_shape: tuple(2) 5 | │ ├─sub_0: int 6 | │ └─sub_1: int 7 | ├─linear: CompiledGraphModule 8 | │ ├─x: Tensor(3, 2) 9 | │ ├─cb: Tensor(3, 2) 10 | │ ├─scb: Tensor(3, 1) 11 | │ ├─bias: Tensor(3) 12 | │ ├─threshold: float 13 | │ ├─mul: Tensor(3, 2) 14 | │ ├─weight: Tensor(3, 2) 15 | │ ├─matmul_8bit: Tensor(3, 3) 16 | │ └─output: Tensor(3, 3) 17 | └─output: Tensor(3, 3) 18 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[opaque]/patchable_tuple_output_module.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 2) 4 | ├─sub_shape: tuple(2) 5 | │ ├─sub_0: int 6 | │ └─sub_1: int 7 | ├─linear_0: OpaqueGraphModule 8 | │ ├─input: Tensor(3, 2) 9 | │ ├─constants: list(2) 10 | │ │ ├─sub_0: str 11 | │ │ └─sub_1: str 12 | │ ├─bias: Tensor(3) 13 | │ ├─in_features: int 14 | │ ├─out_features: int 15 | │ ├─weight: Tensor(3, 2) 16 | │ └─output: Tensor(3, 3) 17 | ├─linear_1: OpaqueGraphModule 18 | │ ├─input: Tensor(3, 2) 19 | │ ├─constants: list(2) 20 | │ │ ├─sub_0: str 21 | │ │ └─sub_1: str 22 | │ ├─bias: Tensor(3) 23 | │ ├─in_features: int 24 | │ ├─out_features: int 25 | │ ├─weight: Tensor(3, 2) 26 | │ └─output: Tensor(3, 3) 27 | └─output: tuple(2) 28 | ├─sub_0: Tensor(3, 3) 29 | └─sub_1: Tensor(3, 3) 30 | ''' -------------------------------------------------------------------------------- /tests/__snapshots__/test_node_path/test_patchable_graph_graph_repr[opaque]/patchable_varargs_module.ambr: -------------------------------------------------------------------------------- 1 | ''' 2 | : OpaqueGraphModule 3 | ├─x: Tensor(3, 3) 4 | ├─foos: tuple(3) 5 | │ ├─sub_0: Tensor(3, 3) 6 | │ ├─sub_1: Tensor(3, 3) 7 | │ └─sub_2: Tensor(3, 3) 8 | ├─blah: int 9 | ├─bars: dict(2) 10 | │ ├─a: Tensor(3, 3) 11 | │ └─b: Tensor(3, 3) 12 | ├─sub_shape: tuple(2) 13 | │ ├─sub_0: int 14 | │ └─sub_1: int 15 | ├─linear_0: OpaqueGraphModule 16 | │ ├─input: Tensor(3, 3) 17 | │ ├─constants: list(2) 18 | │ │ ├─sub_0: str 19 | │ │ └─sub_1: str 20 | │ ├─bias: Tensor(3) 21 | │ ├─in_features: int 22 | │ ├─out_features: int 23 | │ ├─weight: Tensor(3, 3) 24 | │ └─output: Tensor(3, 3) 25 | ├─linear_1: OpaqueGraphModule 26 | │ ├─input: Tensor(3, 3) 27 | │ ├─constants: list(2) 28 | │ │ ├─sub_0: str 29 | │ │ └─sub_1: str 30 | │ ├─bias: Tensor(3) 31 | │ ├─in_features: int 32 | │ ├─out_features: int 33 | │ ├─weight: Tensor(3, 3) 34 | │ └─output: Tensor(3, 3) 35 | └─output: Tensor(3, 3) 36 | ''' -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import site 4 | import sys 5 | from pdb import Pdb 6 | 7 | import pytest 8 | from syrupy.data import SnapshotCollections 9 | from syrupy.report import SnapshotReport 10 | 11 | # No tests in the fixtures directory. 12 | collect_ignore = ["fixtures"] 13 | 14 | # Walk fixtures directory to import all fixtures automatically. 15 | pytest_plugins = [] 16 | for root, dirs, files in os.walk("tests/fixtures"): 17 | for dir in dirs: 18 | if dir.startswith("_"): 19 | dirs.remove(dir) 20 | for file in files: 21 | if file.endswith(".py") and not file.startswith("_"): 22 | pytest_plugins.append(os.path.join(root, file).replace("/", ".")[:-3]) 23 | 24 | pytest.register_assert_rewrite("tests.util") 25 | 26 | 27 | def pytest_addoption(parser, pluginmanager): 28 | # Convenience argument to add breakpoints based on parsed copy+pasted "remote file URLs" from, 29 | # eg, pytorch's python source. Super convenient for developing hacks.py! 30 | parser.addoption("--gpbp", dest="graphpatch_breakpoints", action="append") 31 | 32 | 33 | # https://stackoverflow.com/a/54564137 34 | class RunningTrace: 35 | def set_running_trace(self): 36 | frame = sys._getframe().f_back 37 | self.botframe = None 38 | self.setup(frame, None) 39 | while frame: 40 | frame.f_trace = self.trace_dispatch 41 | self.botframe = frame 42 | frame = frame.f_back 43 | self.set_continue() 44 | self.quitting = False 45 | sys.settrace(self.trace_dispatch) 46 | 47 | 48 | class ProgrammaticPdb(Pdb, RunningTrace): 49 | pass 50 | 51 | 52 | debugger = ProgrammaticPdb() 53 | 54 | 55 | def pytest_configure(config): 56 | site_packages_dir = site.getsitepackages()[0] 57 | for bp in config.option.graphpatch_breakpoints or []: 58 | match = re.match( 59 | r"^https://github.com/.+?/blob/.+?/(.+?)#L(\d+),?(.+?)?$", 60 | bp, 61 | ) 62 | library_source = f"{site_packages_dir}/{match.group(1)}" 63 | line_number = int(match.group(2)) 64 | condition = match.group(3) 65 | debugger.set_break(library_source, line_number, cond=condition) 66 | print( 67 | f"Set breakpoint in {library_source}, line {line_number}" 68 | f" {f'cond={condition}' if condition else ''}" 69 | ) 70 | 71 | if config.option.graphpatch_breakpoints: 72 | debugger.set_running_trace() 73 | 74 | # Monkeypatch syrupy to not delete "unused" snapshots; we need to maintain different snapshots 75 | # depending on torch version, which won't get accessed on envs not using that version. 76 | class TorchVersionedReport(SnapshotReport): 77 | @property 78 | def unused(self): 79 | return SnapshotCollections() 80 | 81 | from syrupy import session 82 | 83 | session.SnapshotReport = TorchVersionedReport 84 | -------------------------------------------------------------------------------- /tests/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evan-lloyd/graphpatch/d1ecec2949ea622eb04a4a364ef08942d6a8025f/tests/fixtures/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/attribute_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import ones 3 | from torch.nn import Linear, Module 4 | 5 | from graphpatch import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class AttributeModule(Module): 9 | """Minimal reproduction of a model containing non-state attributes, to mimic an edge case we 10 | have to handle with LLamaRMSNorm's variance_epsilon. Annoyingly, torch.compile() likes to 11 | convert this kind of value into a constant, but only sometimes, for values in certain ranges. 12 | ints and values like 1e-3 get converted to constant, but 1e-4 and smaller get retained as 13 | attributes. 14 | """ 15 | 16 | _shape = (2, 3) 17 | 18 | def __init__(self, attribute_val=1e-6): 19 | super().__init__() 20 | self.linear = Linear(*AttributeModule._shape) 21 | self.attribute_to_serialize = attribute_val 22 | 23 | def forward(self, x): 24 | return self.linear(x + self.attribute_to_serialize) 25 | 26 | 27 | @pytest.fixture 28 | def attribute_module(): 29 | return AttributeModule() 30 | 31 | 32 | @pytest.fixture 33 | def attribute_module_inputs(): 34 | return ones(*AttributeModule._shape).t() 35 | 36 | 37 | @pytest.fixture 38 | def patchable_attribute_module(request, attribute_module, attribute_module_inputs): 39 | return PatchableGraph( 40 | attribute_module, 41 | ExtractionOptions( 42 | skip_compilation=getattr(request, "param", None) == "opaque", 43 | error_on_compilation_failure=True, 44 | ), 45 | attribute_module_inputs, 46 | ) 47 | -------------------------------------------------------------------------------- /tests/fixtures/buffer_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import ones 3 | from torch.nn import Linear, Module 4 | 5 | from graphpatch import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class BufferModule(Module): 9 | _shape = (2, 3) 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.linear = Linear(*BufferModule._shape) 14 | self.register_buffer("buffer", ones(*([BufferModule._shape[1]] * 2))) 15 | 16 | def forward(self, x): 17 | return self.linear(x) + self.buffer 18 | 19 | 20 | @pytest.fixture 21 | def buffer_module(): 22 | return BufferModule() 23 | 24 | 25 | @pytest.fixture 26 | def buffer_module_inputs(): 27 | return ones(*BufferModule._shape).t() 28 | 29 | 30 | @pytest.fixture 31 | def patchable_buffer_module(request, buffer_module, buffer_module_inputs): 32 | return PatchableGraph( 33 | buffer_module, 34 | ExtractionOptions( 35 | skip_compilation=getattr(request, "param", None) == "opaque", 36 | error_on_compilation_failure=True, 37 | ), 38 | buffer_module_inputs, 39 | ) 40 | -------------------------------------------------------------------------------- /tests/fixtures/container_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import ones 3 | from torch.nn import Linear, Module, ModuleDict, ModuleList, Sequential 4 | 5 | from graphpatch import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class ContainerModule(Module): 9 | _shape = (3, 3) 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.linear = Linear(*ContainerModule._shape) 14 | self.duped_linear = self.linear 15 | chained_linear = Linear(*ContainerModule._shape) 16 | chained_sequential = Sequential(self.duped_linear, self.linear) 17 | self.sequential = Sequential( 18 | chained_linear, self.linear, chained_sequential, chained_linear 19 | ) 20 | self.module_list = ModuleList( 21 | [Linear(*ContainerModule._shape), Linear(*ContainerModule._shape)] 22 | ) 23 | self.module_dict = ModuleDict( 24 | { 25 | "foo": self.linear, 26 | "bar": ModuleDict({"baz": ModuleList([self.linear, self.linear, self.sequential])}), 27 | } 28 | ) 29 | 30 | def forward(self, x): 31 | y = self.duped_linear(x) - self.module_dict["foo"](x) 32 | for i in range(3): 33 | y += self.module_dict["bar"]["baz"][i](x) * self.linear(x) 34 | y += self.module_list[0](x) 35 | y += self.module_list[1](x) 36 | return self.sequential(y) + self.linear(y) + self.module_list[0](y) 37 | 38 | 39 | @pytest.fixture 40 | def container_module(): 41 | return ContainerModule() 42 | 43 | 44 | @pytest.fixture 45 | def container_module_inputs(): 46 | return ones(*ContainerModule._shape).t() 47 | 48 | 49 | @pytest.fixture 50 | def patchable_container_module(request, container_module, container_module_inputs): 51 | return PatchableGraph( 52 | container_module, 53 | ExtractionOptions( 54 | skip_compilation=getattr(request, "param", None) == "opaque", 55 | error_on_compilation_failure=True, 56 | ), 57 | container_module_inputs, 58 | ) 59 | -------------------------------------------------------------------------------- /tests/fixtures/deeply_nested_output_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.nn import Linear, Module, ModuleList 4 | 5 | from graphpatch import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class C(Module): 9 | def __init__(self): 10 | super().__init__() 11 | self.c_linear = Linear(10, 10) 12 | 13 | def forward(self, c_inputs, inputs_2=None, inputs_3=None, inputs_4=None): 14 | return [ 15 | self.c_linear(c_inputs), 16 | (self.c_linear(c_inputs + inputs_2), self.c_linear(inputs_3 + inputs_4)), 17 | ] 18 | 19 | 20 | class B(Module): 21 | def __init__(self): 22 | super().__init__() 23 | self.b_linear = Linear(10, 10) 24 | self.c = C() 25 | 26 | def forward(self, b_inputs, a_inputs): 27 | b_outputs = self.c( 28 | b_inputs, 29 | inputs_3=torch.ones_like(b_inputs), 30 | inputs_2=b_inputs * 2, 31 | inputs_4=a_inputs, 32 | ) 33 | b_outputs[0] = self.b_linear(b_outputs[1][0]) 34 | return (((((b_outputs[1][0],),),),),) 35 | 36 | 37 | class A(Module): 38 | def __init__(self): 39 | super().__init__() 40 | self.a_linear = Linear(10, 10) 41 | self.grandchildren_b = ModuleList([B() for _ in range(3)]) 42 | 43 | def forward(self, a_inputs): 44 | b_outputs = a_inputs 45 | output_dict = [] 46 | for i, b in enumerate(self.grandchildren_b): 47 | b_outputs = b(b_outputs, a_inputs)[0][0][0][0][0] 48 | output_dict.append(([i * a_inputs], b_outputs.clone())) 49 | b_outputs = [self.a_linear(b_outputs)] 50 | return (b_outputs, output_dict) # {k: v for k, v in output_dict}) 51 | 52 | 53 | class DeeplyNestedOutputModule(Module): 54 | """Nonsensical module with arbitrary nesting of intermediate outputs. Utility lies in testing 55 | that our graph extraction maintains the original signature, since otherwise we would fail to 56 | run forward() at all. 57 | """ 58 | 59 | _shape = (10, 10) 60 | 61 | def __init__(self): 62 | super().__init__() 63 | self.child_a = A() 64 | self.linear = Linear(*DeeplyNestedOutputModule._shape) 65 | 66 | def forward(self, x): 67 | a_outputs, a_dict = self.child_a(x) 68 | a_outputs = self.linear(a_outputs[0]) 69 | return ((a_outputs,), a_dict, {"nested_dict": [(self.linear(a_outputs + 2),)]}) 70 | 71 | 72 | @pytest.fixture 73 | def deeply_nested_output_module(): 74 | return DeeplyNestedOutputModule() 75 | 76 | 77 | @pytest.fixture 78 | def deeply_nested_output_module_inputs(): 79 | return torch.ones((10, 10)) 80 | 81 | 82 | @pytest.fixture 83 | def patchable_deeply_nested_output_module( 84 | request, deeply_nested_output_module, deeply_nested_output_module_inputs 85 | ): 86 | return PatchableGraph( 87 | deeply_nested_output_module, 88 | ExtractionOptions( 89 | skip_compilation=getattr(request, "param", None) == "opaque", 90 | error_on_compilation_failure=True, 91 | ), 92 | deeply_nested_output_module_inputs, 93 | ) 94 | -------------------------------------------------------------------------------- /tests/fixtures/fixture_collections.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | # TODO: add metadata to the request defs somehow to avoid needing to hardcode this? 5 | REQUIRES_GPU = ( 6 | "patchable_mixed_cpu_pretrained_module", 7 | "patchable_accelerate_pretrained_module", 8 | "patchable_quantized_pretrained_module", 9 | ) 10 | REQUIRES_BITSANDBYTES = ("patchable_quantized_pretrained_module",) 11 | REQUIRES_TRANSFORMERS = ( 12 | "patchable_pretrained_module", 13 | "patchable_mixed_cpu_pretrained_module", 14 | "patchable_disk_offload_pretrained_module", 15 | "patchable_accelerate_pretrained_module", 16 | "patchable_quantized_pretrained_module", 17 | ) 18 | REQUIRES_ACCELERATE = ( 19 | "patchable_mixed_cpu_pretrained_module", 20 | "patchable_disk_offload_pretrained_module", 21 | "patchable_accelerate_pretrained_module", 22 | "patchable_quantized_pretrained_module", 23 | ) 24 | # We don't yet have transformer_lens fixtures. 25 | REQUIRES_TRANSFORMER_LENS = ("",) 26 | 27 | 28 | def _filter_by_test_env(name): 29 | from graphpatch.optional import ( 30 | accelerate, 31 | bitsandbytes, 32 | transformer_lens, 33 | transformers, 34 | ) 35 | 36 | has_gpu = torch.cuda.device_count() >= 1 37 | return all( 38 | [ 39 | name not in REQUIRES_GPU or has_gpu, 40 | name not in REQUIRES_BITSANDBYTES or bitsandbytes.AVAILABLE, 41 | name not in REQUIRES_TRANSFORMERS or transformers.AVAILABLE, 42 | name not in REQUIRES_ACCELERATE or accelerate.AVAILABLE, 43 | name not in REQUIRES_TRANSFORMER_LENS or transformer_lens.AVAILABLE, 44 | ] 45 | ) 46 | 47 | 48 | @pytest.fixture 49 | def all_patchable_graphs(request): 50 | session = request.session 51 | fixture_defs = session._fixturemanager._arg2fixturedefs 52 | return { 53 | k: v[0].execute(request) 54 | for k, v in fixture_defs.items() 55 | if k.startswith("patchable") and _filter_by_test_env(k) 56 | } 57 | -------------------------------------------------------------------------------- /tests/fixtures/graph_break_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import ones 3 | from torch._dynamo import graph_break 4 | from torch.nn import Linear, Module 5 | 6 | from graphpatch import ExtractionOptions, PatchableGraph 7 | 8 | 9 | class GraphBreakModule(Module): 10 | _shape = (3, 3) 11 | bar = 5 12 | shadowed_class_var = 1 13 | 14 | def __init__(self): 15 | super().__init__() 16 | self.linear = Linear(*GraphBreakModule._shape) 17 | self.instance_value = 7 18 | self.shadowed_class_var = 8 19 | self.unused_submodule = Linear(*GraphBreakModule._shape) 20 | 21 | def member_function(self, n): 22 | return ones(self._shape) - n 23 | 24 | def forward(self, x, foo=3): 25 | x = self.linear(x) 26 | y = self.linear(x) 27 | z = self.linear(y) 28 | 29 | graph_break() 30 | 31 | y = x + y + z + self.bar + self.instance_value + self.shadowed_class_var 32 | 33 | graph_break() 34 | 35 | return y + 5 * self.member_function(foo) 36 | 37 | 38 | @pytest.fixture 39 | def graph_break_module(): 40 | return GraphBreakModule() 41 | 42 | 43 | @pytest.fixture 44 | def graph_break_module_inputs(): 45 | return ones(*GraphBreakModule._shape).t() 46 | 47 | 48 | @pytest.fixture 49 | def patchable_graph_break_module(request, graph_break_module, graph_break_module_inputs): 50 | return PatchableGraph( 51 | graph_break_module, 52 | ExtractionOptions( 53 | skip_compilation=getattr(request, "param", None) == "opaque", 54 | error_on_compilation_failure=False, 55 | ), 56 | graph_break_module_inputs, 57 | ) 58 | -------------------------------------------------------------------------------- /tests/fixtures/layer_norm_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import ones 3 | from torch.nn import LayerNorm, Module 4 | 5 | from graphpatch import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class LayerNormModule(Module): 9 | _shape = (2, 3) 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.ln = LayerNorm(2) 14 | 15 | def forward(self, x): 16 | return self.ln(x) 17 | 18 | 19 | @pytest.fixture 20 | def layer_norm_module(): 21 | return LayerNormModule() 22 | 23 | 24 | @pytest.fixture 25 | def layer_norm_module_inputs(): 26 | return ones(*LayerNormModule._shape).t() 27 | 28 | 29 | @pytest.fixture 30 | def patchable_layer_norm_module(request, layer_norm_module, layer_norm_module_inputs): 31 | return PatchableGraph( 32 | layer_norm_module, 33 | ExtractionOptions( 34 | skip_compilation=getattr(request, "param", None) == "opaque", 35 | error_on_compilation_failure=True, 36 | ), 37 | layer_norm_module_inputs, 38 | ) 39 | -------------------------------------------------------------------------------- /tests/fixtures/llama_tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evan-lloyd/graphpatch/d1ecec2949ea622eb04a4a364ef08942d6a8025f/tests/fixtures/llama_tokenizer.model -------------------------------------------------------------------------------- /tests/fixtures/minimal_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import ones 3 | from torch.nn import Linear, Module 4 | 5 | from graphpatch import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class MinimalModule(Module): 9 | _shape = (2, 3) 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.linear = Linear(*MinimalModule._shape) 14 | 15 | def forward(self, x): 16 | return self.linear(x) 17 | 18 | 19 | @pytest.fixture 20 | def minimal_module(): 21 | return MinimalModule() 22 | 23 | 24 | @pytest.fixture 25 | def minimal_module_inputs(): 26 | return ones(*MinimalModule._shape).t() 27 | 28 | 29 | @pytest.fixture 30 | def patchable_minimal_module(request, minimal_module, minimal_module_inputs): 31 | return PatchableGraph( 32 | minimal_module, 33 | ExtractionOptions( 34 | skip_compilation=getattr(request, "param", None) == "opaque", 35 | error_on_compilation_failure=True, 36 | ), 37 | minimal_module_inputs, 38 | ) 39 | -------------------------------------------------------------------------------- /tests/fixtures/nested_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.nn import Linear, Module, ModuleList 4 | 5 | from graphpatch.patchable_graph import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class C(Module): 9 | def __init__(self): 10 | super().__init__() 11 | self.c_linear = Linear(100, 100) 12 | 13 | def forward(self, c_inputs, inputs_2=None, inputs_3=None, inputs_4=None): 14 | return self.c_linear(c_inputs + inputs_2 + inputs_3 + inputs_4) 15 | 16 | 17 | class B(Module): 18 | def __init__(self): 19 | super().__init__() 20 | self.b_linear = Linear(100, 100) 21 | self.c = C() 22 | 23 | def forward(self, b_inputs, a_inputs): 24 | b_outputs = self.c( 25 | b_inputs, 26 | inputs_3=torch.ones_like(b_inputs), 27 | inputs_2=b_inputs * 2, 28 | inputs_4=a_inputs, 29 | ) 30 | b_outputs = self.b_linear(b_outputs) 31 | return b_outputs 32 | 33 | 34 | class A(Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.a_linear = Linear(100, 100) 38 | self.grandchildren_b = ModuleList([B() for _ in range(3)]) 39 | 40 | def forward(self, a_inputs): 41 | b_outputs = a_inputs 42 | for b in self.grandchildren_b: 43 | b_outputs = b(b_outputs, a_inputs) 44 | b_outputs = self.a_linear(b_outputs) 45 | return b_outputs 46 | 47 | 48 | class NestedModule(Module): 49 | def __init__(self): 50 | super().__init__() 51 | self.root_linear = Linear(100, 100) 52 | self.child_a = A() 53 | 54 | def forward(self, root_inputs): 55 | a_outputs = self.child_a(root_inputs) 56 | a_outputs = self.root_linear(a_outputs) 57 | return a_outputs 58 | 59 | 60 | @pytest.fixture 61 | def nested_module(): 62 | return NestedModule() 63 | 64 | 65 | @pytest.fixture 66 | def nested_module_inputs(): 67 | return torch.ones(1, 100) 68 | 69 | 70 | @pytest.fixture 71 | def patchable_nested_module(request, nested_module, nested_module_inputs): 72 | return PatchableGraph( 73 | nested_module, 74 | ExtractionOptions( 75 | skip_compilation=getattr(request, "param", None) == "opaque", 76 | error_on_compilation_failure=True, 77 | ), 78 | nested_module_inputs, 79 | ) 80 | -------------------------------------------------------------------------------- /tests/fixtures/pretrained/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evan-lloyd/graphpatch/d1ecec2949ea622eb04a4a364ef08942d6a8025f/tests/fixtures/pretrained/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/pretrained/test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytest import fixture 3 | 4 | from graphpatch.optional.transformers import ( 5 | AutoConfig, 6 | AutoModel, 7 | AutoTokenizer, 8 | CausalLMOutput, 9 | PreTrainedModel, 10 | ) 11 | 12 | from ..nested_module import NestedModule 13 | from .test_model_config import TestModelConfig 14 | from .test_model_tokenizer import DummyTokenizer 15 | 16 | 17 | class TestModel(PreTrainedModel): 18 | config_class = TestModelConfig 19 | _no_split_modules = [] 20 | 21 | def __init__(self, config): 22 | super().__init__(config) 23 | self.model = NestedModule() 24 | 25 | def forward(self, input_ids, **kwargs): 26 | # Make a fake "embedding" of the input ids 27 | embedding = ( 28 | input_ids[:, :100] 29 | .view((input_ids.shape[0], 1, 100)) 30 | .repeat((1, 100, 1)) 31 | .to(self.config.torch_dtype) 32 | ) 33 | logits = self.model(embedding) 34 | return CausalLMOutput(logits=logits) 35 | 36 | def prepare_inputs_for_generation(self, input_ids, **kwargs): 37 | return {"input_ids": input_ids} 38 | 39 | 40 | @fixture(scope="session") 41 | def pretrained_module_path(tmp_path_factory): 42 | # TODO: pyfakefs might offer speedups, if we can get it working with Safetensors' rust 43 | # implementation (or easily swap for real fs for tests that need it) 44 | config = TestModelConfig() 45 | model = TestModel(config) 46 | 47 | # Constrain weight magnitude so quantization tests are less flaky 48 | def init_weights(module): 49 | if isinstance(module, torch.nn.Linear): 50 | torch.nn.init.uniform_(module.weight, -0.01, 0.01) 51 | torch.nn.init.uniform_(module.bias, -0.01, 0.01) 52 | 53 | model.apply(init_weights) 54 | save_path = tmp_path_factory.mktemp("models") / "test_model" 55 | model.save_pretrained(save_path) 56 | AutoConfig.register("test_model", TestModelConfig) 57 | AutoModel.register(TestModelConfig, TestModel) 58 | AutoTokenizer.register(TestModelConfig, DummyTokenizer) 59 | open(save_path / "dummy.model", "w").write("dummy") 60 | return save_path 61 | -------------------------------------------------------------------------------- /tests/fixtures/pretrained/test_model_config.py: -------------------------------------------------------------------------------- 1 | from graphpatch.optional.transformers import PretrainedConfig 2 | 3 | 4 | class TestModelConfig(PretrainedConfig): 5 | model_type = "test_model" 6 | 7 | def __init__(self, **kwargs): 8 | kwargs["max_length"] = 110 9 | super().__init__(**kwargs) 10 | -------------------------------------------------------------------------------- /tests/fixtures/pretrained/test_model_tokenizer.py: -------------------------------------------------------------------------------- 1 | from collections import UserDict 2 | from typing import Any 3 | 4 | import torch 5 | 6 | from graphpatch.optional.transformers import PreTrainedTokenizer 7 | 8 | 9 | class DummyTokens(UserDict): 10 | def to(self, *args, **kwargs): 11 | return self 12 | 13 | def __getattr__(self, __name: str) -> Any: 14 | return self.data[__name] 15 | 16 | 17 | class DummyVocab: 18 | def __getitem__(self, word): 19 | return (sum(ord(c) for c in word) % 100) + 1 20 | 21 | def __len__(self): 22 | return 100 23 | 24 | def copy(self): 25 | return self 26 | 27 | 28 | class DummyTokenizer(PreTrainedTokenizer): 29 | """ 30 | Fake tokenizer that will get a consistent "tokenization" for a given text string, for 31 | reproducibility in tests. 32 | """ 33 | 34 | vocab_files_names = {"vocab_file": "dummy.model"} 35 | pretrained_vocab_files_map = {} 36 | 37 | def __init__(self, *args, **kwargs): 38 | self.vocab = DummyVocab() 39 | super().__init__(*args, **kwargs) 40 | 41 | def _tokenize(self, prompt, pad_to=None): 42 | # Deterministically map each input word to one of 100 "tokens". Start with a dummy 43 | # "begin of sequence" token. 44 | ordsums = [-1] + [self.vocab[word] for word in prompt.split(" ")] 45 | if pad_to is not None: 46 | ordsums.extend([0] * (pad_to - len(ordsums))) 47 | 48 | return torch.Tensor(ordsums).to(torch.int64) 49 | 50 | def _convert_token_to_id(self, token): 51 | return token 52 | 53 | def __call__(self, prompt, *args, **kwargs): 54 | if isinstance(prompt, list): 55 | batch_size = len(prompt) 56 | input_ids = [self._tokenize(p, pad_to=100) for p in prompt] 57 | else: 58 | batch_size = 1 59 | input_ids = [self._tokenize(prompt, pad_to=100)] 60 | stacked_inputs = torch.vstack(input_ids).view((batch_size, 100)).to(torch.int64) 61 | stacked_attention = (stacked_inputs > 0) * 1.0 62 | return DummyTokens(input_ids=stacked_inputs, attention_mask=stacked_attention) 63 | 64 | def get_vocab(self): 65 | return self.vocab 66 | 67 | def convert_ids_to_tokens(self, t): 68 | return str(t) 69 | -------------------------------------------------------------------------------- /tests/fixtures/protected_name_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import ones 3 | from torch.nn import Linear, Module 4 | 5 | from graphpatch import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class ProtectedNameModule(Module): 9 | _shape = (2, 3) 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self._code = Linear(*ProtectedNameModule._shape) 14 | self.bar = (5.0, (6.0, (7.09))) 15 | 16 | def forward(self, _shape, _code=5, sub_shape=7): 17 | return self._code(_shape + _code + self._shape[0] + self.bar[1][1] + sub_shape) 18 | 19 | 20 | @pytest.fixture 21 | def protected_name_module(): 22 | return ProtectedNameModule() 23 | 24 | 25 | @pytest.fixture 26 | def protected_name_module_inputs(): 27 | return ones(*ProtectedNameModule._shape).t() 28 | 29 | 30 | @pytest.fixture 31 | def patchable_protected_name_module(request, protected_name_module, protected_name_module_inputs): 32 | return PatchableGraph( 33 | protected_name_module, 34 | ExtractionOptions( 35 | skip_compilation=getattr(request, "param", None) == "opaque", 36 | error_on_compilation_failure=True, 37 | ), 38 | protected_name_module_inputs, 39 | ) 40 | -------------------------------------------------------------------------------- /tests/fixtures/quantized_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.nn import Linear 4 | 5 | from graphpatch import ExtractionOptions, PatchableGraph 6 | from graphpatch.optional.bitsandbytes import AVAILABLE as BNB_AVAILABLE, Linear8bitLt 7 | 8 | if BNB_AVAILABLE: 9 | 10 | class QuantizedModule(torch.nn.Module): 11 | _shape = (2, 3) 12 | 13 | def __init__(self): 14 | super().__init__() 15 | object.__setattr__(self, "original_linear", Linear(*QuantizedModule._shape)) 16 | self.linear = Linear8bitLt( 17 | *QuantizedModule._shape, has_fp16_weights=False, threshold=6.0 18 | ) 19 | self.linear.weight.data = self.original_linear.weight.data 20 | self.linear.bias.data = self.original_linear.bias.data 21 | self.linear.cuda() 22 | 23 | def forward(self, x): 24 | return self.linear(x) 25 | 26 | @pytest.fixture 27 | def quantized_module(): 28 | return QuantizedModule() 29 | 30 | @pytest.fixture 31 | def quantized_module_inputs(): 32 | return torch.ones(*QuantizedModule._shape, device="cuda", dtype=torch.float16).t() 33 | 34 | @pytest.fixture 35 | def patchable_quantized_module(request, quantized_module, quantized_module_inputs): 36 | return PatchableGraph( 37 | quantized_module, 38 | ExtractionOptions( 39 | skip_compilation=getattr(request, "param", None) == "opaque", 40 | error_on_compilation_failure=True, 41 | ), 42 | quantized_module_inputs, 43 | ) 44 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_gpt2.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | import torch 5 | 6 | from graphpatch.optional.transformers import AutoConfig, AutoTokenizer 7 | 8 | 9 | @pytest.fixture(scope="session") 10 | def tiny_gpt2_path(tmp_path_factory): 11 | save_path = tmp_path_factory.mktemp("tiny_llama") 12 | config = { 13 | "activation_function": "gelu_new", 14 | "architectures": ["GPT2LMHeadModel"], 15 | "attn_pdrop": 0.0, 16 | "bos_token_id": 50256, 17 | "embd_pdrop": 0.0, 18 | "eos_token_id": 50256, 19 | "initializer_range": 0.02, 20 | "layer_norm_epsilon": 1e-05, 21 | "model_type": "gpt2", 22 | "n_embd": 2, 23 | "n_head": 2, 24 | "n_layer": 1, 25 | "n_positions": 1024, 26 | "output_past": True, 27 | "resid_pdrop": 0.0, 28 | "summary_activation": None, 29 | "summary_first_dropout": 0.0, 30 | "summary_proj_to_labels": True, 31 | "summary_type": "cls_index", 32 | "summary_use_proj": True, 33 | "task_specific_params": {"text-generation": {"do_sample": True, "max_length": 50}}, 34 | "vocab_size": 50257, 35 | "pad_token": 50256, 36 | "padding_size": "right", 37 | "truncation_side": "left", 38 | "add_prefix_space": True, 39 | "add_bos_token": True, 40 | } 41 | json.dump(config, open(save_path / "config.json", "w")) 42 | with open("./tests/fixtures/gpt2_tokenizer.json", "rb") as in_file, open( 43 | save_path / "tokenizer.json", "wb" 44 | ) as out_file: 45 | out_file.write(in_file.read()) 46 | with open("./tests/fixtures/gpt2_merges.txt", "rb") as in_file, open( 47 | save_path / "merges.txt", "wb" 48 | ) as out_file: 49 | out_file.write(in_file.read()) 50 | with open("./tests/fixtures/gpt2_vocab.json", "rb") as in_file, open( 51 | save_path / "vocab.json", "wb" 52 | ) as out_file: 53 | out_file.write(in_file.read()) 54 | return save_path 55 | 56 | 57 | @pytest.fixture 58 | def tiny_gpt2_config(tiny_gpt2_path): 59 | return AutoConfig.from_pretrained(tiny_gpt2_path, dtype=torch.float16) 60 | 61 | 62 | @pytest.fixture 63 | def tiny_gpt2_tokenizer(tiny_gpt2_path): 64 | return AutoTokenizer.from_pretrained(tiny_gpt2_path, local_files_only=True, use_fast=False) 65 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_llama.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | import torch 5 | 6 | from graphpatch.optional.transformers import AutoConfig, AutoTokenizer 7 | 8 | 9 | @pytest.fixture(scope="session") 10 | def tiny_llama_path(tmp_path_factory): 11 | save_path = tmp_path_factory.mktemp("tiny_llama") 12 | config = { 13 | "architectures": ["LlamaForCausalLM"], 14 | "bos_token_id": 0, 15 | "eos_token_id": 1, 16 | "hidden_act": "silu", 17 | "hidden_size": 20, 18 | "intermediate_size": 2, 19 | "initializer_range": 0.02, 20 | "max_sequence_length": 2048, 21 | "model_type": "llama", 22 | "num_attention_heads": 2, 23 | "num_hidden_layers": 1, 24 | "pad_token_id": -1, 25 | "rms_norm_eps": 1e-6, 26 | "torch_dtype": "float16", 27 | "transformers_version": "4.27.0.dev0", 28 | "use_cache": True, 29 | "vocab_size": 32000, 30 | "num_key_value_heads": 2, 31 | "_attn_implementation": "eager", 32 | } 33 | json.dump(config, open(save_path / "config.json", "w")) 34 | tokenizer_config = { 35 | "bos_token": "", 36 | "eos_token": "", 37 | "model_max_length": 2048, 38 | "tokenizer_class": "LlamaTokenizer", 39 | "unk_token": "", 40 | "pad_token": "", 41 | } 42 | json.dump(tokenizer_config, open(save_path / "tokenizer_config.json", "w")) 43 | with open("./tests/fixtures/llama_tokenizer.model", "rb") as in_file, open( 44 | save_path / "tokenizer.model", "wb" 45 | ) as out_file: 46 | out_file.write(in_file.read()) 47 | return save_path 48 | 49 | 50 | @pytest.fixture 51 | def tiny_llama_config(tiny_llama_path): 52 | return AutoConfig.from_pretrained(tiny_llama_path, dtype=torch.float16) 53 | 54 | 55 | @pytest.fixture 56 | def tiny_llama_tokenizer(tiny_llama_path): 57 | return AutoTokenizer.from_pretrained(tiny_llama_path, local_files_only=True, use_fast=False) 58 | -------------------------------------------------------------------------------- /tests/fixtures/tuple_output_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import ones 3 | from torch.nn import Linear, Module 4 | 5 | from graphpatch import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class TupleOutputModule(Module): 9 | _shape = (2, 3) 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.linear = Linear(*TupleOutputModule._shape) 14 | 15 | def forward(self, x): 16 | return (self.linear(x), self.linear(x + 1)) 17 | 18 | 19 | @pytest.fixture 20 | def tuple_output_module(): 21 | return TupleOutputModule() 22 | 23 | 24 | @pytest.fixture 25 | def tuple_output_module_inputs(): 26 | return ones(*TupleOutputModule._shape).t() 27 | 28 | 29 | @pytest.fixture 30 | def patchable_tuple_output_module(request, tuple_output_module, tuple_output_module_inputs): 31 | return PatchableGraph( 32 | tuple_output_module, 33 | ExtractionOptions( 34 | skip_compilation=getattr(request, "param", None) == "opaque", 35 | error_on_compilation_failure=True, 36 | ), 37 | tuple_output_module_inputs, 38 | ) 39 | -------------------------------------------------------------------------------- /tests/fixtures/unused_submodule_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.nn import Linear, Module, ModuleList 4 | 5 | from graphpatch.patchable_graph import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class C(Module): 9 | def __init__(self): 10 | super().__init__() 11 | self.c_linear = Linear(100, 100) 12 | self.c_unused = Linear(1, 1) 13 | 14 | def forward(self, c_inputs, inputs_2=None, inputs_3=None, inputs_4=None): 15 | return self.c_linear(c_inputs + inputs_2 + inputs_3 + inputs_4) 16 | 17 | 18 | class B(Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.b_linear = Linear(100, 100) 22 | self.c = C() 23 | 24 | def forward(self, b_inputs, a_inputs): 25 | b_outputs = self.c( 26 | b_inputs, 27 | inputs_3=torch.ones_like(b_inputs), 28 | inputs_2=b_inputs * 2, 29 | inputs_4=a_inputs, 30 | ) 31 | b_outputs = self.b_linear(b_outputs) 32 | return b_outputs 33 | 34 | 35 | class A(Module): 36 | def __init__(self): 37 | super().__init__() 38 | self.a_linear = Linear(100, 100) 39 | self.grandchildren_b = ModuleList([B() for _ in range(4)]) 40 | 41 | def forward(self, a_inputs): 42 | b_outputs = a_inputs 43 | for i, b in enumerate(self.grandchildren_b): 44 | if i == 2: 45 | continue 46 | b_outputs = b(b_outputs, a_inputs) 47 | b_outputs = self.a_linear(b_outputs) 48 | return b_outputs 49 | 50 | 51 | class UnusedSubmoduleModule(Module): 52 | def __init__(self): 53 | super().__init__() 54 | self.root_linear = Linear(100, 100) 55 | self.child_a = A() 56 | 57 | def forward(self, root_inputs): 58 | a_outputs = self.child_a(root_inputs) 59 | a_outputs = self.root_linear(a_outputs) 60 | return a_outputs 61 | 62 | 63 | @pytest.fixture 64 | def unused_submodule_module(): 65 | return UnusedSubmoduleModule() 66 | 67 | 68 | @pytest.fixture 69 | def unused_submodule_module_inputs(): 70 | return torch.ones(1, 100) 71 | 72 | 73 | @pytest.fixture 74 | def patchable_unused_submodule_module( 75 | request, unused_submodule_module, unused_submodule_module_inputs 76 | ): 77 | return PatchableGraph( 78 | unused_submodule_module, 79 | ExtractionOptions( 80 | skip_compilation=getattr(request, "param", None) == "opaque", 81 | error_on_compilation_failure=True, 82 | allow_unused_submodules=True, 83 | ), 84 | unused_submodule_module_inputs, 85 | ) 86 | -------------------------------------------------------------------------------- /tests/fixtures/varargs_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import ones 3 | from torch.nn import Linear, Module 4 | 5 | from graphpatch import ExtractionOptions, PatchableGraph 6 | 7 | 8 | class VarargsModule(Module): 9 | _shape = (3, 3) 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.linear = Linear(*VarargsModule._shape) 14 | 15 | def forward(self, x, *foos, blah=3, **bars): 16 | result = self.linear(x) 17 | for f in foos: 18 | result += f 19 | result = self.linear(result + blah) 20 | for v in bars.values(): 21 | result += v 22 | return result 23 | 24 | 25 | @pytest.fixture 26 | def varargs_module(): 27 | return VarargsModule() 28 | 29 | 30 | @pytest.fixture 31 | def varargs_module_inputs(): 32 | return ones(*VarargsModule._shape).t() 33 | 34 | 35 | @pytest.fixture 36 | def varargs_module_varargs(varargs_module_inputs): 37 | return ( 38 | varargs_module_inputs.clone(), 39 | varargs_module_inputs.clone(), 40 | varargs_module_inputs.clone(), 41 | ) 42 | 43 | 44 | @pytest.fixture 45 | def varargs_module_varkwargs(varargs_module_inputs): 46 | return {"a": varargs_module_inputs.clone(), "b": varargs_module_inputs.clone()} 47 | 48 | 49 | @pytest.fixture 50 | def patchable_varargs_module( 51 | request, varargs_module, varargs_module_inputs, varargs_module_varargs, varargs_module_varkwargs 52 | ): 53 | return PatchableGraph( 54 | varargs_module, 55 | ExtractionOptions( 56 | skip_compilation=getattr(request, "param", None) == "opaque", 57 | error_on_compilation_failure=True, 58 | ), 59 | varargs_module_inputs, 60 | *varargs_module_varargs, 61 | **varargs_module_varkwargs, 62 | ) 63 | -------------------------------------------------------------------------------- /tests/test_quantization.py: -------------------------------------------------------------------------------- 1 | from graphpatch.extraction.wrapped_8_bit_linear import Wrapped8BitLinear 2 | 3 | from .util import assert_outputs_identical, requires_bitsandbytes, requires_gpu 4 | 5 | 6 | @requires_gpu 7 | @requires_bitsandbytes 8 | def test_quantization_wrapper(quantized_module, quantized_module_inputs): 9 | wrapped = Wrapped8BitLinear(quantized_module.linear) 10 | assert_outputs_identical(quantized_module, wrapped, quantized_module_inputs, tolerance=0.001) 11 | -------------------------------------------------------------------------------- /tests/test_rome.py: -------------------------------------------------------------------------------- 1 | from demos.ROME.rome import RomePatch, generate_key_vector, generate_value_vector 2 | 3 | from .util import requires_transformers 4 | 5 | 6 | @requires_transformers 7 | def test_rome(patchable_pretrained_module, pretrained_tokenizer, pretrained_module_inputs): 8 | key_vector = generate_key_vector( 9 | patchable_pretrained_module, 10 | pretrained_tokenizer, 11 | "model.root_linear.linear", 12 | "foo", 13 | "is in", 14 | "bar", 15 | ) 16 | value_vector = generate_value_vector( 17 | patchable_pretrained_module, 18 | pretrained_tokenizer, 19 | "model.root_linear.linear", 20 | "model.child_a.a_linear.linear", 21 | "foo", 22 | "is in", 23 | "bar", 24 | key_vector, 25 | ) 26 | with patchable_pretrained_module.patch( 27 | {"model.root_linear.weight": [RomePatch(key_vector=key_vector, value_vector=value_vector)]} 28 | ): 29 | patchable_pretrained_module(pretrained_module_inputs) 30 | -------------------------------------------------------------------------------- /tests/test_transformer_lens.py: -------------------------------------------------------------------------------- 1 | from graphpatch import PatchableGraph 2 | from graphpatch.optional.transformer_lens import ( 3 | HookedTransformer, 4 | loading_from_pretrained, 5 | ) 6 | 7 | from .util import assert_results_identical, requires_transformer_lens 8 | 9 | # Stub for now; better integration is TODO. 10 | 11 | 12 | def _convert_hf_model_config(config): 13 | # Monkeypatching convert_hf_model_config, since that unavoidably calls out to HFHub. 14 | # https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/loading_from_pretrained.py#L845-L859 15 | return { 16 | "d_model": config.n_embd, 17 | "d_head": config.n_embd // config.n_head, 18 | "n_heads": config.n_head, 19 | "d_mlp": config.n_embd * 4, 20 | "n_layers": config.n_layer, 21 | "n_ctx": config.n_embd, 22 | "eps": config.layer_norm_epsilon, 23 | "d_vocab": config.vocab_size, 24 | "act_fn": config.activation_function, 25 | "use_attn_scale": True, 26 | "use_local_attn": False, 27 | "scale_attn_by_inverse_layer_idx": config.scale_attn_by_inverse_layer_idx, 28 | "normalization_type": "LN", 29 | "original_architecture": config.architectures[0], 30 | "tokenizer_name": "gpt2-small", 31 | } 32 | 33 | 34 | @requires_transformer_lens 35 | def test_transformer_lens(tiny_gpt2_config, tiny_gpt2_tokenizer, mocker): 36 | mocker.patch.object( 37 | loading_from_pretrained, 38 | "convert_hf_model_config", 39 | lambda *args, **kwargs: _convert_hf_model_config(tiny_gpt2_config), 40 | ) 41 | config = loading_from_pretrained.get_pretrained_model_config( 42 | "gpt2-small", 43 | hf_cfg=tiny_gpt2_config.to_dict(), 44 | device="cpu", 45 | ) 46 | model = HookedTransformer(config, tiny_gpt2_tokenizer, move_to_device=False) 47 | model._init_weights_gpt2() 48 | pg = PatchableGraph(model, "foo") 49 | assert_results_identical( 50 | model, pg._graph_module, ["hello transformer_lens", "and yes it also handles batch"] 51 | ) 52 | -------------------------------------------------------------------------------- /tests/test_validate_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import pytest 5 | import torch 6 | 7 | """Due to having a single lock-file controlling the dependencies that get set up for each test, 8 | a misconfiguration could cause us to end up with an environment different from the one we were 9 | expecting. We should fail the test suite in that case, since we want to guarantee that we are 10 | truly testing every combination of factors.""" 11 | 12 | 13 | def test_validate_tox_factors(): 14 | # From tox factor, formatted as torchXX 15 | torch_from_env = os.environ["GP_TOX_FACTOR_TORCH"] 16 | assert torch_from_env[5:] == "".join(torch.__version__.split(".")[:2]) 17 | 18 | # From tox factor, formatted as pyXX 19 | py_from_env = os.environ["GP_TOX_FACTOR_PY"] 20 | assert py_from_env[2:] == "".join(map(str, sys.version_info[:2])) 21 | 22 | if os.environ["GP_TOX_FACTOR_EXTRA"] == "extranone": 23 | with pytest.raises(ImportError): 24 | import transformers 25 | elif os.environ["GP_TOX_FACTOR_EXTRA"] == "extraall": 26 | import transformers 27 | 28 | assert transformers is not None 29 | else: 30 | assert False, "Invalid tox factor for extra" 31 | --------------------------------------------------------------------------------