├── requirements.txt
├── images
├── logo.png
├── alexnet.png
├── gradients.png
├── cheatsheet.pdf
├── swin_v2_b_demo.jpg
├── simple_recurrent.png
└── nested_modules_example.png
├── tests
└── test_metadata.py
├── requirements.test.txt
├── torchlens
├── __init__.py
├── cleanup.py
├── model_history.py
├── trace_model.py
├── decorate_torch.py
├── constants.py
├── interface.py
├── model_funcs.py
├── user_funcs.py
├── helper_funcs.py
├── tensor_log.py
└── validation.py
├── .pre-commit-config.yaml
├── .github
└── workflows
│ └── lint.yml
├── setup.py
├── .gitignore
└── README.md
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pandas
3 | tqdm
4 | ipython
5 | graphviz
6 | torch
7 |
--------------------------------------------------------------------------------
/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/logo.png
--------------------------------------------------------------------------------
/images/alexnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/alexnet.png
--------------------------------------------------------------------------------
/images/gradients.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/gradients.png
--------------------------------------------------------------------------------
/tests/test_metadata.py:
--------------------------------------------------------------------------------
1 | # This is for making sure the different kinds of metadata work out properly.
2 |
--------------------------------------------------------------------------------
/images/cheatsheet.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/cheatsheet.pdf
--------------------------------------------------------------------------------
/images/swin_v2_b_demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/swin_v2_b_demo.jpg
--------------------------------------------------------------------------------
/images/simple_recurrent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/simple_recurrent.png
--------------------------------------------------------------------------------
/images/nested_modules_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/nested_modules_example.png
--------------------------------------------------------------------------------
/requirements.test.txt:
--------------------------------------------------------------------------------
1 | black
2 | cornet @ git+https://github.com/dicarlolab/CORnet
3 | lightning
4 | numpy
5 | pillow
6 | pytest
7 | requests
8 | timm
9 | torch
10 | torchaudio
11 | torch_geometric
12 | torchlens
13 | torchvision
14 | transformers
15 | visualpriors
16 |
--------------------------------------------------------------------------------
/torchlens/__init__.py:
--------------------------------------------------------------------------------
1 | """ Top level package: make the user-facing functions top-level, rest accessed as submodules.
2 | """
3 | from .user_funcs import log_forward_pass, show_model_graph, get_model_metadata, validate_saved_activations, \
4 | validate_batch_of_models_and_inputs
5 | from .model_history import ModelHistory
6 | from .tensor_log import TensorLogEntry, RolledTensorLogEntry
7 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v3.2.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: end-of-file-fixer
9 | - id: check-yaml
10 | - id: check-added-large-files
11 | - repo: https://github.com/psf/black
12 | rev: 23.1.0
13 | hooks:
14 | - id: black
15 | - repo: https://github.com/pycqa/isort
16 | rev: 5.12.0
17 | hooks:
18 | - id: isort
19 | args: [ "--profile", "black" ]
20 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 | on:
3 | push:
4 | branches: [ main ]
5 |
6 | concurrency:
7 | group: "${{ github.head_ref || github.ref }}-lint-and-test"
8 | cancel-in-progress: true
9 |
10 | jobs:
11 | # flake8:
12 | # runs-on: ubuntu-latest
13 | # container: python:3.8.5-slim-buster
14 | # steps:
15 | # - uses: actions/checkout@v2
16 | # - run: pip install flake8
17 | # - run: flake8 . --show-source
18 | black:
19 | runs-on: ubuntu-latest
20 | container: python:3.9-slim
21 | steps:
22 | - uses: actions/checkout@v2
23 | - run: pip install black
24 | - run: black --check .
25 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | with open("README.md") as readme_file:
4 | readme = readme_file.read()
5 |
6 | requirements = [
7 | "numpy",
8 | "pandas",
9 | "pytest",
10 | "tqdm",
11 | "ipython",
12 | "graphviz",
13 | ]
14 |
15 | setup(
16 | name="torchlens",
17 | version="0.1.36",
18 | description="A package for extracting activations from PyTorch models",
19 | long_description="A package for extracting activations from PyTorch models. Contains functionality for "
20 | "extracting model activations, visualizing a model's computational graph, and "
21 | "extracting exhaustive metadata about a model.",
22 | author="JohnMark Taylor",
23 | author_email="johnmarkedwardtaylor@gmail.com",
24 | url="https://github.com/johnmarktaylor91/torchlens",
25 | packages=["torchlens"],
26 | include_package_data=True,
27 | install_requires=requirements,
28 | license="GNU GPL v3",
29 | zip_safe=False,
30 | keywords="torch torchlens features",
31 | classifiers=[
32 | "Development Status :: 3 - Alpha",
33 | "Intended Audience :: Science/Research",
34 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
35 | "Natural Language :: English",
36 | "Programming Language :: Python :: 3.9",
37 | ],
38 | extras_require={"dev": ["black[jupyter]", "pytest", "pre-commit"]},
39 | )
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | /tests/visualization_outputs/
131 |
--------------------------------------------------------------------------------
/torchlens/cleanup.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 |
5 | from .constants import MODEL_HISTORY_FIELD_ORDER
6 | from .helper_funcs import remove_entry_from_list
7 | from .tensor_log import TensorLogEntry
8 |
9 |
10 | def cleanup(self):
11 | """Deletes all log entries in the model."""
12 | for tensor_log_entry in self:
13 | self._remove_log_entry(tensor_log_entry, remove_references=True)
14 | for attr in MODEL_HISTORY_FIELD_ORDER:
15 | delattr(self, attr)
16 | torch.cuda.empty_cache()
17 |
18 |
19 | def _remove_log_entry(
20 | self, log_entry: TensorLogEntry, remove_references: bool = True
21 | ):
22 | """Given a TensorLogEntry, destroys it and all references to it.
23 |
24 | Args:
25 | log_entry: Tensor log entry to remove.
26 | remove_references: Whether to also remove references to the log entry
27 | """
28 | if self._pass_finished:
29 | tensor_label = log_entry.layer_label
30 | else:
31 | tensor_label = log_entry.tensor_label_raw
32 | for attr in dir(log_entry):
33 | with warnings.catch_warnings():
34 | warnings.simplefilter("ignore")
35 | if not attr.startswith("_") and not callable(getattr(log_entry, attr)):
36 | delattr(log_entry, attr)
37 | del log_entry
38 | if remove_references:
39 | _remove_log_entry_references(self, tensor_label)
40 |
41 |
42 | def _remove_log_entry_references(self, layer_to_remove: str):
43 | """Removes all references to a given TensorLogEntry in the ModelHistory object.
44 |
45 | Args:
46 | layer_to_remove: The log entry to remove.
47 | """
48 | # Clear any fields in ModelHistory referring to the entry.
49 |
50 | remove_entry_from_list(self.input_layers, layer_to_remove)
51 | remove_entry_from_list(self.output_layers, layer_to_remove)
52 | remove_entry_from_list(self.buffer_layers, layer_to_remove)
53 | remove_entry_from_list(self.internally_initialized_layers, layer_to_remove)
54 | remove_entry_from_list(self.internally_terminated_layers, layer_to_remove)
55 | remove_entry_from_list(self.internally_terminated_bool_layers, layer_to_remove)
56 | remove_entry_from_list(self.layers_with_saved_activations, layer_to_remove)
57 | remove_entry_from_list(self.layers_with_saved_gradients, layer_to_remove)
58 | remove_entry_from_list(
59 | self._layers_where_internal_branches_merge_with_input, layer_to_remove
60 | )
61 |
62 | self.conditional_branch_edges = [
63 | tup for tup in self.conditional_branch_edges if layer_to_remove not in tup
64 | ]
65 |
66 | # Now any nested fields.
67 |
68 | for group_label, group_tensors in self.layers_computed_with_params.items():
69 | if layer_to_remove in group_tensors:
70 | group_tensors.remove(layer_to_remove)
71 | self.layers_computed_with_params = {
72 | k: v for k, v in self.layers_computed_with_params.items() if len(v) > 0
73 | }
74 |
75 | for group_label, group_tensors in self.equivalent_operations.items():
76 | if layer_to_remove in group_tensors:
77 | group_tensors.remove(layer_to_remove)
78 | self.equivalent_operations = {
79 | k: v for k, v in self.equivalent_operations.items() if len(v) > 0
80 | }
81 |
82 | for group_label, group_tensors in self.same_layer_operations.items():
83 | if layer_to_remove in group_tensors:
84 | group_tensors.remove(layer_to_remove)
85 | self.same_layer_operations = {
86 | k: v for k, v in self.same_layer_operations.items() if len(v) > 0
87 | }
88 |
--------------------------------------------------------------------------------
/torchlens/model_history.py:
--------------------------------------------------------------------------------
1 | # This file is for defining the ModelHistory class that stores the representation of the forward pass.
2 | import copy
3 | from collections import OrderedDict, defaultdict
4 | from typing import Any, Callable, Dict, List, Optional, Set, Tuple
5 |
6 | from .cleanup import _remove_log_entry, cleanup
7 | from .decorate_torch import decorate_pytorch
8 | from .helper_funcs import (
9 | human_readable_size,
10 | )
11 | from .interface import (_getitem_after_pass, _getitem_during_pass, _str_after_pass,
12 | _str_during_pass, to_pandas, print_all_fields)
13 | from .logging_funcs import save_new_activations
14 | from .model_funcs import cleanup_model, prepare_model
15 | from .postprocess import postprocess
16 | from .tensor_log import RolledTensorLogEntry, TensorLogEntry
17 | from .trace_model import run_and_log_inputs_through_model
18 | from .validation import validate_saved_activations
19 | from .vis import render_graph
20 |
21 |
22 | # todo add saved_layer field, remove the option to only keep saved layers
23 |
24 |
25 | class ModelHistory:
26 | def __init__(
27 | self,
28 | model_name: str,
29 | output_device: str = "same",
30 | activation_postfunc: Optional[Callable] = None,
31 | keep_unsaved_layers: bool = True,
32 | save_function_args: bool = False,
33 | save_gradients: bool = False,
34 | detach_saved_tensors: bool = False,
35 | mark_input_output_distances: bool = True,
36 | ):
37 | """Object that stores the history of a model's forward pass.
38 | Both logs the history in real time, and stores a nice
39 | representation of the full history for the user afterward.
40 | """
41 | # Setup:
42 | activation_postfunc = copy.deepcopy(activation_postfunc)
43 |
44 | # General info
45 | self.model_name = model_name
46 | self._pass_finished = False
47 | self._track_tensors = False
48 | self.logging_mode = "exhaustive"
49 | self._pause_logging = False
50 | self._all_layers_logged = False
51 | self._all_layers_saved = False
52 | self.keep_unsaved_layers = keep_unsaved_layers
53 | self.activation_postfunc = activation_postfunc
54 | self.current_function_call_barcode = None
55 | self.random_seed_used = None
56 | self.output_device = output_device
57 | self.detach_saved_tensors = detach_saved_tensors
58 | self.save_function_args = save_function_args
59 | self.save_gradients = save_gradients
60 | self.has_saved_gradients = False
61 | self.mark_input_output_distances = mark_input_output_distances
62 |
63 | # Model structure info
64 | self.model_is_recurrent = False
65 | self.model_max_recurrent_loops = 1
66 | self.model_has_conditional_branching = False
67 | self.model_is_branching = False
68 |
69 | # Tensor Tracking:
70 | self.layer_list: List[TensorLogEntry] = []
71 | self.layer_list_rolled: List[RolledTensorLogEntry] = []
72 | self.layer_dict_main_keys: Dict[str, TensorLogEntry] = OrderedDict()
73 | self.layer_dict_all_keys: Dict[str, TensorLogEntry] = OrderedDict()
74 | self.layer_dict_rolled: Dict[str, RolledTensorLogEntry] = OrderedDict()
75 | self.layer_labels: List[str] = []
76 | self.layer_labels_w_pass: List[str] = []
77 | self.layer_labels_no_pass: List[str] = []
78 | self.layer_num_passes: Dict[str, int] = OrderedDict()
79 | self._raw_tensor_dict: Dict[str, TensorLogEntry] = OrderedDict()
80 | self._raw_tensor_labels_list: List[str] = []
81 | self._tensor_nums_to_save: List[int] = []
82 | self._tensor_counter: int = 0
83 | self.num_operations: int = 0
84 | self._raw_layer_type_counter: Dict[str, int] = defaultdict(lambda: 0)
85 | self._unsaved_layers_lookup_keys: Set[str] = set()
86 |
87 | # Mapping from raw to final layer labels:
88 | self._raw_to_final_layer_labels: Dict[str, str] = {}
89 | self._final_to_raw_layer_labels: Dict[str, str] = {}
90 | self._lookup_keys_to_tensor_num_dict: Dict[str, int] = {}
91 | self._tensor_num_to_lookup_keys_dict: Dict[int, List[str]] = defaultdict(list)
92 |
93 | # Special Layers:
94 | self.input_layers: List[str] = []
95 | self.output_layers: List[str] = []
96 | self.buffer_layers: List[str] = []
97 | self.buffer_num_passes: Dict = {}
98 | self.internally_initialized_layers: List[str] = []
99 | self._layers_where_internal_branches_merge_with_input: List[str] = []
100 | self.internally_terminated_layers: List[str] = []
101 | self.internally_terminated_bool_layers: List[str] = []
102 | self.conditional_branch_edges: List[Tuple[str, str]] = []
103 | self.layers_with_saved_activations: List[str] = []
104 | self.orphan_layers: List[str] = []
105 | self.unlogged_layers: List[str] = []
106 | self.layers_with_saved_gradients: List[str] = []
107 | self.layers_computed_with_params: Dict[str, List] = defaultdict(list)
108 | self.equivalent_operations: Dict[str, set] = defaultdict(set)
109 | self.same_layer_operations: Dict[str, list] = defaultdict(list)
110 |
111 | # Tensor info:
112 | self.num_tensors_total: int = 0
113 | self.tensor_fsize_total: int = 0
114 | self.tensor_fsize_total_nice: str = human_readable_size(0)
115 | self.num_tensors_saved: int = 0
116 | self.tensor_fsize_saved: int = 0
117 | self.tensor_fsize_saved_nice: str = human_readable_size(0)
118 |
119 | # Param info:
120 | self.total_param_tensors: int = 0
121 | self.total_param_layers: int = 0
122 | self.total_params: int = 0
123 | self.total_params_fsize: int = 0
124 | self.total_params_fsize_nice: str = human_readable_size(0)
125 |
126 | # Module info:
127 | self.module_addresses: List[str] = []
128 | self.module_types: Dict[str, Any] = {}
129 | self.module_passes: List = []
130 | self.module_num_passes: Dict = defaultdict(lambda: 1)
131 | self.top_level_modules: List = []
132 | self.top_level_module_passes: List = []
133 | self.module_children: Dict = defaultdict(list)
134 | self.module_pass_children: Dict = defaultdict(list)
135 | self.module_nparams: Dict = defaultdict(lambda: 0)
136 | self.module_num_tensors: Dict = defaultdict(lambda: 0)
137 | self.module_pass_num_tensors: Dict = defaultdict(lambda: 0)
138 | self.module_layers: Dict = defaultdict(list)
139 | self.module_pass_layers: Dict = defaultdict(list)
140 | self.module_layer_argnames = defaultdict(list)
141 |
142 | # Time elapsed:
143 | self.pass_start_time: float = 0
144 | self.pass_end_time: float = 0
145 | self.elapsed_time_setup: float = 0
146 | self.elapsed_time_forward_pass: float = 0
147 | self.elapsed_time_cleanup: float = 0
148 | self.elapsed_time_total: float = 0
149 | self.elapsed_time_function_calls: float = 0
150 | self.elapsed_time_torchlens_logging: float = 0
151 |
152 | # Reference info
153 | self.func_argnames: Dict[str, tuple] = defaultdict(lambda: tuple([]))
154 |
155 | # ********************************************
156 | # ************ Built-in Methods **************
157 | # ********************************************
158 |
159 | def __len__(self):
160 | if self._pass_finished:
161 | return len(self.layer_list)
162 | else:
163 | return len(self._raw_tensor_dict)
164 |
165 | def __getitem__(self, ix) -> TensorLogEntry:
166 | """Returns an object logging a model layer given an index. If the pass is finished,
167 | it'll do this intelligently; if not, it simply queries based on the layer's raw barcode.
168 |
169 | Args:
170 | ix: desired index
171 |
172 | Returns:
173 | Tensor log entry object with info about specified layer.
174 | """
175 | if self._pass_finished:
176 | return _getitem_after_pass(self, ix)
177 | else:
178 | return _getitem_during_pass(self, ix)
179 |
180 | def __str__(self) -> str:
181 | if self._pass_finished:
182 | return _str_after_pass(self)
183 | else:
184 | return _str_during_pass(self)
185 |
186 | def __repr__(self):
187 | return self.__str__()
188 |
189 | def __iter__(self):
190 | """Loops through all tensors in the log."""
191 | if self._pass_finished:
192 | return iter(self.layer_list)
193 | else:
194 | return iter(list(self._raw_tensor_dict.values()))
195 |
196 | # ********************************************
197 | # ******** Assign Imported Methods ***********
198 | # ********************************************
199 |
200 | render_graph = render_graph
201 | print_all_fields = print_all_fields
202 | to_pandas = to_pandas
203 | save_new_activations = save_new_activations
204 | validate_saved_activations = validate_saved_activations
205 | cleanup = cleanup
206 | _postprocess = postprocess
207 | _decorate_pytorch = decorate_pytorch
208 | _prepare_model = prepare_model
209 | _cleanup_model = cleanup_model
210 | _run_and_log_inputs_through_model = run_and_log_inputs_through_model
211 | _remove_log_entry = _remove_log_entry
212 |
--------------------------------------------------------------------------------
/torchlens/trace_model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import inspect
3 | import random
4 | import time
5 | from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union
6 |
7 | import torch
8 | from torch import nn
9 |
10 | if TYPE_CHECKING:
11 | from .model_history import ModelHistory
12 | from .decorate_torch import undecorate_pytorch
13 | from .helper_funcs import get_vars_of_type_from_obj, set_random_seed, nested_assign
14 | from .logging_funcs import log_source_tensor
15 | from .interface import _give_user_feedback_about_lookup_key
16 |
17 |
18 | def _get_input_arg_names(model, input_args):
19 | input_arg_names = inspect.getfullargspec(model.forward).args
20 | if "self" in input_arg_names:
21 | input_arg_names.remove("self")
22 | input_arg_names = input_arg_names[0: len(input_args)]
23 | return input_arg_names
24 |
25 |
26 | def _get_op_nums_from_user_labels(
27 | self: "ModelHistory", which_layers: Union[str, List[Union[str, int]]]
28 | ) -> List[int]:
29 | """Given list of user layer labels, returns the original tensor numbers for those labels (i.e.,
30 | the numbers that were generated on the fly during the forward pass, such that they can be
31 | saved on a subsequent pass). Raises an error if the user's labels don't correspond to any layers.
32 |
33 | Args:
34 | which_layers: List of layers to include, using any indexing desired: either the layer label,
35 | the module label, or the ordinal position of the layer. If a layer has multiple passes and
36 | none is specified, will return all of them.
37 |
38 | Returns:
39 | Ordered, unique list of raw tensor numbers associated with the specified layers.
40 | """
41 | if which_layers == "all":
42 | return which_layers
43 | elif which_layers in [None, "none", "None", "NONE", []]:
44 | return []
45 |
46 | if type(which_layers) != list:
47 | which_layers = [which_layers]
48 | which_layers = [
49 | layer.lower() if (type(layer) == str) else layer for layer in which_layers
50 | ]
51 |
52 | raw_tensor_nums_to_save = set()
53 | for layer_key in which_layers:
54 | # First check if it matches a lookup key. If so, use that.
55 | if layer_key in self._lookup_keys_to_tensor_num_dict:
56 | raw_tensor_nums_to_save.add(
57 | self._lookup_keys_to_tensor_num_dict[layer_key]
58 | )
59 | continue
60 |
61 | # If not, pull out all layers for which the key is a substring.
62 | keys_with_substr = [
63 | key for key in self.layer_dict_all_keys if layer_key in str(key)
64 | ]
65 | if len(keys_with_substr) > 0:
66 | for key in keys_with_substr:
67 | raw_tensor_nums_to_save.add(
68 | self.layer_dict_all_keys[key].realtime_tensor_num
69 | )
70 | continue
71 |
72 | # If no luck, try to at least point user in right direction:
73 |
74 | _give_user_feedback_about_lookup_key(self, layer_key, "query_multiple")
75 |
76 | raw_tensor_nums_to_save = sorted(list(raw_tensor_nums_to_save))
77 | return raw_tensor_nums_to_save
78 |
79 |
80 | def _fetch_label_move_input_tensors(
81 | input_args: List[Any],
82 | input_arg_names: List[str],
83 | input_kwargs: Dict,
84 | model_device: str,
85 | ) -> Tuple[List[torch.Tensor], List[str]]:
86 | """Fetches input tensors, gets their addresses, and moves them to the model device.
87 |
88 | Args:
89 | input_args: input arguments
90 | input_arg_names: name of input arguments
91 | input_kwargs: input keyword arguments
92 | model_device: model device
93 |
94 | Returns:
95 | input tensors and their addresses
96 | """
97 | input_arg_tensors = [
98 | get_vars_of_type_from_obj(
99 | arg, torch.Tensor, search_depth=5, return_addresses=True
100 | )
101 | for arg in input_args
102 | ]
103 | input_kwarg_tensors = [
104 | get_vars_of_type_from_obj(
105 | kwarg, torch.Tensor, search_depth=5, return_addresses=True
106 | )
107 | for kwarg in input_kwargs.values()
108 | ]
109 | for a, arg in enumerate(input_args):
110 | for i, (t, addr, addr_full) in enumerate(input_arg_tensors[a]):
111 | t_moved = t.to(model_device)
112 | input_arg_tensors[a][i] = (t_moved, addr, addr_full)
113 | if not addr_full:
114 | input_args[a] = t_moved
115 | else:
116 | nested_assign(input_args[a], addr_full, t_moved)
117 |
118 | for k, (key, val) in enumerate(input_kwargs.items()):
119 | for i, (t, addr, addr_full) in enumerate(input_kwarg_tensors[k]):
120 | t_moved = t.to(model_device)
121 | input_kwarg_tensors[k][i] = (t_moved, addr, addr_full)
122 | if not addr_full:
123 | input_kwargs[key] = t_moved
124 | else:
125 | nested_assign(input_kwargs[key], addr_full, t_moved)
126 |
127 | input_tensors = []
128 | input_tensor_addresses = []
129 | for a, arg_tensors in enumerate(input_arg_tensors):
130 | for t, addr, addr_full in arg_tensors:
131 | input_tensors.append(t)
132 | tensor_addr = f"input.{input_arg_names[a]}"
133 | if addr != "":
134 | tensor_addr += f".{addr}"
135 | input_tensor_addresses.append(tensor_addr)
136 |
137 | for a, kwarg_tensors in enumerate(input_kwarg_tensors):
138 | for t, addr, addr_full in kwarg_tensors:
139 | input_tensors.append(t)
140 | tensor_addr = f"input.{list(input_kwargs.keys())[a]}"
141 | if addr != "":
142 | tensor_addr += f".{addr}"
143 | input_tensor_addresses.append(tensor_addr)
144 |
145 | return input_tensors, input_tensor_addresses
146 |
147 |
148 | def run_and_log_inputs_through_model(
149 | self: "ModelHistory",
150 | model: nn.Module,
151 | input_args: Union[torch.Tensor, List[Any]],
152 | input_kwargs: Dict[Any, Any] = None,
153 | layers_to_save: Optional[Union[str, List[Union[str, int]]]] = "all",
154 | random_seed: Optional[int] = None,
155 | ):
156 | """Runs input through model and logs it in ModelHistory.
157 |
158 | Args:
159 | model: Model for which to save activations
160 | input_args: Either a single tensor input to the model, or list of input arguments.
161 | input_kwargs: Dict of keyword arguments to the model.
162 | layers_to_save: List of tensor numbers to save
163 | random_seed: Which random seed to use
164 | Returns:
165 | Nothing, but now the ModelHistory object will have saved activations for the new input.
166 | """
167 | if random_seed is None: # set random seed
168 | random_seed = random.randint(1, 4294967294)
169 | self.random_seed_used = random_seed
170 | set_random_seed(random_seed)
171 |
172 | self._tensor_nums_to_save = _get_op_nums_from_user_labels(self, layers_to_save)
173 |
174 | if type(input_args) is tuple:
175 | input_args = list(input_args)
176 | elif (type(input_args) not in [list, tuple]) and (input_args is not None):
177 | input_args = [input_args]
178 |
179 | if not input_args:
180 | input_args = []
181 |
182 | if not input_kwargs:
183 | input_kwargs = {}
184 |
185 | if (
186 | type(model) == nn.DataParallel
187 | ): # Unwrap model from DataParallel if relevant:
188 | model = model.module
189 |
190 | if (
191 | len(list(model.parameters())) > 0
192 | ): # Get the model device by looking at the parameters:
193 | model_device = next(iter(model.parameters())).device
194 | else:
195 | model_device = "cpu"
196 |
197 | input_args = [copy.deepcopy(arg) for arg in input_args]
198 | input_arg_names = _get_input_arg_names(model, input_args)
199 | input_kwargs = {key: copy.deepcopy(val) for key, val in input_kwargs.items()}
200 |
201 | self.pass_start_time = time.time()
202 | module_orig_forward_funcs = {}
203 | orig_func_defs = []
204 |
205 | try:
206 | (
207 | input_tensors,
208 | input_tensor_addresses,
209 | ) = _fetch_label_move_input_tensors(
210 | input_args, input_arg_names, input_kwargs, model_device
211 | )
212 | buffer_tensors = list(model.buffers())
213 | tensors_to_decorate = input_tensors + buffer_tensors
214 | decorated_func_mapper = self._decorate_pytorch(torch, orig_func_defs)
215 | self._track_tensors = True
216 | for i, t in enumerate(input_tensors):
217 | log_source_tensor(self, t, "input", input_tensor_addresses[i])
218 | self._prepare_model(model, module_orig_forward_funcs, decorated_func_mapper)
219 | self.elapsed_time_setup = time.time() - self.pass_start_time
220 | outputs = model(*input_args, **input_kwargs)
221 | self.elapsed_time_forward_pass = (
222 | time.time() - self.pass_start_time - self.elapsed_time_setup
223 | )
224 | self._track_tensors = False
225 | output_tensors_w_addresses_all = get_vars_of_type_from_obj(
226 | outputs,
227 | torch.Tensor,
228 | search_depth=5,
229 | return_addresses=True,
230 | allow_repeats=True,
231 | )
232 | # Remove duplicate addresses TODO: make this match the validation procedure so they work the same way
233 | addresses_used = []
234 | output_tensors_w_addresses = []
235 | for entry in output_tensors_w_addresses_all:
236 | if entry[1] in addresses_used:
237 | continue
238 | output_tensors_w_addresses.append(entry)
239 | addresses_used.append(entry[1])
240 |
241 | output_tensors = [t for t, _, _ in output_tensors_w_addresses]
242 | output_tensor_addresses = [
243 | addr for _, addr, _ in output_tensors_w_addresses
244 | ]
245 |
246 | for t in output_tensors:
247 | if self.logging_mode == 'exhaustive':
248 | self.output_layers.append(t.tl_tensor_label_raw)
249 | self._raw_tensor_dict[t.tl_tensor_label_raw].is_output_parent = True
250 | tensors_to_undecorate = tensors_to_decorate + output_tensors
251 | undecorate_pytorch(torch, orig_func_defs, tensors_to_undecorate)
252 | self._cleanup_model(model, module_orig_forward_funcs, decorated_func_mapper)
253 | self._postprocess(output_tensors, output_tensor_addresses)
254 | decorated_func_mapper.clear()
255 |
256 | except (
257 | Exception
258 | ) as e: # if anything fails, make sure everything gets cleaned up
259 | undecorate_pytorch(torch, orig_func_defs, input_tensors)
260 | self._cleanup_model(model, module_orig_forward_funcs, decorated_func_mapper)
261 | print(
262 | "************\nFeature extraction failed; returning model and environment to normal\n*************"
263 | )
264 | raise e
265 |
266 | finally: # do garbage collection no matter what
267 | if 'input_args' in globals():
268 | del input_args
269 | if 'input_kwargs' in globals():
270 | del input_kwargs
271 | if 'input_tensors' in globals():
272 | del input_tensors
273 | if 'output_tensors' in globals():
274 | del output_tensors
275 | if 'outputs' in globals():
276 | del outputs
277 | torch.cuda.empty_cache()
278 |
--------------------------------------------------------------------------------
/torchlens/decorate_torch.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import time
3 | import types
4 | import warnings
5 | from functools import wraps
6 | from typing import Callable, Dict, List, TYPE_CHECKING, Tuple
7 |
8 | import torch
9 |
10 | from .constants import ORIG_TORCH_FUNCS
11 | from .helper_funcs import (clean_to, get_vars_of_type_from_obj, identity, log_current_rng_states, make_random_barcode,
12 | nested_getattr, print_override, safe_copy)
13 | from .logging_funcs import log_function_output_tensors, log_source_tensor
14 |
15 | if TYPE_CHECKING:
16 | from model_history import ModelHistory
17 |
18 | funcs_not_to_log = ["numpy", "__array__", "size", "dim"]
19 | print_funcs = ["__repr__", "__str__", "_str"]
20 |
21 |
22 | def torch_func_decorator(self, func: Callable, func_name: str):
23 | @wraps(func)
24 | def wrapped_func(*args, **kwargs):
25 | # Initial bookkeeping; check if it's a special function, organize the arguments.
26 | self.current_function_call_barcode = 0
27 | if (
28 | (func_name in funcs_not_to_log)
29 | or (not self._track_tensors)
30 | or self._pause_logging
31 | ):
32 | out = func(*args, **kwargs)
33 | return out
34 | all_args = list(args) + list(kwargs.values())
35 | arg_tensorlike = get_vars_of_type_from_obj(all_args, torch.Tensor)
36 |
37 | # Register any buffer tensors in the arguments.
38 |
39 | for t in arg_tensorlike:
40 | if hasattr(t, 'tl_buffer_address'):
41 | log_source_tensor(self, t, 'buffer', getattr(t, 'tl_buffer_address'))
42 |
43 | if (func_name in print_funcs) and (len(arg_tensorlike) > 0):
44 | out = print_override(args[0], func_name)
45 | return out
46 |
47 | # Copy the args and kwargs in case they change in-place:
48 | if self.save_function_args:
49 | arg_copies = tuple([safe_copy(arg) for arg in args])
50 | kwarg_copies = {k: safe_copy(v) for k, v in kwargs.items()}
51 | else:
52 | arg_copies = args
53 | kwarg_copies = kwargs
54 |
55 | # Call the function, tracking the timing, rng states, and whether it's a nested function
56 | func_call_barcode = make_random_barcode()
57 | self.current_function_call_barcode = func_call_barcode
58 | start_time = time.time()
59 | func_rng_states = log_current_rng_states()
60 | out_orig = func(*args, **kwargs)
61 | func_time_elapsed = time.time() - start_time
62 | is_bottom_level_func = (
63 | self.current_function_call_barcode == func_call_barcode
64 | )
65 |
66 | if func_name in ["__setitem__", "zero_", "__delitem__"]:
67 | out_orig = args[0]
68 |
69 | if id(out_orig) == id(args[0]): # special case if the function does nothing
70 | out_orig = safe_copy(out_orig)
71 |
72 | # Log all output tensors
73 | output_tensors = get_vars_of_type_from_obj(
74 | out_orig,
75 | which_type=torch.Tensor,
76 | subclass_exceptions=[torch.nn.Parameter],
77 | )
78 |
79 | if len(output_tensors) > 0:
80 | log_function_output_tensors(
81 | self,
82 | func,
83 | func_name,
84 | args,
85 | kwargs,
86 | arg_copies,
87 | kwarg_copies,
88 | out_orig,
89 | func_time_elapsed,
90 | func_rng_states,
91 | is_bottom_level_func,
92 | )
93 |
94 | return out_orig
95 |
96 | return wrapped_func
97 |
98 |
99 | def decorate_pytorch(
100 | self: "ModelHistory", torch_module: types.ModuleType, orig_func_defs: List[Tuple]
101 | ) -> Dict[Callable, Callable]:
102 | """Mutates all PyTorch functions (TEMPORARILY!) to save the outputs of any functions
103 | that return Tensors, along with marking them with metadata. Returns a list of tuples that
104 | save the current state of the functions, such that they can be restored when done.
105 |
106 | args:
107 | torch_module: The top-level torch module (i.e., from "import torch").
108 | This is supplied as an argument on the off-chance that the user has imported torch
109 | and done their own monkey-patching.
110 | tensors_to_mutate: A list of tensors that will be mutated (since any tensors created
111 | before calling the torch mutation function will not be mutated).
112 | orig_func_defs: Supply a list from outside to guarantee it can be cleaned up properly.
113 | tensor_record: A list to which the outputs of the functions will be appended.
114 |
115 | returns:
116 | List of tuples consisting of [namespace, func_name, orig_func], sufficient
117 | to return torch to normal when finished, and also a dict mapping mutated functions to original functions.
118 | """
119 |
120 | # Do a pass to save the original func defs.
121 | collect_orig_func_defs(torch_module, orig_func_defs)
122 | decorated_func_mapper = {}
123 |
124 | # Get references to the function classes.
125 | function_class = type(lambda: 0)
126 | builtin_class = type(torch.mean)
127 | method_class = type(torch.Tensor.__add__)
128 | wrapper_class = type(torch.Tensor.__getitem__)
129 | getset_class = type(torch.Tensor.real)
130 |
131 | for namespace_name, func_name in ORIG_TORCH_FUNCS:
132 | namespace_name_notorch = namespace_name.replace("torch.", "")
133 | local_func_namespace = nested_getattr(torch_module, namespace_name_notorch)
134 | if not hasattr(local_func_namespace, func_name):
135 | continue
136 | orig_func = getattr(local_func_namespace, func_name)
137 | if func_name not in self.func_argnames:
138 | get_func_argnames(self, orig_func, func_name)
139 | if getattr(orig_func, "__name__", False) == "wrapped_func":
140 | continue
141 |
142 | if type(orig_func) in [function_class, builtin_class, method_class, wrapper_class]:
143 | new_func = torch_func_decorator(self, orig_func, func_name)
144 | try:
145 | with warnings.catch_warnings():
146 | warnings.simplefilter("ignore")
147 | setattr(local_func_namespace, func_name, new_func)
148 | except (AttributeError, TypeError) as _:
149 | pass
150 | new_func.tl_is_decorated_function = True
151 | decorated_func_mapper[new_func] = orig_func
152 | decorated_func_mapper[orig_func] = new_func
153 |
154 | elif type(orig_func) == getset_class:
155 | getter_orig, setter_orig, deleter_orig = orig_func.__get__, orig_func.__set__, orig_func.__delete__
156 | getter_dec, setter_dec, deleter_dec = (torch_func_decorator(self, getter_orig, func_name),
157 | torch_func_decorator(self, setter_orig, func_name),
158 | torch_func_decorator(self, deleter_orig, func_name))
159 | getter_dec.tl_is_decorated_function = True
160 | setter_dec.tl_is_decorated_function = True
161 | deleter_dec.tl_is_decorated_function = True
162 | new_property = property(getter_dec, setter_dec, deleter_dec, doc=func_name)
163 | try:
164 | with warnings.catch_warnings():
165 | warnings.simplefilter("ignore")
166 | setattr(local_func_namespace, func_name, new_property)
167 | except (AttributeError, TypeError) as _:
168 | pass
169 | decorated_func_mapper[new_property] = orig_func
170 | decorated_func_mapper[orig_func] = new_property
171 |
172 | # Bolt on the identity function
173 | new_identity = torch_func_decorator(self, identity, 'identity')
174 | torch.identity = new_identity
175 |
176 | return decorated_func_mapper
177 |
178 |
179 | def undecorate_pytorch(
180 | torch_module, orig_func_defs: List[Tuple], input_tensors: List[torch.Tensor]
181 | ):
182 | """
183 | Returns all PyTorch functions back to the definitions they had when mutate_pytorch was called.
184 | This is done for the output tensors and history_dict too to avoid ugliness. Also deletes
185 | the mutant versions of the functions to remove any references to old ModelHistory object.
186 |
187 | args:
188 | torch_module: The torch module object.
189 | orig_func_defs: List of tuples consisting of [namespace_name, func_name, orig_func], sufficient
190 | to regenerate the original functions.
191 | input_tensors: List of input tensors whose fucntions will be undecorated.
192 | decorated_func_mapper: Maps the decorated function to the original function
193 | """
194 | for namespace_name, func_name, orig_func in orig_func_defs:
195 | namespace_name_notorch = namespace_name.replace("torch.", "")
196 | local_func_namespace = nested_getattr(torch_module, namespace_name_notorch)
197 | with warnings.catch_warnings():
198 | warnings.simplefilter("ignore")
199 | decorated_func = getattr(local_func_namespace, func_name)
200 | del decorated_func
201 | try:
202 | with warnings.catch_warnings():
203 | warnings.simplefilter("ignore")
204 | setattr(local_func_namespace, func_name, orig_func)
205 | except (AttributeError, TypeError) as _:
206 | continue
207 | delattr(torch, "identity")
208 | for input_tensor in input_tensors:
209 | if hasattr(input_tensor, "tl_tensor_label_raw"):
210 | delattr(input_tensor, "tl_tensor_label_raw")
211 |
212 |
213 | def undecorate_tensor(t, device: str = "cpu"):
214 | """Convenience function to replace the tensor with an unmutated version of itself, keeping the same data.
215 |
216 | Args:
217 | t: tensor or parameter object
218 | device: device to move the tensor to
219 |
220 | Returns:
221 | Unmutated tensor.
222 | """
223 | if type(t) in [torch.Tensor, torch.nn.Parameter]:
224 | new_t = safe_copy(t)
225 | else:
226 | new_t = t
227 | del t
228 | for attr in dir(new_t):
229 | if attr.startswith("tl_"):
230 | delattr(new_t, attr)
231 | new_t = clean_to(new_t, device)
232 | return new_t
233 |
234 |
235 | def collect_orig_func_defs(
236 | torch_module: types.ModuleType, orig_func_defs: List[Tuple]
237 | ):
238 | """Collects the original torch function definitions, so they can be restored after the logging is done.
239 |
240 | Args:
241 | torch_module: The top-level torch module
242 | orig_func_defs: List of tuples keeping track of the original function definitions
243 | """
244 | for namespace_name, func_name in ORIG_TORCH_FUNCS:
245 | namespace_name_notorch = namespace_name.replace("torch.", "")
246 | local_func_namespace = nested_getattr(torch_module, namespace_name_notorch)
247 | if not hasattr(local_func_namespace, func_name):
248 | continue
249 | orig_func = getattr(local_func_namespace, func_name)
250 | orig_func_defs.append((namespace_name, func_name, orig_func))
251 |
252 |
253 | # TODO: hard-code some of the arg names; for example truediv, getitem, etc. Can crawl through and see what isn't working
254 | def get_func_argnames(self, orig_func: Callable, func_name: str):
255 | """Attempts to get the argument names for a function, first by checking the signature, then
256 | by checking the documentation. Adds these names to func_argnames if it can find them,
257 | doesn't do anything if it can't."""
258 | if func_name in ['real', 'imag', 'T', 'mT', 'data', 'H']:
259 | return
260 |
261 | try:
262 | argnames = list(inspect.signature(orig_func).parameters.keys())
263 | argnames = tuple([arg.replace('*', '') for arg in argnames if arg not in ['cls', 'self']])
264 | self.func_argnames[func_name] = argnames
265 | return
266 | except ValueError:
267 | pass
268 |
269 | docstring = orig_func.__doc__
270 | if (type(docstring) is not str) or (len(docstring) == 0): # if docstring missing, skip it
271 | return
272 |
273 | open_ind, close_ind = docstring.find('('), docstring.find(')')
274 | argstring = docstring[open_ind + 1: close_ind]
275 | arg_list = argstring.split(',')
276 | arg_list = [arg.strip(' ') for arg in arg_list]
277 | argnames = []
278 | for arg in arg_list:
279 | argname = arg.split('=')[0]
280 | if argname in ['*', '/', '//', '']:
281 | continue
282 | argname = argname.replace('*', '')
283 | argnames.append(argname)
284 | argnames = tuple([arg for arg in argnames if arg not in ['self', 'cls']])
285 | self.func_argnames[func_name] = argnames
286 | return
287 |
--------------------------------------------------------------------------------
/torchlens/constants.py:
--------------------------------------------------------------------------------
1 | import __future__
2 | import functools
3 | import types
4 | from typing import List
5 | import warnings
6 |
7 | import torch
8 | from torch.overrides import get_ignored_functions, get_testing_overrides
9 |
10 | MODEL_HISTORY_FIELD_ORDER = [
11 | # General info
12 | "model_name",
13 | "_pass_finished",
14 | "_track_tensors",
15 | "logging_mode",
16 | "_pause_logging",
17 | "_all_layers_logged",
18 | "_all_layers_saved",
19 | "keep_unsaved_layers",
20 | "current_function_call_barcode",
21 | "random_seed_used",
22 | "detach_saved_tensors",
23 | "output_device",
24 | "save_function_args",
25 | "save_gradients",
26 | "has_saved_gradients",
27 | "activation_postfunc",
28 | "mark_input_output_distances",
29 | # Model structure info
30 | "model_is_recurrent",
31 | "model_max_recurrent_loops",
32 | "model_is_branching",
33 | "model_has_conditional_branching",
34 | # Tensor tracking logs
35 | "layer_list",
36 | "layer_list_rolled",
37 | "layer_dict_main_keys",
38 | "layer_dict_all_keys",
39 | "layer_dict_rolled",
40 | "layer_labels",
41 | "layer_labels_no_pass",
42 | "layer_labels_w_pass",
43 | "layer_num_passes",
44 | "_raw_tensor_dict",
45 | "_raw_tensor_labels_list",
46 | "_tensor_nums_to_save",
47 | "_tensor_counter",
48 | "num_operations",
49 | "_raw_layer_type_counter",
50 | "_unsaved_layers_lookup_keys",
51 | # Mapping from raw to final layer labels:
52 | "_raw_to_final_layer_labels",
53 | "_final_to_raw_layer_labels",
54 | "_lookup_keys_to_tensor_num_dict",
55 | "_tensor_num_to_lookup_keys_dict",
56 | # Special layers
57 | "input_layers",
58 | "output_layers",
59 | "buffer_layers",
60 | "buffer_num_passes",
61 | "internally_initialized_layers",
62 | "_layers_where_internal_branches_merge_with_input",
63 | "internally_terminated_layers",
64 | "internally_terminated_bool_layers",
65 | "conditional_branch_edges",
66 | "layers_computed_with_params",
67 | "equivalent_operations",
68 | "same_layer_operations",
69 | "layers_with_saved_activations",
70 | "unlogged_layers",
71 | "layers_with_saved_gradients",
72 | "orphan_layers",
73 | # Tensor info:
74 | "num_tensors_total",
75 | "tensor_fsize_total",
76 | "tensor_fsize_total_nice",
77 | "num_tensors_saved",
78 | "tensor_fsize_saved",
79 | "tensor_fsize_saved_nice",
80 | # Param info
81 | "total_param_tensors",
82 | "total_param_layers",
83 | "total_params",
84 | "total_params_fsize",
85 | "total_params_fsize_nice",
86 | # Module info
87 | "module_addresses",
88 | "module_types",
89 | "module_passes",
90 | "module_num_passes",
91 | "module_children",
92 | "module_pass_children",
93 | "top_level_modules",
94 | "top_level_module_passes",
95 | "module_nparams",
96 | "module_num_tensors",
97 | "module_pass_num_tensors",
98 | "module_layers",
99 | "module_layer_argnames",
100 | "module_pass_layers",
101 | # Time elapsed
102 | "pass_start_time",
103 | "pass_end_time",
104 | "elapsed_time_setup",
105 | "elapsed_time_forward_pass",
106 | "elapsed_time_cleanup",
107 | "elapsed_time_total",
108 | "elapsed_time_function_calls",
109 | "elapsed_time_torchlens_logging",
110 | # Lookup info
111 | "func_argnames"
112 | ]
113 |
114 | TENSOR_LOG_ENTRY_FIELD_ORDER = [
115 | # General info
116 | "layer_label",
117 | "tensor_label_raw",
118 | "layer_label_raw",
119 | "operation_num",
120 | "realtime_tensor_num",
121 | "source_model_history",
122 | "_pass_finished",
123 | # Other labeling info
124 | "layer_label_short",
125 | "layer_label_w_pass",
126 | "layer_label_w_pass_short",
127 | "layer_label_no_pass",
128 | "layer_label_no_pass_short",
129 | "layer_type",
130 | "layer_type_num",
131 | "layer_total_num",
132 | "pass_num",
133 | "layer_passes_total",
134 | "lookup_keys",
135 | # Saved tensor info
136 | "tensor_contents",
137 | "has_saved_activations",
138 | "output_device",
139 | "activation_postfunc",
140 | "detach_saved_tensor",
141 | "function_args_saved",
142 | "creation_args",
143 | "creation_kwargs",
144 | "tensor_shape",
145 | "tensor_dtype",
146 | "tensor_fsize",
147 | "tensor_fsize_nice",
148 | # Tensor slice-changing complications
149 | "was_getitem_applied",
150 | "children_tensor_versions",
151 | # Saved gradient info
152 | "grad_contents",
153 | "save_gradients",
154 | "has_saved_grad",
155 | "grad_shape",
156 | "grad_dtype",
157 | "grad_fsize",
158 | "grad_fsize_nice",
159 | # Function call info
160 | "func_applied",
161 | "func_applied_name",
162 | "func_call_stack",
163 | "func_time_elapsed",
164 | "func_rng_states",
165 | "func_argnames",
166 | "num_func_args_total",
167 | "num_position_args",
168 | "num_keyword_args",
169 | "func_position_args_non_tensor",
170 | "func_keyword_args_non_tensor",
171 | "func_all_args_non_tensor",
172 | "function_is_inplace",
173 | "gradfunc",
174 | "is_part_of_iterable_output",
175 | "iterable_output_index",
176 | # Param info
177 | "computed_with_params",
178 | "parent_params",
179 | "parent_param_barcodes",
180 | "parent_param_passes",
181 | "num_param_tensors",
182 | "parent_param_shapes",
183 | "num_params_total",
184 | "parent_params_fsize",
185 | "parent_params_fsize_nice",
186 | # Corresponding layer info
187 | "operation_equivalence_type",
188 | "equivalent_operations",
189 | "same_layer_operations",
190 | # Graph info
191 | "parent_layers",
192 | "has_parents",
193 | "parent_layer_arg_locs",
194 | "orig_ancestors",
195 | "child_layers",
196 | "has_children",
197 | "sibling_layers",
198 | "has_siblings",
199 | "spouse_layers",
200 | "has_spouses",
201 | "is_input_layer",
202 | "has_input_ancestor",
203 | "input_ancestors",
204 | "min_distance_from_input",
205 | "max_distance_from_input",
206 | "is_output_layer",
207 | "is_output_parent",
208 | "is_last_output_layer",
209 | "is_output_ancestor",
210 | "output_descendents",
211 | "input_output_address",
212 | "min_distance_from_output",
213 | "max_distance_from_output",
214 | "is_buffer_layer",
215 | "buffer_address",
216 | "buffer_pass",
217 | "buffer_parent",
218 | "initialized_inside_model",
219 | "has_internally_initialized_ancestor",
220 | "internally_initialized_parents",
221 | "internally_initialized_ancestors",
222 | "terminated_inside_model",
223 | # Conditional info
224 | "is_terminal_bool_layer",
225 | "is_atomic_bool_layer",
226 | "atomic_bool_val",
227 | "in_cond_branch",
228 | "cond_branch_start_children",
229 | # Module info
230 | "is_computed_inside_submodule",
231 | "containing_module_origin",
232 | "containing_modules_origin_nested",
233 | "module_nesting_depth",
234 | "modules_entered",
235 | "module_passes_entered",
236 | "modules_entered_argnames",
237 | "is_submodule_input",
238 | "modules_exited",
239 | "module_passes_exited",
240 | "is_submodule_output",
241 | "is_bottom_level_submodule_output",
242 | "bottom_level_submodule_pass_exited",
243 | "module_entry_exit_threads_inputs",
244 | "module_entry_exit_thread_output",
245 | ]
246 |
247 | # Taken from https://pytorch.org/docs/stable/_modules/torch/overrides.html#get_ignored_functions
248 | IGNORED_FUNCS = [
249 | ("torch", "load"),
250 | ("torch", "as_tensor"),
251 | ("torch", "from_numpy"),
252 | ("torch", "tensor"),
253 | ("torch", "arange"),
254 | ("torch", "as_strided"),
255 | ("torch", "bartlett_window"),
256 | ("torch", "blackman_window"),
257 | ("torch", "cudnn_affine_grid_generator"),
258 | ("torch", "cudnn_batch_norm"),
259 | ("torch", "cudnn_convolution"),
260 | ("torch", "cudnn_convolution_transpose"),
261 | ("torch", "cudnn_convolution_relu"),
262 | ("torch", "cudnn_convolution_add_relu"),
263 | ("torch", "cudnn_grid_sampler"),
264 | ("torch", "cudnn_is_acceptable"),
265 | ("torch", "eye"),
266 | ("torch.fft", "fftfreq"),
267 | ("torch.fft", "rfftfreq"),
268 | ("torch", "from_file"),
269 | ("torch", "full"),
270 | ("torch", "fill_"),
271 | ("torch", "hamming_window"),
272 | ("torch", "hann_window"),
273 | ("torch", "kaiser_window"),
274 | ("torch", "linspace"),
275 | ("torch", "logspace"),
276 | ("torch", "mkldnn_adaptive_avg_pool2d"),
277 | ("torch", "mkldnn_convolution"),
278 | ("torch", "mkldnn_max_pool2d"),
279 | ("torch", "mkldnn_max_pool3d"),
280 | ("torch", "mkldnn_linear_backward_weights"),
281 | ("torch", "normal"),
282 | ("torch", "ones"),
283 | ("torch", "rand"),
284 | ("torch", "randn"),
285 | ("torch", "randint"),
286 | ("torch", "randperm"),
287 | ("torch", "range"),
288 | ("torch", "scalar_tensor"),
289 | ("torch", "sparse_coo_tensor"),
290 | ("torch", "_sparse_csr_tensor"),
291 | ("torch", "tril_indices"),
292 | ("torch", "triu_indices"),
293 | ("torch", "vander"),
294 | ("torch", "zeros"),
295 | ("torch.nn.functional", "upsample"),
296 | ("torch.nn.functional", "upsample_bilinear"),
297 | ("torch.nn.functional", "upsample_nearest"),
298 | ("torch.nn.functional", "handle_torch_function"),
299 | ("torch.nn.functional", "sigmoid"),
300 | ("torch.nn.functional", "hardsigmoid"),
301 | ("torch.nn.functional", "tanh"),
302 | ("torch.nn.init", "calculate_gain"),
303 | ("torch.nn.init", "uniform"),
304 | ("torch.nn.init", "normal"),
305 | ("torch.nn.init", "constant"),
306 | ("torch.nn.init", "eye"),
307 | ("torch.nn.init", "dirac"),
308 | ("torch.nn.init", "xavier_uniform"),
309 | ("torch.nn.init", "xavier_normal"),
310 | ("torch.nn.init", "kaiming_uniform"),
311 | ("torch.nn.init", "kaiming_normal"),
312 | ("torch.nn.init", "orthogonal"),
313 | ("torch.nn.init", "sparse"),
314 | ("torch.nn.functional", "hardswish"),
315 | ("torch.Tensor", "__delitem__"),
316 | ("torch.Tensor", "__iter__"),
317 | ("torch.Tensor", "__init_subclass__"),
318 | ("torch.Tensor", "__torch_function__"),
319 | ("torch.Tensor", "__new__"),
320 | ("torch.Tensor", "__subclasshook__"),
321 | ("torch.Tensor", "as_subclass"),
322 | ("torch.Tensor", "reinforce"),
323 | ("torch.Tensor", "new"),
324 | ("torch.Tensor", "new_tensor"),
325 | ("torch.Tensor", "new_empty"),
326 | ("torch.Tensor", "new_empty_strided"),
327 | ("torch.Tensor", "new_zeros"),
328 | ("torch.Tensor", "new_ones"),
329 | ("torch.Tensor", "new_full"),
330 | ("torch.Tensor", "_make_subclass"),
331 | ("torch.Tensor", "solve"),
332 | ("torch.Tensor", "unflatten"),
333 | ("torch.Tensor", "real"),
334 | ("torch.Tensor", "imag"),
335 | ("torch.Tensor", "T"),
336 | ("torch.Tensor", "mT"),
337 | ("torch.Tensor", "H")
338 | ]
339 |
340 |
341 | @functools.lru_cache(None)
342 | def my_get_overridable_functions() -> List:
343 | index = {}
344 | func_names = []
345 | tested_namespaces = [
346 | ("torch", torch, torch.__all__ + dir(torch._C._VariableFunctions)),
347 | ("torch._VF", torch._VF, dir(torch._C._VariableFunctions)),
348 | ("torch.functional", torch.functional, torch.functional.__all__),
349 | ("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
350 | ("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
351 | ("torch.Tensor", torch.Tensor, dir(torch.Tensor)),
352 | ("torch.linalg", torch.linalg, dir(torch.linalg)),
353 | ("torch.fft", torch.fft, dir(torch.fft)),
354 | ]
355 | if hasattr(torch, "special"):
356 | tested_namespaces.append(("torch.special", torch.special, dir(torch.special)))
357 | for namespace_str, namespace, ns_funcs in tested_namespaces:
358 | for func_name in ns_funcs:
359 | ignore = False
360 | # ignore private functions or functions that are deleted in torch.__init__
361 | if namespace is not torch.Tensor:
362 | if func_name.startswith("__"):
363 | continue
364 | elif func_name[0].isupper():
365 | ignore = True
366 | elif func_name == "unique_dim":
367 | continue
368 | else:
369 | func = getattr(namespace, func_name)
370 | if getattr(object, func_name, None) == func:
371 | continue
372 | if func_name == "__weakref__":
373 | continue
374 | func = getattr(namespace, func_name)
375 | if namespace is torch.Tensor and getattr(object, func_name, None) == func:
376 | continue
377 | # ignore re-exported modules
378 | if isinstance(func, types.ModuleType):
379 | continue
380 | # ignore __future__ imports
381 | if isinstance(func, getattr(__future__, "_Feature")):
382 | continue
383 |
384 | if not callable(func) and hasattr(func, "__get__"):
385 | index[func.__get__] = f"{namespace_str}.{func_name}.__get__"
386 | index[func.__set__] = f"{namespace_str}.{func_name}.__set__"
387 | if ignore:
388 | continue
389 | if func.__get__ in get_ignored_functions():
390 | msg = (
391 | "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
392 | "but still has an explicit override"
393 | )
394 | assert func.__get__ not in get_testing_overrides(), msg.format(
395 | namespace, func.__name__
396 | )
397 | continue
398 | else:
399 | func_names.append((f"{namespace_str}.{func_name}", "__get__"))
400 | continue
401 |
402 | if not callable(func):
403 | continue
404 |
405 | index[func] = f"{namespace_str}.{func_name}"
406 |
407 | if ignore:
408 | continue
409 |
410 | # cannot be overriden by __torch_function__
411 | if func in get_ignored_functions():
412 | msg = (
413 | "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
414 | "but still has an explicit override"
415 | )
416 | assert func not in get_testing_overrides(), msg.format(
417 | namespace, func.__name__
418 | )
419 | continue
420 | func_names.append((f"{namespace_str}", func_name))
421 | return func_names
422 |
423 |
424 | TORCHVISION_FUNCS = [
425 | ("torch.ops.torchvision.nms", "_op"),
426 | ("torch.ops.torchvision.deform_conv2d", "_op"),
427 | ("torch.ops.torchvision.ps_roi_align", "_op"),
428 | ("torch.ops.torchvision.ps_roi_pool", "_op"),
429 | ("torch.ops.torchvision.roi_align", "_op"),
430 | ("torch.ops.torchvision.roi_pool", "_op")]
431 |
432 | with warnings.catch_warnings():
433 | warnings.simplefilter("ignore")
434 | OVERRIDABLE_FUNCS = my_get_overridable_functions()
435 | ORIG_TORCH_FUNCS = OVERRIDABLE_FUNCS + IGNORED_FUNCS
436 |
437 | try:
438 | import torchvision
439 |
440 | ORIG_TORCH_FUNCS += TORCHVISION_FUNCS
441 | except ModuleNotFoundError:
442 | pass
443 |
--------------------------------------------------------------------------------
/torchlens/interface.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import TYPE_CHECKING, Union
3 |
4 | import numpy as np
5 | import pandas as pd
6 |
7 | if TYPE_CHECKING:
8 | from .model_history import ModelHistory
9 |
10 | from .tensor_log import TensorLogEntry
11 |
12 |
13 | def _getitem_during_pass(self: "ModelHistory", ix) -> TensorLogEntry:
14 | """Fetches an item when the pass is unfinished, only based on its raw barcode.
15 |
16 | Args:
17 | ix: layer's barcode
18 |
19 | Returns:
20 | Tensor log entry object with info about specified layer.
21 | """
22 | if ix in self._raw_tensor_dict:
23 | return self._raw_tensor_dict[ix]
24 | else:
25 | raise ValueError(f"{ix} not found in the ModelHistory object.")
26 |
27 |
28 | def _getitem_after_pass(self, ix):
29 | """
30 | Overloaded such that entries can be fetched either by their position in the tensor log, their layer label,
31 | or their module address. It should say so and tell them which labels are valid.
32 | """
33 | if ix in self.layer_dict_all_keys:
34 | return self.layer_dict_all_keys[ix]
35 |
36 | keys_with_substr = [
37 | key for key in self.layer_dict_all_keys if str(ix) in str(key)
38 | ]
39 | if len(keys_with_substr) == 1:
40 | return self.layer_dict_all_keys[keys_with_substr[0]]
41 |
42 | _give_user_feedback_about_lookup_key(self, ix, "get_one_item")
43 |
44 |
45 | def _give_user_feedback_about_lookup_key(self, key: Union[int, str], mode: str):
46 | """For __getitem__ and get_op_nums_from_user_labels, gives the user feedback about the user key
47 | they entered if it doesn't yield any matches.
48 |
49 | Args:
50 | key: Lookup key used by the user.
51 | """
52 | if (type(key) == int) and (
53 | key >= len(self.layer_list) or key < -len(self.layer_list)
54 | ):
55 | raise ValueError(
56 | f"You specified the layer with index {key}, but there are only {len(self.layer_list)} "
57 | f"layers; please specify an index in the range "
58 | f"-{len(self.layer_list)} - {len(self.layer_list) - 1}."
59 | )
60 |
61 | if key in self.module_addresses:
62 | module_num_passes = self.module_num_passes[key]
63 | raise ValueError(
64 | f"You specified output of module {key}, but it has {module_num_passes} passes; "
65 | f"please specify e.g. {key}:2 for the second pass of {key}."
66 | )
67 |
68 | if key.split(":")[0] in self.module_addresses:
69 | module, pass_num = key.split(":")
70 | module_num_passes = self.module_num_passes[module]
71 | raise ValueError(
72 | f"You specified module {module} pass {pass_num}, but {module} only has "
73 | f"{module_num_passes} passes; specify a lower number."
74 | )
75 |
76 | if key in self.layer_labels_no_pass:
77 | layer_num_passes = self.layer_num_passes[key]
78 | raise ValueError(
79 | f"You specified output of layer {key}, but it has {layer_num_passes} passes; "
80 | f"please specify e.g. {key}:2 for the second pass of {key}."
81 | )
82 |
83 | if key.split(":")[0] in self.layer_labels_no_pass:
84 | layer_label, pass_num = key.split(":")
85 | layer_num_passes = self.layer_num_passes[layer_label]
86 | raise ValueError(
87 | f"You specified layer {layer_label} pass {pass_num}, but {layer_label} only has "
88 | f"{layer_num_passes} passes. Specify a lower number."
89 | )
90 |
91 | raise ValueError(_get_lookup_help_str(self, key, mode))
92 |
93 |
94 | def _str_after_pass(self) -> str:
95 | """Readable summary of the model history after the pass is finished.
96 |
97 | Returns:
98 | String summarizing the model.
99 | """
100 | s = f"Log of {self.model_name} forward pass:"
101 |
102 | # General info
103 |
104 | s += f"\n\tRandom seed: {self.random_seed_used}"
105 | s += (
106 | f"\n\tTime elapsed: {np.round(self.elapsed_time_total, 3)}s "
107 | f"({np.round(self.elapsed_time_torchlens_logging, 3)}s spent logging)"
108 | )
109 |
110 | # Overall model structure
111 |
112 | s += "\n\tStructure:"
113 | if self.model_is_recurrent:
114 | s += f"\n\t\t- recurrent (at most {self.model_max_recurrent_loops} loops)"
115 | else:
116 | s += "\n\t\t- purely feedforward, no recurrence"
117 |
118 | if self.model_is_branching:
119 | s += "\n\t\t- with branching"
120 | else:
121 | s += "\n\t\t- no branching"
122 |
123 | if self.model_has_conditional_branching:
124 | s += "\n\t\t- with conditional (if-then) branching"
125 | else:
126 | s += "\n\t\t- no conditional (if-then) branching"
127 |
128 | if len(self.buffer_layers) > 0:
129 | s += f"\n\t\t- contains {len(self.buffer_layers)} buffer layers"
130 |
131 | s += f"\n\t\t- {len(self.module_addresses)} total modules"
132 |
133 | # Model tensors:
134 |
135 | s += "\n\tTensor info:"
136 | s += (
137 | f"\n\t\t- {self.num_tensors_total} total tensors ({self.tensor_fsize_total_nice}) "
138 | f"computed in forward pass."
139 | )
140 | s += f"\n\t\t- {self.num_tensors_saved} tensors ({self.tensor_fsize_saved_nice}) with saved activations."
141 |
142 | # Model parameters:
143 |
144 | s += (
145 | f"\n\tParameters: {self.total_param_layers} parameter operations ({self.total_params} params total; "
146 | f"{self.total_params_fsize_nice})"
147 | )
148 |
149 | # Print the module hierarchy.
150 | s += "\n\tModule Hierarchy:"
151 | s += _module_hierarchy_str(self)
152 |
153 | # Now print all layers.
154 | s += "\n\tLayers"
155 | if self._all_layers_saved:
156 | s += " (all have saved activations):"
157 | elif self.num_tensors_saved == 0:
158 | s += " (no layer activations are saved):"
159 | else:
160 | s += " (* means layer has saved activations):"
161 | for layer_ind, layer_barcode in enumerate(self.layer_labels):
162 | pass_num = self.layer_dict_main_keys[layer_barcode].pass_num
163 | total_passes = self.layer_dict_main_keys[layer_barcode].layer_passes_total
164 | if total_passes > 1:
165 | pass_str = f" ({pass_num}/{total_passes} passes)"
166 | else:
167 | pass_str = ""
168 |
169 | if self.layer_dict_main_keys[layer_barcode].has_saved_activations and (
170 | not self._all_layers_saved
171 | ):
172 | s += "\n\t\t* "
173 | else:
174 | s += "\n\t\t "
175 | s += f"({layer_ind}) {layer_barcode} {pass_str}"
176 |
177 | return s
178 |
179 |
180 | def _str_during_pass(self) -> str:
181 | """Readable summary of the model history during the pass, as a debugging aid.
182 |
183 | Returns:
184 | String summarizing the model.
185 | """
186 | s = f"Log of {self.model_name} forward pass (pass still ongoing):"
187 | s += f"\n\tRandom seed: {self.random_seed_used}"
188 | s += f"\n\tInput tensors: {self.input_layers}"
189 | s += f"\n\tOutput tensors: {self.output_layers}"
190 | s += f"\n\tInternally initialized tensors: {self.internally_initialized_layers}"
191 | s += f"\n\tInternally terminated tensors: {self.internally_terminated_layers}"
192 | s += f"\n\tInternally terminated boolean tensors: {self.internally_terminated_bool_layers}"
193 | s += f"\n\tBuffer tensors: {self.buffer_layers}"
194 | s += "\n\tRaw layer labels:"
195 | for layer in self._raw_tensor_labels_list:
196 | s += f"\n\t\t{layer}"
197 | return s
198 |
199 |
200 | def pretty_print_list_w_line_breaks(lst, indent_chars: str, line_break_every=5):
201 | """
202 | Utility function to pretty print a list with line breaks, adding indent_chars every line.
203 | """
204 | s = f"\n{indent_chars}"
205 | for i, item in enumerate(lst):
206 | s += f"{item}"
207 | if i < len(lst) - 1:
208 | s += ", "
209 | if ((i + 1) % line_break_every == 0) and (i < len(lst) - 1):
210 | s += f"\n{indent_chars}"
211 | return s
212 |
213 |
214 | def _get_lookup_help_str(self, layer_label: Union[int, str], mode: str) -> str:
215 | """Generates a help string to be used in error messages when indexing fails."""
216 | sample_layer1 = random.choice(self.layer_labels_w_pass)
217 | sample_layer2 = random.choice(self.layer_labels_no_pass)
218 | if len(self.module_addresses) > 0:
219 | sample_module1 = random.choice(self.module_addresses)
220 | sample_module2 = random.choice(self.module_passes)
221 | else:
222 | sample_module1 = "features.3"
223 | sample_module2 = "features.4:2"
224 | module_str = f"(e.g., {sample_module1}, {sample_module2})"
225 | if mode == "get_one_item":
226 | msg = (
227 | "e.g., 'pool' will grab the maxpool2d or avgpool2d layer, 'maxpool' will grab the 'maxpool2d' "
228 | "layer, etc., but there must be only one such matching layer"
229 | )
230 | elif mode == "query_multiple":
231 | msg = (
232 | "e.g., 'pool' will grab all maxpool2d or avgpool2d layers, 'maxpool' will grab all 'maxpool2d' "
233 | "layers, etc."
234 | )
235 | else:
236 | raise ValueError("mode must be either get_one_item or query_multiple")
237 | help_str = (
238 | f"Layer {layer_label} not recognized; please specify either "
239 | f"\n\n\t1) an integer giving the ordinal position of the layer "
240 | f"(e.g. 2 for 3rd layer, -4 for fourth-to-last), "
241 | f"\n\t2) the layer label (e.g., {sample_layer1}, {sample_layer2}), "
242 | f"\n\t3) the module address {module_str}"
243 | f"\n\t4) A substring of any desired layer label ({msg})."
244 | f"\n\n(Label meaning: conv2d_3_4:2 means the second pass of the third convolutional layer, "
245 | f"and fourth layer overall in the model.)"
246 | )
247 | return help_str
248 |
249 |
250 | def _module_hierarchy_str(self):
251 | """
252 | Utility function to print the nested module hierarchy.
253 | """
254 | s = ""
255 | for module_pass in self.top_level_module_passes:
256 | module, pass_num = module_pass.split(":")
257 | s += f"\n\t\t{module}"
258 | if self.module_num_passes[module] > 1:
259 | s += f":{pass_num}"
260 | s += _module_hierarchy_str_helper(self, module_pass, 1)
261 | return s
262 |
263 |
264 | def _module_hierarchy_str_helper(self, module_pass, level):
265 | """
266 | Helper function for _module_hierarchy_str.
267 | """
268 | s = ""
269 | any_grandchild_modules = any(
270 | [
271 | len(self.module_pass_children[submodule_pass]) > 0
272 | for submodule_pass in self.module_pass_children[module_pass]
273 | ]
274 | )
275 | if any_grandchild_modules or len(self.module_pass_children[module_pass]) == 0:
276 | for submodule_pass in self.module_pass_children[module_pass]:
277 | submodule, pass_num = submodule_pass.split(":")
278 | s += f"\n\t\t{' ' * level}{submodule}"
279 | if self.module_num_passes[submodule] > 1:
280 | s += f":{pass_num}"
281 | s += _module_hierarchy_str_helper(self, submodule_pass, level + 1)
282 | else:
283 | submodule_list = []
284 | for submodule_pass in self.module_pass_children[module_pass]:
285 | submodule, pass_num = submodule_pass.split(":")
286 | if self.module_num_passes[submodule] == 1:
287 | submodule_list.append(submodule)
288 | else:
289 | submodule_list.append(submodule_pass)
290 | s += pretty_print_list_w_line_breaks(
291 | submodule_list, line_break_every=8, indent_chars=f"\t\t{' ' * level}"
292 | )
293 | return s
294 |
295 |
296 | def print_all_fields(self):
297 | """Print all data fields for ModelHistory."""
298 | fields_to_exclude = [
299 | "layer_list",
300 | "layer_dict_main_keys",
301 | "layer_dict_all_keys",
302 | "raw_tensor_dict",
303 | "decorated_to_orig_funcs_dict",
304 | ]
305 |
306 | for field in dir(self):
307 | attr = getattr(self, field)
308 | if not any(
309 | [field.startswith("_"), field in fields_to_exclude, callable(attr)]
310 | ):
311 | print(f"{field}: {attr}")
312 |
313 |
314 | def to_pandas(self) -> pd.DataFrame:
315 | """Returns a pandas dataframe with info about each layer.
316 |
317 | Returns:
318 | Pandas dataframe with info about each layer.
319 | """
320 | fields_for_df = [
321 | "layer_label",
322 | "layer_label_w_pass",
323 | "layer_label_no_pass",
324 | "layer_label_short",
325 | "layer_label_w_pass_short",
326 | "layer_label_no_pass_short",
327 | "layer_type",
328 | "layer_type_num",
329 | "layer_total_num",
330 | "layer_passes_total",
331 | "pass_num",
332 | "operation_num",
333 | "tensor_shape",
334 | "tensor_dtype",
335 | "tensor_fsize",
336 | "tensor_fsize_nice",
337 | "func_applied_name",
338 | "func_time_elapsed",
339 | "function_is_inplace",
340 | "gradfunc",
341 | "is_input_layer",
342 | "is_output_layer",
343 | "is_buffer_layer",
344 | "is_part_of_iterable_output",
345 | "iterable_output_index",
346 | "parent_layers",
347 | "has_parents",
348 | "orig_ancestors",
349 | "child_layers",
350 | "has_children",
351 | "output_descendents",
352 | "sibling_layers",
353 | "has_siblings",
354 | "spouse_layers",
355 | "has_spouses",
356 | "initialized_inside_model",
357 | "min_distance_from_input",
358 | "max_distance_from_input",
359 | "min_distance_from_output",
360 | "max_distance_from_output",
361 | "computed_with_params",
362 | "num_params_total",
363 | "parent_param_shapes",
364 | "parent_params_fsize",
365 | "parent_params_fsize_nice",
366 | "modules_entered",
367 | "modules_exited",
368 | "is_submodule_input",
369 | "is_submodule_output",
370 | "containing_module_origin",
371 | "containing_modules_origin_nested",
372 | ]
373 |
374 | fields_to_change_type = {
375 | "layer_type_num": int,
376 | "layer_total_num": int,
377 | "layer_passes_total": int,
378 | "pass_num": int,
379 | "operation_num": int,
380 | "function_is_inplace": bool,
381 | "is_input_layer": bool,
382 | "is_output_layer": bool,
383 | "is_buffer_layer": bool,
384 | "is_part_of_iterable_output": bool,
385 | "has_parents": bool,
386 | "has_children": bool,
387 | "has_siblings": bool,
388 | "has_spouses": bool,
389 | "computed_with_params": bool,
390 | "num_params_total": int,
391 | "parent_params_fsize": int,
392 | "tensor_fsize": int,
393 | "is_submodule_input": bool,
394 | "is_submodule_output": bool,
395 | }
396 |
397 | model_df_dictlist = []
398 | for tensor_entry in self.layer_list:
399 | tensor_dict = {}
400 | for field_name in fields_for_df:
401 | tensor_dict[field_name] = getattr(tensor_entry, field_name)
402 | model_df_dictlist.append(tensor_dict)
403 | model_df = pd.DataFrame(model_df_dictlist)
404 |
405 | for field in fields_to_change_type:
406 | model_df[field] = model_df[field].astype(fields_to_change_type[field])
407 |
408 | return model_df
409 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
TorchLens
2 |
3 | **Quick Links**
4 |
5 | - [Paper introducing TorchLens](https://www.nature.com/articles/s41598-023-40807-0)
6 | - [CoLab tutorial](https://colab.research.google.com/drive/1ORJLGZPifvdsVPFqq1LYT3t5hV560SoW?usp=sharing)
7 | - [\"Menagerie\" of model visualizations](https://drive.google.com/drive/u/0/folders/1BsM6WPf3eB79-CRNgZejMxjg38rN6VCb)
8 | - [Metadata provided by TorchLens](https://static-content.springer.com/esm/art%3A10.1038%2Fs41598-023-40807-0/MediaObjects/41598_2023_40807_MOESM1_ESM.pdf)
9 |
10 | ## Overview
11 |
12 | *TorchLens* is a package for doing exactly two things:
13 |
14 | 1) Easily extracting the activations from every single intermediate operation in a PyTorch model—no
15 | modifications needed—in one line of code. "Every operation" means every operation; "one line" means one line.
16 | 2) Understanding the model's computational structure via an intuitive automatic visualization and extensive
17 | metadata ([partial list here](https://static-content.springer.com/esm/art%3A10.1038%2Fs41598-023-40807-0/MediaObjects/41598_2023_40807_MOESM1_ESM.pdf))
18 | about the network's computational graph.
19 |
20 | Here it is in action for a very simple recurrent model; as you can see, you just define the model like normal and pass
21 | it in, and *TorchLens* returns a full log of the forward pass along with a visualization:
22 |
23 | ```python
24 | class SimpleRecurrent(nn.Module):
25 | def __init__(self):
26 | super().__init__()
27 | self.fc = nn.Linear(in_features=5, out_features=5)
28 |
29 | def forward(self, x):
30 | for r in range(4):
31 | x = self.fc(x)
32 | x = x + 1
33 | x = x * 2
34 | return x
35 |
36 |
37 | simple_recurrent = SimpleRecurrent()
38 | model_history = tl.log_forward_pass(simple_recurrent, x,
39 | layers_to_save='all',
40 | vis_opt='rolled')
41 | print(model_history['linear_1_1:2'].tensor_contents) # second pass of first linear layer
42 |
43 | '''
44 | tensor([[-0.0690, -1.3957, -0.3231, -0.1980, 0.7197],
45 | [-0.1083, -1.5051, -0.2570, -0.2024, 0.8248],
46 | [ 0.1031, -1.4315, -0.5999, -0.4017, 0.7580],
47 | [-0.0396, -1.3813, -0.3523, -0.2008, 0.6654],
48 | [ 0.0980, -1.4073, -0.5934, -0.3866, 0.7371],
49 | [-0.1106, -1.2909, -0.3393, -0.2439, 0.7345]])
50 | '''
51 | ```
52 |
53 |
54 |
55 | And here it is for a very complex transformer model ([swin_v2_b](https://arxiv.org/abs/2103.14030)) with 1932 operations
56 | in its forward pass; you can grab the saved outputs of every last one:
57 |
58 |
59 |
60 | The goal of *TorchLens* is to do this for any PyTorch model whatsoever. You can see a bunch of example model
61 | visualizations in this [model menagerie](https://drive.google.com/drive/u/0/folders/1BsM6WPf3eB79-CRNgZejMxjg38rN6VCb).
62 |
63 | ## Installation
64 |
65 | To install *TorchLens*, first install graphviz if you haven't already (required to generate the network visualizations),
66 | and then install *TorchLens* using pip:
67 |
68 | ```bash
69 | sudo apt install graphviz
70 | pip install torchlens
71 | ```
72 |
73 | *TorchLens* is compatible with versions 1.8.0+ of PyTorch.
74 |
75 | ## How-To Guide
76 |
77 | Below is a quick demo of how to use it; for an interactive demonstration, see
78 | the [CoLab walkthrough](https://colab.research.google.com/drive/1ORJLGZPifvdsVPFqq1LYT3t5hV560SoW?usp=sharing).
79 |
80 | The main function of *TorchLens* is `log_forward_pass`: when called on a model and input, it runs a
81 | forward pass on the model and returns a ModelHistory object containing the intermediate layer activations and
82 | accompanying metadata, along with a visual representation of every operation that occurred during the forward pass:
83 |
84 | ```python
85 | import torch
86 | import torchvision
87 | import torchlens as tl
88 |
89 | alexnet = torchvision.models.alexnet()
90 | x = torch.rand(1, 3, 224, 224)
91 | model_history = tl.log_forward_pass(alexnet, x, layers_to_save='all', vis_opt='unrolled')
92 | print(model_history)
93 |
94 | '''
95 | Log of AlexNet forward pass:
96 | Model structure: purely feedforward, without branching; 23 total modules.
97 | 24 tensors (4.8 MB) computed in forward pass; 24 tensors (4.8 MB) saved.
98 | 16 parameter operations (61100840 params total; 248.7 MB).
99 | Random seed: 3210097511
100 | Time elapsed: 0.288s
101 | Module Hierarchy:
102 | features:
103 | features.0, features.1, features.2, features.3, features.4, features.5, features.6, features.7,
104 | features.8, features.9, features.10, features.11, features.12
105 | avgpool
106 | classifier:
107 | classifier.0, classifier.1, classifier.2, classifier.3, classifier.4, classifier.5, classifier.6
108 | Layers:
109 | 0: input_1_0
110 | 1: conv2d_1_1
111 | 2: relu_1_2
112 | 3: maxpool2d_1_3
113 | 4: conv2d_2_4
114 | 5: relu_2_5
115 | 6: maxpool2d_2_6
116 | 7: conv2d_3_7
117 | 8: relu_3_8
118 | 9: conv2d_4_9
119 | 10: relu_4_10
120 | 11: conv2d_5_11
121 | 12: relu_5_12
122 | 13: maxpool2d_3_13
123 | 14: adaptiveavgpool2d_1_14
124 | 15: flatten_1_15
125 | 16: dropout_1_16
126 | 17: linear_1_17
127 | 18: relu_6_18
128 | 19: dropout_2_19
129 | 20: linear_2_20
130 | 21: relu_7_21
131 | 22: linear_3_22
132 | 23: output_1_23
133 | '''
134 | ```
135 |
136 |
137 |
138 | You can pull out information about a given layer, including its activations and helpful metadata, by indexing
139 | the ModelHistory object in any of these equivalent ways:
140 |
141 | 1) the name of a layer (with the convention that 'conv2d_3_7' is the 3rd convolutional layer, and the 7th layer overall)
142 | 2) the name of a module (e.g., 'features' or 'classifier.3') for which that layer is an output, or
143 | 3) the ordinal position of the layer (e.g., 2 for the 2nd layer, -5 for the fifth-to-last; inputs and outputs count as
144 | layers here).
145 |
146 | To quickly figure out these names, you can look at the graph visualization, or at the output of printing the
147 | ModelHistory object (both shown above). Here are some examples of how to pull out information about a
148 | particular layer, and also how to pull out the actual activations from that layer:
149 |
150 | ```python
151 | print(model_history['conv2d_3_7']) # pulling out layer by its name
152 | # The following commented lines pull out the same layer:
153 | # model_history['conv2d_3'] you can omit the second number (since strictly speaking it's redundant)
154 | # model_history['conv2d_3_7:1'] colon indicates the pass of a layer (here just one)
155 | # model_history['features.6'] can grab a layer by the module for which it is an output
156 | # model_history[7] the 7th layer overall
157 | # model_history[-17] the 17th-to-last layer
158 | '''
159 | Layer conv2d_3_7, operation 8/24:
160 | Output tensor: shape=(1, 384, 13, 13), dype=torch.float32, size=253.5 KB
161 | tensor([[ 0.0503, -0.1089, -0.1210, -0.1034, -0.1254],
162 | [ 0.0789, -0.0752, -0.0581, -0.0372, -0.0181],
163 | [ 0.0949, -0.0780, -0.0401, -0.0209, -0.0095],
164 | [ 0.0929, -0.0353, -0.0220, -0.0324, -0.0295],
165 | [ 0.1100, -0.0337, -0.0330, -0.0479, -0.0235]])...
166 | Params: Computed from params with shape (384,), (384, 192, 3, 3); 663936 params total (2.5 MB)
167 | Parent Layers: maxpool2d_2_6
168 | Child Layers: relu_3_8
169 | Function: conv2d (gradfunc=ConvolutionBackward0)
170 | Computed inside module: features.6
171 | Time elapsed: 5.670E-04s
172 | Output of modules: features.6
173 | Output of bottom-level module: features.6
174 | Lookup keys: -17, 7, conv2d_3_7, conv2d_3_7:1, features.6, features.6:1
175 | '''
176 |
177 | # You can pull out the actual output activations from a layer with the tensor_contents field:
178 | print(model_history['conv2d_3_7'].tensor_contents)
179 | '''
180 | tensor([[[[-0.0867, -0.0787, -0.0817, ..., -0.0820, -0.0655, -0.0195],
181 | [-0.1213, -0.1130, -0.1386, ..., -0.1331, -0.1118, -0.0520],
182 | [-0.0959, -0.0973, -0.1078, ..., -0.1103, -0.1091, -0.0760],
183 | ...,
184 | [-0.0906, -0.1146, -0.1308, ..., -0.1076, -0.1129, -0.0689],
185 | [-0.1017, -0.1256, -0.1100, ..., -0.1160, -0.1035, -0.0801],
186 | [-0.1006, -0.0941, -0.1204, ..., -0.1146, -0.1065, -0.0631]]...
187 | '''
188 | ```
189 |
190 | If you do not wish to save the activations for all layers (e.g., to save memory), you can specify which layers to save
191 | with the `layers_to_save` argument when calling `log_forward_pass`; you can either indicate layers in the same way
192 | as indexing them above, or by passing in a desired substring for filtering the layers (e.g., 'conv'
193 | will pull out all conv layers):
194 |
195 | ```python
196 | # Pull out conv2d_3_7, the output of the 'features' module, the fifth-to-last layer, and all linear (i.e., fc) layers:
197 | model_history = tl.log_forward_pass(alexnet, x, vis_opt='unrolled',
198 | layers_to_save=['conv2d_3_7', 'features', -5, 'linear'])
199 | print(model_history.layer_labels)
200 | '''
201 | ['conv2d_3_7', 'maxpool2d_3_13', 'linear_1_17', 'dropout_2_19', 'linear_2_20', 'linear_3_22']
202 | '''
203 | ```
204 |
205 | The main function of *TorchLens* is `log_forward_pass`; the remaining functions are:
206 |
207 | 1) `get_model_metadata`, to retrieve all model metadata without saving any activations (e.g., to figure out which
208 | layers you wish to save; note that this is the same as calling `log_forward_pass` with `layers_to_save=None`)
209 | 2) `show_model_graph`, which visualizes the model graph without saving any activations
210 | 3) `validate_model_activations`, which runs a procedure to check that the activations are correct: specifically,
211 | it runs a forward pass and saves all intermediate activations, re-runs the forward pass from each intermediate
212 | layer, and checks that the resulting output matches the ground-truth output. It also checks that swapping in
213 | random nonsense activations instead of the saved activations generates the wrong output. **If this function ever
214 | returns False (i.e., the saved activations are wrong), please contact me via email (johnmarkedwardtaylor@gmail.com)
215 | or on this GitHub page with a description of the problem, and I will update TorchLens to fix the problem.**
216 |
217 | And that's it. *TorchLens* remains in active development, and the goal is for it to work with any PyTorch model
218 | whatosever without exception. As of the time of this writing, it has been tested with over 700
219 | image, video, auditory, multimodal, and language models, including feedforward, recurrent, transformer,
220 | and graph neural networks.
221 |
222 | ## Miscellaneous Features
223 |
224 | - You can visualize models at different levels of nesting depth using the `vis_nesting_depth` argument
225 | to `log_forward_pass`; for example, here you can see one of GoogLeNet's "inception" modules at different levels of
226 | nesting depth:
227 |
228 |
229 |
230 | - An experimental feature is to extract not just the activations from all of a model's operations,
231 | but also the gradients from a backward pass (which you can compute based on any intermediate layer, not just the
232 | model's
233 | output),
234 | and also visualize the path taken by the backward pass (shown with blue arrows below). See the CoLab tutorial for
235 | instructions on how to do this.
236 |
237 |
238 |
239 | - You can see the literal code that was used to run the model with the func_call_stack field:
240 |
241 | ```python
242 | print(model_history['conv2d_3'].func_call_stack[8])
243 | '''
244 | {'call_fname': '/usr/local/lib/python3.10/dist-packages/torchvision/models/alexnet.py',
245 | 'call_linenum': 48,
246 | 'function': 'forward',
247 | 'code_context': [' nn.Linear(256 * 6 * 6, 4096),\n',
248 | ' nn.ReLU(inplace=True),\n',
249 | ' nn.Dropout(p=dropout),\n',
250 | ' nn.Linear(4096, 4096),\n',
251 | ' nn.ReLU(inplace=True),\n',
252 | ' nn.Linear(4096, num_classes),\n',
253 | ' )\n',
254 | '\n',
255 | ' def forward(self, x: torch.Tensor) -> torch.Tensor:\n',
256 | ' x = self.features(x)\n',
257 | ' x = self.avgpool(x)\n',
258 | ' x = torch.flatten(x, 1)\n',
259 | ' x = self.classifier(x)\n',
260 | ' return x\n',
261 | '\n',
262 | '\n',
263 | 'class AlexNet_Weights(WeightsEnum):\n',
264 | ' IMAGENET1K_V1 = Weights(\n',
265 | ' url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",\n']}
266 | '''
267 | ```
268 |
269 | ## Planned Features
270 |
271 | 1) In the further future, I am considering adding functionality to not just save activations,
272 | but counterfactually intervene on them (e.g., how would the output have changed if these parameters
273 | were different or if a different nonlinearity were used). Let me know if you'd find this useful
274 | and if so, what specific kind of functionality you'd want.
275 | 2) I am planning to add an option to only visualize a single submodule of a model rather than the full graph at once.
276 |
277 | ## Other Packages You Should Check Out
278 |
279 | The goal is for *TorchLens* to completely solve the problem of extracting activations and metadata
280 | from deep neural networks and visualizing their structure so that nobody has to think about this stuff ever again, but
281 | it intentionally leaves out certain functionality: for example, it has no functions for loading models or stimuli, or
282 | for analyzing the extracted activations. This is in part because it's impossible to predict all the things you might
283 | want to do with the activations, or all the possible models you might want to look at, but also because there are
284 | already outstanding packages for doing these things. Here are a few-let me know if I've missed any!
285 |
286 | - [Cerbrec](cerbrec.com): Program for interactively visualizing and debugging deep neural networks (uses TorchLens under
287 | the hood for extracting the graphs of PyTorch models!)
288 | - [ThingsVision](https://github.com/ViCCo-Group/thingsvision): has excellent functionality for loading vision models,
289 | loading stimuli, and analyzing the extracted activations
290 | - [Net2Brain](https://github.com/cvai-roig-lab/Net2Brain): similar excellent end-to-end functionality to ThingsVision,
291 | along with functionality for comparing extracted activations to neural data.
292 | - [surgeon-pytorch](https://github.com/archinetai/surgeon-pytorch): easy-to-use functionality for extracting activations
293 | from models, along with functionality for training a model using loss functions based on intermediate layer
294 | activations
295 | - [deepdive](https://github.com/ColinConwell/DeepDive): has outstanding functionality for loading and benchmarking
296 | many different models
297 | - [torchvision feature_extraction module](https://pytorch.org/vision/stable/feature_extraction.html): can extract
298 | activations from models with static computational graphs
299 | - [rsatoolbox3](https://github.com/rsagroup/rsatoolbox): total solution for performing representational similarity
300 | analysis on DNN activations and brain data
301 |
302 | ## Acknowledgments
303 |
304 | The development of *TorchLens* benefitted greatly from discussions with Nikolaus Kriegeskorte, George Alvarez,
305 | Alfredo Canziani, Tal Golan, and the Visual Inference Lab at Columbia University. Thank you to Kale Kundert
306 | for helpful discussion and for his code contributions enabling PyTorch Lightning compatibility.
307 | All network visualizations were created with graphviz. Logo created by Nikolaus Kriegeskorte.
308 |
309 | ## Citing Torchlens
310 |
311 | To cite *TorchLens*, you can
312 | cite [this paper describing the package](https://www.nature.com/articles/s41598-023-40807-0) (and consider adding a star
313 | to this repo if you find *TorchLens* useful):
314 |
315 | Taylor, J., Kriegeskorte, N. Extracting and visualizing hidden activations and computational graphs of PyTorch models
316 | with *TorchLens*. Sci Rep 13, 14375 (2023). https://doi.org/10.1038/s41598-023-40807-0
317 |
318 | ## Contact
319 |
320 | As *TorchLens* is still in active development, I would love your feedback. Please contact
321 | johnmarkedwardtaylor@gmail.com,
322 | contact me via [twitter](https://twitter.com/johnmark_taylor), or post on
323 | the [issues](https://github.com/johnmarktaylor91/torchlens/issues)
324 | or [discussion](https://github.com/johnmarktaylor91/torchlens/discussions) page for this GitHub
325 | repository, if you have any questions, comments, or suggestions (or if you'd be interested in collaborating!).
326 |
--------------------------------------------------------------------------------
/torchlens/model_funcs.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from functools import wraps
3 | from typing import Callable, Dict, List, TYPE_CHECKING
4 |
5 | import torch
6 | from torch import nn
7 |
8 | from .helper_funcs import get_vars_of_type_from_obj, iter_accessible_attributes, remove_attributes_starting_with_str
9 | from .logging_funcs import log_source_tensor
10 |
11 | if TYPE_CHECKING:
12 | from .model_history import ModelHistory
13 |
14 |
15 | def prepare_model(
16 | model_log: "ModelHistory",
17 | model: nn.Module,
18 | module_orig_forward_funcs: Dict,
19 | decorated_func_mapper: Dict[Callable, Callable],
20 | ):
21 | """Adds annotations and hooks to the model, and decorates any functions in the model.
22 |
23 | Args:
24 | model: Model to prepare.
25 | module_orig_forward_funcs: Dict with the original forward funcs for each submodule
26 | decorated_func_mapper: Dictionary mapping decorated functions to original functions, so they can be restored
27 |
28 | Returns:
29 | Model with hooks and attributes added.
30 | """
31 | model_log.model_name = str(type(model).__name__)
32 | model.tl_module_address = ""
33 | model.tl_source_model_history = model_log
34 |
35 | module_stack = [(model, "")] # list of tuples (name, module)
36 |
37 | while len(module_stack) > 0:
38 | module, parent_address = module_stack.pop()
39 | module_children = list(module.named_children())
40 |
41 | # Decorate any torch functions in the model:
42 | for func_name, func in module.__dict__.items():
43 | if (
44 | (func_name[0:2] == "__")
45 | or (not callable(func))
46 | or (func not in decorated_func_mapper)
47 | ):
48 | continue
49 | module.__dict__[func_name] = decorated_func_mapper[func]
50 |
51 | # Annotate the children with the full address.
52 | for c, (child_name, child_module) in enumerate(module_children):
53 | child_address = (
54 | f"{parent_address}.{child_name}"
55 | if parent_address != ""
56 | else child_name
57 | )
58 | child_module.tl_module_address = child_address
59 | module_children[c] = (child_module, child_address)
60 | module_stack = module_children + module_stack
61 |
62 | if module == model: # don't tag the model itself.
63 | continue
64 |
65 | module.tl_source_model_history = model_log
66 | module.tl_module_type = str(type(module).__name__)
67 | model_log.module_types[module.tl_module_address] = module.tl_module_type
68 | module.tl_module_pass_num = 0
69 | module.tl_module_pass_labels = []
70 | module.tl_tensors_entered_labels = []
71 | module.tl_tensors_exited_labels = []
72 |
73 | # Add decorators.
74 |
75 | if hasattr(module, "forward") and not hasattr(
76 | module.forward, "tl_forward_call_is_decorated"
77 | ):
78 | module_orig_forward_funcs[module] = module.forward
79 | module.forward = module_forward_decorator(model_log, module.forward, module)
80 | module.forward.tl_forward_call_is_decorated = True
81 |
82 | # Mark all parameters with requires_grad = True, and mark what they were before, so they can be restored on cleanup.
83 | for param in model.parameters():
84 | param.tl_requires_grad = param.requires_grad
85 | param.requires_grad = True
86 |
87 | # And prepare any buffer tensors.
88 | prepare_buffer_tensors(model_log, model)
89 |
90 |
91 | def prepare_buffer_tensors(model_log, model: nn.Module):
92 | """Goes through a model and all its submodules, and prepares any "buffer" tensors: tensors
93 | attached to the module that aren't model parameters.
94 |
95 | Args:
96 | model: PyTorch model
97 |
98 | Returns:
99 | PyTorch model with all buffer tensors prepared and ready to track.
100 | """
101 | submodules = get_all_submodules(model)
102 | for submodule in submodules:
103 | attr_list = list(submodule.named_buffers()) + list(iter_accessible_attributes(submodule))
104 | for attribute_name, attribute in attr_list:
105 | if issubclass(type(attribute), torch.Tensor) and not issubclass(
106 | type(attribute), torch.nn.Parameter
107 | ) and not hasattr(attribute, 'tl_buffer_address'):
108 | if submodule.tl_module_address == "":
109 | buffer_address = attribute_name
110 | else:
111 | buffer_address = (
112 | submodule.tl_module_address + "." + attribute_name
113 | )
114 | setattr(attribute, 'tl_buffer_address', buffer_address)
115 |
116 |
117 | def module_forward_decorator(
118 | model_log, orig_forward: Callable, module: nn.Module
119 | ) -> Callable:
120 | @wraps(orig_forward)
121 | def decorated_forward(*args, **kwargs):
122 | if model_log.logging_mode == "fast": # do bare minimum for logging.
123 | out = orig_forward(*args, **kwargs)
124 | output_tensors = get_vars_of_type_from_obj(
125 | out, torch.Tensor, search_depth=4
126 | )
127 | for t in output_tensors:
128 | # if identity module, run the function for bookkeeping
129 | if module.tl_module_type.lower() == "identity":
130 | t = getattr(torch, "identity")(t)
131 | return out
132 |
133 | # "Pre-hook" operations:
134 | module_address = module.tl_module_address
135 | module.tl_module_pass_num += 1
136 | module_pass_label = (module_address, module.tl_module_pass_num)
137 | module.tl_module_pass_labels.append(module_pass_label)
138 | input_tensors = get_vars_of_type_from_obj(
139 | [args, kwargs], torch.Tensor, [torch.nn.Parameter], search_depth=5
140 | )
141 | input_tensor_labels = set()
142 | for t in input_tensors:
143 | if (not hasattr(t, 'tl_tensor_label_raw')) and hasattr(t, 'tl_buffer_address'):
144 | log_source_tensor(model_log, t, 'buffer', getattr(t, 'tl_buffer_address'))
145 | tensor_entry = model_log._raw_tensor_dict[t.tl_tensor_label_raw]
146 | input_tensor_labels.add(t.tl_tensor_label_raw)
147 | module.tl_tensors_entered_labels.append(t.tl_tensor_label_raw)
148 | tensor_entry.modules_entered.append(module_address)
149 | tensor_entry.module_passes_entered.append(module_pass_label)
150 | tensor_entry.is_submodule_input = True
151 | for arg_key, arg_val in list(enumerate(args)) + list(kwargs.items()):
152 | if arg_val is t:
153 | tensor_entry.modules_entered_argnames[
154 | f"{module_pass_label[0]}:{module_pass_label[1]}"].append(arg_key)
155 | model_log.module_layer_argnames[(f"{module_pass_label[0]}:"
156 | f"{module_pass_label[1]}")].append(
157 | (t.tl_tensor_label_raw, arg_key))
158 | tensor_entry.module_entry_exit_thread_output.append(
159 | ("+", module_pass_label[0], module_pass_label[1])
160 | )
161 |
162 | # Check the buffers.
163 | for buffer_name, buffer_tensor in module.named_buffers():
164 | if hasattr(buffer_tensor, 'tl_buffer_address'):
165 | continue
166 | if module.tl_module_address == '':
167 | buffer_address = buffer_name
168 | else:
169 | buffer_address = f"{module.tl_module_address}.{buffer_name}"
170 | buffer_tensor.tl_buffer_address = buffer_address
171 | buffer_tensor.tl_buffer_parent = getattr(buffer_tensor, 'tl_tensor_label_raw')
172 | delattr(buffer_tensor, 'tl_tensor_label_raw')
173 |
174 | # The function call
175 | out = orig_forward(*args, **kwargs)
176 |
177 | # "Post-hook" operations:
178 | module_address = module.tl_module_address
179 | module_pass_num = module.tl_module_pass_num
180 | module_entry_label = module.tl_module_pass_labels.pop()
181 | output_tensors = get_vars_of_type_from_obj(
182 | out, torch.Tensor, search_depth=4
183 | )
184 | for t in output_tensors:
185 | # if identity module or tensor unchanged, run the identity function for bookkeeping
186 | if (module.tl_module_type.lower() == "identity") or (
187 | t.tl_tensor_label_raw in input_tensor_labels
188 | ):
189 | t = getattr(torch, "identity")(t)
190 | tensor_entry = model_log._raw_tensor_dict[t.tl_tensor_label_raw]
191 | tensor_entry.is_submodule_output = True
192 | tensor_entry.is_bottom_level_submodule_output = (
193 | log_whether_exited_submodule_is_bottom_level(model_log, t, module)
194 | )
195 | tensor_entry.modules_exited.append(module_address)
196 | tensor_entry.module_passes_exited.append(
197 | (module_address, module_pass_num)
198 | )
199 | tensor_entry.module_entry_exit_thread_output.append(
200 | ("-", module_entry_label[0], module_entry_label[1])
201 | )
202 | module.tl_tensors_exited_labels.append(t.tl_tensor_label_raw)
203 |
204 | for (
205 | t
206 | ) in (
207 | input_tensors
208 | ): # Now that module is finished, roll back the threads of all input tensors.
209 | tensor_entry = model_log._raw_tensor_dict[t.tl_tensor_label_raw]
210 | input_module_thread = tensor_entry.module_entry_exit_thread_output[:]
211 | if (
212 | "+",
213 | module_entry_label[0],
214 | module_entry_label[1],
215 | ) in input_module_thread:
216 | module_entry_ix = input_module_thread.index(
217 | ("+", module_entry_label[0], module_entry_label[1])
218 | )
219 | tensor_entry.module_entry_exit_thread_output = (
220 | tensor_entry.module_entry_exit_thread_output[:module_entry_ix]
221 | )
222 |
223 | return out
224 |
225 | return decorated_forward
226 |
227 |
228 | def log_whether_exited_submodule_is_bottom_level(
229 | model_log, t: torch.Tensor, submodule: nn.Module
230 | ):
231 | """Checks whether the submodule that a tensor is leaving is a "bottom-level" submodule;
232 | that is, that only one tensor operation happened inside the submodule.
233 |
234 | Args:
235 | t: the tensor leaving the module
236 | submodule: the module that the tensor is leaving
237 |
238 | Returns:
239 | Whether the tensor operation is bottom level.
240 | """
241 | tensor_entry = model_log._raw_tensor_dict[getattr(t, "tl_tensor_label_raw")]
242 | submodule_address = submodule.tl_module_address
243 |
244 | if tensor_entry.is_bottom_level_submodule_output:
245 | return True
246 |
247 | # If it was initialized inside the model and nothing entered the module, it's bottom-level.
248 | if (
249 | tensor_entry.initialized_inside_model
250 | and len(submodule.tl_tensors_entered_labels) == 0
251 | ):
252 | tensor_entry.is_bottom_level_submodule_output = True
253 | tensor_entry.bottom_level_submodule_pass_exited = (
254 | submodule_address,
255 | submodule.tl_module_pass_num,
256 | )
257 | return True
258 |
259 | # Else, all parents must have entered the submodule for it to be a bottom-level submodule.
260 | for parent_label in tensor_entry.parent_layers:
261 | parent_tensor = model_log[parent_label]
262 | parent_modules_entered = parent_tensor.modules_entered
263 | if (len(parent_modules_entered) == 0) or (
264 | parent_modules_entered[-1] != submodule_address
265 | ):
266 | tensor_entry.is_bottom_level_submodule_output = False
267 | return False
268 |
269 | # If it survived the above tests, it's a bottom-level submodule.
270 | tensor_entry.is_bottom_level_submodule_output = True
271 | tensor_entry.bottom_level_submodule_pass_exited = (
272 | submodule_address,
273 | submodule.tl_module_pass_num,
274 | )
275 | return True
276 |
277 |
278 | def get_all_submodules(
279 | model: nn.Module, is_top_level_model: bool = True
280 | ) -> List[nn.Module]:
281 | """Recursively gets list of all submodules for given module, no matter their level in the
282 | hierarchy; this includes the model itself.
283 |
284 | Args:
285 | model: PyTorch model.
286 | is_top_level_model: Whether it's the top-level model; just for the recursive logic of it.
287 |
288 | Returns:
289 | List of all submodules.
290 | """
291 | submodules = []
292 | if is_top_level_model:
293 | submodules.append(model)
294 | for module in model.children():
295 | if module not in submodules:
296 | submodules.append(module)
297 | submodules += get_all_submodules(module, is_top_level_model=False)
298 | return submodules
299 |
300 |
301 | def cleanup_model(
302 | model_log,
303 | model: nn.Module,
304 | module_orig_forward_funcs: Dict[nn.Module, Callable],
305 | decorated_func_mapper: Dict[Callable, Callable],
306 | ):
307 | """Reverses all temporary changes to the model (namely, the forward hooks and added
308 | model attributes) that were added for PyTorch x-ray (scout's honor; leave no trace).
309 |
310 | Args:
311 | model: PyTorch model.
312 | module_orig_forward_funcs: Dict containing the original, undecorated forward pass functions for
313 | each submodule
314 | decorated_func_mapper: Dict mapping between original and decorated PyTorch funcs
315 |
316 | Returns:
317 | Original version of the model.
318 | """
319 | submodules = get_all_submodules(model, is_top_level_model=True)
320 | for submodule in submodules:
321 | if submodule == model:
322 | continue
323 | submodule.forward = module_orig_forward_funcs[submodule]
324 | restore_model_attributes(
325 | model, decorated_func_mapper=decorated_func_mapper, attribute_keyword="tl"
326 | )
327 | undecorate_model_tensors(model)
328 |
329 |
330 | def clear_hooks(hook_handles: List):
331 | """Takes in a list of tuples (module, hook_handle), and clears the hook at that
332 | handle for each module.
333 |
334 | Args:
335 | hook_handles: List of tuples (module, hook_handle)
336 |
337 | Returns:
338 | Nothing.
339 | """
340 | for hook_handle in hook_handles:
341 | hook_handle.remove()
342 |
343 |
344 | def restore_module_attributes(
345 | module: nn.Module,
346 | decorated_func_mapper: Dict[Callable, Callable],
347 | attribute_keyword: str = "tl",
348 | ):
349 | def del_attrs_with_prefix(module, attribute_name):
350 | if attribute_name.startswith(attribute_keyword):
351 | delattr(module, attribute_name)
352 | return True
353 |
354 | for attribute_name, attr in iter_accessible_attributes(module, short_circuit=del_attrs_with_prefix):
355 | if (
356 | isinstance(attr, Callable)
357 | and (attr in decorated_func_mapper)
358 | and (attribute_name[0:2] != "__")
359 | ):
360 | with warnings.catch_warnings():
361 | warnings.simplefilter("ignore")
362 | setattr(module, attribute_name, decorated_func_mapper[attr])
363 |
364 |
365 | def restore_model_attributes(
366 | model: nn.Module,
367 | decorated_func_mapper: Dict[Callable, Callable],
368 | attribute_keyword: str = "tl",
369 | ):
370 | """Recursively clears the given attribute from all modules in the model.
371 |
372 | Args:
373 | model: PyTorch model.
374 | decorated_func_mapper: Dict mapping between original and decorated PyTorch funcs
375 | attribute_keyword: Any attribute with this keyword will be cleared.
376 |
377 | Returns:
378 | Nothing.
379 | """
380 | for module in get_all_submodules(model):
381 | restore_module_attributes(
382 | module,
383 | decorated_func_mapper=decorated_func_mapper,
384 | attribute_keyword=attribute_keyword,
385 | )
386 |
387 | for param in model.parameters():
388 | if hasattr(param, "tl_requires_grad"):
389 | param.requires_grad = getattr(param, "tl_requires_grad")
390 | delattr(param, "tl_requires_grad")
391 |
392 |
393 | def undecorate_model_tensors(model: nn.Module):
394 | """Goes through a model and all its submodules, and unmutates any tensor attributes. Normally just clearing
395 | parameters would have done this, but some module types (e.g., batchnorm) contain attributes that are tensors,
396 | but not parameters.
397 |
398 | Args:
399 | model: PyTorch model
400 |
401 | Returns:
402 | PyTorch model with unmutated versions of all tensor attributes.
403 | """
404 | submodules = get_all_submodules(model)
405 | for submodule in submodules:
406 | for attribute_name, attribute in iter_accessible_attributes(submodule):
407 | if issubclass(type(attribute), torch.Tensor):
408 | if not issubclass(type(attribute), torch.nn.Parameter) and hasattr(
409 | attribute, "tl_tensor_label_raw"
410 | ):
411 | delattr(attribute, "tl_tensor_label_raw")
412 | if hasattr(attribute, 'tl_buffer_address'):
413 | delattr(attribute, "tl_buffer_address")
414 | if hasattr(attribute, 'tl_buffer_parent'):
415 | delattr(attribute, "tl_buffer_parent")
416 | else:
417 | remove_attributes_starting_with_str(attribute, "tl_")
418 | elif type(attribute) in [list, tuple, set]:
419 | for item in attribute:
420 | if issubclass(type(item), torch.Tensor) and hasattr(
421 | item, "tl_tensor_label_raw"
422 | ):
423 | delattr(item, "tl_tensor_label_raw")
424 | if hasattr(item, 'tl_buffer_address'):
425 | delattr(item, "tl_buffer_address")
426 | if hasattr(item, 'tl_buffer_parent'):
427 | delattr(item, "tl_buffer_parent")
428 | elif type(attribute) == dict:
429 | for key, val in attribute.items():
430 | if issubclass(type(val), torch.Tensor) and hasattr(
431 | val, "tl_tensor_label_raw"
432 | ):
433 | delattr(val, "tl_tensor_label_raw")
434 | if hasattr(val, 'tl_buffer_address'):
435 | delattr(val, "tl_buffer_address")
436 | if hasattr(val, 'tl_buffer_parent'):
437 | delattr(val, "tl_buffer_parent")
438 |
--------------------------------------------------------------------------------
/torchlens/user_funcs.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import random
4 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5 |
6 | import pandas as pd
7 | import torch
8 | from torch import nn
9 | from tqdm import tqdm
10 |
11 | from .helper_funcs import (get_vars_of_type_from_obj, set_random_seed, warn_parallel)
12 | from .model_history import (
13 | ModelHistory,
14 | )
15 |
16 |
17 | def run_model_and_save_specified_activations(
18 | model: nn.Module,
19 | input_args: Union[torch.Tensor, List[Any]],
20 | input_kwargs: Dict[Any, Any],
21 | layers_to_save: Optional[Union[str, List[Union[int, str]]]] = "all",
22 | keep_unsaved_layers: bool = True,
23 | output_device: str = "same",
24 | activation_postfunc: Optional[Callable] = None,
25 | mark_input_output_distances: bool = False,
26 | detach_saved_tensors: bool = False,
27 | save_function_args: bool = False,
28 | save_gradients: bool = False,
29 | random_seed: Optional[int] = None,
30 | ) -> ModelHistory:
31 | """Internal function that runs the given input through the given model, and saves the
32 | specified activations, as given by the tensor numbers (these will not be visible to the user;
33 | they will be generated from the nicer human-readable names and then fed in).
34 |
35 | Args:
36 | model: PyTorch model.
37 | input_args: Input arguments to the model's forward pass: either a single tensor, or a list of arguments.
38 | input_kwargs: Keyword arguments to the model's forward pass.
39 | layers_to_save: List of layers to save
40 | keep_unsaved_layers: Whether to keep layers in the ModelHistory log if they don't have saved activations.
41 | output_device: device where saved tensors will be stored: either 'same' to keep unchanged, or
42 | 'cpu' or 'cuda' to move to cpu or cuda.
43 | activation_postfunc: Function to apply to activations before saving them (e.g., any averaging)
44 | mark_input_output_distances: Whether to compute the distance of each layer from the input or output.
45 | This is computationally expensive for large networks, so it is off by default.
46 | detach_saved_tensors: whether to detach the saved tensors, so they remain attached to the computational graph
47 | save_function_args: whether to save the arguments to each function
48 | save_gradients: whether to save gradients from any subsequent backward pass
49 | random_seed: Which random seed to use.
50 |
51 | Returns:
52 | ModelHistory object with full log of the forward pass
53 | """
54 | model_name = str(type(model).__name__)
55 | model_history = ModelHistory(
56 | model_name,
57 | output_device,
58 | activation_postfunc,
59 | keep_unsaved_layers,
60 | save_function_args,
61 | save_gradients,
62 | detach_saved_tensors,
63 | mark_input_output_distances,
64 | )
65 | model_history._run_and_log_inputs_through_model(
66 | model, input_args, input_kwargs, layers_to_save, random_seed
67 | )
68 | return model_history
69 |
70 |
71 | def log_forward_pass(
72 | model: nn.Module,
73 | input_args: Union[torch.Tensor, List[Any], Tuple[Any]],
74 | input_kwargs: Dict[Any, Any] = None,
75 | layers_to_save: Optional[Union[str, List]] = "all",
76 | keep_unsaved_layers: bool = True,
77 | output_device: str = "same",
78 | activation_postfunc: Optional[Callable] = None,
79 | mark_input_output_distances: bool = False,
80 | detach_saved_tensors: bool = False,
81 | save_function_args: bool = False,
82 | save_gradients: bool = False,
83 | vis_opt: str = "none",
84 | vis_nesting_depth: int = 1000,
85 | vis_outpath: str = "graph.gv",
86 | vis_save_only: bool = False,
87 | vis_fileformat: str = "pdf",
88 | vis_buffer_layers: bool = False,
89 | vis_direction: str = "bottomup",
90 | vis_graph_overrides: Dict = None,
91 | vis_node_overrides: Dict = None,
92 | vis_nested_node_overrides: Dict = None,
93 | vis_edge_overrides: Dict = None,
94 | vis_gradient_edge_overrides: Dict = None,
95 | vis_module_overrides: Dict = None,
96 | random_seed: Optional[int] = None,
97 | ) -> ModelHistory:
98 | """Runs a forward pass through a model given input x, and returns a ModelHistory object containing a log
99 | (layer activations and accompanying layer metadata) of the forward pass for all layers specified in which_layers,
100 | and optionally visualizes the model graph if vis_opt is set to 'rolled' or 'unrolled'.
101 |
102 | In which_layers, can specify 'all', for all layers (default), or a list containing any combination of:
103 | 1) desired layer names (e.g., 'conv2d_1_1'; if a layer has multiple passes, this includes all passes),
104 | 2) a layer pass (e.g., conv2d_1_1:2 for just the second pass), 3) a module name to fetch the output of a particular
105 | module, 4) the ordinal index of a layer in the model (e.g. 3 for the third layer, -2 for the second to last, etc.),
106 | or 5) a desired substring with which to filter desired layers (e.g., 'conv2d' for all conv2d layers).
107 |
108 | Args:
109 | model: PyTorch model
110 | input_args: input arguments for model forward pass; as a list if multiple, else as a single tensor.
111 | input_kwargs: keyword arguments for model forward pass
112 | layers_to_save: list of layers to include (described above), or 'all' to include all layers.
113 | keep_unsaved_layers: whether to keep layers without saved activations in the log (i.e., with just metadata)
114 | activation_postfunc: Function to apply to tensors before saving them (e.g., channelwise averaging).
115 | mark_input_output_distances: whether to mark the distance of each layer from the input or output;
116 | False by default since this is computationally expensive.
117 | output_device: device where saved tensors are to be stored. Either 'same' to keep on the same device,
118 | or 'cpu' or 'cuda' to move them to cpu or cuda when saved.
119 | detach_saved_tensors: whether to detach the saved tensors, so they remain attached to the computational graph
120 | save_function_args: whether to save the arguments to each function involved in computing each tensor
121 | save_gradients: whether to save gradients from any subsequent backward pass
122 | vis_opt: whether, and how, to visualize the network; 'none' for
123 | no visualization, 'rolled' to show the graph in rolled-up format (i.e.,
124 | one node per layer if a recurrent network), or 'unrolled' to show the graph
125 | in unrolled format (i.e., one node per pass through a layer if a recurrent)
126 | vis_nesting_depth: How many levels of nested modules to show; 1 for only top-level modules, 2 for two
127 | levels, etc.
128 | vis_outpath: file path to save the graph visualization
129 | vis_save_only: whether to only save the graph visual without immediately showing it
130 | vis_fileformat: the format of the visualization (e.g,. 'pdf', 'jpg', etc.)
131 | vis_buffer_layers: whether to visualize the buffer layers
132 | vis_direction: either 'bottomup', 'topdown', or 'leftright'
133 | random_seed: which random seed to use in case model involves randomness
134 |
135 | Returns:
136 | ModelHistory object with layer activations and metadata
137 | """
138 | warn_parallel()
139 |
140 | if vis_opt not in ["none", "rolled", "unrolled"]:
141 | raise ValueError(
142 | "Visualization option must be either 'none', 'rolled', or 'unrolled'."
143 | )
144 |
145 | if output_device not in ["same", "cpu", "cuda"]:
146 | raise ValueError("output_device must be either 'same', 'cpu', or 'cuda'.")
147 |
148 | if type(layers_to_save) is str:
149 | layers_to_save = layers_to_save.lower()
150 |
151 | if layers_to_save in ["all", "none", None, []]:
152 | model_history = run_model_and_save_specified_activations(
153 | model=model,
154 | input_args=input_args,
155 | input_kwargs=input_kwargs,
156 | layers_to_save=layers_to_save,
157 | keep_unsaved_layers=keep_unsaved_layers,
158 | output_device=output_device,
159 | activation_postfunc=activation_postfunc,
160 | mark_input_output_distances=mark_input_output_distances,
161 | detach_saved_tensors=detach_saved_tensors,
162 | save_function_args=save_function_args,
163 | save_gradients=save_gradients,
164 | random_seed=random_seed,
165 | )
166 | else:
167 | model_history = run_model_and_save_specified_activations(
168 | model=model,
169 | input_args=input_args,
170 | input_kwargs=input_kwargs,
171 | layers_to_save=None,
172 | keep_unsaved_layers=True,
173 | output_device=output_device,
174 | activation_postfunc=activation_postfunc,
175 | mark_input_output_distances=mark_input_output_distances,
176 | detach_saved_tensors=detach_saved_tensors,
177 | save_function_args=save_function_args,
178 | save_gradients=save_gradients,
179 | random_seed=random_seed,
180 | )
181 | model_history.keep_unsaved_layers = keep_unsaved_layers
182 | model_history.save_new_activations(
183 | model=model,
184 | input_args=input_args,
185 | input_kwargs=input_kwargs,
186 | layers_to_save=layers_to_save,
187 | random_seed=random_seed,
188 | )
189 |
190 | # Visualize if desired.
191 | if vis_opt != "none":
192 | model_history.render_graph(
193 | vis_opt,
194 | vis_nesting_depth,
195 | vis_outpath,
196 | vis_graph_overrides,
197 | vis_node_overrides,
198 | vis_nested_node_overrides,
199 | vis_edge_overrides,
200 | vis_gradient_edge_overrides,
201 | vis_module_overrides,
202 | vis_save_only,
203 | vis_fileformat,
204 | vis_buffer_layers,
205 | vis_direction,
206 | )
207 |
208 | return model_history
209 |
210 |
211 | def get_model_metadata(
212 | model: nn.Module,
213 | input_args: Union[torch.Tensor, List[Any], Tuple[Any]],
214 | input_kwargs: Dict[Any, Any] = None,
215 | ) -> ModelHistory:
216 | """Logs all metadata for a given model and inputs without saving any activations. NOTE: this function
217 | will be removed in a future version of TorchLens, since calling it is identical to calling
218 | log_forward_pass without saving any layers.
219 |
220 | Args:
221 | model: model to inspect
222 | input_args: list of input positional arguments, or a single tensor
223 | input_kwargs: dict of keyword arguments
224 | Returns:
225 | ModelHistory object with metadata about the model.
226 | """
227 | model_history = log_forward_pass(
228 | model,
229 | input_args,
230 | input_kwargs,
231 | layers_to_save=None,
232 | mark_input_output_distances=True,
233 | )
234 | return model_history
235 |
236 |
237 | def show_model_graph(
238 | model: nn.Module,
239 | input_args: Union[torch.Tensor, List, Tuple],
240 | input_kwargs: Dict[Any, Any] = None,
241 | vis_opt: str = "unrolled",
242 | vis_nesting_depth: int = 1000,
243 | vis_outpath: str = "graph.gv",
244 | vis_graph_overrides: Dict = None,
245 | vis_node_overrides: Dict = None,
246 | vis_nested_node_overrides: Dict = None,
247 | vis_edge_overrides: Dict = None,
248 | vis_gradient_edge_overrides: Dict = None,
249 | vis_module_overrides: Dict = None,
250 | save_only: bool = False,
251 | vis_fileformat: str = "pdf",
252 | vis_buffer_layers: bool = False,
253 | vis_direction: str = "bottomup",
254 | random_seed: Optional[int] = None,
255 | ) -> None:
256 | """Visualize the model graph without saving any activations.
257 |
258 | Args:
259 | model: PyTorch model
260 | input_args: Arguments for model forward pass
261 | input_kwargs: Keyword arguments for model forward pass
262 | vis_opt: whether, and how, to visualize the network; 'none' for
263 | no visualization, 'rolled' to show the graph in rolled-up format (i.e.,
264 | one node per layer if a recurrent network), or 'unrolled' to show the graph
265 | in unrolled format (i.e., one node per pass through a layer if a recurrent)
266 | vis_nesting_depth: How many levels of nested modules to show; 1 for only top-level modules, 2 for two
267 | levels, etc.
268 | vis_outpath: file path to save the graph visualization
269 | save_only: whether to only save the graph visual without immediately showing it
270 | vis_fileformat: the format of the visualization (e.g,. 'pdf', 'jpg', etc.)
271 | vis_buffer_layers: whether to visualize the buffer layers
272 | vis_direction: either 'bottomup', 'topdown', or 'leftright'
273 | random_seed: which random seed to use in case model involves randomness
274 |
275 | Returns:
276 | Nothing.
277 | """
278 | if not input_kwargs:
279 | input_kwargs = {}
280 |
281 | if vis_opt not in ["none", "rolled", "unrolled"]:
282 | raise ValueError(
283 | "Visualization option must be either 'none', 'rolled', or 'unrolled'."
284 | )
285 |
286 | model_history = run_model_and_save_specified_activations(
287 | model=model,
288 | input_args=input_args,
289 | input_kwargs=input_kwargs,
290 | layers_to_save=None,
291 | activation_postfunc=None,
292 | mark_input_output_distances=False,
293 | detach_saved_tensors=False,
294 | save_gradients=False,
295 | random_seed=random_seed,
296 | )
297 | model_history.render_graph(
298 | vis_opt,
299 | vis_nesting_depth,
300 | vis_outpath,
301 | vis_graph_overrides,
302 | vis_node_overrides,
303 | vis_nested_node_overrides,
304 | vis_edge_overrides,
305 | vis_gradient_edge_overrides,
306 | vis_module_overrides,
307 | save_only,
308 | vis_fileformat,
309 | vis_buffer_layers,
310 | vis_direction,
311 | )
312 |
313 |
314 | def validate_saved_activations(
315 | model: nn.Module,
316 | input_args: Union[torch.Tensor, List[Any], Tuple[Any]],
317 | input_kwargs: Dict[Any, Any] = None,
318 | random_seed: Union[int, None] = None,
319 | verbose: bool = False,
320 | ) -> bool:
321 | """Validate that the saved model activations correctly reproduce the ground truth output. This function works by
322 | running a forward pass through the model, saving all activations, re-running the forward pass starting from
323 | the saved activations in each layer, and checking that the resulting output matches the original output.
324 | Additionally, it substitutes in random activations and checks whether the output changes accordingly, for
325 | at least min_proportion_consequential_layers of the layers (in case some layers do not change the output for some
326 | reason). Returns True if a model passes these tests for the given input, and False otherwise.
327 |
328 | Args:
329 | model: PyTorch model.
330 | input_args: Input for which to validate the saved activations.
331 | input_kwargs: Keyword arguments for model forward pass
332 | random_seed: random seed in case model is stochastic
333 | verbose: whether to show verbose error messages
334 | Returns:
335 | True if the saved activations correctly reproduce the ground truth output, false otherwise.
336 | """
337 | warn_parallel()
338 | if random_seed is None: # set random seed
339 | random_seed = random.randint(1, 4294967294)
340 | set_random_seed(random_seed)
341 | if type(input_args) is tuple:
342 | input_args = list(input_args)
343 | elif (type(input_args) not in [list, tuple]) and (input_args is not None):
344 | input_args = [input_args]
345 | if not input_args:
346 | input_args = []
347 | if not input_kwargs:
348 | input_kwargs = {}
349 | input_args_copy = [copy.deepcopy(arg) for arg in input_args]
350 | input_kwargs_copy = {key: copy.deepcopy(val) for key, val in input_kwargs.items()}
351 | state_dict = model.state_dict()
352 | ground_truth_output_tensors = get_vars_of_type_from_obj(
353 | model(*input_args_copy, **input_kwargs_copy),
354 | torch.Tensor,
355 | search_depth=5,
356 | allow_repeats=True,
357 | )
358 | model.load_state_dict(state_dict)
359 | model_history = run_model_and_save_specified_activations(
360 | model=model,
361 | input_args=input_args,
362 | input_kwargs=input_kwargs,
363 | layers_to_save="all",
364 | keep_unsaved_layers=True,
365 | activation_postfunc=None,
366 | mark_input_output_distances=False,
367 | detach_saved_tensors=False,
368 | save_gradients=False,
369 | save_function_args=True,
370 | random_seed=random_seed,
371 | )
372 | activations_are_valid = model_history.validate_saved_activations(
373 | ground_truth_output_tensors, verbose
374 | )
375 |
376 | model_history.cleanup()
377 | del model_history
378 | return activations_are_valid
379 |
380 |
381 | def validate_batch_of_models_and_inputs(
382 | models_and_inputs_dict: Dict[str, Dict[str, Union[str, Callable, Dict]]],
383 | out_path: str,
384 | redo_model_if_already_run: bool = True,
385 | ) -> pd.DataFrame:
386 | """Given multiple models and several inputs for each, validates the saved activations for all of them
387 | and returns a Pandas dataframe summarizing the validation results.
388 |
389 | Args:
390 | models_and_inputs_dict: Dict mapping each model name to a dict of model info:
391 | model_category: category of model (e.g., torchvision; just for directory bookkeeping)
392 | model_loading_func: function to load the model
393 | model_sample_inputs: dict of example inputs {name: input}
394 | out_path: Path to save the validation results to
395 | redo_model_if_already_run: If True, will re-run the validation for a model even if already run
396 |
397 | Returns:
398 | Pandas dataframe with validation information for each model and input
399 | """
400 | if os.path.exists(out_path):
401 | current_csv = pd.read_csv(out_path)
402 | else:
403 | current_csv = pd.DataFrame.from_dict(
404 | {
405 | "model_category": [],
406 | "model_name": [],
407 | "input_name": [],
408 | "validation_success": [],
409 | }
410 | )
411 | models_already_run = current_csv["model_name"].unique()
412 | for model_name, model_info in tqdm(
413 | models_and_inputs_dict.items(), desc="Validating models"
414 | ):
415 | print(f"Validating model {model_name}")
416 | if model_name in models_already_run and not redo_model_if_already_run:
417 | continue
418 | model_category = model_info["model_category"]
419 | model_loading_func = model_info["model_loading_func"]
420 | model = model_loading_func()
421 | model_sample_inputs = model_info["model_sample_inputs"]
422 | for input_name, x in model_sample_inputs.items():
423 | validation_success = validate_saved_activations(model, x)
424 | current_csv = current_csv.append(
425 | {
426 | "model_category": model_category,
427 | "model_name": model_name,
428 | "input_name": input_name,
429 | "validation_success": validation_success,
430 | },
431 | ignore_index=True,
432 | )
433 | current_csv.to_csv(out_path, index=False)
434 | del model
435 | return current_csv
436 |
--------------------------------------------------------------------------------
/torchlens/helper_funcs.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import copy
3 | import inspect
4 | import multiprocessing as mp
5 | import random
6 | import secrets
7 | import string
8 | import warnings
9 | from sys import getsizeof
10 | from typing import Any, Callable, Dict, List, Optional, Type
11 |
12 | import numpy as np
13 | import torch
14 | from IPython import get_ipython
15 | from torch import nn
16 |
17 | MAX_FLOATING_POINT_TOLERANCE = 3e-6
18 |
19 |
20 | def identity(x):
21 | return x
22 |
23 |
24 | def set_random_seed(seed: int):
25 | """Sets the random seed for all random number generators.
26 |
27 | Args:
28 | seed: Seed to set.
29 |
30 | Returns:
31 | Nothing.
32 | """
33 | random.seed(seed)
34 | np.random.seed(seed)
35 | torch.manual_seed(seed)
36 | torch.cuda.manual_seed_all(seed)
37 |
38 |
39 | def log_current_rng_states() -> Dict:
40 | """Utility function to fetch sufficient information from all RNG states to recover the same state later.
41 |
42 | Returns:
43 | Dict with sufficient information to recover all RNG states.
44 | """
45 | rng_dict = {
46 | "random": random.getstate(),
47 | "np": np.random.get_state(),
48 | "torch": torch.random.get_rng_state(),
49 | }
50 | if torch.cuda.is_available():
51 | rng_dict["torch_cuda"] = torch.cuda.get_rng_state("cuda")
52 | return rng_dict
53 |
54 |
55 | def set_rng_from_saved_states(rng_states: Dict):
56 | """Utility function to set the state of random seeds to a cached value.
57 |
58 | Args:
59 | rng_states: Dict of rng_states saved by get_random_seed_states
60 |
61 | Returns:
62 | Nothing, but correctly sets all random seed states.
63 | """
64 | random.setstate(rng_states["random"])
65 | np.random.set_state(rng_states["np"])
66 | torch.random.set_rng_state(rng_states["torch"])
67 | if torch.cuda.is_available():
68 | torch.cuda.set_rng_state(rng_states["torch_cuda"], "cuda")
69 |
70 |
71 | def make_random_barcode(barcode_len: int = 8) -> str:
72 | """Generates a random integer hash for a layer to use as internal label (invisible from user side).
73 |
74 | Args:
75 | barcode_len: Length of the desired barcode
76 |
77 | Returns:
78 | Random hash.
79 | """
80 | alphabet = string.ascii_letters + string.digits
81 | barcode = "".join(secrets.choice(alphabet) for _ in range(barcode_len))
82 | return barcode
83 |
84 |
85 | def make_short_barcode_from_input(
86 | things_to_hash: List[Any], barcode_len: int = 16
87 | ) -> str:
88 | """Utility function that takes a list of anything and returns a short hash of it.
89 |
90 | Args:
91 | things_to_hash: List of things to hash; they must all be convertible to a string.
92 | barcode_len:
93 |
94 | Returns:
95 | Short hash of the input.
96 | """
97 | barcode = "".join([str(x) for x in things_to_hash])
98 | barcode = str(hash(barcode))
99 | barcode = barcode.encode("utf-8")
100 | barcode = base64.urlsafe_b64encode(barcode)
101 | barcode = barcode.decode("utf-8")
102 | barcode = barcode[0:barcode_len]
103 | return barcode
104 |
105 |
106 | def _get_call_stack_dicts():
107 | call_stack = inspect.stack()
108 | call_stack = [
109 | inspect.getframeinfo(call_stack[i][0], context=19)
110 | for i in range(len(call_stack))
111 | ]
112 | call_stack_dicts = [
113 | {
114 | "call_fname": caller.filename,
115 | "call_linenum": caller.lineno,
116 | "function": caller.function,
117 | "code_context": caller.code_context,
118 | }
119 | for caller in call_stack
120 | ]
121 |
122 | for call_stack_dict in call_stack_dicts:
123 | if is_iterable(call_stack_dict['code_context']):
124 | call_stack_dict['code_context_str'] = ''.join(call_stack_dict['code_context'])
125 | else:
126 | call_stack_dict['code_context_str'] = str(call_stack_dict['code_context'])
127 |
128 | # Only start at the level of that first forward pass, going from shallow to deep.
129 | tracking = False
130 | filtered_dicts = []
131 | for d in range(len(call_stack_dicts) - 1, -1, -1):
132 | call_stack_dict = call_stack_dicts[d]
133 | if any(
134 | [
135 | call_stack_dict["call_fname"].endswith("model_history.py"),
136 | call_stack_dict["call_fname"].endswith("torchlens/helper_funcs.py"),
137 | call_stack_dict["call_fname"].endswith("torchlens/user_funcs.py"),
138 | call_stack_dict["call_fname"].endswith("torchlens/trace_model.py"),
139 | call_stack_dict["call_fname"].endswith("torchlens/logging_funcs.py"),
140 | call_stack_dict["call_fname"].endswith("torchlens/decorate_torch.py"),
141 | call_stack_dict["call_fname"].endswith("torchlens/model_funcs.py"),
142 | "_call_impl" in call_stack_dict["function"],
143 | ]
144 | ):
145 | continue
146 | if call_stack_dict["function"] == "forward":
147 | tracking = True
148 |
149 | if tracking:
150 | filtered_dicts.append(call_stack_dict)
151 |
152 | return filtered_dicts
153 |
154 |
155 | def is_iterable(obj: Any) -> bool:
156 | """Checks if an object is iterable.
157 |
158 | Args:
159 | obj: Object to check.
160 |
161 | Returns:
162 | True if object is iterable, False otherwise.
163 | """
164 | try:
165 | iter(obj)
166 | return True
167 | except TypeError:
168 | return False
169 |
170 |
171 | def make_var_iterable(x):
172 | """Utility function to facilitate dealing with outputs:
173 | - If not a list, tuple, or dict, make it a list of length 1
174 | - If a dict, make it a list of the values
175 | - If a list or tuple, keep it.
176 |
177 | Args:
178 | x: Output of the function
179 |
180 | Returns:
181 | Iterable output
182 | """
183 | if any([issubclass(type(x), cls) for cls in [list, tuple, set]]):
184 | return x
185 | elif issubclass(type(x), dict):
186 | return list(x.values())
187 | else:
188 | return [x]
189 |
190 |
191 | def index_nested(x: Any, indices: List[int]) -> Any:
192 | """Utility function to index into a nested list or tuple.
193 |
194 | Args:
195 | x: Nested list or tuple.
196 | indices: List of indices to use.
197 |
198 | Returns:
199 | Indexed object.
200 | """
201 | indices = make_var_iterable(indices)
202 | for i in indices:
203 | x = x[i]
204 | return x
205 |
206 |
207 | def remove_entry_from_list(list_: List, entry: Any):
208 | """Removes all instances of an entry from a list if present, in-place.
209 |
210 | Args:
211 | list_: the list
212 | entry: the entry to remove
213 | """
214 | while entry in list_:
215 | list_.remove(entry)
216 |
217 |
218 | def tuple_tolerant_assign(obj_: Any, ind: int, new_value: any):
219 | """Utility function to assign an entry of a list, tuple, or dict to a new value.
220 |
221 | Args:
222 | obj_: Tuple to change.
223 | ind: Index to change.
224 | new_value: The new value.
225 |
226 | Returns:
227 | Tuple with the new value swapped out.
228 | """
229 | if type(obj_) == tuple:
230 | list_ = list(obj_)
231 | list_[ind] = new_value
232 | return tuple(list_)
233 |
234 | obj_[ind] = new_value
235 | return obj_
236 |
237 |
238 | def int_list_to_compact_str(int_list: List[int]) -> str:
239 | """Given a list of integers, returns a compact string representation of the list, where
240 | contiguous stretches of the integers are represented as ranges (e.g., [1 2 3 4] becomes "1-4"),
241 | and all such ranges are separated by commas.
242 |
243 | Args:
244 | int_list: List of integers.
245 |
246 | Returns:
247 | Compact string representation of the list.
248 | """
249 | int_list = sorted(int_list)
250 | if len(int_list) == 0:
251 | return ""
252 | if len(int_list) == 1:
253 | return str(int_list[0])
254 | ranges = []
255 | start = int_list[0]
256 | end = int_list[0]
257 | for i in range(1, len(int_list)):
258 | if int_list[i] == end + 1:
259 | end = int_list[i]
260 | else:
261 | if start == end:
262 | ranges.append(str(start))
263 | else:
264 | ranges.append(f"{start}-{end}")
265 | start = int_list[i]
266 | end = int_list[i]
267 | if start == end:
268 | ranges.append(str(start))
269 | else:
270 | ranges.append(f"{start}-{end}")
271 | return ",".join(ranges)
272 |
273 |
274 | def get_vars_of_type_from_obj(
275 | obj: Any,
276 | which_type: Type,
277 | subclass_exceptions: Optional[List] = None,
278 | search_depth: int = 3,
279 | return_addresses=False,
280 | allow_repeats=False,
281 | ) -> List:
282 | """Recursively finds all tensors in an object, excluding specified subclasses (e.g., parameters)
283 | up to the given search depth.
284 |
285 | Args:
286 | obj: Object to search.
287 | which_type: Type of variable to pull out
288 | subclass_exceptions: subclasses that you don't want to pull out.
289 | search_depth: How many layers deep to search before giving up.
290 | return_addresses: if True, then returns list of tuples (object, address), where the
291 | address is how you'd index to get the object
292 | allow_repeats: whether to allow repeats of the same tensor
293 |
294 | Returns:
295 | List of objects of desired type found in the input object.
296 | """
297 | if subclass_exceptions is None:
298 | subclass_exceptions = []
299 | this_stack = [(obj, "", [])]
300 | tensors_in_obj = []
301 | tensor_addresses = []
302 | tensor_addresses_full = []
303 | tensor_ids_in_obj = []
304 | for _ in range(search_depth):
305 | this_stack = search_stack_for_vars_of_type(
306 | this_stack,
307 | which_type,
308 | tensors_in_obj,
309 | tensor_addresses,
310 | tensor_addresses_full,
311 | tensor_ids_in_obj,
312 | subclass_exceptions,
313 | allow_repeats,
314 | )
315 |
316 | if return_addresses:
317 | return list(zip(tensors_in_obj, tensor_addresses, tensor_addresses_full))
318 | else:
319 | return tensors_in_obj
320 |
321 |
322 | def search_stack_for_vars_of_type(
323 | current_stack: List,
324 | which_type: Type,
325 | tensors_in_obj: List,
326 | tensor_addresses: List,
327 | tensor_addresses_full: List,
328 | tensor_ids_in_obj: List,
329 | subclass_exceptions: List,
330 | allow_repeats: bool,
331 | ):
332 | """Helper function that searches current stack for vars of a given type, and
333 | returns the next stack to search.
334 |
335 | Args:
336 | current_stack: The current stack.
337 | which_type: Type of variable to pull out
338 | tensors_in_obj: List of tensors found in the object so far
339 | tensor_addresses: Addresses of the tensors found so far
340 | tensor_addresses_full: explicit instructions for indexing the obj
341 | tensor_ids_in_obj: List of tensor ids found in the object so far
342 | subclass_exceptions: Subclasses of tensors not to use.
343 | allow_repeats: whether to allow repeat tensors
344 |
345 | Returns:
346 | The next stack.
347 | """
348 | next_stack = []
349 | if len(current_stack) == 0:
350 | return current_stack
351 | while len(current_stack) > 0:
352 | item, address, address_full = current_stack.pop(0)
353 | item_class = type(item)
354 | if any(
355 | [issubclass(item_class, subclass) for subclass in subclass_exceptions]
356 | ) or ((id(item) in tensor_ids_in_obj) and not allow_repeats):
357 | continue
358 | if issubclass(item_class, which_type):
359 | tensors_in_obj.append(item)
360 | tensor_addresses.append(address)
361 | tensor_addresses_full.append(address_full)
362 | tensor_ids_in_obj.append(id(item))
363 | continue
364 | if item_class in [str, int, float, bool, np.ndarray, torch.tensor]:
365 | continue
366 | extend_search_stack_from_item(item, address, address_full, next_stack)
367 | return next_stack
368 |
369 |
370 | def extend_search_stack_from_item(
371 | item: Any, address: str, address_full, next_stack: List
372 | ):
373 | """Utility function to iterate through a single item to populate the next stack to search for.
374 |
375 | Args:
376 | item: The item
377 | next_stack: Stack to add to
378 | """
379 | if type(item) in [list, tuple, set]:
380 | if address == "":
381 | next_stack.extend(
382 | [(x, f"{i}", address_full + [("ind", i)]) for i, x in enumerate(item)]
383 | )
384 | else:
385 | next_stack.extend(
386 | [
387 | (x, f"{address}.{i}", address_full + [("ind", i)])
388 | for i, x in enumerate(item)
389 | ]
390 | )
391 |
392 | if issubclass(type(item), dict):
393 | if address == "":
394 | next_stack.extend(
395 | [(val, key, address_full + [("ind", key)]) for key, val in item.items()]
396 | )
397 | else:
398 | next_stack.extend(
399 | [
400 | (val, f"{address}.{key}", address_full + [("ind", key)])
401 | for key, val in item.items()
402 | ]
403 | )
404 |
405 | for attr_name in dir(item):
406 | if ((attr_name.startswith("__")) or
407 | (attr_name in ['T', 'mT', 'real', 'imag', 'H']) or
408 | ('grad' in attr_name)):
409 | continue
410 | try:
411 | with warnings.catch_warnings():
412 | warnings.simplefilter("ignore")
413 | attr = getattr(item, attr_name)
414 | except:
415 | continue
416 | attr_cls = type(attr)
417 | if attr_cls in [str, int, float, bool, np.ndarray]:
418 | continue
419 | if callable(attr) and not issubclass(attr_cls, nn.Module):
420 | continue
421 | if address == "":
422 | next_stack.append(
423 | (attr, attr_name.strip("_"), address_full + [("attr", attr_name)])
424 | )
425 | else:
426 | next_stack.append(
427 | (
428 | attr,
429 | f'{address}.{attr_name.strip("_")}',
430 | address_full + [("attr", attr_name)],
431 | )
432 | )
433 |
434 |
435 | def get_attr_values_from_tensor_list(
436 | tensor_list: List[torch.Tensor], field_name: str
437 | ) -> List[Any]:
438 | """For a list of tensors, gets the value of a given attribute from each tensor that has that attribute.
439 |
440 | Args:
441 | tensor_list: List of tensors to search.
442 | field_name: Name of the field to check in the tensor.
443 |
444 | Returns:
445 | List of marks from the tensors.
446 | """
447 | marks = []
448 | for tensor in tensor_list:
449 | mark = getattr(tensor, field_name, None)
450 | if mark is not None:
451 | marks.append(mark)
452 | return marks
453 |
454 |
455 | def nested_getattr(obj: Any, attr: str) -> Any:
456 | """Helper function that takes in an object, and a string of attributes separated by '.' and recursively
457 | returns the attribute.
458 |
459 | Args:
460 | obj: Any object, e.g. "torch"
461 | attr: String specifying the nested attribute, e.g. "nn.functional"
462 |
463 | Returns:
464 | The attribute specified by the string.
465 | """
466 | if attr == "":
467 | return obj
468 |
469 | attributes = attr.split(".")
470 | for i, a in enumerate(attributes):
471 | if a in [
472 | "volatile",
473 | "T",
474 | 'H',
475 | 'mH',
476 | 'mT'
477 | ]: # avoid annoying warning; if there's more, make a list
478 | with warnings.catch_warnings():
479 | warnings.simplefilter("ignore")
480 | obj = getattr(obj, a)
481 | else:
482 | obj = getattr(obj, a)
483 | return obj
484 |
485 |
486 | def nested_assign(obj, addr, val):
487 | """Given object and an address in that object, assign value to that address."""
488 | for i, (entry_type, entry_val) in enumerate(addr):
489 | if i == len(addr) - 1:
490 | if entry_type == "ind":
491 | obj[entry_val] = val
492 | elif entry_type == "attr":
493 | setattr(obj, entry_val, val)
494 | else:
495 | if entry_type == "ind":
496 | obj = obj[entry_val]
497 | elif entry_type == "attr":
498 | with warnings.catch_warnings():
499 | warnings.simplefilter("ignore")
500 | obj = getattr(obj, entry_val)
501 |
502 |
503 | def iter_accessible_attributes(obj: Any, *, short_circuit: Optional[Callable[[Any, str], bool]] = None):
504 | for attr_name in dir(obj):
505 | if short_circuit and short_circuit(obj, attr_name):
506 | continue
507 |
508 | # Attribute access can fail for any number of reasons, especially when
509 | # working with objects that we don't know anything about. This
510 | # function makes a best-effort attempt to access every attribute, but
511 | # gracefully skips any that cause problems.
512 |
513 | with warnings.catch_warnings():
514 | warnings.simplefilter("ignore")
515 | try:
516 | attr = getattr(obj, attr_name)
517 | except Exception:
518 | continue
519 |
520 | yield attr_name, attr
521 |
522 |
523 | def remove_attributes_starting_with_str(obj: Any, s: str):
524 | """Given an object removes, any attributes for that object beginning with a given
525 | substring.
526 |
527 | Args:
528 | obj: object
529 | s: string that marks fields to remove
530 | """
531 | for field in dir(obj):
532 | if field.startswith(s):
533 | delattr(obj, field)
534 |
535 |
536 | def tensor_all_nan(t: torch.Tensor) -> bool:
537 | """Returns True if tensor is all nans, False otherwise."""
538 | if torch.isnan(t).int().sum() == t.numel():
539 | return True
540 | else:
541 | return False
542 |
543 |
544 | def tensor_nanequal(t1: torch.Tensor,
545 | t2: torch.Tensor,
546 | allow_tolerance=False) -> bool:
547 | """Returns True if the two tensors are equal, allowing for nans."""
548 | if t1.shape != t2.shape:
549 | return False
550 |
551 | if t1.dtype != t2.dtype:
552 | return False
553 |
554 | if not torch.equal(t1.isinf(), t2.isinf()):
555 | return False
556 |
557 | t1_nonan = torch.nan_to_num(t1, 0.7234691827346)
558 | t2_nonan = torch.nan_to_num(t2, 0.7234691827346)
559 |
560 | if torch.equal(t1_nonan, t2_nonan):
561 | return True
562 |
563 | if (
564 | allow_tolerance
565 | and (t1_nonan.dtype != torch.bool)
566 | and (t2_nonan.dtype != torch.bool)
567 | and ((t1_nonan - t2_nonan).abs().max() <= MAX_FLOATING_POINT_TOLERANCE)
568 | ):
569 | return True
570 |
571 | return False
572 |
573 |
574 | def safe_to(x: Any, device: str):
575 | """Moves object to device if it's a tensor, does nothing otherwise.
576 |
577 | Args:
578 | x: The object.
579 | device: which device to move to
580 |
581 | Returns:
582 | Object either moved to device if a tensor, same object if otherwise.
583 | """
584 | if type(x) == torch.Tensor:
585 | return clean_to(x, device)
586 | else:
587 | return x
588 |
589 |
590 | def get_tensor_memory_amount(t: torch.Tensor) -> int:
591 | """Returns the size of a tensor in bytes.
592 |
593 | Args:
594 | t: Tensor.
595 |
596 | Returns:
597 | Size of tensor in bytes.
598 | """
599 | cpu_data = clean_cpu(t.data)
600 | if cpu_data.dtype == torch.bfloat16:
601 | cpu_data = clean_to(cpu_data, torch.float16)
602 | return getsizeof(np.array(clean_dense(cpu_data)))
603 |
604 |
605 | def human_readable_size(size: int, decimal_places: int = 1) -> str:
606 | """Utility function to convert a size in bytes to a human-readable format.
607 |
608 | Args:
609 | size: Number of bytes.
610 | decimal_places: Number of decimal places to use.
611 |
612 | Returns:
613 | String with human-readable size.
614 | """
615 | for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
616 | if size < 1024.0 or unit == "PB":
617 | break
618 | size /= 1024.0
619 | if unit == "B":
620 | size = int(size)
621 | else:
622 | size = np.round(size, decimals=decimal_places)
623 | return f"{size} {unit}"
624 |
625 |
626 | clean_from_numpy = copy.deepcopy(torch.from_numpy)
627 | clean_new_param = copy.deepcopy(torch.nn.Parameter)
628 | clean_clone = copy.deepcopy(torch.clone)
629 | clean_cpu = copy.deepcopy(torch.Tensor.cpu)
630 | clean_cuda = copy.deepcopy(torch.Tensor.cuda)
631 | clean_to = copy.deepcopy(torch.Tensor.to)
632 | clean_dense = copy.deepcopy(torch.Tensor.to_dense)
633 |
634 |
635 | def print_override(t: torch.Tensor, func_name: str):
636 | """Overrides the __str__ and __repr__ methods of Tensor so as not to lead to any infinite recursion.
637 |
638 | Args:
639 | t: Tensor
640 | func_name: Either "__str__" or "__repr__"
641 |
642 | Returns:
643 | The string representation of the tensor.
644 | """
645 | cpu_data = clean_cpu(t.data)
646 | if cpu_data.dtype == torch.bfloat16:
647 | cpu_data = clean_to(cpu_data, torch.float16)
648 | n = np.array(cpu_data)
649 | np_str = getattr(n, func_name)()
650 | np_str = np_str.replace("array", "tensor")
651 | np_str = np_str.replace("\n", "\n ")
652 | if t.grad_fn is not None:
653 | grad_fn_str = f", grad_fn={type(t.grad_fn).__name__})"
654 | np_str = np_str[0:-1] + grad_fn_str
655 | elif t.requires_grad:
656 | np_str = np_str[0:-1] + ", requires_grad=True)"
657 | return np_str
658 |
659 |
660 | def safe_copy(x, detach_tensor: bool = False):
661 | """Utility function to make a copy of a tensor or parameter when torch is in mutated mode, or just copy
662 | the thing if it's not a tensor.
663 |
664 | Args:
665 | x: Input
666 | detach_tensor: Whether to detach the cloned tensor from the computational graph or not.
667 |
668 | Returns:
669 | Safely copied variant of the input with same values and same class, but different memory
670 | """
671 | if issubclass(type(x), (torch.Tensor, torch.nn.Parameter)):
672 | if not detach_tensor:
673 | return clean_clone(x)
674 | vals_cpu = clean_cpu(x.data)
675 | if vals_cpu.dtype == torch.bfloat16:
676 | vals_cpu = clean_to(vals_cpu, torch.float16)
677 | vals_np = vals_cpu.numpy()
678 | vals_tensor = clean_from_numpy(vals_np)
679 | if hasattr(x, "tl_tensor_label_raw"):
680 | vals_tensor.tl_tensor_label_raw = x.tl_tensor_label_raw
681 | if type(x) == torch.Tensor:
682 | return vals_tensor
683 | elif type(x) == torch.nn.Parameter:
684 | return clean_new_param(vals_tensor)
685 | else:
686 | return copy.copy(x)
687 |
688 |
689 | def in_notebook():
690 | try:
691 | if "IPKernelApp" not in get_ipython().config:
692 | return False
693 | except ImportError:
694 | return False
695 | except AttributeError:
696 | return False
697 | return True
698 |
699 |
700 | def warn_parallel():
701 | """
702 | Utility function to give raise error if it's being run in parallel processing.
703 | """
704 | if mp.current_process().name != "MainProcess":
705 | raise RuntimeError(
706 | "WARNING: It looks like you are using parallel execution; only run "
707 | "pytorch-xray in the main process, since certain operations "
708 | "depend on execution order."
709 | )
710 |
--------------------------------------------------------------------------------
/torchlens/tensor_log.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import defaultdict
3 | from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Tuple, Union
4 |
5 | import torch
6 |
7 | from .constants import TENSOR_LOG_ENTRY_FIELD_ORDER
8 | from .helper_funcs import clean_to, get_tensor_memory_amount, human_readable_size, print_override, safe_copy
9 |
10 | if TYPE_CHECKING:
11 | from .model_history import ModelHistory
12 |
13 |
14 | class TensorLogEntry:
15 | def __init__(self, fields_dict: Dict):
16 | """Object that stores information about a single tensor operation in the forward pass,
17 | including metadata and the tensor itself (if specified). Initialized by passing in a dictionary with
18 | values for all fields.
19 | Args:
20 | fields_dict: Dict with values for all fields in TensorLogEntry.
21 | """
22 | # Note: this all has to be tediously initialized instead of a for-loop in order for
23 | # autocomplete features to work well. But, this also serves as a reference for all attributes
24 | # of a tensor log entry.
25 |
26 | # Check that fields_dict contains all fields for TensorLogEntry:
27 | field_order_set = set(TENSOR_LOG_ENTRY_FIELD_ORDER)
28 | fields_dict_key_set = set(fields_dict.keys())
29 | if fields_dict_key_set != field_order_set:
30 | error_str = "Error initializing TensorLogEntry:"
31 | missing_fields = field_order_set - fields_dict_key_set
32 | extra_fields = fields_dict_key_set - field_order_set
33 | if len(missing_fields) > 0:
34 | error_str += f"\n\t- Missing fields {', '.join(missing_fields)}"
35 | if len(extra_fields) > 0:
36 | error_str += f"\n\t- Extra fields {', '.join(extra_fields)}"
37 | raise ValueError(error_str)
38 |
39 | # General info:
40 | self.tensor_label_raw = fields_dict["tensor_label_raw"]
41 | self.layer_label_raw = fields_dict["layer_label_raw"]
42 | self.operation_num = fields_dict["operation_num"]
43 | self.realtime_tensor_num = fields_dict["realtime_tensor_num"]
44 | self.source_model_history: "ModelHistory" = fields_dict["source_model_history"]
45 | self._pass_finished = fields_dict["_pass_finished"]
46 |
47 | # Label info:
48 | self.layer_label = fields_dict["layer_label"]
49 | self.layer_label_short = fields_dict["layer_label_short"]
50 | self.layer_label_w_pass = fields_dict["layer_label_w_pass"]
51 | self.layer_label_w_pass_short = fields_dict["layer_label_w_pass_short"]
52 | self.layer_label_no_pass = fields_dict["layer_label_no_pass"]
53 | self.layer_label_no_pass_short = fields_dict["layer_label_no_pass_short"]
54 | self.layer_type = fields_dict["layer_type"]
55 | self.layer_type_num = fields_dict["layer_type_num"]
56 | self.layer_total_num = fields_dict["layer_total_num"]
57 | self.pass_num = fields_dict["pass_num"]
58 | self.layer_passes_total = fields_dict["layer_passes_total"]
59 | self.lookup_keys = fields_dict["lookup_keys"]
60 |
61 | # Saved tensor info:
62 | self.tensor_contents = fields_dict["tensor_contents"]
63 | self.has_saved_activations = fields_dict["has_saved_activations"]
64 | self.output_device = fields_dict["output_device"]
65 | self.activation_postfunc = fields_dict["activation_postfunc"]
66 | self.detach_saved_tensor = fields_dict["detach_saved_tensor"]
67 | self.function_args_saved = fields_dict["function_args_saved"]
68 | self.creation_args = fields_dict["creation_args"]
69 | self.creation_kwargs = fields_dict["creation_kwargs"]
70 | self.tensor_shape = fields_dict["tensor_shape"]
71 | self.tensor_dtype = fields_dict["tensor_dtype"]
72 | self.tensor_fsize = fields_dict["tensor_fsize"]
73 | self.tensor_fsize_nice = fields_dict["tensor_fsize_nice"]
74 |
75 | # Dealing with getitem complexities
76 | self.was_getitem_applied = fields_dict["was_getitem_applied"]
77 | self.children_tensor_versions = fields_dict["children_tensor_versions"]
78 |
79 | # Saved gradient info
80 | self.grad_contents = fields_dict["grad_contents"]
81 | self.save_gradients = fields_dict["save_gradients"]
82 | self.has_saved_grad = fields_dict["has_saved_grad"]
83 | self.grad_shape = fields_dict["grad_shape"]
84 | self.grad_dtype = fields_dict["grad_dtype"]
85 | self.grad_fsize = fields_dict["grad_fsize"]
86 | self.grad_fsize_nice = fields_dict["grad_fsize_nice"]
87 |
88 | # Function call info:
89 | self.func_applied = fields_dict["func_applied"]
90 | self.func_applied_name = fields_dict["func_applied_name"]
91 | self.func_call_stack = fields_dict["func_call_stack"]
92 | self.func_time_elapsed = fields_dict["func_time_elapsed"]
93 | self.func_rng_states = fields_dict["func_rng_states"]
94 | self.func_argnames = fields_dict["func_argnames"]
95 | self.num_func_args_total = fields_dict["num_func_args_total"]
96 | self.num_position_args = fields_dict["num_position_args"]
97 | self.num_keyword_args = fields_dict["num_keyword_args"]
98 | self.func_position_args_non_tensor = fields_dict[
99 | "func_position_args_non_tensor"
100 | ]
101 | self.func_keyword_args_non_tensor = fields_dict["func_keyword_args_non_tensor"]
102 | self.func_all_args_non_tensor = fields_dict["func_all_args_non_tensor"]
103 | self.function_is_inplace = fields_dict["function_is_inplace"]
104 | self.gradfunc = fields_dict["gradfunc"]
105 | self.is_part_of_iterable_output = fields_dict["is_part_of_iterable_output"]
106 | self.iterable_output_index = fields_dict["iterable_output_index"]
107 |
108 | # Param info:
109 | self.computed_with_params = fields_dict["computed_with_params"]
110 | self.parent_params = fields_dict["parent_params"]
111 | self.parent_param_barcodes = fields_dict["parent_param_barcodes"]
112 | self.parent_param_passes = fields_dict["parent_param_passes"]
113 | self.num_param_tensors = fields_dict["num_param_tensors"]
114 | self.parent_param_shapes = fields_dict["parent_param_shapes"]
115 | self.num_params_total = fields_dict["num_params_total"]
116 | self.parent_params_fsize = fields_dict["parent_params_fsize"]
117 | self.parent_params_fsize_nice = fields_dict["parent_params_fsize_nice"]
118 |
119 | # Corresponding layer info:
120 | self.operation_equivalence_type = fields_dict["operation_equivalence_type"]
121 | self.equivalent_operations = fields_dict["equivalent_operations"]
122 | self.same_layer_operations = fields_dict["same_layer_operations"]
123 |
124 | # Graph info:
125 | self.parent_layers = fields_dict["parent_layers"]
126 | self.has_parents = fields_dict["has_parents"]
127 | self.parent_layer_arg_locs = fields_dict["parent_layer_arg_locs"]
128 | self.orig_ancestors = fields_dict["orig_ancestors"]
129 | self.child_layers = fields_dict["child_layers"]
130 | self.has_children = fields_dict["has_children"]
131 | self.sibling_layers = fields_dict["sibling_layers"]
132 | self.has_siblings = fields_dict["has_siblings"]
133 | self.spouse_layers = fields_dict["spouse_layers"]
134 | self.has_spouses = fields_dict["has_spouses"]
135 | self.is_input_layer = fields_dict["is_input_layer"]
136 | self.has_input_ancestor = fields_dict["has_input_ancestor"]
137 | self.input_ancestors = fields_dict["input_ancestors"]
138 | self.min_distance_from_input = fields_dict["min_distance_from_input"]
139 | self.max_distance_from_input = fields_dict["max_distance_from_input"]
140 | self.is_output_layer = fields_dict["is_output_layer"]
141 | self.is_output_parent = fields_dict["is_output_parent"]
142 | self.is_last_output_layer = fields_dict["is_last_output_layer"]
143 | self.is_output_ancestor = fields_dict["is_output_ancestor"]
144 | self.output_descendents = fields_dict["output_descendents"]
145 | self.min_distance_from_output = fields_dict["min_distance_from_output"]
146 | self.max_distance_from_output = fields_dict["max_distance_from_output"]
147 | self.input_output_address = fields_dict["input_output_address"]
148 | self.is_buffer_layer = fields_dict["is_buffer_layer"]
149 | self.buffer_address = fields_dict["buffer_address"]
150 | self.buffer_pass = fields_dict["buffer_pass"]
151 | self.buffer_parent = fields_dict["buffer_parent"]
152 | self.initialized_inside_model = fields_dict["initialized_inside_model"]
153 | self.has_internally_initialized_ancestor = fields_dict[
154 | "has_internally_initialized_ancestor"
155 | ]
156 | self.internally_initialized_parents = fields_dict[
157 | "internally_initialized_parents"
158 | ]
159 | self.internally_initialized_ancestors = fields_dict[
160 | "internally_initialized_ancestors"
161 | ]
162 | self.terminated_inside_model = fields_dict["terminated_inside_model"]
163 |
164 | # Conditional info
165 | self.is_terminal_bool_layer = fields_dict["is_terminal_bool_layer"]
166 | self.is_atomic_bool_layer = fields_dict["is_atomic_bool_layer"]
167 | self.atomic_bool_val = fields_dict["atomic_bool_val"]
168 | self.in_cond_branch = fields_dict["in_cond_branch"]
169 | self.cond_branch_start_children = fields_dict["cond_branch_start_children"]
170 |
171 | # Module info
172 | self.is_computed_inside_submodule = fields_dict["is_computed_inside_submodule"]
173 | self.containing_module_origin = fields_dict["containing_module_origin"]
174 | self.containing_modules_origin_nested = fields_dict[
175 | "containing_modules_origin_nested"
176 | ]
177 | self.module_nesting_depth = fields_dict["module_nesting_depth"]
178 | self.modules_entered = fields_dict["modules_entered"]
179 | self.modules_entered_argnames = fields_dict["modules_entered_argnames"]
180 | self.module_passes_entered = fields_dict["module_passes_entered"]
181 | self.is_submodule_input = fields_dict["is_submodule_input"]
182 | self.modules_exited = fields_dict["modules_exited"]
183 | self.module_passes_exited = fields_dict["module_passes_exited"]
184 | self.is_submodule_output = fields_dict["is_submodule_output"]
185 | self.is_bottom_level_submodule_output = fields_dict[
186 | "is_bottom_level_submodule_output"
187 | ]
188 | self.bottom_level_submodule_pass_exited = fields_dict[
189 | "bottom_level_submodule_pass_exited"
190 | ]
191 | self.module_entry_exit_threads_inputs = fields_dict[
192 | "module_entry_exit_threads_inputs"
193 | ]
194 | self.module_entry_exit_thread_output = fields_dict[
195 | "module_entry_exit_thread_output"
196 | ]
197 |
198 | # ********************************************
199 | # *********** User-Facing Functions **********
200 | # ********************************************
201 |
202 | def print_all_fields(self):
203 | """Print all data fields in the layer."""
204 | fields_to_exclude = ["source_model_history", "func_rng_states"]
205 |
206 | for field in dir(self):
207 | attr = getattr(self, field)
208 | if not any(
209 | [field.startswith("_"), field in fields_to_exclude, callable(attr)]
210 | ):
211 | print(f"{field}: {attr}")
212 |
213 | # ********************************************
214 | # ************* Logging Functions ************
215 | # ********************************************
216 |
217 | def copy(self):
218 | """Return a copy of itself.
219 |
220 | Returns:
221 | Copy of itself.
222 | """
223 | fields_dict = {}
224 | fields_not_to_deepcopy = [
225 | "func_applied",
226 | "gradfunc",
227 | "source_model_history",
228 | "func_rng_states",
229 | "creation_args",
230 | "creation_kwargs",
231 | "parent_params",
232 | "tensor_contents",
233 | "children_tensor_versions",
234 | ]
235 | for field in TENSOR_LOG_ENTRY_FIELD_ORDER:
236 | if field not in fields_not_to_deepcopy:
237 | fields_dict[field] = copy.deepcopy(getattr(self, field))
238 | else:
239 | fields_dict[field] = getattr(self, field)
240 | copied_entry = TensorLogEntry(fields_dict)
241 | return copied_entry
242 |
243 | def save_tensor_data(
244 | self,
245 | t: torch.Tensor,
246 | t_args: Union[List, Tuple],
247 | t_kwargs: Dict,
248 | save_function_args: bool,
249 | activation_postfunc: Optional[Callable] = None,
250 | ):
251 | """Saves the tensor data for a given tensor operation.
252 |
253 | Args:
254 | t: the tensor.
255 | t_args: tensor positional arguments for the operation
256 | t_kwargs: tensor keyword arguments for the operation
257 | save_function_args: whether to save the arguments to the function
258 | activation_postfunc: function to apply to activations before saving them
259 | """
260 | # The tensor itself:
261 | self.tensor_contents = safe_copy(t, self.detach_saved_tensor)
262 | if self.output_device not in [str(self.tensor_contents.device), "same"]:
263 | self.tensor_contents = clean_to(self.tensor_contents, self.output_device)
264 | if activation_postfunc is not None:
265 | self.source_model_history._pause_logging = True
266 | self.tensor_contents = activation_postfunc(self.tensor_contents)
267 | self.source_model_history._pause_logging = False
268 |
269 | self.has_saved_activations = True
270 |
271 | # Tensor args and kwargs:
272 | if save_function_args:
273 | self.function_args_saved = True
274 | creation_args = []
275 | for arg in t_args:
276 | if type(arg) == list:
277 | creation_args.append([safe_copy(a) for a in arg])
278 | else:
279 | creation_args.append(safe_copy(arg))
280 |
281 | creation_kwargs = {}
282 | for key, value in t_kwargs.items():
283 | if type(value) == list:
284 | creation_kwargs[key] = [safe_copy(v) for v in value]
285 | else:
286 | creation_kwargs[key] = safe_copy(value)
287 | self.creation_args = creation_args
288 | self.creation_kwargs = creation_kwargs
289 | else:
290 | self.creation_args = None
291 | self.creation_kwargs = None
292 |
293 | def log_tensor_grad(self, grad: torch.Tensor):
294 | """Logs the gradient for a tensor to the log entry
295 |
296 | Args:
297 | grad: The gradient to save.
298 | """
299 | self.grad_contents = grad
300 | self.has_saved_grad = True
301 | self.grad_shape = grad.shape
302 | self.grad_dtype = grad.dtype
303 | self.grad_fsize = get_tensor_memory_amount(grad)
304 | self.grad_fsize_nice = human_readable_size(get_tensor_memory_amount(grad))
305 |
306 | # ********************************************
307 | # ************* Fetcher Functions ************
308 | # ********************************************
309 |
310 | def get_child_layers(self):
311 | return [
312 | self.source_model_history[child_label] for child_label in self.child_layers
313 | ]
314 |
315 | def get_parent_layers(self):
316 | return [
317 | self.source_model_history[parent_label]
318 | for parent_label in self.parent_layers
319 | ]
320 |
321 | # ********************************************
322 | # ************* Built-in Methods *************
323 | # ********************************************
324 |
325 | def __str__(self):
326 | if self._pass_finished:
327 | return self._str_after_pass()
328 | else:
329 | return self._str_during_pass()
330 |
331 | def _str_during_pass(self):
332 | s = f"Tensor {self.tensor_label_raw} (layer {self.layer_label_raw}) (PASS NOT FINISHED):"
333 | s += f"\n\tPass: {self.pass_num}"
334 | s += f"\n\tTensor info: shape {self.tensor_shape}, dtype {self.tensor_dtype}"
335 | s += f"\n\tComputed from params: {self.computed_with_params}"
336 | s += f"\n\tComputed in modules: {self.containing_modules_origin_nested}"
337 | s += f"\n\tOutput of modules: {self.module_passes_exited}"
338 | if self.is_bottom_level_submodule_output:
339 | s += " (bottom-level submodule output)"
340 | else:
341 | s += " (not bottom-level submodule output)"
342 | s += "\n\tFamily info:"
343 | s += f"\n\t\tParents: {self.parent_layers}"
344 | s += f"\n\t\tChildren: {self.child_layers}"
345 | s += f"\n\t\tSpouses: {self.spouse_layers}"
346 | s += f"\n\t\tSiblings: {self.sibling_layers}"
347 | s += (
348 | f"\n\t\tOriginal Ancestors: {self.orig_ancestors} "
349 | f"(min dist {self.min_distance_from_input} nodes, max dist {self.max_distance_from_input} nodes)"
350 | )
351 | s += f"\n\t\tInput Ancestors: {self.input_ancestors}"
352 | s += f"\n\t\tInternal Ancestors: {self.internally_initialized_ancestors}"
353 | s += (
354 | f"\n\t\tOutput Descendents: {self.output_descendents} "
355 | f"(min dist {self.min_distance_from_output} nodes, max dist {self.max_distance_from_output} nodes)"
356 | )
357 | if self.tensor_contents is not None:
358 | s += f"\n\tTensor contents: \n{print_override(self.tensor_contents, '__str__')}"
359 | return s
360 |
361 | def _str_after_pass(self):
362 | if self.layer_passes_total > 1:
363 | pass_str = f" (pass {self.pass_num}/{self.layer_passes_total}), "
364 | else:
365 | pass_str = ", "
366 | s = (
367 | f"Layer {self.layer_label_no_pass}"
368 | f"{pass_str}operation {self.operation_num}/"
369 | f"{self.source_model_history.num_operations}:"
370 | )
371 | s += f"\n\tOutput tensor: shape={self.tensor_shape}, dype={self.tensor_dtype}, size={self.tensor_fsize_nice}"
372 | if not self.has_saved_activations:
373 | s += " (not saved)"
374 | s += self._tensor_contents_str_helper()
375 | s += self._tensor_family_str_helper()
376 | if len(self.parent_param_shapes) > 0:
377 | params_shapes_str = ", ".join(
378 | str(param_shape) for param_shape in self.parent_param_shapes
379 | )
380 | s += (
381 | f"\n\tParams: Computed from params with shape {params_shapes_str}; "
382 | f"{self.num_params_total} params total ({self.parent_params_fsize_nice})"
383 | )
384 | else:
385 | s += "\n\tParams: no params used"
386 | if self.containing_module_origin is None:
387 | module_str = "\n\tComputed inside module: not computed inside a module"
388 | else:
389 | module_str = f"\n\tComputed inside module: {self.containing_module_origin}"
390 | if not self.is_input_layer:
391 | s += (
392 | f"\n\tFunction: {self.func_applied_name} (grad_fn: {self.gradfunc}) "
393 | f"{module_str}"
394 | )
395 | s += f"\n\tTime elapsed: {self.func_time_elapsed: .3E}s"
396 | if len(self.modules_exited) > 0:
397 | modules_exited_str = ", ".join(self.modules_exited)
398 | s += f"\n\tOutput of modules: {modules_exited_str}"
399 | else:
400 | s += "\n\tOutput of modules: none"
401 | if self.is_bottom_level_submodule_output:
402 | s += f"\n\tOutput of bottom-level module: {self.bottom_level_submodule_pass_exited}"
403 | lookup_keys_str = ", ".join([str(key) for key in self.lookup_keys])
404 | s += f"\n\tLookup keys: {lookup_keys_str}"
405 |
406 | return s
407 |
408 | def _tensor_contents_str_helper(self) -> str:
409 | """Returns short, readable string for the tensor contents."""
410 | if self.tensor_contents is None:
411 | return ""
412 | else:
413 | s = ""
414 | tensor_size_shown = 8
415 | saved_shape = self.tensor_contents.shape
416 | if len(saved_shape) == 0:
417 | tensor_slice = self.tensor_contents
418 | elif len(saved_shape) == 1:
419 | num_dims = min(tensor_size_shown, saved_shape[0])
420 | tensor_slice = self.tensor_contents[0:num_dims]
421 | elif len(saved_shape) == 2:
422 | num_dims = min([tensor_size_shown, saved_shape[-2], saved_shape[-1]])
423 | tensor_slice = self.tensor_contents[0:num_dims, 0:num_dims]
424 | else:
425 | num_dims = min(
426 | [tensor_size_shown, self.tensor_shape[-2], self.tensor_shape[-1]]
427 | )
428 | tensor_slice = self.tensor_contents.data.clone()
429 | for _ in range(len(saved_shape) - 2):
430 | tensor_slice = tensor_slice[0]
431 | tensor_slice = tensor_slice[0:num_dims, 0:num_dims]
432 | tensor_slice = tensor_slice.detach()
433 | tensor_slice.requires_grad = False
434 | s += f"\n\t\t{str(tensor_slice)}"
435 | if (len(saved_shape) > 0) and (max(saved_shape) > tensor_size_shown):
436 | s += "..."
437 | return s
438 |
439 | def _tensor_family_str_helper(self) -> str:
440 | s = "\n\tRelated Layers:"
441 | if len(self.parent_layers) > 0:
442 | s += "\n\t\t- parent layers: " + ", ".join(self.parent_layers)
443 | else:
444 | s += "\n\t\t- no parent layers"
445 |
446 | if len(self.child_layers) > 0:
447 | s += "\n\t\t- child layers: " + ", ".join(self.child_layers)
448 | else:
449 | s += "\n\t\t- no child layers"
450 |
451 | if len(self.sibling_layers) > 0:
452 | s += "\n\t\t- shares parents with layers: " + ", ".join(self.sibling_layers)
453 | else:
454 | s += "\n\t\t- shares parents with no other layers"
455 |
456 | if len(self.spouse_layers) > 0:
457 | s += "\n\t\t- shares children with layers: " + ", ".join(self.spouse_layers)
458 | else:
459 | s += "\n\t\t- shares children with no other layers"
460 |
461 | if self.has_input_ancestor:
462 | s += "\n\t\t- descendent of input layers: " + ", ".join(
463 | self.input_ancestors
464 | )
465 | else:
466 | s += "\n\t\t- tensor was created de novo inside the model (not computed from input)"
467 |
468 | if self.is_output_ancestor:
469 | s += "\n\t\t- ancestor of output layers: " + ", ".join(
470 | self.output_descendents
471 | )
472 | else:
473 | s += "\n\t\t- tensor is not an ancestor of the model output; it terminates within the model"
474 |
475 | return s
476 |
477 | def __repr__(self):
478 | return self.__str__()
479 |
480 |
481 | class RolledTensorLogEntry:
482 | def __init__(self, source_entry: TensorLogEntry):
483 | """Stripped-down version TensorLogEntry that only encodes the information needed to plot the model
484 | in its rolled-up form.
485 |
486 | Args:
487 | source_entry: The source TensorLogEntry from which the rolled node is constructed
488 | """
489 | # Label & general info
490 | self.layer_label = source_entry.layer_label_no_pass
491 | self.layer_type = source_entry.layer_type
492 | self.layer_type_num = source_entry.layer_type_num
493 | self.layer_total_num = source_entry.layer_total_num
494 | self.layer_passes_total = source_entry.layer_passes_total
495 | self.source_model_history = source_entry.source_model_history
496 |
497 | # Saved tensor info
498 | self.tensor_shape = source_entry.tensor_shape
499 | self.tensor_fsize_nice = source_entry.tensor_fsize_nice
500 |
501 | # Param info:
502 | self.computed_with_params = source_entry.computed_with_params
503 | self.parent_param_shapes = source_entry.parent_param_shapes
504 | self.num_param_tensors = source_entry.num_param_tensors
505 |
506 | # Graph info
507 | self.is_input_layer = source_entry.is_input_layer
508 | self.has_input_ancestor = source_entry.has_input_ancestor
509 | self.is_output_layer = source_entry.is_output_layer
510 | self.is_last_output_layer = source_entry.is_last_output_layer
511 | self.is_buffer_layer = source_entry.is_buffer_layer
512 | self.buffer_address = source_entry.buffer_address
513 | self.buffer_pass = source_entry.buffer_pass
514 | self.input_output_address = source_entry.input_output_address
515 | self.cond_branch_start_children = source_entry.cond_branch_start_children
516 | self.is_terminal_bool_layer = source_entry.is_terminal_bool_layer
517 | self.atomic_bool_val = source_entry.atomic_bool_val
518 | self.child_layers = []
519 | self.parent_layers = []
520 | self.orphan_layers = []
521 |
522 | # Module info:
523 | self.containing_modules_origin_nested = (
524 | source_entry.containing_modules_origin_nested
525 | )
526 | self.modules_exited = source_entry.modules_exited
527 | self.module_passes_exited = source_entry.module_passes_exited
528 | self.is_bottom_level_submodule_output = False
529 | self.bottom_level_submodule_passes_exited = set()
530 |
531 | # Fields specific to rolled node to fill in:
532 | self.edges_vary_across_passes = False
533 | self.child_layers_per_pass = defaultdict(list)
534 | self.child_passes_per_layer = defaultdict(list)
535 | self.parent_layers_per_pass = defaultdict(list)
536 | self.parent_passes_per_layer = defaultdict(list)
537 |
538 | # Each one will now be a list of layers, since they can vary across passes.
539 | self.parent_layer_arg_locs = {
540 | "args": defaultdict(set),
541 | "kwargs": defaultdict(set),
542 | }
543 |
544 | def update_data(self, source_node: TensorLogEntry):
545 | """Updates the data as need be.
546 | Args:
547 | source_node: the source node
548 | """
549 | if source_node.has_input_ancestor:
550 | self.has_input_ancestor = True
551 | if not any(
552 | [
553 | self.input_output_address is None,
554 | source_node.input_output_address is None,
555 | ]
556 | ):
557 | self.input_output_address = "".join(
558 | [
559 | char if (source_node.input_output_address[c] == char) else "*"
560 | for c, char in enumerate(self.input_output_address)
561 | ]
562 | )
563 | if self.input_output_address[-1] == ".":
564 | self.input_output_address = self.input_output_address[:-1]
565 | if self.input_output_address[-1] == "*":
566 | self.input_output_address = self.input_output_address.strip("*") + "*"
567 |
568 | def add_pass_info(self, source_node: TensorLogEntry):
569 | """Adds information about another pass of the same layer: namely, mark information about what the
570 | child and parent layers are for each pass.
571 |
572 | Args:
573 | source_node: Information for the source pass
574 | """
575 | # Label the layers for each pass
576 | child_layer_labels = [
577 | self.source_model_history[child].layer_label_no_pass
578 | for child in source_node.child_layers
579 | ]
580 | for child_layer in child_layer_labels:
581 | if child_layer not in self.child_layers:
582 | self.child_layers.append(child_layer)
583 | if child_layer not in self.child_layers_per_pass[source_node.pass_num]:
584 | self.child_layers_per_pass[source_node.pass_num].append(child_layer)
585 |
586 | parent_layer_labels = [
587 | self.source_model_history[parent].layer_label_no_pass
588 | for parent in source_node.parent_layers
589 | ]
590 | for parent_layer in parent_layer_labels:
591 | if parent_layer not in self.parent_layers:
592 | self.parent_layers.append(parent_layer)
593 | if parent_layer not in self.parent_layers_per_pass[source_node.pass_num]:
594 | self.parent_layers_per_pass[source_node.pass_num].append(parent_layer)
595 |
596 | # Label the passes for each layer, and indicate if any layers vary based on the pass.
597 | for child_layer in source_node.child_layers:
598 | child_layer_label = self.source_model_history[
599 | child_layer
600 | ].layer_label_no_pass
601 | if (
602 | source_node.pass_num
603 | not in self.child_passes_per_layer[child_layer_label]
604 | ):
605 | self.child_passes_per_layer[child_layer_label].append(
606 | source_node.pass_num
607 | )
608 |
609 | for parent_layer in source_node.parent_layers:
610 | parent_layer_label = self.source_model_history[
611 | parent_layer
612 | ].layer_label_no_pass
613 | if (
614 | source_node.pass_num
615 | not in self.parent_passes_per_layer[parent_layer_label]
616 | ):
617 | self.parent_passes_per_layer[parent_layer_label].append(
618 | source_node.pass_num
619 | )
620 |
621 | # Check if any edges vary across passes.
622 | if source_node.pass_num == source_node.layer_passes_total:
623 | pass_lists = list(self.parent_passes_per_layer.values()) + list(
624 | self.child_passes_per_layer.values()
625 | )
626 | pass_lens = [len(passes) for passes in pass_lists]
627 | if any(
628 | [pass_len < source_node.layer_passes_total for pass_len in pass_lens]
629 | ):
630 | self.edges_vary_across_passes = True
631 | else:
632 | self.edges_vary_across_passes = False
633 |
634 | # Add submodule info:
635 | if source_node.is_bottom_level_submodule_output:
636 | self.is_bottom_level_submodule_output = True
637 | self.bottom_level_submodule_passes_exited.add(
638 | source_node.bottom_level_submodule_pass_exited
639 | )
640 |
641 | # For the parent arg locations, have a list of layers rather than single layer, since they can
642 | # vary across passes.
643 |
644 | for arg_type in ["args", "kwargs"]:
645 | for arg_key, layer_label in source_node.parent_layer_arg_locs[
646 | arg_type
647 | ].items():
648 | layer_label_no_pass = self.source_model_history[
649 | layer_label
650 | ].layer_label_no_pass
651 | self.parent_layer_arg_locs[arg_type][arg_key].add(layer_label_no_pass)
652 |
653 | def __str__(self) -> str:
654 | fields_not_to_print = ["source_model_history"]
655 | s = ""
656 | for field in dir(self):
657 | attr = getattr(self, field)
658 | if (
659 | not field.startswith("_")
660 | and field not in fields_not_to_print
661 | and not (callable(attr))
662 | ):
663 | s += f"{field}: {attr}\n"
664 | return s
665 |
666 | def __repr__(self):
667 | return self.__str__()
668 |
--------------------------------------------------------------------------------
/torchlens/validation.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Any, Dict, List, Set, TYPE_CHECKING, Union
3 |
4 | import torch
5 |
6 | from .tensor_log import TensorLogEntry
7 |
8 | if TYPE_CHECKING:
9 | pass
10 |
11 | from .helper_funcs import (
12 | log_current_rng_states,
13 | set_rng_from_saved_states,
14 | tuple_tolerant_assign,
15 | tensor_nanequal,
16 | tensor_all_nan,
17 | )
18 |
19 | FUNCS_NOT_TO_PERTURB_IN_VALIDATION = [
20 | "expand_as",
21 | "new_zeros",
22 | "new_ones",
23 | "zero_",
24 | "copy_",
25 | "clamp",
26 | "fill_",
27 | "zeros_like",
28 | "ones_like",
29 | ]
30 |
31 |
32 | def validate_saved_activations(
33 | self, ground_truth_output_tensors: List[torch.Tensor], verbose: bool = False
34 | ) -> bool:
35 | """Starting from outputs and internally terminated tensors, checks whether computing their values from the saved
36 | values of their input tensors yields their actually saved values, and whether computing their values from
37 | their parent tensors yields their saved values.
38 |
39 | Returns:
40 | True if it passes the tests, False otherwise.
41 | """
42 | # First check that the ground truth output tensors are accurate:
43 | for i, output_layer_label in enumerate(self.output_layers):
44 | output_layer = self[output_layer_label]
45 | if not tensor_nanequal(
46 | output_layer.tensor_contents,
47 | ground_truth_output_tensors[i],
48 | allow_tolerance=False,
49 | ):
50 | print(
51 | f"The {i}th output layer, {output_layer_label}, does not match the ground truth output tensor."
52 | )
53 | return False
54 |
55 | # Validate the parents of each validated layer.
56 | validated_child_edges_for_each_layer = defaultdict(set)
57 | validated_layers = set(self.output_layers + self.internally_terminated_layers)
58 | layers_to_validate_parents_for = list(validated_layers)
59 |
60 | while len(layers_to_validate_parents_for) > 0:
61 | layer_to_validate_parents_for = layers_to_validate_parents_for.pop(0)
62 | parent_layers_valid = validate_parents_of_saved_layer(
63 | self,
64 | layer_to_validate_parents_for,
65 | validated_layers,
66 | validated_child_edges_for_each_layer,
67 | layers_to_validate_parents_for,
68 | verbose,
69 | )
70 | if not parent_layers_valid:
71 | return False
72 |
73 | if len(validated_layers) < len(self.layer_labels):
74 | print(
75 | f"All saved activations were accurate, but some layers were not reached (check that "
76 | f"child args logged accurately): {set(self.layer_labels) - validated_layers}"
77 | )
78 | return False
79 |
80 | return True
81 |
82 |
83 | def validate_parents_of_saved_layer(
84 | self,
85 | layer_to_validate_parents_for_label: str,
86 | validated_layers: Set[str],
87 | validated_child_edges_for_each_layer: Dict[str, Set[str]],
88 | layers_to_validate_parents_for: List[str],
89 | verbose: bool = False,
90 | ) -> bool:
91 | """Given a layer, checks that 1) all parent tensors appear properly in the saved arguments for that layer,
92 | 2) that executing the function for that layer with the saved parent layer activations yields the
93 | ground truth activation values for that layer, and 3) that plugging in "perturbed" values for each
94 | child layer yields values different from the saved activations for that layer.
95 |
96 | Args:
97 | layer_to_validate_parents_for_label:
98 | validated_layers:
99 | validated_child_edges_for_each_layer:
100 | layers_to_validate_parents_for:
101 | verbose: whether to print warning messages
102 | """
103 | layer_to_validate_parents_for = self[layer_to_validate_parents_for_label]
104 |
105 | # Check that the arguments are logged correctly:
106 | if not _check_layer_arguments_logged_correctly(
107 | self, layer_to_validate_parents_for_label
108 | ):
109 | print(
110 | f"Parent arguments for layer {layer_to_validate_parents_for_label} are not logged properly; "
111 | f"either a parent wasn't logged as an argument, or was logged an extra time"
112 | )
113 | return False
114 |
115 | # Check that executing the function based on the actual saved values of the parents yields the saved
116 | # values of the layer itself:
117 |
118 | if not _check_whether_func_on_saved_parents_yields_saved_tensor(
119 | self, layer_to_validate_parents_for_label, perturb=False
120 | ):
121 | return False
122 |
123 | # Check that executing the layer's function on the wrong version of the saved parent tensors
124 | # yields the wrong tensors, when each saved tensor is perturbed in turn:
125 |
126 | for perturb_layer in layer_to_validate_parents_for.parent_layers:
127 | if (
128 | layer_to_validate_parents_for.func_applied_name
129 | in FUNCS_NOT_TO_PERTURB_IN_VALIDATION
130 | ):
131 | continue
132 | if not _check_whether_func_on_saved_parents_yields_saved_tensor(
133 | self,
134 | layer_to_validate_parents_for_label,
135 | perturb=True,
136 | layers_to_perturb=[perturb_layer],
137 | verbose=verbose,
138 | ):
139 | return False
140 |
141 | # Log that each parent layer has been validated for this source layer.
142 |
143 | for parent_layer_label in layer_to_validate_parents_for.parent_layers:
144 | parent_layer = self[parent_layer_label]
145 | validated_child_edges_for_each_layer[parent_layer_label].add(
146 | layer_to_validate_parents_for_label
147 | )
148 | if validated_child_edges_for_each_layer[parent_layer_label] == set(
149 | parent_layer.child_layers
150 | ):
151 | validated_layers.add(parent_layer_label)
152 | if ((not parent_layer.is_input_layer) and
153 | not (parent_layer.is_buffer_layer and (parent_layer.buffer_parent is None))):
154 | layers_to_validate_parents_for.append(parent_layer_label)
155 |
156 | return True
157 |
158 |
159 | def _check_layer_arguments_logged_correctly(self, target_layer_label: str) -> bool:
160 | """Check whether the activations of the parent layers match the saved arguments of
161 | the target layer, and that the argument locations have been logged correctly.
162 |
163 | Args:
164 | target_layer_label: Layer to check
165 |
166 | Returns:
167 | True if arguments logged accurately, False otherwise
168 | """
169 | target_layer = self[target_layer_label]
170 |
171 | # Make sure that all parent layers appear in at least one argument and that no extra layers appear:
172 | parent_layers_in_args = set()
173 | for arg_type in ["args", "kwargs"]:
174 | parent_layers_in_args.update(
175 | list(target_layer.parent_layer_arg_locs[arg_type].values())
176 | )
177 | if parent_layers_in_args != set(target_layer.parent_layers):
178 | return False
179 |
180 | argtype_dict = {
181 | "args": (enumerate, "creation_args"),
182 | "kwargs": (lambda x: x.items(), "creation_kwargs"),
183 | }
184 |
185 | # Check for each parent layer that it is logged as a saved argument when it matches an argument, and
186 | # is not logged when it does not match a saved argument.
187 |
188 | for parent_layer_label in target_layer.parent_layers:
189 | parent_layer = self[parent_layer_label]
190 | for arg_type in ["args", "kwargs"]:
191 | iterfunc, argtype_field = argtype_dict[arg_type]
192 | for key, val in iterfunc(getattr(target_layer, argtype_field)):
193 | validation_correct_for_arg_and_layer = (
194 | _validate_layer_against_arg(
195 | self, target_layer, parent_layer, arg_type, key, val
196 | )
197 | )
198 | if not validation_correct_for_arg_and_layer:
199 | return False
200 | return True
201 |
202 |
203 | def _validate_layer_against_arg(
204 | self, target_layer, parent_layer, arg_type, key, val
205 | ):
206 | if type(val) in [list, tuple]:
207 | for v, subval in enumerate(val):
208 | argloc_key = (key, v)
209 | validation_correct_for_arg_and_layer = (
210 | _check_arglocs_correct_for_arg(
211 | self, target_layer, parent_layer, arg_type, argloc_key, subval
212 | )
213 | )
214 | if not validation_correct_for_arg_and_layer:
215 | return False
216 |
217 | elif type(val) == dict:
218 | for subkey, subval in val.items():
219 | argloc_key = (key, subkey)
220 | validation_correct_for_arg_and_layer = (
221 | _check_arglocs_correct_for_arg(
222 | self, target_layer, parent_layer, arg_type, argloc_key, subval
223 | )
224 | )
225 | if not validation_correct_for_arg_and_layer:
226 | return False
227 | else:
228 | argloc_key = key
229 | validation_correct_for_arg_and_layer = _check_arglocs_correct_for_arg(
230 | self, target_layer, parent_layer, arg_type, argloc_key, val
231 | )
232 | if not validation_correct_for_arg_and_layer:
233 | return False
234 |
235 | return True
236 |
237 |
238 | def _check_arglocs_correct_for_arg(
239 | self,
240 | target_layer: TensorLogEntry,
241 | parent_layer: TensorLogEntry,
242 | arg_type: str,
243 | argloc_key: Union[str, tuple],
244 | saved_arg_val: Any,
245 | ):
246 | """For a given layer and an argument to its child layer, checks that it is logged correctly:
247 | that is, that it's logged as an argument if it matches, and is not logged as an argument if it doesn't match.
248 | """
249 | target_layer_label = target_layer.layer_label
250 | parent_layer_label = parent_layer.layer_label
251 | if target_layer_label in parent_layer.children_tensor_versions:
252 | parent_activations = parent_layer.children_tensor_versions[
253 | target_layer_label
254 | ]
255 | else:
256 | parent_activations = parent_layer.tensor_contents
257 |
258 | if type(saved_arg_val) == torch.Tensor:
259 | parent_layer_matches_arg = tensor_nanequal(
260 | saved_arg_val, parent_activations, allow_tolerance=False
261 | )
262 | else:
263 | parent_layer_matches_arg = False
264 | parent_layer_logged_as_arg = (
265 | argloc_key in target_layer.parent_layer_arg_locs[arg_type]
266 | ) and (
267 | target_layer.parent_layer_arg_locs[arg_type][argloc_key]
268 | == parent_layer_label
269 | )
270 |
271 | if (
272 | parent_layer_matches_arg
273 | and (not parent_layer_logged_as_arg)
274 | and (parent_activations.numel() != 0)
275 | and (parent_activations.dtype != torch.bool)
276 | and (not tensor_all_nan(parent_activations))
277 | and (parent_activations.abs().float().mean() != 0)
278 | and (parent_activations.abs().float().mean() != 1)
279 | and not any(
280 | [
281 | torch.equal(parent_activations, self[other_parent].tensor_contents)
282 | for other_parent in target_layer.parent_layers
283 | if other_parent != parent_layer_label
284 | ]
285 | )
286 | ):
287 | print(
288 | f"Parent {parent_layer_label} of {target_layer_label} has activations that match "
289 | f"{arg_type} {argloc_key} for {target_layer_label}, but is not logged as "
290 | f"such in parent_layer_arg_locs."
291 | )
292 | return False
293 |
294 | if (not parent_layer_matches_arg) and parent_layer_logged_as_arg:
295 | print(
296 | f"Parent {parent_layer_label} of {target_layer_label} is logged as {arg_type} {argloc_key} to "
297 | f"{target_layer_label}, but its saved activations don't match the saved argument."
298 | )
299 | return False
300 |
301 | return True
302 |
303 |
304 | def _check_whether_func_on_saved_parents_yields_saved_tensor(
305 | self,
306 | layer_to_validate_parents_for_label: str,
307 | perturb: bool = False,
308 | layers_to_perturb: List[str] = None,
309 | verbose: bool = False,
310 | ) -> bool:
311 | """Checks whether executing the saved function for a layer on the saved value of its parent layers
312 | in fact yields the saved activations for that layer.
313 |
314 | Args:
315 | layer_to_validate_parents_for_label: label of the layer to check the saved activations
316 | perturb: whether to perturb the saved activations
317 | layers_to_perturb: layers for which to perturb the saved activations
318 |
319 | Returns:
320 | True if the activations match, False otherwise
321 | """
322 | if layers_to_perturb is None:
323 | layers_to_perturb = []
324 |
325 | layer_to_validate_parents_for = self[layer_to_validate_parents_for_label]
326 |
327 | if (
328 | perturb
329 | and (layer_to_validate_parents_for.func_applied_name == "__getitem__")
330 | and (type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor)
331 | and torch.equal(
332 | self[layers_to_perturb[0]].tensor_contents,
333 | layer_to_validate_parents_for.creation_args[1],
334 | )
335 | ):
336 | return True
337 | elif (
338 | perturb
339 | and (layer_to_validate_parents_for.func_applied_name == "__getitem__")
340 | and not torch.equal(
341 | self[layers_to_perturb[0]].tensor_contents,
342 | layer_to_validate_parents_for.creation_args[0],
343 | )
344 | ):
345 | return True
346 | elif layer_to_validate_parents_for.func_applied_name == 'empty_like':
347 | return True
348 | elif (
349 | perturb
350 | and (layer_to_validate_parents_for.func_applied_name == "__setitem__")
351 | and (type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor)
352 | and (layer_to_validate_parents_for.creation_args[1].dtype == torch.bool)
353 | and torch.equal(
354 | self[layers_to_perturb[0]].tensor_contents,
355 | layer_to_validate_parents_for.creation_args[1],
356 | )
357 | ):
358 | return True
359 | elif (
360 | perturb
361 | and (layer_to_validate_parents_for.func_applied_name == "cross_entropy")
362 | and torch.equal(
363 | self[layers_to_perturb[0]].tensor_contents,
364 | layer_to_validate_parents_for.creation_args[1],
365 | )
366 | ):
367 | return True
368 | elif (
369 | perturb
370 | and (layer_to_validate_parents_for.func_applied_name == "__setitem__")
371 | and (type(layer_to_validate_parents_for.creation_args[1]) == tuple)
372 | and (
373 | type(layer_to_validate_parents_for.creation_args[1][0]) == torch.Tensor
374 | )
375 | and (layer_to_validate_parents_for.creation_args[1][0].dtype == torch.bool)
376 | and torch.equal(
377 | self[layers_to_perturb[0]].tensor_contents,
378 | layer_to_validate_parents_for.creation_args[1][0],
379 | )
380 | ):
381 | return True
382 | elif (
383 | perturb
384 | and (layer_to_validate_parents_for.func_applied_name == "index_select")
385 | and torch.equal(
386 | self[layers_to_perturb[0]].tensor_contents,
387 | layer_to_validate_parents_for.creation_args[2],
388 | )
389 | ):
390 | return True
391 | elif (
392 | perturb
393 | and (layer_to_validate_parents_for.func_applied_name == "lstm")
394 | and (torch.equal(
395 | self[layers_to_perturb[0]].tensor_contents,
396 | layer_to_validate_parents_for.creation_args[1][0]) or
397 | torch.equal(
398 | self[layers_to_perturb[0]].tensor_contents,
399 | layer_to_validate_parents_for.creation_args[1][1]) or
400 | torch.equal(
401 | self[layers_to_perturb[0]].tensor_contents,
402 | layer_to_validate_parents_for.creation_args[2][0]) or
403 | torch.equal(
404 | self[layers_to_perturb[0]].tensor_contents,
405 | layer_to_validate_parents_for.creation_args[2][1]) or
406 | ((type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor) and
407 | torch.equal(
408 | self[layers_to_perturb[0]].tensor_contents,
409 | layer_to_validate_parents_for.creation_args[1])
410 | ))):
411 | return True
412 | elif (
413 | perturb
414 | and (layer_to_validate_parents_for.func_applied_name == "_pad_packed_sequence")
415 | and torch.equal(
416 | self[layers_to_perturb[0]].tensor_contents,
417 | layer_to_validate_parents_for.creation_args[1]
418 | )):
419 | return True
420 | elif (
421 | perturb
422 | and (layer_to_validate_parents_for.func_applied_name == "masked_fill_")
423 | and torch.equal(
424 | self[layers_to_perturb[0]].tensor_contents,
425 | layer_to_validate_parents_for.creation_args[1]
426 | )):
427 | return True
428 | elif (
429 | perturb
430 | and (layer_to_validate_parents_for.func_applied_name == "scatter_")
431 | and torch.equal(
432 | self[layers_to_perturb[0]].tensor_contents,
433 | layer_to_validate_parents_for.creation_args[2]
434 | )):
435 | return True
436 | elif (
437 | perturb
438 | and (layer_to_validate_parents_for.func_applied_name == "interpolate")
439 | and ((('scale_factor' in layer_to_validate_parents_for.creation_kwargs)
440 | and (layer_to_validate_parents_for.creation_kwargs['scale_factor'] is not None)
441 | and torch.equal(
442 | self[layers_to_perturb[0]].tensor_contents,
443 | torch.tensor(layer_to_validate_parents_for.creation_kwargs['scale_factor'])))
444 | or ((len(layer_to_validate_parents_for.creation_args) >= 3)
445 | and torch.equal(
446 | self[layers_to_perturb[0]].tensor_contents,
447 | layer_to_validate_parents_for.creation_args[2])))
448 | ):
449 | return True
450 |
451 | # Prepare input arguments: keep the ones that should just be kept, perturb those that should be perturbed
452 |
453 | input_args = _prepare_input_args_for_validating_layer(
454 | self, layer_to_validate_parents_for, layers_to_perturb
455 | )
456 |
457 | # set the saved rng value:
458 | layer_func = layer_to_validate_parents_for.func_applied
459 | current_rng_states = log_current_rng_states()
460 | set_rng_from_saved_states(layer_to_validate_parents_for.func_rng_states)
461 | try:
462 | recomputed_output = layer_func(*input_args["args"], **input_args["kwargs"])
463 | except:
464 | raise Exception(f"Invalid perturbed arguments for layer {layer_to_validate_parents_for_label}")
465 | set_rng_from_saved_states(current_rng_states)
466 |
467 | if layer_func.__name__ in [
468 | "__setitem__",
469 | "zero_",
470 | "__delitem__",
471 | ]: # TODO: fix this
472 | recomputed_output = input_args["args"][0]
473 |
474 | if any([issubclass(type(recomputed_output), which_type) for which_type in [list, tuple]]):
475 | recomputed_output = recomputed_output[
476 | layer_to_validate_parents_for.iterable_output_index
477 | ]
478 |
479 | if (
480 | not (
481 | tensor_nanequal(
482 | recomputed_output,
483 | layer_to_validate_parents_for.tensor_contents,
484 | allow_tolerance=True,
485 | )
486 | )
487 | and not perturb
488 | ):
489 | print(
490 | f"Saved activations for layer {layer_to_validate_parents_for_label} do not match the "
491 | f"values computed based on the parent layers {layer_to_validate_parents_for.parent_layers}."
492 | )
493 | return False
494 |
495 | if (
496 | tensor_nanequal(
497 | recomputed_output,
498 | layer_to_validate_parents_for.tensor_contents,
499 | allow_tolerance=False,
500 | )
501 | and perturb
502 | ):
503 | return _posthoc_perturb_check(
504 | self, layer_to_validate_parents_for, layers_to_perturb, verbose
505 | )
506 |
507 | return True
508 |
509 |
510 | def _prepare_input_args_for_validating_layer(
511 | self,
512 | layer_to_validate_parents_for: TensorLogEntry,
513 | layers_to_perturb: List[str],
514 | ) -> Dict:
515 | """Prepares the input arguments for validating the saved activations of a layer.
516 |
517 | Args:
518 | layer_to_validate_parents_for: Layer being checked.
519 | layers_to_perturb: Layers for which to perturb the saved activations.
520 |
521 | Returns:
522 | Dict of input arguments.
523 | """
524 | input_args = {
525 | "args": list(layer_to_validate_parents_for.creation_args[:]),
526 | "kwargs": layer_to_validate_parents_for.creation_kwargs.copy(),
527 | }
528 | input_args = _copy_validation_args(input_args)
529 |
530 | # Swap in saved parent activations:
531 |
532 | for arg_type in ["args", "kwargs"]:
533 | for (
534 | key,
535 | parent_layer_arg,
536 | ) in layer_to_validate_parents_for.parent_layer_arg_locs[arg_type].items():
537 | parent_layer = self[parent_layer_arg]
538 | if (
539 | layer_to_validate_parents_for.layer_label
540 | in parent_layer.children_tensor_versions
541 | ):
542 | parent_values = parent_layer.children_tensor_versions[
543 | layer_to_validate_parents_for.layer_label
544 | ]
545 | else:
546 | parent_values = parent_layer.tensor_contents
547 | parent_values = parent_values.detach().clone()
548 |
549 | if parent_layer_arg in layers_to_perturb:
550 | parent_layer_func_values = _perturb_layer_activations(
551 | parent_values, layer_to_validate_parents_for.tensor_contents
552 | )
553 | else:
554 | parent_layer_func_values = parent_values
555 |
556 | if type(key) != tuple:
557 | input_args[arg_type][key] = parent_layer_func_values
558 | else:
559 | input_args[arg_type][key[0]] = tuple_tolerant_assign(
560 | input_args[arg_type][key[0]], key[1], parent_layer_func_values
561 | )
562 |
563 | return input_args
564 |
565 |
566 | def _copy_validation_args(input_args: Dict):
567 | new_args = []
568 | for i, val in enumerate(input_args["args"]):
569 | if type(val) == torch.Tensor:
570 | new_args.append(val.detach().clone())
571 | elif type(val) in [list, tuple, set]:
572 | new_iter = []
573 | for i2, val2 in enumerate(val):
574 | if type(val2) == torch.Tensor:
575 | new_iter.append(val2.detach().clone())
576 | else:
577 | new_iter.append(val2)
578 | new_args.append(type(val)(new_iter))
579 | else:
580 | new_args.append(val)
581 | input_args["args"] = new_args
582 |
583 | new_kwargs = {}
584 | for key, val in input_args["kwargs"].items():
585 | if type(val) == torch.Tensor:
586 | new_kwargs[key] = val.detach().clone()
587 | elif type(val) in [list, tuple, set]:
588 | new_iter = []
589 | for i2, val2 in enumerate(val):
590 | if type(val2) == torch.Tensor:
591 | new_iter.append(val2.detach().clone())
592 | else:
593 | new_iter.append(val2)
594 | new_kwargs[key] = type(val)(new_iter)
595 | else:
596 | new_kwargs[key] = val
597 | input_args["kwargs"] = new_kwargs
598 | return input_args
599 |
600 |
601 | def _perturb_layer_activations(
602 | parent_activations: torch.Tensor, output_activations: torch.Tensor
603 | ) -> torch.Tensor:
604 | """
605 | Perturbs the values of a saved tensor.
606 |
607 | Args:
608 | parent_activations: Tensor of activation values for the parent tensor
609 | output_activations: Tensor of activation values for the tensor whose parents are being tested (the output)
610 |
611 | Returns:
612 | Perturbed version of saved tensor
613 | """
614 | device = parent_activations.device
615 | if parent_activations.numel() == 0:
616 | return parent_activations.detach().clone()
617 |
618 | if parent_activations.dtype in [
619 | torch.int,
620 | torch.long,
621 | torch.short,
622 | torch.uint8,
623 | torch.int8,
624 | torch.int16,
625 | torch.int32,
626 | torch.int64,
627 | ]:
628 | tensor_unique_vals = torch.unique(parent_activations)
629 | if len(tensor_unique_vals) > 1:
630 | perturbed_activations = parent_activations.detach().clone()
631 | while torch.equal(perturbed_activations, parent_activations):
632 | perturbed_activations = torch.randint(
633 | parent_activations.min(),
634 | parent_activations.max() + 1,
635 | size=parent_activations.shape,
636 | device=device,
637 | ).type(parent_activations.dtype)
638 | else:
639 | perturbed_activations = parent_activations.detach().clone()
640 | while torch.equal(perturbed_activations, parent_activations):
641 | if torch.min(parent_activations) < 0:
642 | perturbed_activations = torch.randint(
643 | -10, 11, size=parent_activations.shape, device=device
644 | ).type(parent_activations.dtype)
645 | else:
646 | perturbed_activations = torch.randint(
647 | 0, 11, size=parent_activations.shape, device=device
648 | ).type(parent_activations.dtype)
649 |
650 | elif parent_activations.dtype == torch.bool:
651 | perturbed_activations = parent_activations.detach().clone()
652 | while torch.equal(perturbed_activations, parent_activations):
653 | perturbed_activations = torch.randint(
654 | 0, 2, size=parent_activations.shape, device=device
655 | ).bool()
656 | else:
657 | mean_output_sqrt = output_activations.detach().float().abs().mean()
658 | mean_output_sqrt += torch.rand(mean_output_sqrt.shape) * 100
659 | mean_output_sqrt *= torch.rand(mean_output_sqrt.shape)
660 | mean_output_sqrt.requires_grad = False
661 | perturbed_activations = torch.randn_like(
662 | parent_activations.float(), device=device
663 | ) * mean_output_sqrt.to(device)
664 | perturbed_activations = perturbed_activations.type(parent_activations.dtype)
665 |
666 | return perturbed_activations
667 |
668 |
669 | def _posthoc_perturb_check(
670 | self,
671 | layer_to_validate_parents_for: TensorLogEntry,
672 | layers_to_perturb: List[str],
673 | verbose: bool = False,
674 | ) -> bool:
675 | """If a layer fails the "perturbation check"--that is, if perturbing the values of parent
676 | layers doesn't change the values relative to the layer's saved values--checks whether one of the
677 | remaining arguments is a "special" tensor, such as all-ones or all-zeros, such that perturbing a tensor
678 | wouldn't necessarily change the output of the layer.
679 |
680 | Args:
681 | layer_to_validate_parents_for: layer being checked.
682 | layers_to_perturb: parent layers being perturbed
683 |
684 | Returns:
685 | True if there's an "excuse" for the perturbation failing, False otherwise.
686 | """
687 | # Check if the tensor is all nans or all infinite:
688 | if layer_to_validate_parents_for.tensor_dtype == torch.bool:
689 | return True
690 | elif (
691 | (layer_to_validate_parents_for.func_applied_name == "to")
692 | and (len(layer_to_validate_parents_for.creation_args) > 1)
693 | and (type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor)
694 | ):
695 | return True
696 | elif (
697 | (layer_to_validate_parents_for.func_applied_name == "__setitem__")
698 | and (type(layer_to_validate_parents_for.creation_args[2]) == torch.Tensor)
699 | and (
700 | layer_to_validate_parents_for.creation_args[0].shape
701 | == layer_to_validate_parents_for.creation_args[2].shape
702 | )
703 | ):
704 | return True
705 | elif (
706 | layer_to_validate_parents_for.func_applied_name in ["__getitem__", "unbind"]
707 | ) and (
708 | layer_to_validate_parents_for.tensor_contents.numel() < 20
709 | ): # some elements can be the same by chance
710 | return True
711 | elif (
712 | (layer_to_validate_parents_for.func_applied_name == "__getitem__")
713 | and (type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor)
714 | and (layer_to_validate_parents_for.creation_args[1].unique() < 20)
715 | ):
716 | return True
717 | elif (layer_to_validate_parents_for.func_applied_name == "max") and len(
718 | layer_to_validate_parents_for.creation_args
719 | ) > 1:
720 | return True
721 | elif (
722 | layer_to_validate_parents_for.func_applied_name == "max"
723 | ) and not torch.is_floating_point(
724 | layer_to_validate_parents_for.creation_args[0]
725 | ):
726 | return True
727 | else:
728 | num_inf = (
729 | torch.isinf(layer_to_validate_parents_for.tensor_contents.abs())
730 | .int()
731 | .sum()
732 | )
733 | num_nan = (
734 | torch.isnan(layer_to_validate_parents_for.tensor_contents.abs())
735 | .int()
736 | .sum()
737 | )
738 | if (num_inf == layer_to_validate_parents_for.tensor_contents.numel()) or (
739 | num_nan == layer_to_validate_parents_for.tensor_contents.numel()
740 | ):
741 | return True
742 |
743 | arg_type_dict = {
744 | "args": (enumerate, "creation_args"),
745 | "kwargs": (lambda x: x.items(), "creation_kwargs"),
746 | }
747 |
748 | layer_to_validate_parents_for_label = layer_to_validate_parents_for.layer_label
749 | for arg_type in ["args", "kwargs"]:
750 | iterfunc, fieldname = arg_type_dict[arg_type]
751 | for key, val in iterfunc(getattr(layer_to_validate_parents_for, fieldname)):
752 | # Skip if it's the argument itself:
753 | if (
754 | key in layer_to_validate_parents_for.parent_layer_arg_locs[arg_type]
755 | ) and (
756 | layer_to_validate_parents_for.parent_layer_arg_locs[arg_type][key]
757 | ) in layers_to_perturb:
758 | continue
759 | arg_is_special = _check_if_arg_is_special_val(val)
760 | if arg_is_special:
761 | if verbose:
762 | print(
763 | f"Activations for layer {layer_to_validate_parents_for_label} do not change when "
764 | f"values for {layers_to_perturb} are changed (out of parent "
765 | f"layers {layer_to_validate_parents_for.parent_layers}), but {arg_type[:-1]} {key} is "
766 | f"all zeros or all-ones, so validation still succeeds..."
767 | )
768 | return True
769 |
770 | print(
771 | f"Activations for layer {layer_to_validate_parents_for_label} do not change when "
772 | f"values for {layers_to_perturb} are changed (out of parent "
773 | f"layers {layer_to_validate_parents_for.parent_layers}), and the other "
774 | f'arguments are not "special" (all-ones or all-zeros) tensors.'
775 | )
776 | return False
777 |
778 |
779 | def _check_if_arg_is_special_val(val: Union[torch.Tensor, Any]):
780 | # If it's one of the other arguments, check if it's all zeros or all ones:
781 | if type(val) != torch.Tensor:
782 | try:
783 | val = torch.Tensor(val)
784 | except:
785 | return True
786 | if torch.all(torch.eq(val, 0)) or torch.all(torch.eq(val, 1)) or (val.numel() == 0):
787 | return True
788 | else:
789 | return False
790 |
--------------------------------------------------------------------------------