├── requirements.txt ├── images ├── logo.png ├── alexnet.png ├── gradients.png ├── cheatsheet.pdf ├── swin_v2_b_demo.jpg ├── simple_recurrent.png └── nested_modules_example.png ├── tests └── test_metadata.py ├── requirements.test.txt ├── torchlens ├── __init__.py ├── cleanup.py ├── model_history.py ├── trace_model.py ├── decorate_torch.py ├── constants.py ├── interface.py ├── model_funcs.py ├── user_funcs.py ├── helper_funcs.py ├── tensor_log.py └── validation.py ├── .pre-commit-config.yaml ├── .github └── workflows │ └── lint.yml ├── setup.py ├── .gitignore └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | tqdm 4 | ipython 5 | graphviz 6 | torch 7 | -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/logo.png -------------------------------------------------------------------------------- /images/alexnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/alexnet.png -------------------------------------------------------------------------------- /images/gradients.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/gradients.png -------------------------------------------------------------------------------- /tests/test_metadata.py: -------------------------------------------------------------------------------- 1 | # This is for making sure the different kinds of metadata work out properly. 2 | -------------------------------------------------------------------------------- /images/cheatsheet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/cheatsheet.pdf -------------------------------------------------------------------------------- /images/swin_v2_b_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/swin_v2_b_demo.jpg -------------------------------------------------------------------------------- /images/simple_recurrent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/simple_recurrent.png -------------------------------------------------------------------------------- /images/nested_modules_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmarktaylor91/torchlens/HEAD/images/nested_modules_example.png -------------------------------------------------------------------------------- /requirements.test.txt: -------------------------------------------------------------------------------- 1 | black 2 | cornet @ git+https://github.com/dicarlolab/CORnet 3 | lightning 4 | numpy 5 | pillow 6 | pytest 7 | requests 8 | timm 9 | torch 10 | torchaudio 11 | torch_geometric 12 | torchlens 13 | torchvision 14 | transformers 15 | visualpriors 16 | -------------------------------------------------------------------------------- /torchlens/__init__.py: -------------------------------------------------------------------------------- 1 | """ Top level package: make the user-facing functions top-level, rest accessed as submodules. 2 | """ 3 | from .user_funcs import log_forward_pass, show_model_graph, get_model_metadata, validate_saved_activations, \ 4 | validate_batch_of_models_and_inputs 5 | from .model_history import ModelHistory 6 | from .tensor_log import TensorLogEntry, RolledTensorLogEntry 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v3.2.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | - repo: https://github.com/psf/black 12 | rev: 23.1.0 13 | hooks: 14 | - id: black 15 | - repo: https://github.com/pycqa/isort 16 | rev: 5.12.0 17 | hooks: 18 | - id: isort 19 | args: [ "--profile", "black" ] 20 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: 3 | push: 4 | branches: [ main ] 5 | 6 | concurrency: 7 | group: "${{ github.head_ref || github.ref }}-lint-and-test" 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | # flake8: 12 | # runs-on: ubuntu-latest 13 | # container: python:3.8.5-slim-buster 14 | # steps: 15 | # - uses: actions/checkout@v2 16 | # - run: pip install flake8 17 | # - run: flake8 . --show-source 18 | black: 19 | runs-on: ubuntu-latest 20 | container: python:3.9-slim 21 | steps: 22 | - uses: actions/checkout@v2 23 | - run: pip install black 24 | - run: black --check . 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("README.md") as readme_file: 4 | readme = readme_file.read() 5 | 6 | requirements = [ 7 | "numpy", 8 | "pandas", 9 | "pytest", 10 | "tqdm", 11 | "ipython", 12 | "graphviz", 13 | ] 14 | 15 | setup( 16 | name="torchlens", 17 | version="0.1.36", 18 | description="A package for extracting activations from PyTorch models", 19 | long_description="A package for extracting activations from PyTorch models. Contains functionality for " 20 | "extracting model activations, visualizing a model's computational graph, and " 21 | "extracting exhaustive metadata about a model.", 22 | author="JohnMark Taylor", 23 | author_email="johnmarkedwardtaylor@gmail.com", 24 | url="https://github.com/johnmarktaylor91/torchlens", 25 | packages=["torchlens"], 26 | include_package_data=True, 27 | install_requires=requirements, 28 | license="GNU GPL v3", 29 | zip_safe=False, 30 | keywords="torch torchlens features", 31 | classifiers=[ 32 | "Development Status :: 3 - Alpha", 33 | "Intended Audience :: Science/Research", 34 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 35 | "Natural Language :: English", 36 | "Programming Language :: Python :: 3.9", 37 | ], 38 | extras_require={"dev": ["black[jupyter]", "pytest", "pre-commit"]}, 39 | ) 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /tests/visualization_outputs/ 131 | -------------------------------------------------------------------------------- /torchlens/cleanup.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | from .constants import MODEL_HISTORY_FIELD_ORDER 6 | from .helper_funcs import remove_entry_from_list 7 | from .tensor_log import TensorLogEntry 8 | 9 | 10 | def cleanup(self): 11 | """Deletes all log entries in the model.""" 12 | for tensor_log_entry in self: 13 | self._remove_log_entry(tensor_log_entry, remove_references=True) 14 | for attr in MODEL_HISTORY_FIELD_ORDER: 15 | delattr(self, attr) 16 | torch.cuda.empty_cache() 17 | 18 | 19 | def _remove_log_entry( 20 | self, log_entry: TensorLogEntry, remove_references: bool = True 21 | ): 22 | """Given a TensorLogEntry, destroys it and all references to it. 23 | 24 | Args: 25 | log_entry: Tensor log entry to remove. 26 | remove_references: Whether to also remove references to the log entry 27 | """ 28 | if self._pass_finished: 29 | tensor_label = log_entry.layer_label 30 | else: 31 | tensor_label = log_entry.tensor_label_raw 32 | for attr in dir(log_entry): 33 | with warnings.catch_warnings(): 34 | warnings.simplefilter("ignore") 35 | if not attr.startswith("_") and not callable(getattr(log_entry, attr)): 36 | delattr(log_entry, attr) 37 | del log_entry 38 | if remove_references: 39 | _remove_log_entry_references(self, tensor_label) 40 | 41 | 42 | def _remove_log_entry_references(self, layer_to_remove: str): 43 | """Removes all references to a given TensorLogEntry in the ModelHistory object. 44 | 45 | Args: 46 | layer_to_remove: The log entry to remove. 47 | """ 48 | # Clear any fields in ModelHistory referring to the entry. 49 | 50 | remove_entry_from_list(self.input_layers, layer_to_remove) 51 | remove_entry_from_list(self.output_layers, layer_to_remove) 52 | remove_entry_from_list(self.buffer_layers, layer_to_remove) 53 | remove_entry_from_list(self.internally_initialized_layers, layer_to_remove) 54 | remove_entry_from_list(self.internally_terminated_layers, layer_to_remove) 55 | remove_entry_from_list(self.internally_terminated_bool_layers, layer_to_remove) 56 | remove_entry_from_list(self.layers_with_saved_activations, layer_to_remove) 57 | remove_entry_from_list(self.layers_with_saved_gradients, layer_to_remove) 58 | remove_entry_from_list( 59 | self._layers_where_internal_branches_merge_with_input, layer_to_remove 60 | ) 61 | 62 | self.conditional_branch_edges = [ 63 | tup for tup in self.conditional_branch_edges if layer_to_remove not in tup 64 | ] 65 | 66 | # Now any nested fields. 67 | 68 | for group_label, group_tensors in self.layers_computed_with_params.items(): 69 | if layer_to_remove in group_tensors: 70 | group_tensors.remove(layer_to_remove) 71 | self.layers_computed_with_params = { 72 | k: v for k, v in self.layers_computed_with_params.items() if len(v) > 0 73 | } 74 | 75 | for group_label, group_tensors in self.equivalent_operations.items(): 76 | if layer_to_remove in group_tensors: 77 | group_tensors.remove(layer_to_remove) 78 | self.equivalent_operations = { 79 | k: v for k, v in self.equivalent_operations.items() if len(v) > 0 80 | } 81 | 82 | for group_label, group_tensors in self.same_layer_operations.items(): 83 | if layer_to_remove in group_tensors: 84 | group_tensors.remove(layer_to_remove) 85 | self.same_layer_operations = { 86 | k: v for k, v in self.same_layer_operations.items() if len(v) > 0 87 | } 88 | -------------------------------------------------------------------------------- /torchlens/model_history.py: -------------------------------------------------------------------------------- 1 | # This file is for defining the ModelHistory class that stores the representation of the forward pass. 2 | import copy 3 | from collections import OrderedDict, defaultdict 4 | from typing import Any, Callable, Dict, List, Optional, Set, Tuple 5 | 6 | from .cleanup import _remove_log_entry, cleanup 7 | from .decorate_torch import decorate_pytorch 8 | from .helper_funcs import ( 9 | human_readable_size, 10 | ) 11 | from .interface import (_getitem_after_pass, _getitem_during_pass, _str_after_pass, 12 | _str_during_pass, to_pandas, print_all_fields) 13 | from .logging_funcs import save_new_activations 14 | from .model_funcs import cleanup_model, prepare_model 15 | from .postprocess import postprocess 16 | from .tensor_log import RolledTensorLogEntry, TensorLogEntry 17 | from .trace_model import run_and_log_inputs_through_model 18 | from .validation import validate_saved_activations 19 | from .vis import render_graph 20 | 21 | 22 | # todo add saved_layer field, remove the option to only keep saved layers 23 | 24 | 25 | class ModelHistory: 26 | def __init__( 27 | self, 28 | model_name: str, 29 | output_device: str = "same", 30 | activation_postfunc: Optional[Callable] = None, 31 | keep_unsaved_layers: bool = True, 32 | save_function_args: bool = False, 33 | save_gradients: bool = False, 34 | detach_saved_tensors: bool = False, 35 | mark_input_output_distances: bool = True, 36 | ): 37 | """Object that stores the history of a model's forward pass. 38 | Both logs the history in real time, and stores a nice 39 | representation of the full history for the user afterward. 40 | """ 41 | # Setup: 42 | activation_postfunc = copy.deepcopy(activation_postfunc) 43 | 44 | # General info 45 | self.model_name = model_name 46 | self._pass_finished = False 47 | self._track_tensors = False 48 | self.logging_mode = "exhaustive" 49 | self._pause_logging = False 50 | self._all_layers_logged = False 51 | self._all_layers_saved = False 52 | self.keep_unsaved_layers = keep_unsaved_layers 53 | self.activation_postfunc = activation_postfunc 54 | self.current_function_call_barcode = None 55 | self.random_seed_used = None 56 | self.output_device = output_device 57 | self.detach_saved_tensors = detach_saved_tensors 58 | self.save_function_args = save_function_args 59 | self.save_gradients = save_gradients 60 | self.has_saved_gradients = False 61 | self.mark_input_output_distances = mark_input_output_distances 62 | 63 | # Model structure info 64 | self.model_is_recurrent = False 65 | self.model_max_recurrent_loops = 1 66 | self.model_has_conditional_branching = False 67 | self.model_is_branching = False 68 | 69 | # Tensor Tracking: 70 | self.layer_list: List[TensorLogEntry] = [] 71 | self.layer_list_rolled: List[RolledTensorLogEntry] = [] 72 | self.layer_dict_main_keys: Dict[str, TensorLogEntry] = OrderedDict() 73 | self.layer_dict_all_keys: Dict[str, TensorLogEntry] = OrderedDict() 74 | self.layer_dict_rolled: Dict[str, RolledTensorLogEntry] = OrderedDict() 75 | self.layer_labels: List[str] = [] 76 | self.layer_labels_w_pass: List[str] = [] 77 | self.layer_labels_no_pass: List[str] = [] 78 | self.layer_num_passes: Dict[str, int] = OrderedDict() 79 | self._raw_tensor_dict: Dict[str, TensorLogEntry] = OrderedDict() 80 | self._raw_tensor_labels_list: List[str] = [] 81 | self._tensor_nums_to_save: List[int] = [] 82 | self._tensor_counter: int = 0 83 | self.num_operations: int = 0 84 | self._raw_layer_type_counter: Dict[str, int] = defaultdict(lambda: 0) 85 | self._unsaved_layers_lookup_keys: Set[str] = set() 86 | 87 | # Mapping from raw to final layer labels: 88 | self._raw_to_final_layer_labels: Dict[str, str] = {} 89 | self._final_to_raw_layer_labels: Dict[str, str] = {} 90 | self._lookup_keys_to_tensor_num_dict: Dict[str, int] = {} 91 | self._tensor_num_to_lookup_keys_dict: Dict[int, List[str]] = defaultdict(list) 92 | 93 | # Special Layers: 94 | self.input_layers: List[str] = [] 95 | self.output_layers: List[str] = [] 96 | self.buffer_layers: List[str] = [] 97 | self.buffer_num_passes: Dict = {} 98 | self.internally_initialized_layers: List[str] = [] 99 | self._layers_where_internal_branches_merge_with_input: List[str] = [] 100 | self.internally_terminated_layers: List[str] = [] 101 | self.internally_terminated_bool_layers: List[str] = [] 102 | self.conditional_branch_edges: List[Tuple[str, str]] = [] 103 | self.layers_with_saved_activations: List[str] = [] 104 | self.orphan_layers: List[str] = [] 105 | self.unlogged_layers: List[str] = [] 106 | self.layers_with_saved_gradients: List[str] = [] 107 | self.layers_computed_with_params: Dict[str, List] = defaultdict(list) 108 | self.equivalent_operations: Dict[str, set] = defaultdict(set) 109 | self.same_layer_operations: Dict[str, list] = defaultdict(list) 110 | 111 | # Tensor info: 112 | self.num_tensors_total: int = 0 113 | self.tensor_fsize_total: int = 0 114 | self.tensor_fsize_total_nice: str = human_readable_size(0) 115 | self.num_tensors_saved: int = 0 116 | self.tensor_fsize_saved: int = 0 117 | self.tensor_fsize_saved_nice: str = human_readable_size(0) 118 | 119 | # Param info: 120 | self.total_param_tensors: int = 0 121 | self.total_param_layers: int = 0 122 | self.total_params: int = 0 123 | self.total_params_fsize: int = 0 124 | self.total_params_fsize_nice: str = human_readable_size(0) 125 | 126 | # Module info: 127 | self.module_addresses: List[str] = [] 128 | self.module_types: Dict[str, Any] = {} 129 | self.module_passes: List = [] 130 | self.module_num_passes: Dict = defaultdict(lambda: 1) 131 | self.top_level_modules: List = [] 132 | self.top_level_module_passes: List = [] 133 | self.module_children: Dict = defaultdict(list) 134 | self.module_pass_children: Dict = defaultdict(list) 135 | self.module_nparams: Dict = defaultdict(lambda: 0) 136 | self.module_num_tensors: Dict = defaultdict(lambda: 0) 137 | self.module_pass_num_tensors: Dict = defaultdict(lambda: 0) 138 | self.module_layers: Dict = defaultdict(list) 139 | self.module_pass_layers: Dict = defaultdict(list) 140 | self.module_layer_argnames = defaultdict(list) 141 | 142 | # Time elapsed: 143 | self.pass_start_time: float = 0 144 | self.pass_end_time: float = 0 145 | self.elapsed_time_setup: float = 0 146 | self.elapsed_time_forward_pass: float = 0 147 | self.elapsed_time_cleanup: float = 0 148 | self.elapsed_time_total: float = 0 149 | self.elapsed_time_function_calls: float = 0 150 | self.elapsed_time_torchlens_logging: float = 0 151 | 152 | # Reference info 153 | self.func_argnames: Dict[str, tuple] = defaultdict(lambda: tuple([])) 154 | 155 | # ******************************************** 156 | # ************ Built-in Methods ************** 157 | # ******************************************** 158 | 159 | def __len__(self): 160 | if self._pass_finished: 161 | return len(self.layer_list) 162 | else: 163 | return len(self._raw_tensor_dict) 164 | 165 | def __getitem__(self, ix) -> TensorLogEntry: 166 | """Returns an object logging a model layer given an index. If the pass is finished, 167 | it'll do this intelligently; if not, it simply queries based on the layer's raw barcode. 168 | 169 | Args: 170 | ix: desired index 171 | 172 | Returns: 173 | Tensor log entry object with info about specified layer. 174 | """ 175 | if self._pass_finished: 176 | return _getitem_after_pass(self, ix) 177 | else: 178 | return _getitem_during_pass(self, ix) 179 | 180 | def __str__(self) -> str: 181 | if self._pass_finished: 182 | return _str_after_pass(self) 183 | else: 184 | return _str_during_pass(self) 185 | 186 | def __repr__(self): 187 | return self.__str__() 188 | 189 | def __iter__(self): 190 | """Loops through all tensors in the log.""" 191 | if self._pass_finished: 192 | return iter(self.layer_list) 193 | else: 194 | return iter(list(self._raw_tensor_dict.values())) 195 | 196 | # ******************************************** 197 | # ******** Assign Imported Methods *********** 198 | # ******************************************** 199 | 200 | render_graph = render_graph 201 | print_all_fields = print_all_fields 202 | to_pandas = to_pandas 203 | save_new_activations = save_new_activations 204 | validate_saved_activations = validate_saved_activations 205 | cleanup = cleanup 206 | _postprocess = postprocess 207 | _decorate_pytorch = decorate_pytorch 208 | _prepare_model = prepare_model 209 | _cleanup_model = cleanup_model 210 | _run_and_log_inputs_through_model = run_and_log_inputs_through_model 211 | _remove_log_entry = _remove_log_entry 212 | -------------------------------------------------------------------------------- /torchlens/trace_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import inspect 3 | import random 4 | import time 5 | from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union 6 | 7 | import torch 8 | from torch import nn 9 | 10 | if TYPE_CHECKING: 11 | from .model_history import ModelHistory 12 | from .decorate_torch import undecorate_pytorch 13 | from .helper_funcs import get_vars_of_type_from_obj, set_random_seed, nested_assign 14 | from .logging_funcs import log_source_tensor 15 | from .interface import _give_user_feedback_about_lookup_key 16 | 17 | 18 | def _get_input_arg_names(model, input_args): 19 | input_arg_names = inspect.getfullargspec(model.forward).args 20 | if "self" in input_arg_names: 21 | input_arg_names.remove("self") 22 | input_arg_names = input_arg_names[0: len(input_args)] 23 | return input_arg_names 24 | 25 | 26 | def _get_op_nums_from_user_labels( 27 | self: "ModelHistory", which_layers: Union[str, List[Union[str, int]]] 28 | ) -> List[int]: 29 | """Given list of user layer labels, returns the original tensor numbers for those labels (i.e., 30 | the numbers that were generated on the fly during the forward pass, such that they can be 31 | saved on a subsequent pass). Raises an error if the user's labels don't correspond to any layers. 32 | 33 | Args: 34 | which_layers: List of layers to include, using any indexing desired: either the layer label, 35 | the module label, or the ordinal position of the layer. If a layer has multiple passes and 36 | none is specified, will return all of them. 37 | 38 | Returns: 39 | Ordered, unique list of raw tensor numbers associated with the specified layers. 40 | """ 41 | if which_layers == "all": 42 | return which_layers 43 | elif which_layers in [None, "none", "None", "NONE", []]: 44 | return [] 45 | 46 | if type(which_layers) != list: 47 | which_layers = [which_layers] 48 | which_layers = [ 49 | layer.lower() if (type(layer) == str) else layer for layer in which_layers 50 | ] 51 | 52 | raw_tensor_nums_to_save = set() 53 | for layer_key in which_layers: 54 | # First check if it matches a lookup key. If so, use that. 55 | if layer_key in self._lookup_keys_to_tensor_num_dict: 56 | raw_tensor_nums_to_save.add( 57 | self._lookup_keys_to_tensor_num_dict[layer_key] 58 | ) 59 | continue 60 | 61 | # If not, pull out all layers for which the key is a substring. 62 | keys_with_substr = [ 63 | key for key in self.layer_dict_all_keys if layer_key in str(key) 64 | ] 65 | if len(keys_with_substr) > 0: 66 | for key in keys_with_substr: 67 | raw_tensor_nums_to_save.add( 68 | self.layer_dict_all_keys[key].realtime_tensor_num 69 | ) 70 | continue 71 | 72 | # If no luck, try to at least point user in right direction: 73 | 74 | _give_user_feedback_about_lookup_key(self, layer_key, "query_multiple") 75 | 76 | raw_tensor_nums_to_save = sorted(list(raw_tensor_nums_to_save)) 77 | return raw_tensor_nums_to_save 78 | 79 | 80 | def _fetch_label_move_input_tensors( 81 | input_args: List[Any], 82 | input_arg_names: List[str], 83 | input_kwargs: Dict, 84 | model_device: str, 85 | ) -> Tuple[List[torch.Tensor], List[str]]: 86 | """Fetches input tensors, gets their addresses, and moves them to the model device. 87 | 88 | Args: 89 | input_args: input arguments 90 | input_arg_names: name of input arguments 91 | input_kwargs: input keyword arguments 92 | model_device: model device 93 | 94 | Returns: 95 | input tensors and their addresses 96 | """ 97 | input_arg_tensors = [ 98 | get_vars_of_type_from_obj( 99 | arg, torch.Tensor, search_depth=5, return_addresses=True 100 | ) 101 | for arg in input_args 102 | ] 103 | input_kwarg_tensors = [ 104 | get_vars_of_type_from_obj( 105 | kwarg, torch.Tensor, search_depth=5, return_addresses=True 106 | ) 107 | for kwarg in input_kwargs.values() 108 | ] 109 | for a, arg in enumerate(input_args): 110 | for i, (t, addr, addr_full) in enumerate(input_arg_tensors[a]): 111 | t_moved = t.to(model_device) 112 | input_arg_tensors[a][i] = (t_moved, addr, addr_full) 113 | if not addr_full: 114 | input_args[a] = t_moved 115 | else: 116 | nested_assign(input_args[a], addr_full, t_moved) 117 | 118 | for k, (key, val) in enumerate(input_kwargs.items()): 119 | for i, (t, addr, addr_full) in enumerate(input_kwarg_tensors[k]): 120 | t_moved = t.to(model_device) 121 | input_kwarg_tensors[k][i] = (t_moved, addr, addr_full) 122 | if not addr_full: 123 | input_kwargs[key] = t_moved 124 | else: 125 | nested_assign(input_kwargs[key], addr_full, t_moved) 126 | 127 | input_tensors = [] 128 | input_tensor_addresses = [] 129 | for a, arg_tensors in enumerate(input_arg_tensors): 130 | for t, addr, addr_full in arg_tensors: 131 | input_tensors.append(t) 132 | tensor_addr = f"input.{input_arg_names[a]}" 133 | if addr != "": 134 | tensor_addr += f".{addr}" 135 | input_tensor_addresses.append(tensor_addr) 136 | 137 | for a, kwarg_tensors in enumerate(input_kwarg_tensors): 138 | for t, addr, addr_full in kwarg_tensors: 139 | input_tensors.append(t) 140 | tensor_addr = f"input.{list(input_kwargs.keys())[a]}" 141 | if addr != "": 142 | tensor_addr += f".{addr}" 143 | input_tensor_addresses.append(tensor_addr) 144 | 145 | return input_tensors, input_tensor_addresses 146 | 147 | 148 | def run_and_log_inputs_through_model( 149 | self: "ModelHistory", 150 | model: nn.Module, 151 | input_args: Union[torch.Tensor, List[Any]], 152 | input_kwargs: Dict[Any, Any] = None, 153 | layers_to_save: Optional[Union[str, List[Union[str, int]]]] = "all", 154 | random_seed: Optional[int] = None, 155 | ): 156 | """Runs input through model and logs it in ModelHistory. 157 | 158 | Args: 159 | model: Model for which to save activations 160 | input_args: Either a single tensor input to the model, or list of input arguments. 161 | input_kwargs: Dict of keyword arguments to the model. 162 | layers_to_save: List of tensor numbers to save 163 | random_seed: Which random seed to use 164 | Returns: 165 | Nothing, but now the ModelHistory object will have saved activations for the new input. 166 | """ 167 | if random_seed is None: # set random seed 168 | random_seed = random.randint(1, 4294967294) 169 | self.random_seed_used = random_seed 170 | set_random_seed(random_seed) 171 | 172 | self._tensor_nums_to_save = _get_op_nums_from_user_labels(self, layers_to_save) 173 | 174 | if type(input_args) is tuple: 175 | input_args = list(input_args) 176 | elif (type(input_args) not in [list, tuple]) and (input_args is not None): 177 | input_args = [input_args] 178 | 179 | if not input_args: 180 | input_args = [] 181 | 182 | if not input_kwargs: 183 | input_kwargs = {} 184 | 185 | if ( 186 | type(model) == nn.DataParallel 187 | ): # Unwrap model from DataParallel if relevant: 188 | model = model.module 189 | 190 | if ( 191 | len(list(model.parameters())) > 0 192 | ): # Get the model device by looking at the parameters: 193 | model_device = next(iter(model.parameters())).device 194 | else: 195 | model_device = "cpu" 196 | 197 | input_args = [copy.deepcopy(arg) for arg in input_args] 198 | input_arg_names = _get_input_arg_names(model, input_args) 199 | input_kwargs = {key: copy.deepcopy(val) for key, val in input_kwargs.items()} 200 | 201 | self.pass_start_time = time.time() 202 | module_orig_forward_funcs = {} 203 | orig_func_defs = [] 204 | 205 | try: 206 | ( 207 | input_tensors, 208 | input_tensor_addresses, 209 | ) = _fetch_label_move_input_tensors( 210 | input_args, input_arg_names, input_kwargs, model_device 211 | ) 212 | buffer_tensors = list(model.buffers()) 213 | tensors_to_decorate = input_tensors + buffer_tensors 214 | decorated_func_mapper = self._decorate_pytorch(torch, orig_func_defs) 215 | self._track_tensors = True 216 | for i, t in enumerate(input_tensors): 217 | log_source_tensor(self, t, "input", input_tensor_addresses[i]) 218 | self._prepare_model(model, module_orig_forward_funcs, decorated_func_mapper) 219 | self.elapsed_time_setup = time.time() - self.pass_start_time 220 | outputs = model(*input_args, **input_kwargs) 221 | self.elapsed_time_forward_pass = ( 222 | time.time() - self.pass_start_time - self.elapsed_time_setup 223 | ) 224 | self._track_tensors = False 225 | output_tensors_w_addresses_all = get_vars_of_type_from_obj( 226 | outputs, 227 | torch.Tensor, 228 | search_depth=5, 229 | return_addresses=True, 230 | allow_repeats=True, 231 | ) 232 | # Remove duplicate addresses TODO: make this match the validation procedure so they work the same way 233 | addresses_used = [] 234 | output_tensors_w_addresses = [] 235 | for entry in output_tensors_w_addresses_all: 236 | if entry[1] in addresses_used: 237 | continue 238 | output_tensors_w_addresses.append(entry) 239 | addresses_used.append(entry[1]) 240 | 241 | output_tensors = [t for t, _, _ in output_tensors_w_addresses] 242 | output_tensor_addresses = [ 243 | addr for _, addr, _ in output_tensors_w_addresses 244 | ] 245 | 246 | for t in output_tensors: 247 | if self.logging_mode == 'exhaustive': 248 | self.output_layers.append(t.tl_tensor_label_raw) 249 | self._raw_tensor_dict[t.tl_tensor_label_raw].is_output_parent = True 250 | tensors_to_undecorate = tensors_to_decorate + output_tensors 251 | undecorate_pytorch(torch, orig_func_defs, tensors_to_undecorate) 252 | self._cleanup_model(model, module_orig_forward_funcs, decorated_func_mapper) 253 | self._postprocess(output_tensors, output_tensor_addresses) 254 | decorated_func_mapper.clear() 255 | 256 | except ( 257 | Exception 258 | ) as e: # if anything fails, make sure everything gets cleaned up 259 | undecorate_pytorch(torch, orig_func_defs, input_tensors) 260 | self._cleanup_model(model, module_orig_forward_funcs, decorated_func_mapper) 261 | print( 262 | "************\nFeature extraction failed; returning model and environment to normal\n*************" 263 | ) 264 | raise e 265 | 266 | finally: # do garbage collection no matter what 267 | if 'input_args' in globals(): 268 | del input_args 269 | if 'input_kwargs' in globals(): 270 | del input_kwargs 271 | if 'input_tensors' in globals(): 272 | del input_tensors 273 | if 'output_tensors' in globals(): 274 | del output_tensors 275 | if 'outputs' in globals(): 276 | del outputs 277 | torch.cuda.empty_cache() 278 | -------------------------------------------------------------------------------- /torchlens/decorate_torch.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import time 3 | import types 4 | import warnings 5 | from functools import wraps 6 | from typing import Callable, Dict, List, TYPE_CHECKING, Tuple 7 | 8 | import torch 9 | 10 | from .constants import ORIG_TORCH_FUNCS 11 | from .helper_funcs import (clean_to, get_vars_of_type_from_obj, identity, log_current_rng_states, make_random_barcode, 12 | nested_getattr, print_override, safe_copy) 13 | from .logging_funcs import log_function_output_tensors, log_source_tensor 14 | 15 | if TYPE_CHECKING: 16 | from model_history import ModelHistory 17 | 18 | funcs_not_to_log = ["numpy", "__array__", "size", "dim"] 19 | print_funcs = ["__repr__", "__str__", "_str"] 20 | 21 | 22 | def torch_func_decorator(self, func: Callable, func_name: str): 23 | @wraps(func) 24 | def wrapped_func(*args, **kwargs): 25 | # Initial bookkeeping; check if it's a special function, organize the arguments. 26 | self.current_function_call_barcode = 0 27 | if ( 28 | (func_name in funcs_not_to_log) 29 | or (not self._track_tensors) 30 | or self._pause_logging 31 | ): 32 | out = func(*args, **kwargs) 33 | return out 34 | all_args = list(args) + list(kwargs.values()) 35 | arg_tensorlike = get_vars_of_type_from_obj(all_args, torch.Tensor) 36 | 37 | # Register any buffer tensors in the arguments. 38 | 39 | for t in arg_tensorlike: 40 | if hasattr(t, 'tl_buffer_address'): 41 | log_source_tensor(self, t, 'buffer', getattr(t, 'tl_buffer_address')) 42 | 43 | if (func_name in print_funcs) and (len(arg_tensorlike) > 0): 44 | out = print_override(args[0], func_name) 45 | return out 46 | 47 | # Copy the args and kwargs in case they change in-place: 48 | if self.save_function_args: 49 | arg_copies = tuple([safe_copy(arg) for arg in args]) 50 | kwarg_copies = {k: safe_copy(v) for k, v in kwargs.items()} 51 | else: 52 | arg_copies = args 53 | kwarg_copies = kwargs 54 | 55 | # Call the function, tracking the timing, rng states, and whether it's a nested function 56 | func_call_barcode = make_random_barcode() 57 | self.current_function_call_barcode = func_call_barcode 58 | start_time = time.time() 59 | func_rng_states = log_current_rng_states() 60 | out_orig = func(*args, **kwargs) 61 | func_time_elapsed = time.time() - start_time 62 | is_bottom_level_func = ( 63 | self.current_function_call_barcode == func_call_barcode 64 | ) 65 | 66 | if func_name in ["__setitem__", "zero_", "__delitem__"]: 67 | out_orig = args[0] 68 | 69 | if id(out_orig) == id(args[0]): # special case if the function does nothing 70 | out_orig = safe_copy(out_orig) 71 | 72 | # Log all output tensors 73 | output_tensors = get_vars_of_type_from_obj( 74 | out_orig, 75 | which_type=torch.Tensor, 76 | subclass_exceptions=[torch.nn.Parameter], 77 | ) 78 | 79 | if len(output_tensors) > 0: 80 | log_function_output_tensors( 81 | self, 82 | func, 83 | func_name, 84 | args, 85 | kwargs, 86 | arg_copies, 87 | kwarg_copies, 88 | out_orig, 89 | func_time_elapsed, 90 | func_rng_states, 91 | is_bottom_level_func, 92 | ) 93 | 94 | return out_orig 95 | 96 | return wrapped_func 97 | 98 | 99 | def decorate_pytorch( 100 | self: "ModelHistory", torch_module: types.ModuleType, orig_func_defs: List[Tuple] 101 | ) -> Dict[Callable, Callable]: 102 | """Mutates all PyTorch functions (TEMPORARILY!) to save the outputs of any functions 103 | that return Tensors, along with marking them with metadata. Returns a list of tuples that 104 | save the current state of the functions, such that they can be restored when done. 105 | 106 | args: 107 | torch_module: The top-level torch module (i.e., from "import torch"). 108 | This is supplied as an argument on the off-chance that the user has imported torch 109 | and done their own monkey-patching. 110 | tensors_to_mutate: A list of tensors that will be mutated (since any tensors created 111 | before calling the torch mutation function will not be mutated). 112 | orig_func_defs: Supply a list from outside to guarantee it can be cleaned up properly. 113 | tensor_record: A list to which the outputs of the functions will be appended. 114 | 115 | returns: 116 | List of tuples consisting of [namespace, func_name, orig_func], sufficient 117 | to return torch to normal when finished, and also a dict mapping mutated functions to original functions. 118 | """ 119 | 120 | # Do a pass to save the original func defs. 121 | collect_orig_func_defs(torch_module, orig_func_defs) 122 | decorated_func_mapper = {} 123 | 124 | # Get references to the function classes. 125 | function_class = type(lambda: 0) 126 | builtin_class = type(torch.mean) 127 | method_class = type(torch.Tensor.__add__) 128 | wrapper_class = type(torch.Tensor.__getitem__) 129 | getset_class = type(torch.Tensor.real) 130 | 131 | for namespace_name, func_name in ORIG_TORCH_FUNCS: 132 | namespace_name_notorch = namespace_name.replace("torch.", "") 133 | local_func_namespace = nested_getattr(torch_module, namespace_name_notorch) 134 | if not hasattr(local_func_namespace, func_name): 135 | continue 136 | orig_func = getattr(local_func_namespace, func_name) 137 | if func_name not in self.func_argnames: 138 | get_func_argnames(self, orig_func, func_name) 139 | if getattr(orig_func, "__name__", False) == "wrapped_func": 140 | continue 141 | 142 | if type(orig_func) in [function_class, builtin_class, method_class, wrapper_class]: 143 | new_func = torch_func_decorator(self, orig_func, func_name) 144 | try: 145 | with warnings.catch_warnings(): 146 | warnings.simplefilter("ignore") 147 | setattr(local_func_namespace, func_name, new_func) 148 | except (AttributeError, TypeError) as _: 149 | pass 150 | new_func.tl_is_decorated_function = True 151 | decorated_func_mapper[new_func] = orig_func 152 | decorated_func_mapper[orig_func] = new_func 153 | 154 | elif type(orig_func) == getset_class: 155 | getter_orig, setter_orig, deleter_orig = orig_func.__get__, orig_func.__set__, orig_func.__delete__ 156 | getter_dec, setter_dec, deleter_dec = (torch_func_decorator(self, getter_orig, func_name), 157 | torch_func_decorator(self, setter_orig, func_name), 158 | torch_func_decorator(self, deleter_orig, func_name)) 159 | getter_dec.tl_is_decorated_function = True 160 | setter_dec.tl_is_decorated_function = True 161 | deleter_dec.tl_is_decorated_function = True 162 | new_property = property(getter_dec, setter_dec, deleter_dec, doc=func_name) 163 | try: 164 | with warnings.catch_warnings(): 165 | warnings.simplefilter("ignore") 166 | setattr(local_func_namespace, func_name, new_property) 167 | except (AttributeError, TypeError) as _: 168 | pass 169 | decorated_func_mapper[new_property] = orig_func 170 | decorated_func_mapper[orig_func] = new_property 171 | 172 | # Bolt on the identity function 173 | new_identity = torch_func_decorator(self, identity, 'identity') 174 | torch.identity = new_identity 175 | 176 | return decorated_func_mapper 177 | 178 | 179 | def undecorate_pytorch( 180 | torch_module, orig_func_defs: List[Tuple], input_tensors: List[torch.Tensor] 181 | ): 182 | """ 183 | Returns all PyTorch functions back to the definitions they had when mutate_pytorch was called. 184 | This is done for the output tensors and history_dict too to avoid ugliness. Also deletes 185 | the mutant versions of the functions to remove any references to old ModelHistory object. 186 | 187 | args: 188 | torch_module: The torch module object. 189 | orig_func_defs: List of tuples consisting of [namespace_name, func_name, orig_func], sufficient 190 | to regenerate the original functions. 191 | input_tensors: List of input tensors whose fucntions will be undecorated. 192 | decorated_func_mapper: Maps the decorated function to the original function 193 | """ 194 | for namespace_name, func_name, orig_func in orig_func_defs: 195 | namespace_name_notorch = namespace_name.replace("torch.", "") 196 | local_func_namespace = nested_getattr(torch_module, namespace_name_notorch) 197 | with warnings.catch_warnings(): 198 | warnings.simplefilter("ignore") 199 | decorated_func = getattr(local_func_namespace, func_name) 200 | del decorated_func 201 | try: 202 | with warnings.catch_warnings(): 203 | warnings.simplefilter("ignore") 204 | setattr(local_func_namespace, func_name, orig_func) 205 | except (AttributeError, TypeError) as _: 206 | continue 207 | delattr(torch, "identity") 208 | for input_tensor in input_tensors: 209 | if hasattr(input_tensor, "tl_tensor_label_raw"): 210 | delattr(input_tensor, "tl_tensor_label_raw") 211 | 212 | 213 | def undecorate_tensor(t, device: str = "cpu"): 214 | """Convenience function to replace the tensor with an unmutated version of itself, keeping the same data. 215 | 216 | Args: 217 | t: tensor or parameter object 218 | device: device to move the tensor to 219 | 220 | Returns: 221 | Unmutated tensor. 222 | """ 223 | if type(t) in [torch.Tensor, torch.nn.Parameter]: 224 | new_t = safe_copy(t) 225 | else: 226 | new_t = t 227 | del t 228 | for attr in dir(new_t): 229 | if attr.startswith("tl_"): 230 | delattr(new_t, attr) 231 | new_t = clean_to(new_t, device) 232 | return new_t 233 | 234 | 235 | def collect_orig_func_defs( 236 | torch_module: types.ModuleType, orig_func_defs: List[Tuple] 237 | ): 238 | """Collects the original torch function definitions, so they can be restored after the logging is done. 239 | 240 | Args: 241 | torch_module: The top-level torch module 242 | orig_func_defs: List of tuples keeping track of the original function definitions 243 | """ 244 | for namespace_name, func_name in ORIG_TORCH_FUNCS: 245 | namespace_name_notorch = namespace_name.replace("torch.", "") 246 | local_func_namespace = nested_getattr(torch_module, namespace_name_notorch) 247 | if not hasattr(local_func_namespace, func_name): 248 | continue 249 | orig_func = getattr(local_func_namespace, func_name) 250 | orig_func_defs.append((namespace_name, func_name, orig_func)) 251 | 252 | 253 | # TODO: hard-code some of the arg names; for example truediv, getitem, etc. Can crawl through and see what isn't working 254 | def get_func_argnames(self, orig_func: Callable, func_name: str): 255 | """Attempts to get the argument names for a function, first by checking the signature, then 256 | by checking the documentation. Adds these names to func_argnames if it can find them, 257 | doesn't do anything if it can't.""" 258 | if func_name in ['real', 'imag', 'T', 'mT', 'data', 'H']: 259 | return 260 | 261 | try: 262 | argnames = list(inspect.signature(orig_func).parameters.keys()) 263 | argnames = tuple([arg.replace('*', '') for arg in argnames if arg not in ['cls', 'self']]) 264 | self.func_argnames[func_name] = argnames 265 | return 266 | except ValueError: 267 | pass 268 | 269 | docstring = orig_func.__doc__ 270 | if (type(docstring) is not str) or (len(docstring) == 0): # if docstring missing, skip it 271 | return 272 | 273 | open_ind, close_ind = docstring.find('('), docstring.find(')') 274 | argstring = docstring[open_ind + 1: close_ind] 275 | arg_list = argstring.split(',') 276 | arg_list = [arg.strip(' ') for arg in arg_list] 277 | argnames = [] 278 | for arg in arg_list: 279 | argname = arg.split('=')[0] 280 | if argname in ['*', '/', '//', '']: 281 | continue 282 | argname = argname.replace('*', '') 283 | argnames.append(argname) 284 | argnames = tuple([arg for arg in argnames if arg not in ['self', 'cls']]) 285 | self.func_argnames[func_name] = argnames 286 | return 287 | -------------------------------------------------------------------------------- /torchlens/constants.py: -------------------------------------------------------------------------------- 1 | import __future__ 2 | import functools 3 | import types 4 | from typing import List 5 | import warnings 6 | 7 | import torch 8 | from torch.overrides import get_ignored_functions, get_testing_overrides 9 | 10 | MODEL_HISTORY_FIELD_ORDER = [ 11 | # General info 12 | "model_name", 13 | "_pass_finished", 14 | "_track_tensors", 15 | "logging_mode", 16 | "_pause_logging", 17 | "_all_layers_logged", 18 | "_all_layers_saved", 19 | "keep_unsaved_layers", 20 | "current_function_call_barcode", 21 | "random_seed_used", 22 | "detach_saved_tensors", 23 | "output_device", 24 | "save_function_args", 25 | "save_gradients", 26 | "has_saved_gradients", 27 | "activation_postfunc", 28 | "mark_input_output_distances", 29 | # Model structure info 30 | "model_is_recurrent", 31 | "model_max_recurrent_loops", 32 | "model_is_branching", 33 | "model_has_conditional_branching", 34 | # Tensor tracking logs 35 | "layer_list", 36 | "layer_list_rolled", 37 | "layer_dict_main_keys", 38 | "layer_dict_all_keys", 39 | "layer_dict_rolled", 40 | "layer_labels", 41 | "layer_labels_no_pass", 42 | "layer_labels_w_pass", 43 | "layer_num_passes", 44 | "_raw_tensor_dict", 45 | "_raw_tensor_labels_list", 46 | "_tensor_nums_to_save", 47 | "_tensor_counter", 48 | "num_operations", 49 | "_raw_layer_type_counter", 50 | "_unsaved_layers_lookup_keys", 51 | # Mapping from raw to final layer labels: 52 | "_raw_to_final_layer_labels", 53 | "_final_to_raw_layer_labels", 54 | "_lookup_keys_to_tensor_num_dict", 55 | "_tensor_num_to_lookup_keys_dict", 56 | # Special layers 57 | "input_layers", 58 | "output_layers", 59 | "buffer_layers", 60 | "buffer_num_passes", 61 | "internally_initialized_layers", 62 | "_layers_where_internal_branches_merge_with_input", 63 | "internally_terminated_layers", 64 | "internally_terminated_bool_layers", 65 | "conditional_branch_edges", 66 | "layers_computed_with_params", 67 | "equivalent_operations", 68 | "same_layer_operations", 69 | "layers_with_saved_activations", 70 | "unlogged_layers", 71 | "layers_with_saved_gradients", 72 | "orphan_layers", 73 | # Tensor info: 74 | "num_tensors_total", 75 | "tensor_fsize_total", 76 | "tensor_fsize_total_nice", 77 | "num_tensors_saved", 78 | "tensor_fsize_saved", 79 | "tensor_fsize_saved_nice", 80 | # Param info 81 | "total_param_tensors", 82 | "total_param_layers", 83 | "total_params", 84 | "total_params_fsize", 85 | "total_params_fsize_nice", 86 | # Module info 87 | "module_addresses", 88 | "module_types", 89 | "module_passes", 90 | "module_num_passes", 91 | "module_children", 92 | "module_pass_children", 93 | "top_level_modules", 94 | "top_level_module_passes", 95 | "module_nparams", 96 | "module_num_tensors", 97 | "module_pass_num_tensors", 98 | "module_layers", 99 | "module_layer_argnames", 100 | "module_pass_layers", 101 | # Time elapsed 102 | "pass_start_time", 103 | "pass_end_time", 104 | "elapsed_time_setup", 105 | "elapsed_time_forward_pass", 106 | "elapsed_time_cleanup", 107 | "elapsed_time_total", 108 | "elapsed_time_function_calls", 109 | "elapsed_time_torchlens_logging", 110 | # Lookup info 111 | "func_argnames" 112 | ] 113 | 114 | TENSOR_LOG_ENTRY_FIELD_ORDER = [ 115 | # General info 116 | "layer_label", 117 | "tensor_label_raw", 118 | "layer_label_raw", 119 | "operation_num", 120 | "realtime_tensor_num", 121 | "source_model_history", 122 | "_pass_finished", 123 | # Other labeling info 124 | "layer_label_short", 125 | "layer_label_w_pass", 126 | "layer_label_w_pass_short", 127 | "layer_label_no_pass", 128 | "layer_label_no_pass_short", 129 | "layer_type", 130 | "layer_type_num", 131 | "layer_total_num", 132 | "pass_num", 133 | "layer_passes_total", 134 | "lookup_keys", 135 | # Saved tensor info 136 | "tensor_contents", 137 | "has_saved_activations", 138 | "output_device", 139 | "activation_postfunc", 140 | "detach_saved_tensor", 141 | "function_args_saved", 142 | "creation_args", 143 | "creation_kwargs", 144 | "tensor_shape", 145 | "tensor_dtype", 146 | "tensor_fsize", 147 | "tensor_fsize_nice", 148 | # Tensor slice-changing complications 149 | "was_getitem_applied", 150 | "children_tensor_versions", 151 | # Saved gradient info 152 | "grad_contents", 153 | "save_gradients", 154 | "has_saved_grad", 155 | "grad_shape", 156 | "grad_dtype", 157 | "grad_fsize", 158 | "grad_fsize_nice", 159 | # Function call info 160 | "func_applied", 161 | "func_applied_name", 162 | "func_call_stack", 163 | "func_time_elapsed", 164 | "func_rng_states", 165 | "func_argnames", 166 | "num_func_args_total", 167 | "num_position_args", 168 | "num_keyword_args", 169 | "func_position_args_non_tensor", 170 | "func_keyword_args_non_tensor", 171 | "func_all_args_non_tensor", 172 | "function_is_inplace", 173 | "gradfunc", 174 | "is_part_of_iterable_output", 175 | "iterable_output_index", 176 | # Param info 177 | "computed_with_params", 178 | "parent_params", 179 | "parent_param_barcodes", 180 | "parent_param_passes", 181 | "num_param_tensors", 182 | "parent_param_shapes", 183 | "num_params_total", 184 | "parent_params_fsize", 185 | "parent_params_fsize_nice", 186 | # Corresponding layer info 187 | "operation_equivalence_type", 188 | "equivalent_operations", 189 | "same_layer_operations", 190 | # Graph info 191 | "parent_layers", 192 | "has_parents", 193 | "parent_layer_arg_locs", 194 | "orig_ancestors", 195 | "child_layers", 196 | "has_children", 197 | "sibling_layers", 198 | "has_siblings", 199 | "spouse_layers", 200 | "has_spouses", 201 | "is_input_layer", 202 | "has_input_ancestor", 203 | "input_ancestors", 204 | "min_distance_from_input", 205 | "max_distance_from_input", 206 | "is_output_layer", 207 | "is_output_parent", 208 | "is_last_output_layer", 209 | "is_output_ancestor", 210 | "output_descendents", 211 | "input_output_address", 212 | "min_distance_from_output", 213 | "max_distance_from_output", 214 | "is_buffer_layer", 215 | "buffer_address", 216 | "buffer_pass", 217 | "buffer_parent", 218 | "initialized_inside_model", 219 | "has_internally_initialized_ancestor", 220 | "internally_initialized_parents", 221 | "internally_initialized_ancestors", 222 | "terminated_inside_model", 223 | # Conditional info 224 | "is_terminal_bool_layer", 225 | "is_atomic_bool_layer", 226 | "atomic_bool_val", 227 | "in_cond_branch", 228 | "cond_branch_start_children", 229 | # Module info 230 | "is_computed_inside_submodule", 231 | "containing_module_origin", 232 | "containing_modules_origin_nested", 233 | "module_nesting_depth", 234 | "modules_entered", 235 | "module_passes_entered", 236 | "modules_entered_argnames", 237 | "is_submodule_input", 238 | "modules_exited", 239 | "module_passes_exited", 240 | "is_submodule_output", 241 | "is_bottom_level_submodule_output", 242 | "bottom_level_submodule_pass_exited", 243 | "module_entry_exit_threads_inputs", 244 | "module_entry_exit_thread_output", 245 | ] 246 | 247 | # Taken from https://pytorch.org/docs/stable/_modules/torch/overrides.html#get_ignored_functions 248 | IGNORED_FUNCS = [ 249 | ("torch", "load"), 250 | ("torch", "as_tensor"), 251 | ("torch", "from_numpy"), 252 | ("torch", "tensor"), 253 | ("torch", "arange"), 254 | ("torch", "as_strided"), 255 | ("torch", "bartlett_window"), 256 | ("torch", "blackman_window"), 257 | ("torch", "cudnn_affine_grid_generator"), 258 | ("torch", "cudnn_batch_norm"), 259 | ("torch", "cudnn_convolution"), 260 | ("torch", "cudnn_convolution_transpose"), 261 | ("torch", "cudnn_convolution_relu"), 262 | ("torch", "cudnn_convolution_add_relu"), 263 | ("torch", "cudnn_grid_sampler"), 264 | ("torch", "cudnn_is_acceptable"), 265 | ("torch", "eye"), 266 | ("torch.fft", "fftfreq"), 267 | ("torch.fft", "rfftfreq"), 268 | ("torch", "from_file"), 269 | ("torch", "full"), 270 | ("torch", "fill_"), 271 | ("torch", "hamming_window"), 272 | ("torch", "hann_window"), 273 | ("torch", "kaiser_window"), 274 | ("torch", "linspace"), 275 | ("torch", "logspace"), 276 | ("torch", "mkldnn_adaptive_avg_pool2d"), 277 | ("torch", "mkldnn_convolution"), 278 | ("torch", "mkldnn_max_pool2d"), 279 | ("torch", "mkldnn_max_pool3d"), 280 | ("torch", "mkldnn_linear_backward_weights"), 281 | ("torch", "normal"), 282 | ("torch", "ones"), 283 | ("torch", "rand"), 284 | ("torch", "randn"), 285 | ("torch", "randint"), 286 | ("torch", "randperm"), 287 | ("torch", "range"), 288 | ("torch", "scalar_tensor"), 289 | ("torch", "sparse_coo_tensor"), 290 | ("torch", "_sparse_csr_tensor"), 291 | ("torch", "tril_indices"), 292 | ("torch", "triu_indices"), 293 | ("torch", "vander"), 294 | ("torch", "zeros"), 295 | ("torch.nn.functional", "upsample"), 296 | ("torch.nn.functional", "upsample_bilinear"), 297 | ("torch.nn.functional", "upsample_nearest"), 298 | ("torch.nn.functional", "handle_torch_function"), 299 | ("torch.nn.functional", "sigmoid"), 300 | ("torch.nn.functional", "hardsigmoid"), 301 | ("torch.nn.functional", "tanh"), 302 | ("torch.nn.init", "calculate_gain"), 303 | ("torch.nn.init", "uniform"), 304 | ("torch.nn.init", "normal"), 305 | ("torch.nn.init", "constant"), 306 | ("torch.nn.init", "eye"), 307 | ("torch.nn.init", "dirac"), 308 | ("torch.nn.init", "xavier_uniform"), 309 | ("torch.nn.init", "xavier_normal"), 310 | ("torch.nn.init", "kaiming_uniform"), 311 | ("torch.nn.init", "kaiming_normal"), 312 | ("torch.nn.init", "orthogonal"), 313 | ("torch.nn.init", "sparse"), 314 | ("torch.nn.functional", "hardswish"), 315 | ("torch.Tensor", "__delitem__"), 316 | ("torch.Tensor", "__iter__"), 317 | ("torch.Tensor", "__init_subclass__"), 318 | ("torch.Tensor", "__torch_function__"), 319 | ("torch.Tensor", "__new__"), 320 | ("torch.Tensor", "__subclasshook__"), 321 | ("torch.Tensor", "as_subclass"), 322 | ("torch.Tensor", "reinforce"), 323 | ("torch.Tensor", "new"), 324 | ("torch.Tensor", "new_tensor"), 325 | ("torch.Tensor", "new_empty"), 326 | ("torch.Tensor", "new_empty_strided"), 327 | ("torch.Tensor", "new_zeros"), 328 | ("torch.Tensor", "new_ones"), 329 | ("torch.Tensor", "new_full"), 330 | ("torch.Tensor", "_make_subclass"), 331 | ("torch.Tensor", "solve"), 332 | ("torch.Tensor", "unflatten"), 333 | ("torch.Tensor", "real"), 334 | ("torch.Tensor", "imag"), 335 | ("torch.Tensor", "T"), 336 | ("torch.Tensor", "mT"), 337 | ("torch.Tensor", "H") 338 | ] 339 | 340 | 341 | @functools.lru_cache(None) 342 | def my_get_overridable_functions() -> List: 343 | index = {} 344 | func_names = [] 345 | tested_namespaces = [ 346 | ("torch", torch, torch.__all__ + dir(torch._C._VariableFunctions)), 347 | ("torch._VF", torch._VF, dir(torch._C._VariableFunctions)), 348 | ("torch.functional", torch.functional, torch.functional.__all__), 349 | ("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)), 350 | ("torch.nn.init", torch.nn.init, dir(torch.nn.init)), 351 | ("torch.Tensor", torch.Tensor, dir(torch.Tensor)), 352 | ("torch.linalg", torch.linalg, dir(torch.linalg)), 353 | ("torch.fft", torch.fft, dir(torch.fft)), 354 | ] 355 | if hasattr(torch, "special"): 356 | tested_namespaces.append(("torch.special", torch.special, dir(torch.special))) 357 | for namespace_str, namespace, ns_funcs in tested_namespaces: 358 | for func_name in ns_funcs: 359 | ignore = False 360 | # ignore private functions or functions that are deleted in torch.__init__ 361 | if namespace is not torch.Tensor: 362 | if func_name.startswith("__"): 363 | continue 364 | elif func_name[0].isupper(): 365 | ignore = True 366 | elif func_name == "unique_dim": 367 | continue 368 | else: 369 | func = getattr(namespace, func_name) 370 | if getattr(object, func_name, None) == func: 371 | continue 372 | if func_name == "__weakref__": 373 | continue 374 | func = getattr(namespace, func_name) 375 | if namespace is torch.Tensor and getattr(object, func_name, None) == func: 376 | continue 377 | # ignore re-exported modules 378 | if isinstance(func, types.ModuleType): 379 | continue 380 | # ignore __future__ imports 381 | if isinstance(func, getattr(__future__, "_Feature")): 382 | continue 383 | 384 | if not callable(func) and hasattr(func, "__get__"): 385 | index[func.__get__] = f"{namespace_str}.{func_name}.__get__" 386 | index[func.__set__] = f"{namespace_str}.{func_name}.__set__" 387 | if ignore: 388 | continue 389 | if func.__get__ in get_ignored_functions(): 390 | msg = ( 391 | "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " 392 | "but still has an explicit override" 393 | ) 394 | assert func.__get__ not in get_testing_overrides(), msg.format( 395 | namespace, func.__name__ 396 | ) 397 | continue 398 | else: 399 | func_names.append((f"{namespace_str}.{func_name}", "__get__")) 400 | continue 401 | 402 | if not callable(func): 403 | continue 404 | 405 | index[func] = f"{namespace_str}.{func_name}" 406 | 407 | if ignore: 408 | continue 409 | 410 | # cannot be overriden by __torch_function__ 411 | if func in get_ignored_functions(): 412 | msg = ( 413 | "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " 414 | "but still has an explicit override" 415 | ) 416 | assert func not in get_testing_overrides(), msg.format( 417 | namespace, func.__name__ 418 | ) 419 | continue 420 | func_names.append((f"{namespace_str}", func_name)) 421 | return func_names 422 | 423 | 424 | TORCHVISION_FUNCS = [ 425 | ("torch.ops.torchvision.nms", "_op"), 426 | ("torch.ops.torchvision.deform_conv2d", "_op"), 427 | ("torch.ops.torchvision.ps_roi_align", "_op"), 428 | ("torch.ops.torchvision.ps_roi_pool", "_op"), 429 | ("torch.ops.torchvision.roi_align", "_op"), 430 | ("torch.ops.torchvision.roi_pool", "_op")] 431 | 432 | with warnings.catch_warnings(): 433 | warnings.simplefilter("ignore") 434 | OVERRIDABLE_FUNCS = my_get_overridable_functions() 435 | ORIG_TORCH_FUNCS = OVERRIDABLE_FUNCS + IGNORED_FUNCS 436 | 437 | try: 438 | import torchvision 439 | 440 | ORIG_TORCH_FUNCS += TORCHVISION_FUNCS 441 | except ModuleNotFoundError: 442 | pass 443 | -------------------------------------------------------------------------------- /torchlens/interface.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import TYPE_CHECKING, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | if TYPE_CHECKING: 8 | from .model_history import ModelHistory 9 | 10 | from .tensor_log import TensorLogEntry 11 | 12 | 13 | def _getitem_during_pass(self: "ModelHistory", ix) -> TensorLogEntry: 14 | """Fetches an item when the pass is unfinished, only based on its raw barcode. 15 | 16 | Args: 17 | ix: layer's barcode 18 | 19 | Returns: 20 | Tensor log entry object with info about specified layer. 21 | """ 22 | if ix in self._raw_tensor_dict: 23 | return self._raw_tensor_dict[ix] 24 | else: 25 | raise ValueError(f"{ix} not found in the ModelHistory object.") 26 | 27 | 28 | def _getitem_after_pass(self, ix): 29 | """ 30 | Overloaded such that entries can be fetched either by their position in the tensor log, their layer label, 31 | or their module address. It should say so and tell them which labels are valid. 32 | """ 33 | if ix in self.layer_dict_all_keys: 34 | return self.layer_dict_all_keys[ix] 35 | 36 | keys_with_substr = [ 37 | key for key in self.layer_dict_all_keys if str(ix) in str(key) 38 | ] 39 | if len(keys_with_substr) == 1: 40 | return self.layer_dict_all_keys[keys_with_substr[0]] 41 | 42 | _give_user_feedback_about_lookup_key(self, ix, "get_one_item") 43 | 44 | 45 | def _give_user_feedback_about_lookup_key(self, key: Union[int, str], mode: str): 46 | """For __getitem__ and get_op_nums_from_user_labels, gives the user feedback about the user key 47 | they entered if it doesn't yield any matches. 48 | 49 | Args: 50 | key: Lookup key used by the user. 51 | """ 52 | if (type(key) == int) and ( 53 | key >= len(self.layer_list) or key < -len(self.layer_list) 54 | ): 55 | raise ValueError( 56 | f"You specified the layer with index {key}, but there are only {len(self.layer_list)} " 57 | f"layers; please specify an index in the range " 58 | f"-{len(self.layer_list)} - {len(self.layer_list) - 1}." 59 | ) 60 | 61 | if key in self.module_addresses: 62 | module_num_passes = self.module_num_passes[key] 63 | raise ValueError( 64 | f"You specified output of module {key}, but it has {module_num_passes} passes; " 65 | f"please specify e.g. {key}:2 for the second pass of {key}." 66 | ) 67 | 68 | if key.split(":")[0] in self.module_addresses: 69 | module, pass_num = key.split(":") 70 | module_num_passes = self.module_num_passes[module] 71 | raise ValueError( 72 | f"You specified module {module} pass {pass_num}, but {module} only has " 73 | f"{module_num_passes} passes; specify a lower number." 74 | ) 75 | 76 | if key in self.layer_labels_no_pass: 77 | layer_num_passes = self.layer_num_passes[key] 78 | raise ValueError( 79 | f"You specified output of layer {key}, but it has {layer_num_passes} passes; " 80 | f"please specify e.g. {key}:2 for the second pass of {key}." 81 | ) 82 | 83 | if key.split(":")[0] in self.layer_labels_no_pass: 84 | layer_label, pass_num = key.split(":") 85 | layer_num_passes = self.layer_num_passes[layer_label] 86 | raise ValueError( 87 | f"You specified layer {layer_label} pass {pass_num}, but {layer_label} only has " 88 | f"{layer_num_passes} passes. Specify a lower number." 89 | ) 90 | 91 | raise ValueError(_get_lookup_help_str(self, key, mode)) 92 | 93 | 94 | def _str_after_pass(self) -> str: 95 | """Readable summary of the model history after the pass is finished. 96 | 97 | Returns: 98 | String summarizing the model. 99 | """ 100 | s = f"Log of {self.model_name} forward pass:" 101 | 102 | # General info 103 | 104 | s += f"\n\tRandom seed: {self.random_seed_used}" 105 | s += ( 106 | f"\n\tTime elapsed: {np.round(self.elapsed_time_total, 3)}s " 107 | f"({np.round(self.elapsed_time_torchlens_logging, 3)}s spent logging)" 108 | ) 109 | 110 | # Overall model structure 111 | 112 | s += "\n\tStructure:" 113 | if self.model_is_recurrent: 114 | s += f"\n\t\t- recurrent (at most {self.model_max_recurrent_loops} loops)" 115 | else: 116 | s += "\n\t\t- purely feedforward, no recurrence" 117 | 118 | if self.model_is_branching: 119 | s += "\n\t\t- with branching" 120 | else: 121 | s += "\n\t\t- no branching" 122 | 123 | if self.model_has_conditional_branching: 124 | s += "\n\t\t- with conditional (if-then) branching" 125 | else: 126 | s += "\n\t\t- no conditional (if-then) branching" 127 | 128 | if len(self.buffer_layers) > 0: 129 | s += f"\n\t\t- contains {len(self.buffer_layers)} buffer layers" 130 | 131 | s += f"\n\t\t- {len(self.module_addresses)} total modules" 132 | 133 | # Model tensors: 134 | 135 | s += "\n\tTensor info:" 136 | s += ( 137 | f"\n\t\t- {self.num_tensors_total} total tensors ({self.tensor_fsize_total_nice}) " 138 | f"computed in forward pass." 139 | ) 140 | s += f"\n\t\t- {self.num_tensors_saved} tensors ({self.tensor_fsize_saved_nice}) with saved activations." 141 | 142 | # Model parameters: 143 | 144 | s += ( 145 | f"\n\tParameters: {self.total_param_layers} parameter operations ({self.total_params} params total; " 146 | f"{self.total_params_fsize_nice})" 147 | ) 148 | 149 | # Print the module hierarchy. 150 | s += "\n\tModule Hierarchy:" 151 | s += _module_hierarchy_str(self) 152 | 153 | # Now print all layers. 154 | s += "\n\tLayers" 155 | if self._all_layers_saved: 156 | s += " (all have saved activations):" 157 | elif self.num_tensors_saved == 0: 158 | s += " (no layer activations are saved):" 159 | else: 160 | s += " (* means layer has saved activations):" 161 | for layer_ind, layer_barcode in enumerate(self.layer_labels): 162 | pass_num = self.layer_dict_main_keys[layer_barcode].pass_num 163 | total_passes = self.layer_dict_main_keys[layer_barcode].layer_passes_total 164 | if total_passes > 1: 165 | pass_str = f" ({pass_num}/{total_passes} passes)" 166 | else: 167 | pass_str = "" 168 | 169 | if self.layer_dict_main_keys[layer_barcode].has_saved_activations and ( 170 | not self._all_layers_saved 171 | ): 172 | s += "\n\t\t* " 173 | else: 174 | s += "\n\t\t " 175 | s += f"({layer_ind}) {layer_barcode} {pass_str}" 176 | 177 | return s 178 | 179 | 180 | def _str_during_pass(self) -> str: 181 | """Readable summary of the model history during the pass, as a debugging aid. 182 | 183 | Returns: 184 | String summarizing the model. 185 | """ 186 | s = f"Log of {self.model_name} forward pass (pass still ongoing):" 187 | s += f"\n\tRandom seed: {self.random_seed_used}" 188 | s += f"\n\tInput tensors: {self.input_layers}" 189 | s += f"\n\tOutput tensors: {self.output_layers}" 190 | s += f"\n\tInternally initialized tensors: {self.internally_initialized_layers}" 191 | s += f"\n\tInternally terminated tensors: {self.internally_terminated_layers}" 192 | s += f"\n\tInternally terminated boolean tensors: {self.internally_terminated_bool_layers}" 193 | s += f"\n\tBuffer tensors: {self.buffer_layers}" 194 | s += "\n\tRaw layer labels:" 195 | for layer in self._raw_tensor_labels_list: 196 | s += f"\n\t\t{layer}" 197 | return s 198 | 199 | 200 | def pretty_print_list_w_line_breaks(lst, indent_chars: str, line_break_every=5): 201 | """ 202 | Utility function to pretty print a list with line breaks, adding indent_chars every line. 203 | """ 204 | s = f"\n{indent_chars}" 205 | for i, item in enumerate(lst): 206 | s += f"{item}" 207 | if i < len(lst) - 1: 208 | s += ", " 209 | if ((i + 1) % line_break_every == 0) and (i < len(lst) - 1): 210 | s += f"\n{indent_chars}" 211 | return s 212 | 213 | 214 | def _get_lookup_help_str(self, layer_label: Union[int, str], mode: str) -> str: 215 | """Generates a help string to be used in error messages when indexing fails.""" 216 | sample_layer1 = random.choice(self.layer_labels_w_pass) 217 | sample_layer2 = random.choice(self.layer_labels_no_pass) 218 | if len(self.module_addresses) > 0: 219 | sample_module1 = random.choice(self.module_addresses) 220 | sample_module2 = random.choice(self.module_passes) 221 | else: 222 | sample_module1 = "features.3" 223 | sample_module2 = "features.4:2" 224 | module_str = f"(e.g., {sample_module1}, {sample_module2})" 225 | if mode == "get_one_item": 226 | msg = ( 227 | "e.g., 'pool' will grab the maxpool2d or avgpool2d layer, 'maxpool' will grab the 'maxpool2d' " 228 | "layer, etc., but there must be only one such matching layer" 229 | ) 230 | elif mode == "query_multiple": 231 | msg = ( 232 | "e.g., 'pool' will grab all maxpool2d or avgpool2d layers, 'maxpool' will grab all 'maxpool2d' " 233 | "layers, etc." 234 | ) 235 | else: 236 | raise ValueError("mode must be either get_one_item or query_multiple") 237 | help_str = ( 238 | f"Layer {layer_label} not recognized; please specify either " 239 | f"\n\n\t1) an integer giving the ordinal position of the layer " 240 | f"(e.g. 2 for 3rd layer, -4 for fourth-to-last), " 241 | f"\n\t2) the layer label (e.g., {sample_layer1}, {sample_layer2}), " 242 | f"\n\t3) the module address {module_str}" 243 | f"\n\t4) A substring of any desired layer label ({msg})." 244 | f"\n\n(Label meaning: conv2d_3_4:2 means the second pass of the third convolutional layer, " 245 | f"and fourth layer overall in the model.)" 246 | ) 247 | return help_str 248 | 249 | 250 | def _module_hierarchy_str(self): 251 | """ 252 | Utility function to print the nested module hierarchy. 253 | """ 254 | s = "" 255 | for module_pass in self.top_level_module_passes: 256 | module, pass_num = module_pass.split(":") 257 | s += f"\n\t\t{module}" 258 | if self.module_num_passes[module] > 1: 259 | s += f":{pass_num}" 260 | s += _module_hierarchy_str_helper(self, module_pass, 1) 261 | return s 262 | 263 | 264 | def _module_hierarchy_str_helper(self, module_pass, level): 265 | """ 266 | Helper function for _module_hierarchy_str. 267 | """ 268 | s = "" 269 | any_grandchild_modules = any( 270 | [ 271 | len(self.module_pass_children[submodule_pass]) > 0 272 | for submodule_pass in self.module_pass_children[module_pass] 273 | ] 274 | ) 275 | if any_grandchild_modules or len(self.module_pass_children[module_pass]) == 0: 276 | for submodule_pass in self.module_pass_children[module_pass]: 277 | submodule, pass_num = submodule_pass.split(":") 278 | s += f"\n\t\t{' ' * level}{submodule}" 279 | if self.module_num_passes[submodule] > 1: 280 | s += f":{pass_num}" 281 | s += _module_hierarchy_str_helper(self, submodule_pass, level + 1) 282 | else: 283 | submodule_list = [] 284 | for submodule_pass in self.module_pass_children[module_pass]: 285 | submodule, pass_num = submodule_pass.split(":") 286 | if self.module_num_passes[submodule] == 1: 287 | submodule_list.append(submodule) 288 | else: 289 | submodule_list.append(submodule_pass) 290 | s += pretty_print_list_w_line_breaks( 291 | submodule_list, line_break_every=8, indent_chars=f"\t\t{' ' * level}" 292 | ) 293 | return s 294 | 295 | 296 | def print_all_fields(self): 297 | """Print all data fields for ModelHistory.""" 298 | fields_to_exclude = [ 299 | "layer_list", 300 | "layer_dict_main_keys", 301 | "layer_dict_all_keys", 302 | "raw_tensor_dict", 303 | "decorated_to_orig_funcs_dict", 304 | ] 305 | 306 | for field in dir(self): 307 | attr = getattr(self, field) 308 | if not any( 309 | [field.startswith("_"), field in fields_to_exclude, callable(attr)] 310 | ): 311 | print(f"{field}: {attr}") 312 | 313 | 314 | def to_pandas(self) -> pd.DataFrame: 315 | """Returns a pandas dataframe with info about each layer. 316 | 317 | Returns: 318 | Pandas dataframe with info about each layer. 319 | """ 320 | fields_for_df = [ 321 | "layer_label", 322 | "layer_label_w_pass", 323 | "layer_label_no_pass", 324 | "layer_label_short", 325 | "layer_label_w_pass_short", 326 | "layer_label_no_pass_short", 327 | "layer_type", 328 | "layer_type_num", 329 | "layer_total_num", 330 | "layer_passes_total", 331 | "pass_num", 332 | "operation_num", 333 | "tensor_shape", 334 | "tensor_dtype", 335 | "tensor_fsize", 336 | "tensor_fsize_nice", 337 | "func_applied_name", 338 | "func_time_elapsed", 339 | "function_is_inplace", 340 | "gradfunc", 341 | "is_input_layer", 342 | "is_output_layer", 343 | "is_buffer_layer", 344 | "is_part_of_iterable_output", 345 | "iterable_output_index", 346 | "parent_layers", 347 | "has_parents", 348 | "orig_ancestors", 349 | "child_layers", 350 | "has_children", 351 | "output_descendents", 352 | "sibling_layers", 353 | "has_siblings", 354 | "spouse_layers", 355 | "has_spouses", 356 | "initialized_inside_model", 357 | "min_distance_from_input", 358 | "max_distance_from_input", 359 | "min_distance_from_output", 360 | "max_distance_from_output", 361 | "computed_with_params", 362 | "num_params_total", 363 | "parent_param_shapes", 364 | "parent_params_fsize", 365 | "parent_params_fsize_nice", 366 | "modules_entered", 367 | "modules_exited", 368 | "is_submodule_input", 369 | "is_submodule_output", 370 | "containing_module_origin", 371 | "containing_modules_origin_nested", 372 | ] 373 | 374 | fields_to_change_type = { 375 | "layer_type_num": int, 376 | "layer_total_num": int, 377 | "layer_passes_total": int, 378 | "pass_num": int, 379 | "operation_num": int, 380 | "function_is_inplace": bool, 381 | "is_input_layer": bool, 382 | "is_output_layer": bool, 383 | "is_buffer_layer": bool, 384 | "is_part_of_iterable_output": bool, 385 | "has_parents": bool, 386 | "has_children": bool, 387 | "has_siblings": bool, 388 | "has_spouses": bool, 389 | "computed_with_params": bool, 390 | "num_params_total": int, 391 | "parent_params_fsize": int, 392 | "tensor_fsize": int, 393 | "is_submodule_input": bool, 394 | "is_submodule_output": bool, 395 | } 396 | 397 | model_df_dictlist = [] 398 | for tensor_entry in self.layer_list: 399 | tensor_dict = {} 400 | for field_name in fields_for_df: 401 | tensor_dict[field_name] = getattr(tensor_entry, field_name) 402 | model_df_dictlist.append(tensor_dict) 403 | model_df = pd.DataFrame(model_df_dictlist) 404 | 405 | for field in fields_to_change_type: 406 | model_df[field] = model_df[field].astype(fields_to_change_type[field]) 407 | 408 | return model_df 409 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchLens 2 | 3 | **Quick Links** 4 | 5 | - [Paper introducing TorchLens](https://www.nature.com/articles/s41598-023-40807-0) 6 | - [CoLab tutorial](https://colab.research.google.com/drive/1ORJLGZPifvdsVPFqq1LYT3t5hV560SoW?usp=sharing) 7 | - [\"Menagerie\" of model visualizations](https://drive.google.com/drive/u/0/folders/1BsM6WPf3eB79-CRNgZejMxjg38rN6VCb) 8 | - [Metadata provided by TorchLens](https://static-content.springer.com/esm/art%3A10.1038%2Fs41598-023-40807-0/MediaObjects/41598_2023_40807_MOESM1_ESM.pdf) 9 | 10 | ## Overview 11 | 12 | *TorchLens* is a package for doing exactly two things: 13 | 14 | 1) Easily extracting the activations from every single intermediate operation in a PyTorch model—no 15 | modifications needed—in one line of code. "Every operation" means every operation; "one line" means one line. 16 | 2) Understanding the model's computational structure via an intuitive automatic visualization and extensive 17 | metadata ([partial list here](https://static-content.springer.com/esm/art%3A10.1038%2Fs41598-023-40807-0/MediaObjects/41598_2023_40807_MOESM1_ESM.pdf)) 18 | about the network's computational graph. 19 | 20 | Here it is in action for a very simple recurrent model; as you can see, you just define the model like normal and pass 21 | it in, and *TorchLens* returns a full log of the forward pass along with a visualization: 22 | 23 | ```python 24 | class SimpleRecurrent(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | self.fc = nn.Linear(in_features=5, out_features=5) 28 | 29 | def forward(self, x): 30 | for r in range(4): 31 | x = self.fc(x) 32 | x = x + 1 33 | x = x * 2 34 | return x 35 | 36 | 37 | simple_recurrent = SimpleRecurrent() 38 | model_history = tl.log_forward_pass(simple_recurrent, x, 39 | layers_to_save='all', 40 | vis_opt='rolled') 41 | print(model_history['linear_1_1:2'].tensor_contents) # second pass of first linear layer 42 | 43 | ''' 44 | tensor([[-0.0690, -1.3957, -0.3231, -0.1980, 0.7197], 45 | [-0.1083, -1.5051, -0.2570, -0.2024, 0.8248], 46 | [ 0.1031, -1.4315, -0.5999, -0.4017, 0.7580], 47 | [-0.0396, -1.3813, -0.3523, -0.2008, 0.6654], 48 | [ 0.0980, -1.4073, -0.5934, -0.3866, 0.7371], 49 | [-0.1106, -1.2909, -0.3393, -0.2439, 0.7345]]) 50 | ''' 51 | ``` 52 | 53 | 54 | 55 | And here it is for a very complex transformer model ([swin_v2_b](https://arxiv.org/abs/2103.14030)) with 1932 operations 56 | in its forward pass; you can grab the saved outputs of every last one: 57 | 58 | 59 | 60 | The goal of *TorchLens* is to do this for any PyTorch model whatsoever. You can see a bunch of example model 61 | visualizations in this [model menagerie](https://drive.google.com/drive/u/0/folders/1BsM6WPf3eB79-CRNgZejMxjg38rN6VCb). 62 | 63 | ## Installation 64 | 65 | To install *TorchLens*, first install graphviz if you haven't already (required to generate the network visualizations), 66 | and then install *TorchLens* using pip: 67 | 68 | ```bash 69 | sudo apt install graphviz 70 | pip install torchlens 71 | ``` 72 | 73 | *TorchLens* is compatible with versions 1.8.0+ of PyTorch. 74 | 75 | ## How-To Guide 76 | 77 | Below is a quick demo of how to use it; for an interactive demonstration, see 78 | the [CoLab walkthrough](https://colab.research.google.com/drive/1ORJLGZPifvdsVPFqq1LYT3t5hV560SoW?usp=sharing). 79 | 80 | The main function of *TorchLens* is `log_forward_pass`: when called on a model and input, it runs a 81 | forward pass on the model and returns a ModelHistory object containing the intermediate layer activations and 82 | accompanying metadata, along with a visual representation of every operation that occurred during the forward pass: 83 | 84 | ```python 85 | import torch 86 | import torchvision 87 | import torchlens as tl 88 | 89 | alexnet = torchvision.models.alexnet() 90 | x = torch.rand(1, 3, 224, 224) 91 | model_history = tl.log_forward_pass(alexnet, x, layers_to_save='all', vis_opt='unrolled') 92 | print(model_history) 93 | 94 | ''' 95 | Log of AlexNet forward pass: 96 | Model structure: purely feedforward, without branching; 23 total modules. 97 | 24 tensors (4.8 MB) computed in forward pass; 24 tensors (4.8 MB) saved. 98 | 16 parameter operations (61100840 params total; 248.7 MB). 99 | Random seed: 3210097511 100 | Time elapsed: 0.288s 101 | Module Hierarchy: 102 | features: 103 | features.0, features.1, features.2, features.3, features.4, features.5, features.6, features.7, 104 | features.8, features.9, features.10, features.11, features.12 105 | avgpool 106 | classifier: 107 | classifier.0, classifier.1, classifier.2, classifier.3, classifier.4, classifier.5, classifier.6 108 | Layers: 109 | 0: input_1_0 110 | 1: conv2d_1_1 111 | 2: relu_1_2 112 | 3: maxpool2d_1_3 113 | 4: conv2d_2_4 114 | 5: relu_2_5 115 | 6: maxpool2d_2_6 116 | 7: conv2d_3_7 117 | 8: relu_3_8 118 | 9: conv2d_4_9 119 | 10: relu_4_10 120 | 11: conv2d_5_11 121 | 12: relu_5_12 122 | 13: maxpool2d_3_13 123 | 14: adaptiveavgpool2d_1_14 124 | 15: flatten_1_15 125 | 16: dropout_1_16 126 | 17: linear_1_17 127 | 18: relu_6_18 128 | 19: dropout_2_19 129 | 20: linear_2_20 130 | 21: relu_7_21 131 | 22: linear_3_22 132 | 23: output_1_23 133 | ''' 134 | ``` 135 | 136 | 137 | 138 | You can pull out information about a given layer, including its activations and helpful metadata, by indexing 139 | the ModelHistory object in any of these equivalent ways: 140 | 141 | 1) the name of a layer (with the convention that 'conv2d_3_7' is the 3rd convolutional layer, and the 7th layer overall) 142 | 2) the name of a module (e.g., 'features' or 'classifier.3') for which that layer is an output, or 143 | 3) the ordinal position of the layer (e.g., 2 for the 2nd layer, -5 for the fifth-to-last; inputs and outputs count as 144 | layers here). 145 | 146 | To quickly figure out these names, you can look at the graph visualization, or at the output of printing the 147 | ModelHistory object (both shown above). Here are some examples of how to pull out information about a 148 | particular layer, and also how to pull out the actual activations from that layer: 149 | 150 | ```python 151 | print(model_history['conv2d_3_7']) # pulling out layer by its name 152 | # The following commented lines pull out the same layer: 153 | # model_history['conv2d_3'] you can omit the second number (since strictly speaking it's redundant) 154 | # model_history['conv2d_3_7:1'] colon indicates the pass of a layer (here just one) 155 | # model_history['features.6'] can grab a layer by the module for which it is an output 156 | # model_history[7] the 7th layer overall 157 | # model_history[-17] the 17th-to-last layer 158 | ''' 159 | Layer conv2d_3_7, operation 8/24: 160 | Output tensor: shape=(1, 384, 13, 13), dype=torch.float32, size=253.5 KB 161 | tensor([[ 0.0503, -0.1089, -0.1210, -0.1034, -0.1254], 162 | [ 0.0789, -0.0752, -0.0581, -0.0372, -0.0181], 163 | [ 0.0949, -0.0780, -0.0401, -0.0209, -0.0095], 164 | [ 0.0929, -0.0353, -0.0220, -0.0324, -0.0295], 165 | [ 0.1100, -0.0337, -0.0330, -0.0479, -0.0235]])... 166 | Params: Computed from params with shape (384,), (384, 192, 3, 3); 663936 params total (2.5 MB) 167 | Parent Layers: maxpool2d_2_6 168 | Child Layers: relu_3_8 169 | Function: conv2d (gradfunc=ConvolutionBackward0) 170 | Computed inside module: features.6 171 | Time elapsed: 5.670E-04s 172 | Output of modules: features.6 173 | Output of bottom-level module: features.6 174 | Lookup keys: -17, 7, conv2d_3_7, conv2d_3_7:1, features.6, features.6:1 175 | ''' 176 | 177 | # You can pull out the actual output activations from a layer with the tensor_contents field: 178 | print(model_history['conv2d_3_7'].tensor_contents) 179 | ''' 180 | tensor([[[[-0.0867, -0.0787, -0.0817, ..., -0.0820, -0.0655, -0.0195], 181 | [-0.1213, -0.1130, -0.1386, ..., -0.1331, -0.1118, -0.0520], 182 | [-0.0959, -0.0973, -0.1078, ..., -0.1103, -0.1091, -0.0760], 183 | ..., 184 | [-0.0906, -0.1146, -0.1308, ..., -0.1076, -0.1129, -0.0689], 185 | [-0.1017, -0.1256, -0.1100, ..., -0.1160, -0.1035, -0.0801], 186 | [-0.1006, -0.0941, -0.1204, ..., -0.1146, -0.1065, -0.0631]]... 187 | ''' 188 | ``` 189 | 190 | If you do not wish to save the activations for all layers (e.g., to save memory), you can specify which layers to save 191 | with the `layers_to_save` argument when calling `log_forward_pass`; you can either indicate layers in the same way 192 | as indexing them above, or by passing in a desired substring for filtering the layers (e.g., 'conv' 193 | will pull out all conv layers): 194 | 195 | ```python 196 | # Pull out conv2d_3_7, the output of the 'features' module, the fifth-to-last layer, and all linear (i.e., fc) layers: 197 | model_history = tl.log_forward_pass(alexnet, x, vis_opt='unrolled', 198 | layers_to_save=['conv2d_3_7', 'features', -5, 'linear']) 199 | print(model_history.layer_labels) 200 | ''' 201 | ['conv2d_3_7', 'maxpool2d_3_13', 'linear_1_17', 'dropout_2_19', 'linear_2_20', 'linear_3_22'] 202 | ''' 203 | ``` 204 | 205 | The main function of *TorchLens* is `log_forward_pass`; the remaining functions are: 206 | 207 | 1) `get_model_metadata`, to retrieve all model metadata without saving any activations (e.g., to figure out which 208 | layers you wish to save; note that this is the same as calling `log_forward_pass` with `layers_to_save=None`) 209 | 2) `show_model_graph`, which visualizes the model graph without saving any activations 210 | 3) `validate_model_activations`, which runs a procedure to check that the activations are correct: specifically, 211 | it runs a forward pass and saves all intermediate activations, re-runs the forward pass from each intermediate 212 | layer, and checks that the resulting output matches the ground-truth output. It also checks that swapping in 213 | random nonsense activations instead of the saved activations generates the wrong output. **If this function ever 214 | returns False (i.e., the saved activations are wrong), please contact me via email (johnmarkedwardtaylor@gmail.com) 215 | or on this GitHub page with a description of the problem, and I will update TorchLens to fix the problem.** 216 | 217 | And that's it. *TorchLens* remains in active development, and the goal is for it to work with any PyTorch model 218 | whatosever without exception. As of the time of this writing, it has been tested with over 700 219 | image, video, auditory, multimodal, and language models, including feedforward, recurrent, transformer, 220 | and graph neural networks. 221 | 222 | ## Miscellaneous Features 223 | 224 | - You can visualize models at different levels of nesting depth using the `vis_nesting_depth` argument 225 | to `log_forward_pass`; for example, here you can see one of GoogLeNet's "inception" modules at different levels of 226 | nesting depth: 227 | 228 | 229 | 230 | - An experimental feature is to extract not just the activations from all of a model's operations, 231 | but also the gradients from a backward pass (which you can compute based on any intermediate layer, not just the 232 | model's 233 | output), 234 | and also visualize the path taken by the backward pass (shown with blue arrows below). See the CoLab tutorial for 235 | instructions on how to do this. 236 | 237 | 238 | 239 | - You can see the literal code that was used to run the model with the func_call_stack field: 240 | 241 | ```python 242 | print(model_history['conv2d_3'].func_call_stack[8]) 243 | ''' 244 | {'call_fname': '/usr/local/lib/python3.10/dist-packages/torchvision/models/alexnet.py', 245 | 'call_linenum': 48, 246 | 'function': 'forward', 247 | 'code_context': [' nn.Linear(256 * 6 * 6, 4096),\n', 248 | ' nn.ReLU(inplace=True),\n', 249 | ' nn.Dropout(p=dropout),\n', 250 | ' nn.Linear(4096, 4096),\n', 251 | ' nn.ReLU(inplace=True),\n', 252 | ' nn.Linear(4096, num_classes),\n', 253 | ' )\n', 254 | '\n', 255 | ' def forward(self, x: torch.Tensor) -> torch.Tensor:\n', 256 | ' x = self.features(x)\n', 257 | ' x = self.avgpool(x)\n', 258 | ' x = torch.flatten(x, 1)\n', 259 | ' x = self.classifier(x)\n', 260 | ' return x\n', 261 | '\n', 262 | '\n', 263 | 'class AlexNet_Weights(WeightsEnum):\n', 264 | ' IMAGENET1K_V1 = Weights(\n', 265 | ' url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",\n']} 266 | ''' 267 | ``` 268 | 269 | ## Planned Features 270 | 271 | 1) In the further future, I am considering adding functionality to not just save activations, 272 | but counterfactually intervene on them (e.g., how would the output have changed if these parameters 273 | were different or if a different nonlinearity were used). Let me know if you'd find this useful 274 | and if so, what specific kind of functionality you'd want. 275 | 2) I am planning to add an option to only visualize a single submodule of a model rather than the full graph at once. 276 | 277 | ## Other Packages You Should Check Out 278 | 279 | The goal is for *TorchLens* to completely solve the problem of extracting activations and metadata 280 | from deep neural networks and visualizing their structure so that nobody has to think about this stuff ever again, but 281 | it intentionally leaves out certain functionality: for example, it has no functions for loading models or stimuli, or 282 | for analyzing the extracted activations. This is in part because it's impossible to predict all the things you might 283 | want to do with the activations, or all the possible models you might want to look at, but also because there are 284 | already outstanding packages for doing these things. Here are a few-let me know if I've missed any! 285 | 286 | - [Cerbrec](cerbrec.com): Program for interactively visualizing and debugging deep neural networks (uses TorchLens under 287 | the hood for extracting the graphs of PyTorch models!) 288 | - [ThingsVision](https://github.com/ViCCo-Group/thingsvision): has excellent functionality for loading vision models, 289 | loading stimuli, and analyzing the extracted activations 290 | - [Net2Brain](https://github.com/cvai-roig-lab/Net2Brain): similar excellent end-to-end functionality to ThingsVision, 291 | along with functionality for comparing extracted activations to neural data. 292 | - [surgeon-pytorch](https://github.com/archinetai/surgeon-pytorch): easy-to-use functionality for extracting activations 293 | from models, along with functionality for training a model using loss functions based on intermediate layer 294 | activations 295 | - [deepdive](https://github.com/ColinConwell/DeepDive): has outstanding functionality for loading and benchmarking 296 | many different models 297 | - [torchvision feature_extraction module](https://pytorch.org/vision/stable/feature_extraction.html): can extract 298 | activations from models with static computational graphs 299 | - [rsatoolbox3](https://github.com/rsagroup/rsatoolbox): total solution for performing representational similarity 300 | analysis on DNN activations and brain data 301 | 302 | ## Acknowledgments 303 | 304 | The development of *TorchLens* benefitted greatly from discussions with Nikolaus Kriegeskorte, George Alvarez, 305 | Alfredo Canziani, Tal Golan, and the Visual Inference Lab at Columbia University. Thank you to Kale Kundert 306 | for helpful discussion and for his code contributions enabling PyTorch Lightning compatibility. 307 | All network visualizations were created with graphviz. Logo created by Nikolaus Kriegeskorte. 308 | 309 | ## Citing Torchlens 310 | 311 | To cite *TorchLens*, you can 312 | cite [this paper describing the package](https://www.nature.com/articles/s41598-023-40807-0) (and consider adding a star 313 | to this repo if you find *TorchLens* useful): 314 | 315 | Taylor, J., Kriegeskorte, N. Extracting and visualizing hidden activations and computational graphs of PyTorch models 316 | with *TorchLens*. Sci Rep 13, 14375 (2023). https://doi.org/10.1038/s41598-023-40807-0 317 | 318 | ## Contact 319 | 320 | As *TorchLens* is still in active development, I would love your feedback. Please contact 321 | johnmarkedwardtaylor@gmail.com, 322 | contact me via [twitter](https://twitter.com/johnmark_taylor), or post on 323 | the [issues](https://github.com/johnmarktaylor91/torchlens/issues) 324 | or [discussion](https://github.com/johnmarktaylor91/torchlens/discussions) page for this GitHub 325 | repository, if you have any questions, comments, or suggestions (or if you'd be interested in collaborating!). 326 | -------------------------------------------------------------------------------- /torchlens/model_funcs.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from functools import wraps 3 | from typing import Callable, Dict, List, TYPE_CHECKING 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from .helper_funcs import get_vars_of_type_from_obj, iter_accessible_attributes, remove_attributes_starting_with_str 9 | from .logging_funcs import log_source_tensor 10 | 11 | if TYPE_CHECKING: 12 | from .model_history import ModelHistory 13 | 14 | 15 | def prepare_model( 16 | model_log: "ModelHistory", 17 | model: nn.Module, 18 | module_orig_forward_funcs: Dict, 19 | decorated_func_mapper: Dict[Callable, Callable], 20 | ): 21 | """Adds annotations and hooks to the model, and decorates any functions in the model. 22 | 23 | Args: 24 | model: Model to prepare. 25 | module_orig_forward_funcs: Dict with the original forward funcs for each submodule 26 | decorated_func_mapper: Dictionary mapping decorated functions to original functions, so they can be restored 27 | 28 | Returns: 29 | Model with hooks and attributes added. 30 | """ 31 | model_log.model_name = str(type(model).__name__) 32 | model.tl_module_address = "" 33 | model.tl_source_model_history = model_log 34 | 35 | module_stack = [(model, "")] # list of tuples (name, module) 36 | 37 | while len(module_stack) > 0: 38 | module, parent_address = module_stack.pop() 39 | module_children = list(module.named_children()) 40 | 41 | # Decorate any torch functions in the model: 42 | for func_name, func in module.__dict__.items(): 43 | if ( 44 | (func_name[0:2] == "__") 45 | or (not callable(func)) 46 | or (func not in decorated_func_mapper) 47 | ): 48 | continue 49 | module.__dict__[func_name] = decorated_func_mapper[func] 50 | 51 | # Annotate the children with the full address. 52 | for c, (child_name, child_module) in enumerate(module_children): 53 | child_address = ( 54 | f"{parent_address}.{child_name}" 55 | if parent_address != "" 56 | else child_name 57 | ) 58 | child_module.tl_module_address = child_address 59 | module_children[c] = (child_module, child_address) 60 | module_stack = module_children + module_stack 61 | 62 | if module == model: # don't tag the model itself. 63 | continue 64 | 65 | module.tl_source_model_history = model_log 66 | module.tl_module_type = str(type(module).__name__) 67 | model_log.module_types[module.tl_module_address] = module.tl_module_type 68 | module.tl_module_pass_num = 0 69 | module.tl_module_pass_labels = [] 70 | module.tl_tensors_entered_labels = [] 71 | module.tl_tensors_exited_labels = [] 72 | 73 | # Add decorators. 74 | 75 | if hasattr(module, "forward") and not hasattr( 76 | module.forward, "tl_forward_call_is_decorated" 77 | ): 78 | module_orig_forward_funcs[module] = module.forward 79 | module.forward = module_forward_decorator(model_log, module.forward, module) 80 | module.forward.tl_forward_call_is_decorated = True 81 | 82 | # Mark all parameters with requires_grad = True, and mark what they were before, so they can be restored on cleanup. 83 | for param in model.parameters(): 84 | param.tl_requires_grad = param.requires_grad 85 | param.requires_grad = True 86 | 87 | # And prepare any buffer tensors. 88 | prepare_buffer_tensors(model_log, model) 89 | 90 | 91 | def prepare_buffer_tensors(model_log, model: nn.Module): 92 | """Goes through a model and all its submodules, and prepares any "buffer" tensors: tensors 93 | attached to the module that aren't model parameters. 94 | 95 | Args: 96 | model: PyTorch model 97 | 98 | Returns: 99 | PyTorch model with all buffer tensors prepared and ready to track. 100 | """ 101 | submodules = get_all_submodules(model) 102 | for submodule in submodules: 103 | attr_list = list(submodule.named_buffers()) + list(iter_accessible_attributes(submodule)) 104 | for attribute_name, attribute in attr_list: 105 | if issubclass(type(attribute), torch.Tensor) and not issubclass( 106 | type(attribute), torch.nn.Parameter 107 | ) and not hasattr(attribute, 'tl_buffer_address'): 108 | if submodule.tl_module_address == "": 109 | buffer_address = attribute_name 110 | else: 111 | buffer_address = ( 112 | submodule.tl_module_address + "." + attribute_name 113 | ) 114 | setattr(attribute, 'tl_buffer_address', buffer_address) 115 | 116 | 117 | def module_forward_decorator( 118 | model_log, orig_forward: Callable, module: nn.Module 119 | ) -> Callable: 120 | @wraps(orig_forward) 121 | def decorated_forward(*args, **kwargs): 122 | if model_log.logging_mode == "fast": # do bare minimum for logging. 123 | out = orig_forward(*args, **kwargs) 124 | output_tensors = get_vars_of_type_from_obj( 125 | out, torch.Tensor, search_depth=4 126 | ) 127 | for t in output_tensors: 128 | # if identity module, run the function for bookkeeping 129 | if module.tl_module_type.lower() == "identity": 130 | t = getattr(torch, "identity")(t) 131 | return out 132 | 133 | # "Pre-hook" operations: 134 | module_address = module.tl_module_address 135 | module.tl_module_pass_num += 1 136 | module_pass_label = (module_address, module.tl_module_pass_num) 137 | module.tl_module_pass_labels.append(module_pass_label) 138 | input_tensors = get_vars_of_type_from_obj( 139 | [args, kwargs], torch.Tensor, [torch.nn.Parameter], search_depth=5 140 | ) 141 | input_tensor_labels = set() 142 | for t in input_tensors: 143 | if (not hasattr(t, 'tl_tensor_label_raw')) and hasattr(t, 'tl_buffer_address'): 144 | log_source_tensor(model_log, t, 'buffer', getattr(t, 'tl_buffer_address')) 145 | tensor_entry = model_log._raw_tensor_dict[t.tl_tensor_label_raw] 146 | input_tensor_labels.add(t.tl_tensor_label_raw) 147 | module.tl_tensors_entered_labels.append(t.tl_tensor_label_raw) 148 | tensor_entry.modules_entered.append(module_address) 149 | tensor_entry.module_passes_entered.append(module_pass_label) 150 | tensor_entry.is_submodule_input = True 151 | for arg_key, arg_val in list(enumerate(args)) + list(kwargs.items()): 152 | if arg_val is t: 153 | tensor_entry.modules_entered_argnames[ 154 | f"{module_pass_label[0]}:{module_pass_label[1]}"].append(arg_key) 155 | model_log.module_layer_argnames[(f"{module_pass_label[0]}:" 156 | f"{module_pass_label[1]}")].append( 157 | (t.tl_tensor_label_raw, arg_key)) 158 | tensor_entry.module_entry_exit_thread_output.append( 159 | ("+", module_pass_label[0], module_pass_label[1]) 160 | ) 161 | 162 | # Check the buffers. 163 | for buffer_name, buffer_tensor in module.named_buffers(): 164 | if hasattr(buffer_tensor, 'tl_buffer_address'): 165 | continue 166 | if module.tl_module_address == '': 167 | buffer_address = buffer_name 168 | else: 169 | buffer_address = f"{module.tl_module_address}.{buffer_name}" 170 | buffer_tensor.tl_buffer_address = buffer_address 171 | buffer_tensor.tl_buffer_parent = getattr(buffer_tensor, 'tl_tensor_label_raw') 172 | delattr(buffer_tensor, 'tl_tensor_label_raw') 173 | 174 | # The function call 175 | out = orig_forward(*args, **kwargs) 176 | 177 | # "Post-hook" operations: 178 | module_address = module.tl_module_address 179 | module_pass_num = module.tl_module_pass_num 180 | module_entry_label = module.tl_module_pass_labels.pop() 181 | output_tensors = get_vars_of_type_from_obj( 182 | out, torch.Tensor, search_depth=4 183 | ) 184 | for t in output_tensors: 185 | # if identity module or tensor unchanged, run the identity function for bookkeeping 186 | if (module.tl_module_type.lower() == "identity") or ( 187 | t.tl_tensor_label_raw in input_tensor_labels 188 | ): 189 | t = getattr(torch, "identity")(t) 190 | tensor_entry = model_log._raw_tensor_dict[t.tl_tensor_label_raw] 191 | tensor_entry.is_submodule_output = True 192 | tensor_entry.is_bottom_level_submodule_output = ( 193 | log_whether_exited_submodule_is_bottom_level(model_log, t, module) 194 | ) 195 | tensor_entry.modules_exited.append(module_address) 196 | tensor_entry.module_passes_exited.append( 197 | (module_address, module_pass_num) 198 | ) 199 | tensor_entry.module_entry_exit_thread_output.append( 200 | ("-", module_entry_label[0], module_entry_label[1]) 201 | ) 202 | module.tl_tensors_exited_labels.append(t.tl_tensor_label_raw) 203 | 204 | for ( 205 | t 206 | ) in ( 207 | input_tensors 208 | ): # Now that module is finished, roll back the threads of all input tensors. 209 | tensor_entry = model_log._raw_tensor_dict[t.tl_tensor_label_raw] 210 | input_module_thread = tensor_entry.module_entry_exit_thread_output[:] 211 | if ( 212 | "+", 213 | module_entry_label[0], 214 | module_entry_label[1], 215 | ) in input_module_thread: 216 | module_entry_ix = input_module_thread.index( 217 | ("+", module_entry_label[0], module_entry_label[1]) 218 | ) 219 | tensor_entry.module_entry_exit_thread_output = ( 220 | tensor_entry.module_entry_exit_thread_output[:module_entry_ix] 221 | ) 222 | 223 | return out 224 | 225 | return decorated_forward 226 | 227 | 228 | def log_whether_exited_submodule_is_bottom_level( 229 | model_log, t: torch.Tensor, submodule: nn.Module 230 | ): 231 | """Checks whether the submodule that a tensor is leaving is a "bottom-level" submodule; 232 | that is, that only one tensor operation happened inside the submodule. 233 | 234 | Args: 235 | t: the tensor leaving the module 236 | submodule: the module that the tensor is leaving 237 | 238 | Returns: 239 | Whether the tensor operation is bottom level. 240 | """ 241 | tensor_entry = model_log._raw_tensor_dict[getattr(t, "tl_tensor_label_raw")] 242 | submodule_address = submodule.tl_module_address 243 | 244 | if tensor_entry.is_bottom_level_submodule_output: 245 | return True 246 | 247 | # If it was initialized inside the model and nothing entered the module, it's bottom-level. 248 | if ( 249 | tensor_entry.initialized_inside_model 250 | and len(submodule.tl_tensors_entered_labels) == 0 251 | ): 252 | tensor_entry.is_bottom_level_submodule_output = True 253 | tensor_entry.bottom_level_submodule_pass_exited = ( 254 | submodule_address, 255 | submodule.tl_module_pass_num, 256 | ) 257 | return True 258 | 259 | # Else, all parents must have entered the submodule for it to be a bottom-level submodule. 260 | for parent_label in tensor_entry.parent_layers: 261 | parent_tensor = model_log[parent_label] 262 | parent_modules_entered = parent_tensor.modules_entered 263 | if (len(parent_modules_entered) == 0) or ( 264 | parent_modules_entered[-1] != submodule_address 265 | ): 266 | tensor_entry.is_bottom_level_submodule_output = False 267 | return False 268 | 269 | # If it survived the above tests, it's a bottom-level submodule. 270 | tensor_entry.is_bottom_level_submodule_output = True 271 | tensor_entry.bottom_level_submodule_pass_exited = ( 272 | submodule_address, 273 | submodule.tl_module_pass_num, 274 | ) 275 | return True 276 | 277 | 278 | def get_all_submodules( 279 | model: nn.Module, is_top_level_model: bool = True 280 | ) -> List[nn.Module]: 281 | """Recursively gets list of all submodules for given module, no matter their level in the 282 | hierarchy; this includes the model itself. 283 | 284 | Args: 285 | model: PyTorch model. 286 | is_top_level_model: Whether it's the top-level model; just for the recursive logic of it. 287 | 288 | Returns: 289 | List of all submodules. 290 | """ 291 | submodules = [] 292 | if is_top_level_model: 293 | submodules.append(model) 294 | for module in model.children(): 295 | if module not in submodules: 296 | submodules.append(module) 297 | submodules += get_all_submodules(module, is_top_level_model=False) 298 | return submodules 299 | 300 | 301 | def cleanup_model( 302 | model_log, 303 | model: nn.Module, 304 | module_orig_forward_funcs: Dict[nn.Module, Callable], 305 | decorated_func_mapper: Dict[Callable, Callable], 306 | ): 307 | """Reverses all temporary changes to the model (namely, the forward hooks and added 308 | model attributes) that were added for PyTorch x-ray (scout's honor; leave no trace). 309 | 310 | Args: 311 | model: PyTorch model. 312 | module_orig_forward_funcs: Dict containing the original, undecorated forward pass functions for 313 | each submodule 314 | decorated_func_mapper: Dict mapping between original and decorated PyTorch funcs 315 | 316 | Returns: 317 | Original version of the model. 318 | """ 319 | submodules = get_all_submodules(model, is_top_level_model=True) 320 | for submodule in submodules: 321 | if submodule == model: 322 | continue 323 | submodule.forward = module_orig_forward_funcs[submodule] 324 | restore_model_attributes( 325 | model, decorated_func_mapper=decorated_func_mapper, attribute_keyword="tl" 326 | ) 327 | undecorate_model_tensors(model) 328 | 329 | 330 | def clear_hooks(hook_handles: List): 331 | """Takes in a list of tuples (module, hook_handle), and clears the hook at that 332 | handle for each module. 333 | 334 | Args: 335 | hook_handles: List of tuples (module, hook_handle) 336 | 337 | Returns: 338 | Nothing. 339 | """ 340 | for hook_handle in hook_handles: 341 | hook_handle.remove() 342 | 343 | 344 | def restore_module_attributes( 345 | module: nn.Module, 346 | decorated_func_mapper: Dict[Callable, Callable], 347 | attribute_keyword: str = "tl", 348 | ): 349 | def del_attrs_with_prefix(module, attribute_name): 350 | if attribute_name.startswith(attribute_keyword): 351 | delattr(module, attribute_name) 352 | return True 353 | 354 | for attribute_name, attr in iter_accessible_attributes(module, short_circuit=del_attrs_with_prefix): 355 | if ( 356 | isinstance(attr, Callable) 357 | and (attr in decorated_func_mapper) 358 | and (attribute_name[0:2] != "__") 359 | ): 360 | with warnings.catch_warnings(): 361 | warnings.simplefilter("ignore") 362 | setattr(module, attribute_name, decorated_func_mapper[attr]) 363 | 364 | 365 | def restore_model_attributes( 366 | model: nn.Module, 367 | decorated_func_mapper: Dict[Callable, Callable], 368 | attribute_keyword: str = "tl", 369 | ): 370 | """Recursively clears the given attribute from all modules in the model. 371 | 372 | Args: 373 | model: PyTorch model. 374 | decorated_func_mapper: Dict mapping between original and decorated PyTorch funcs 375 | attribute_keyword: Any attribute with this keyword will be cleared. 376 | 377 | Returns: 378 | Nothing. 379 | """ 380 | for module in get_all_submodules(model): 381 | restore_module_attributes( 382 | module, 383 | decorated_func_mapper=decorated_func_mapper, 384 | attribute_keyword=attribute_keyword, 385 | ) 386 | 387 | for param in model.parameters(): 388 | if hasattr(param, "tl_requires_grad"): 389 | param.requires_grad = getattr(param, "tl_requires_grad") 390 | delattr(param, "tl_requires_grad") 391 | 392 | 393 | def undecorate_model_tensors(model: nn.Module): 394 | """Goes through a model and all its submodules, and unmutates any tensor attributes. Normally just clearing 395 | parameters would have done this, but some module types (e.g., batchnorm) contain attributes that are tensors, 396 | but not parameters. 397 | 398 | Args: 399 | model: PyTorch model 400 | 401 | Returns: 402 | PyTorch model with unmutated versions of all tensor attributes. 403 | """ 404 | submodules = get_all_submodules(model) 405 | for submodule in submodules: 406 | for attribute_name, attribute in iter_accessible_attributes(submodule): 407 | if issubclass(type(attribute), torch.Tensor): 408 | if not issubclass(type(attribute), torch.nn.Parameter) and hasattr( 409 | attribute, "tl_tensor_label_raw" 410 | ): 411 | delattr(attribute, "tl_tensor_label_raw") 412 | if hasattr(attribute, 'tl_buffer_address'): 413 | delattr(attribute, "tl_buffer_address") 414 | if hasattr(attribute, 'tl_buffer_parent'): 415 | delattr(attribute, "tl_buffer_parent") 416 | else: 417 | remove_attributes_starting_with_str(attribute, "tl_") 418 | elif type(attribute) in [list, tuple, set]: 419 | for item in attribute: 420 | if issubclass(type(item), torch.Tensor) and hasattr( 421 | item, "tl_tensor_label_raw" 422 | ): 423 | delattr(item, "tl_tensor_label_raw") 424 | if hasattr(item, 'tl_buffer_address'): 425 | delattr(item, "tl_buffer_address") 426 | if hasattr(item, 'tl_buffer_parent'): 427 | delattr(item, "tl_buffer_parent") 428 | elif type(attribute) == dict: 429 | for key, val in attribute.items(): 430 | if issubclass(type(val), torch.Tensor) and hasattr( 431 | val, "tl_tensor_label_raw" 432 | ): 433 | delattr(val, "tl_tensor_label_raw") 434 | if hasattr(val, 'tl_buffer_address'): 435 | delattr(val, "tl_buffer_address") 436 | if hasattr(val, 'tl_buffer_parent'): 437 | delattr(val, "tl_buffer_parent") 438 | -------------------------------------------------------------------------------- /torchlens/user_funcs.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 5 | 6 | import pandas as pd 7 | import torch 8 | from torch import nn 9 | from tqdm import tqdm 10 | 11 | from .helper_funcs import (get_vars_of_type_from_obj, set_random_seed, warn_parallel) 12 | from .model_history import ( 13 | ModelHistory, 14 | ) 15 | 16 | 17 | def run_model_and_save_specified_activations( 18 | model: nn.Module, 19 | input_args: Union[torch.Tensor, List[Any]], 20 | input_kwargs: Dict[Any, Any], 21 | layers_to_save: Optional[Union[str, List[Union[int, str]]]] = "all", 22 | keep_unsaved_layers: bool = True, 23 | output_device: str = "same", 24 | activation_postfunc: Optional[Callable] = None, 25 | mark_input_output_distances: bool = False, 26 | detach_saved_tensors: bool = False, 27 | save_function_args: bool = False, 28 | save_gradients: bool = False, 29 | random_seed: Optional[int] = None, 30 | ) -> ModelHistory: 31 | """Internal function that runs the given input through the given model, and saves the 32 | specified activations, as given by the tensor numbers (these will not be visible to the user; 33 | they will be generated from the nicer human-readable names and then fed in). 34 | 35 | Args: 36 | model: PyTorch model. 37 | input_args: Input arguments to the model's forward pass: either a single tensor, or a list of arguments. 38 | input_kwargs: Keyword arguments to the model's forward pass. 39 | layers_to_save: List of layers to save 40 | keep_unsaved_layers: Whether to keep layers in the ModelHistory log if they don't have saved activations. 41 | output_device: device where saved tensors will be stored: either 'same' to keep unchanged, or 42 | 'cpu' or 'cuda' to move to cpu or cuda. 43 | activation_postfunc: Function to apply to activations before saving them (e.g., any averaging) 44 | mark_input_output_distances: Whether to compute the distance of each layer from the input or output. 45 | This is computationally expensive for large networks, so it is off by default. 46 | detach_saved_tensors: whether to detach the saved tensors, so they remain attached to the computational graph 47 | save_function_args: whether to save the arguments to each function 48 | save_gradients: whether to save gradients from any subsequent backward pass 49 | random_seed: Which random seed to use. 50 | 51 | Returns: 52 | ModelHistory object with full log of the forward pass 53 | """ 54 | model_name = str(type(model).__name__) 55 | model_history = ModelHistory( 56 | model_name, 57 | output_device, 58 | activation_postfunc, 59 | keep_unsaved_layers, 60 | save_function_args, 61 | save_gradients, 62 | detach_saved_tensors, 63 | mark_input_output_distances, 64 | ) 65 | model_history._run_and_log_inputs_through_model( 66 | model, input_args, input_kwargs, layers_to_save, random_seed 67 | ) 68 | return model_history 69 | 70 | 71 | def log_forward_pass( 72 | model: nn.Module, 73 | input_args: Union[torch.Tensor, List[Any], Tuple[Any]], 74 | input_kwargs: Dict[Any, Any] = None, 75 | layers_to_save: Optional[Union[str, List]] = "all", 76 | keep_unsaved_layers: bool = True, 77 | output_device: str = "same", 78 | activation_postfunc: Optional[Callable] = None, 79 | mark_input_output_distances: bool = False, 80 | detach_saved_tensors: bool = False, 81 | save_function_args: bool = False, 82 | save_gradients: bool = False, 83 | vis_opt: str = "none", 84 | vis_nesting_depth: int = 1000, 85 | vis_outpath: str = "graph.gv", 86 | vis_save_only: bool = False, 87 | vis_fileformat: str = "pdf", 88 | vis_buffer_layers: bool = False, 89 | vis_direction: str = "bottomup", 90 | vis_graph_overrides: Dict = None, 91 | vis_node_overrides: Dict = None, 92 | vis_nested_node_overrides: Dict = None, 93 | vis_edge_overrides: Dict = None, 94 | vis_gradient_edge_overrides: Dict = None, 95 | vis_module_overrides: Dict = None, 96 | random_seed: Optional[int] = None, 97 | ) -> ModelHistory: 98 | """Runs a forward pass through a model given input x, and returns a ModelHistory object containing a log 99 | (layer activations and accompanying layer metadata) of the forward pass for all layers specified in which_layers, 100 | and optionally visualizes the model graph if vis_opt is set to 'rolled' or 'unrolled'. 101 | 102 | In which_layers, can specify 'all', for all layers (default), or a list containing any combination of: 103 | 1) desired layer names (e.g., 'conv2d_1_1'; if a layer has multiple passes, this includes all passes), 104 | 2) a layer pass (e.g., conv2d_1_1:2 for just the second pass), 3) a module name to fetch the output of a particular 105 | module, 4) the ordinal index of a layer in the model (e.g. 3 for the third layer, -2 for the second to last, etc.), 106 | or 5) a desired substring with which to filter desired layers (e.g., 'conv2d' for all conv2d layers). 107 | 108 | Args: 109 | model: PyTorch model 110 | input_args: input arguments for model forward pass; as a list if multiple, else as a single tensor. 111 | input_kwargs: keyword arguments for model forward pass 112 | layers_to_save: list of layers to include (described above), or 'all' to include all layers. 113 | keep_unsaved_layers: whether to keep layers without saved activations in the log (i.e., with just metadata) 114 | activation_postfunc: Function to apply to tensors before saving them (e.g., channelwise averaging). 115 | mark_input_output_distances: whether to mark the distance of each layer from the input or output; 116 | False by default since this is computationally expensive. 117 | output_device: device where saved tensors are to be stored. Either 'same' to keep on the same device, 118 | or 'cpu' or 'cuda' to move them to cpu or cuda when saved. 119 | detach_saved_tensors: whether to detach the saved tensors, so they remain attached to the computational graph 120 | save_function_args: whether to save the arguments to each function involved in computing each tensor 121 | save_gradients: whether to save gradients from any subsequent backward pass 122 | vis_opt: whether, and how, to visualize the network; 'none' for 123 | no visualization, 'rolled' to show the graph in rolled-up format (i.e., 124 | one node per layer if a recurrent network), or 'unrolled' to show the graph 125 | in unrolled format (i.e., one node per pass through a layer if a recurrent) 126 | vis_nesting_depth: How many levels of nested modules to show; 1 for only top-level modules, 2 for two 127 | levels, etc. 128 | vis_outpath: file path to save the graph visualization 129 | vis_save_only: whether to only save the graph visual without immediately showing it 130 | vis_fileformat: the format of the visualization (e.g,. 'pdf', 'jpg', etc.) 131 | vis_buffer_layers: whether to visualize the buffer layers 132 | vis_direction: either 'bottomup', 'topdown', or 'leftright' 133 | random_seed: which random seed to use in case model involves randomness 134 | 135 | Returns: 136 | ModelHistory object with layer activations and metadata 137 | """ 138 | warn_parallel() 139 | 140 | if vis_opt not in ["none", "rolled", "unrolled"]: 141 | raise ValueError( 142 | "Visualization option must be either 'none', 'rolled', or 'unrolled'." 143 | ) 144 | 145 | if output_device not in ["same", "cpu", "cuda"]: 146 | raise ValueError("output_device must be either 'same', 'cpu', or 'cuda'.") 147 | 148 | if type(layers_to_save) is str: 149 | layers_to_save = layers_to_save.lower() 150 | 151 | if layers_to_save in ["all", "none", None, []]: 152 | model_history = run_model_and_save_specified_activations( 153 | model=model, 154 | input_args=input_args, 155 | input_kwargs=input_kwargs, 156 | layers_to_save=layers_to_save, 157 | keep_unsaved_layers=keep_unsaved_layers, 158 | output_device=output_device, 159 | activation_postfunc=activation_postfunc, 160 | mark_input_output_distances=mark_input_output_distances, 161 | detach_saved_tensors=detach_saved_tensors, 162 | save_function_args=save_function_args, 163 | save_gradients=save_gradients, 164 | random_seed=random_seed, 165 | ) 166 | else: 167 | model_history = run_model_and_save_specified_activations( 168 | model=model, 169 | input_args=input_args, 170 | input_kwargs=input_kwargs, 171 | layers_to_save=None, 172 | keep_unsaved_layers=True, 173 | output_device=output_device, 174 | activation_postfunc=activation_postfunc, 175 | mark_input_output_distances=mark_input_output_distances, 176 | detach_saved_tensors=detach_saved_tensors, 177 | save_function_args=save_function_args, 178 | save_gradients=save_gradients, 179 | random_seed=random_seed, 180 | ) 181 | model_history.keep_unsaved_layers = keep_unsaved_layers 182 | model_history.save_new_activations( 183 | model=model, 184 | input_args=input_args, 185 | input_kwargs=input_kwargs, 186 | layers_to_save=layers_to_save, 187 | random_seed=random_seed, 188 | ) 189 | 190 | # Visualize if desired. 191 | if vis_opt != "none": 192 | model_history.render_graph( 193 | vis_opt, 194 | vis_nesting_depth, 195 | vis_outpath, 196 | vis_graph_overrides, 197 | vis_node_overrides, 198 | vis_nested_node_overrides, 199 | vis_edge_overrides, 200 | vis_gradient_edge_overrides, 201 | vis_module_overrides, 202 | vis_save_only, 203 | vis_fileformat, 204 | vis_buffer_layers, 205 | vis_direction, 206 | ) 207 | 208 | return model_history 209 | 210 | 211 | def get_model_metadata( 212 | model: nn.Module, 213 | input_args: Union[torch.Tensor, List[Any], Tuple[Any]], 214 | input_kwargs: Dict[Any, Any] = None, 215 | ) -> ModelHistory: 216 | """Logs all metadata for a given model and inputs without saving any activations. NOTE: this function 217 | will be removed in a future version of TorchLens, since calling it is identical to calling 218 | log_forward_pass without saving any layers. 219 | 220 | Args: 221 | model: model to inspect 222 | input_args: list of input positional arguments, or a single tensor 223 | input_kwargs: dict of keyword arguments 224 | Returns: 225 | ModelHistory object with metadata about the model. 226 | """ 227 | model_history = log_forward_pass( 228 | model, 229 | input_args, 230 | input_kwargs, 231 | layers_to_save=None, 232 | mark_input_output_distances=True, 233 | ) 234 | return model_history 235 | 236 | 237 | def show_model_graph( 238 | model: nn.Module, 239 | input_args: Union[torch.Tensor, List, Tuple], 240 | input_kwargs: Dict[Any, Any] = None, 241 | vis_opt: str = "unrolled", 242 | vis_nesting_depth: int = 1000, 243 | vis_outpath: str = "graph.gv", 244 | vis_graph_overrides: Dict = None, 245 | vis_node_overrides: Dict = None, 246 | vis_nested_node_overrides: Dict = None, 247 | vis_edge_overrides: Dict = None, 248 | vis_gradient_edge_overrides: Dict = None, 249 | vis_module_overrides: Dict = None, 250 | save_only: bool = False, 251 | vis_fileformat: str = "pdf", 252 | vis_buffer_layers: bool = False, 253 | vis_direction: str = "bottomup", 254 | random_seed: Optional[int] = None, 255 | ) -> None: 256 | """Visualize the model graph without saving any activations. 257 | 258 | Args: 259 | model: PyTorch model 260 | input_args: Arguments for model forward pass 261 | input_kwargs: Keyword arguments for model forward pass 262 | vis_opt: whether, and how, to visualize the network; 'none' for 263 | no visualization, 'rolled' to show the graph in rolled-up format (i.e., 264 | one node per layer if a recurrent network), or 'unrolled' to show the graph 265 | in unrolled format (i.e., one node per pass through a layer if a recurrent) 266 | vis_nesting_depth: How many levels of nested modules to show; 1 for only top-level modules, 2 for two 267 | levels, etc. 268 | vis_outpath: file path to save the graph visualization 269 | save_only: whether to only save the graph visual without immediately showing it 270 | vis_fileformat: the format of the visualization (e.g,. 'pdf', 'jpg', etc.) 271 | vis_buffer_layers: whether to visualize the buffer layers 272 | vis_direction: either 'bottomup', 'topdown', or 'leftright' 273 | random_seed: which random seed to use in case model involves randomness 274 | 275 | Returns: 276 | Nothing. 277 | """ 278 | if not input_kwargs: 279 | input_kwargs = {} 280 | 281 | if vis_opt not in ["none", "rolled", "unrolled"]: 282 | raise ValueError( 283 | "Visualization option must be either 'none', 'rolled', or 'unrolled'." 284 | ) 285 | 286 | model_history = run_model_and_save_specified_activations( 287 | model=model, 288 | input_args=input_args, 289 | input_kwargs=input_kwargs, 290 | layers_to_save=None, 291 | activation_postfunc=None, 292 | mark_input_output_distances=False, 293 | detach_saved_tensors=False, 294 | save_gradients=False, 295 | random_seed=random_seed, 296 | ) 297 | model_history.render_graph( 298 | vis_opt, 299 | vis_nesting_depth, 300 | vis_outpath, 301 | vis_graph_overrides, 302 | vis_node_overrides, 303 | vis_nested_node_overrides, 304 | vis_edge_overrides, 305 | vis_gradient_edge_overrides, 306 | vis_module_overrides, 307 | save_only, 308 | vis_fileformat, 309 | vis_buffer_layers, 310 | vis_direction, 311 | ) 312 | 313 | 314 | def validate_saved_activations( 315 | model: nn.Module, 316 | input_args: Union[torch.Tensor, List[Any], Tuple[Any]], 317 | input_kwargs: Dict[Any, Any] = None, 318 | random_seed: Union[int, None] = None, 319 | verbose: bool = False, 320 | ) -> bool: 321 | """Validate that the saved model activations correctly reproduce the ground truth output. This function works by 322 | running a forward pass through the model, saving all activations, re-running the forward pass starting from 323 | the saved activations in each layer, and checking that the resulting output matches the original output. 324 | Additionally, it substitutes in random activations and checks whether the output changes accordingly, for 325 | at least min_proportion_consequential_layers of the layers (in case some layers do not change the output for some 326 | reason). Returns True if a model passes these tests for the given input, and False otherwise. 327 | 328 | Args: 329 | model: PyTorch model. 330 | input_args: Input for which to validate the saved activations. 331 | input_kwargs: Keyword arguments for model forward pass 332 | random_seed: random seed in case model is stochastic 333 | verbose: whether to show verbose error messages 334 | Returns: 335 | True if the saved activations correctly reproduce the ground truth output, false otherwise. 336 | """ 337 | warn_parallel() 338 | if random_seed is None: # set random seed 339 | random_seed = random.randint(1, 4294967294) 340 | set_random_seed(random_seed) 341 | if type(input_args) is tuple: 342 | input_args = list(input_args) 343 | elif (type(input_args) not in [list, tuple]) and (input_args is not None): 344 | input_args = [input_args] 345 | if not input_args: 346 | input_args = [] 347 | if not input_kwargs: 348 | input_kwargs = {} 349 | input_args_copy = [copy.deepcopy(arg) for arg in input_args] 350 | input_kwargs_copy = {key: copy.deepcopy(val) for key, val in input_kwargs.items()} 351 | state_dict = model.state_dict() 352 | ground_truth_output_tensors = get_vars_of_type_from_obj( 353 | model(*input_args_copy, **input_kwargs_copy), 354 | torch.Tensor, 355 | search_depth=5, 356 | allow_repeats=True, 357 | ) 358 | model.load_state_dict(state_dict) 359 | model_history = run_model_and_save_specified_activations( 360 | model=model, 361 | input_args=input_args, 362 | input_kwargs=input_kwargs, 363 | layers_to_save="all", 364 | keep_unsaved_layers=True, 365 | activation_postfunc=None, 366 | mark_input_output_distances=False, 367 | detach_saved_tensors=False, 368 | save_gradients=False, 369 | save_function_args=True, 370 | random_seed=random_seed, 371 | ) 372 | activations_are_valid = model_history.validate_saved_activations( 373 | ground_truth_output_tensors, verbose 374 | ) 375 | 376 | model_history.cleanup() 377 | del model_history 378 | return activations_are_valid 379 | 380 | 381 | def validate_batch_of_models_and_inputs( 382 | models_and_inputs_dict: Dict[str, Dict[str, Union[str, Callable, Dict]]], 383 | out_path: str, 384 | redo_model_if_already_run: bool = True, 385 | ) -> pd.DataFrame: 386 | """Given multiple models and several inputs for each, validates the saved activations for all of them 387 | and returns a Pandas dataframe summarizing the validation results. 388 | 389 | Args: 390 | models_and_inputs_dict: Dict mapping each model name to a dict of model info: 391 | model_category: category of model (e.g., torchvision; just for directory bookkeeping) 392 | model_loading_func: function to load the model 393 | model_sample_inputs: dict of example inputs {name: input} 394 | out_path: Path to save the validation results to 395 | redo_model_if_already_run: If True, will re-run the validation for a model even if already run 396 | 397 | Returns: 398 | Pandas dataframe with validation information for each model and input 399 | """ 400 | if os.path.exists(out_path): 401 | current_csv = pd.read_csv(out_path) 402 | else: 403 | current_csv = pd.DataFrame.from_dict( 404 | { 405 | "model_category": [], 406 | "model_name": [], 407 | "input_name": [], 408 | "validation_success": [], 409 | } 410 | ) 411 | models_already_run = current_csv["model_name"].unique() 412 | for model_name, model_info in tqdm( 413 | models_and_inputs_dict.items(), desc="Validating models" 414 | ): 415 | print(f"Validating model {model_name}") 416 | if model_name in models_already_run and not redo_model_if_already_run: 417 | continue 418 | model_category = model_info["model_category"] 419 | model_loading_func = model_info["model_loading_func"] 420 | model = model_loading_func() 421 | model_sample_inputs = model_info["model_sample_inputs"] 422 | for input_name, x in model_sample_inputs.items(): 423 | validation_success = validate_saved_activations(model, x) 424 | current_csv = current_csv.append( 425 | { 426 | "model_category": model_category, 427 | "model_name": model_name, 428 | "input_name": input_name, 429 | "validation_success": validation_success, 430 | }, 431 | ignore_index=True, 432 | ) 433 | current_csv.to_csv(out_path, index=False) 434 | del model 435 | return current_csv 436 | -------------------------------------------------------------------------------- /torchlens/helper_funcs.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import copy 3 | import inspect 4 | import multiprocessing as mp 5 | import random 6 | import secrets 7 | import string 8 | import warnings 9 | from sys import getsizeof 10 | from typing import Any, Callable, Dict, List, Optional, Type 11 | 12 | import numpy as np 13 | import torch 14 | from IPython import get_ipython 15 | from torch import nn 16 | 17 | MAX_FLOATING_POINT_TOLERANCE = 3e-6 18 | 19 | 20 | def identity(x): 21 | return x 22 | 23 | 24 | def set_random_seed(seed: int): 25 | """Sets the random seed for all random number generators. 26 | 27 | Args: 28 | seed: Seed to set. 29 | 30 | Returns: 31 | Nothing. 32 | """ 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed_all(seed) 37 | 38 | 39 | def log_current_rng_states() -> Dict: 40 | """Utility function to fetch sufficient information from all RNG states to recover the same state later. 41 | 42 | Returns: 43 | Dict with sufficient information to recover all RNG states. 44 | """ 45 | rng_dict = { 46 | "random": random.getstate(), 47 | "np": np.random.get_state(), 48 | "torch": torch.random.get_rng_state(), 49 | } 50 | if torch.cuda.is_available(): 51 | rng_dict["torch_cuda"] = torch.cuda.get_rng_state("cuda") 52 | return rng_dict 53 | 54 | 55 | def set_rng_from_saved_states(rng_states: Dict): 56 | """Utility function to set the state of random seeds to a cached value. 57 | 58 | Args: 59 | rng_states: Dict of rng_states saved by get_random_seed_states 60 | 61 | Returns: 62 | Nothing, but correctly sets all random seed states. 63 | """ 64 | random.setstate(rng_states["random"]) 65 | np.random.set_state(rng_states["np"]) 66 | torch.random.set_rng_state(rng_states["torch"]) 67 | if torch.cuda.is_available(): 68 | torch.cuda.set_rng_state(rng_states["torch_cuda"], "cuda") 69 | 70 | 71 | def make_random_barcode(barcode_len: int = 8) -> str: 72 | """Generates a random integer hash for a layer to use as internal label (invisible from user side). 73 | 74 | Args: 75 | barcode_len: Length of the desired barcode 76 | 77 | Returns: 78 | Random hash. 79 | """ 80 | alphabet = string.ascii_letters + string.digits 81 | barcode = "".join(secrets.choice(alphabet) for _ in range(barcode_len)) 82 | return barcode 83 | 84 | 85 | def make_short_barcode_from_input( 86 | things_to_hash: List[Any], barcode_len: int = 16 87 | ) -> str: 88 | """Utility function that takes a list of anything and returns a short hash of it. 89 | 90 | Args: 91 | things_to_hash: List of things to hash; they must all be convertible to a string. 92 | barcode_len: 93 | 94 | Returns: 95 | Short hash of the input. 96 | """ 97 | barcode = "".join([str(x) for x in things_to_hash]) 98 | barcode = str(hash(barcode)) 99 | barcode = barcode.encode("utf-8") 100 | barcode = base64.urlsafe_b64encode(barcode) 101 | barcode = barcode.decode("utf-8") 102 | barcode = barcode[0:barcode_len] 103 | return barcode 104 | 105 | 106 | def _get_call_stack_dicts(): 107 | call_stack = inspect.stack() 108 | call_stack = [ 109 | inspect.getframeinfo(call_stack[i][0], context=19) 110 | for i in range(len(call_stack)) 111 | ] 112 | call_stack_dicts = [ 113 | { 114 | "call_fname": caller.filename, 115 | "call_linenum": caller.lineno, 116 | "function": caller.function, 117 | "code_context": caller.code_context, 118 | } 119 | for caller in call_stack 120 | ] 121 | 122 | for call_stack_dict in call_stack_dicts: 123 | if is_iterable(call_stack_dict['code_context']): 124 | call_stack_dict['code_context_str'] = ''.join(call_stack_dict['code_context']) 125 | else: 126 | call_stack_dict['code_context_str'] = str(call_stack_dict['code_context']) 127 | 128 | # Only start at the level of that first forward pass, going from shallow to deep. 129 | tracking = False 130 | filtered_dicts = [] 131 | for d in range(len(call_stack_dicts) - 1, -1, -1): 132 | call_stack_dict = call_stack_dicts[d] 133 | if any( 134 | [ 135 | call_stack_dict["call_fname"].endswith("model_history.py"), 136 | call_stack_dict["call_fname"].endswith("torchlens/helper_funcs.py"), 137 | call_stack_dict["call_fname"].endswith("torchlens/user_funcs.py"), 138 | call_stack_dict["call_fname"].endswith("torchlens/trace_model.py"), 139 | call_stack_dict["call_fname"].endswith("torchlens/logging_funcs.py"), 140 | call_stack_dict["call_fname"].endswith("torchlens/decorate_torch.py"), 141 | call_stack_dict["call_fname"].endswith("torchlens/model_funcs.py"), 142 | "_call_impl" in call_stack_dict["function"], 143 | ] 144 | ): 145 | continue 146 | if call_stack_dict["function"] == "forward": 147 | tracking = True 148 | 149 | if tracking: 150 | filtered_dicts.append(call_stack_dict) 151 | 152 | return filtered_dicts 153 | 154 | 155 | def is_iterable(obj: Any) -> bool: 156 | """Checks if an object is iterable. 157 | 158 | Args: 159 | obj: Object to check. 160 | 161 | Returns: 162 | True if object is iterable, False otherwise. 163 | """ 164 | try: 165 | iter(obj) 166 | return True 167 | except TypeError: 168 | return False 169 | 170 | 171 | def make_var_iterable(x): 172 | """Utility function to facilitate dealing with outputs: 173 | - If not a list, tuple, or dict, make it a list of length 1 174 | - If a dict, make it a list of the values 175 | - If a list or tuple, keep it. 176 | 177 | Args: 178 | x: Output of the function 179 | 180 | Returns: 181 | Iterable output 182 | """ 183 | if any([issubclass(type(x), cls) for cls in [list, tuple, set]]): 184 | return x 185 | elif issubclass(type(x), dict): 186 | return list(x.values()) 187 | else: 188 | return [x] 189 | 190 | 191 | def index_nested(x: Any, indices: List[int]) -> Any: 192 | """Utility function to index into a nested list or tuple. 193 | 194 | Args: 195 | x: Nested list or tuple. 196 | indices: List of indices to use. 197 | 198 | Returns: 199 | Indexed object. 200 | """ 201 | indices = make_var_iterable(indices) 202 | for i in indices: 203 | x = x[i] 204 | return x 205 | 206 | 207 | def remove_entry_from_list(list_: List, entry: Any): 208 | """Removes all instances of an entry from a list if present, in-place. 209 | 210 | Args: 211 | list_: the list 212 | entry: the entry to remove 213 | """ 214 | while entry in list_: 215 | list_.remove(entry) 216 | 217 | 218 | def tuple_tolerant_assign(obj_: Any, ind: int, new_value: any): 219 | """Utility function to assign an entry of a list, tuple, or dict to a new value. 220 | 221 | Args: 222 | obj_: Tuple to change. 223 | ind: Index to change. 224 | new_value: The new value. 225 | 226 | Returns: 227 | Tuple with the new value swapped out. 228 | """ 229 | if type(obj_) == tuple: 230 | list_ = list(obj_) 231 | list_[ind] = new_value 232 | return tuple(list_) 233 | 234 | obj_[ind] = new_value 235 | return obj_ 236 | 237 | 238 | def int_list_to_compact_str(int_list: List[int]) -> str: 239 | """Given a list of integers, returns a compact string representation of the list, where 240 | contiguous stretches of the integers are represented as ranges (e.g., [1 2 3 4] becomes "1-4"), 241 | and all such ranges are separated by commas. 242 | 243 | Args: 244 | int_list: List of integers. 245 | 246 | Returns: 247 | Compact string representation of the list. 248 | """ 249 | int_list = sorted(int_list) 250 | if len(int_list) == 0: 251 | return "" 252 | if len(int_list) == 1: 253 | return str(int_list[0]) 254 | ranges = [] 255 | start = int_list[0] 256 | end = int_list[0] 257 | for i in range(1, len(int_list)): 258 | if int_list[i] == end + 1: 259 | end = int_list[i] 260 | else: 261 | if start == end: 262 | ranges.append(str(start)) 263 | else: 264 | ranges.append(f"{start}-{end}") 265 | start = int_list[i] 266 | end = int_list[i] 267 | if start == end: 268 | ranges.append(str(start)) 269 | else: 270 | ranges.append(f"{start}-{end}") 271 | return ",".join(ranges) 272 | 273 | 274 | def get_vars_of_type_from_obj( 275 | obj: Any, 276 | which_type: Type, 277 | subclass_exceptions: Optional[List] = None, 278 | search_depth: int = 3, 279 | return_addresses=False, 280 | allow_repeats=False, 281 | ) -> List: 282 | """Recursively finds all tensors in an object, excluding specified subclasses (e.g., parameters) 283 | up to the given search depth. 284 | 285 | Args: 286 | obj: Object to search. 287 | which_type: Type of variable to pull out 288 | subclass_exceptions: subclasses that you don't want to pull out. 289 | search_depth: How many layers deep to search before giving up. 290 | return_addresses: if True, then returns list of tuples (object, address), where the 291 | address is how you'd index to get the object 292 | allow_repeats: whether to allow repeats of the same tensor 293 | 294 | Returns: 295 | List of objects of desired type found in the input object. 296 | """ 297 | if subclass_exceptions is None: 298 | subclass_exceptions = [] 299 | this_stack = [(obj, "", [])] 300 | tensors_in_obj = [] 301 | tensor_addresses = [] 302 | tensor_addresses_full = [] 303 | tensor_ids_in_obj = [] 304 | for _ in range(search_depth): 305 | this_stack = search_stack_for_vars_of_type( 306 | this_stack, 307 | which_type, 308 | tensors_in_obj, 309 | tensor_addresses, 310 | tensor_addresses_full, 311 | tensor_ids_in_obj, 312 | subclass_exceptions, 313 | allow_repeats, 314 | ) 315 | 316 | if return_addresses: 317 | return list(zip(tensors_in_obj, tensor_addresses, tensor_addresses_full)) 318 | else: 319 | return tensors_in_obj 320 | 321 | 322 | def search_stack_for_vars_of_type( 323 | current_stack: List, 324 | which_type: Type, 325 | tensors_in_obj: List, 326 | tensor_addresses: List, 327 | tensor_addresses_full: List, 328 | tensor_ids_in_obj: List, 329 | subclass_exceptions: List, 330 | allow_repeats: bool, 331 | ): 332 | """Helper function that searches current stack for vars of a given type, and 333 | returns the next stack to search. 334 | 335 | Args: 336 | current_stack: The current stack. 337 | which_type: Type of variable to pull out 338 | tensors_in_obj: List of tensors found in the object so far 339 | tensor_addresses: Addresses of the tensors found so far 340 | tensor_addresses_full: explicit instructions for indexing the obj 341 | tensor_ids_in_obj: List of tensor ids found in the object so far 342 | subclass_exceptions: Subclasses of tensors not to use. 343 | allow_repeats: whether to allow repeat tensors 344 | 345 | Returns: 346 | The next stack. 347 | """ 348 | next_stack = [] 349 | if len(current_stack) == 0: 350 | return current_stack 351 | while len(current_stack) > 0: 352 | item, address, address_full = current_stack.pop(0) 353 | item_class = type(item) 354 | if any( 355 | [issubclass(item_class, subclass) for subclass in subclass_exceptions] 356 | ) or ((id(item) in tensor_ids_in_obj) and not allow_repeats): 357 | continue 358 | if issubclass(item_class, which_type): 359 | tensors_in_obj.append(item) 360 | tensor_addresses.append(address) 361 | tensor_addresses_full.append(address_full) 362 | tensor_ids_in_obj.append(id(item)) 363 | continue 364 | if item_class in [str, int, float, bool, np.ndarray, torch.tensor]: 365 | continue 366 | extend_search_stack_from_item(item, address, address_full, next_stack) 367 | return next_stack 368 | 369 | 370 | def extend_search_stack_from_item( 371 | item: Any, address: str, address_full, next_stack: List 372 | ): 373 | """Utility function to iterate through a single item to populate the next stack to search for. 374 | 375 | Args: 376 | item: The item 377 | next_stack: Stack to add to 378 | """ 379 | if type(item) in [list, tuple, set]: 380 | if address == "": 381 | next_stack.extend( 382 | [(x, f"{i}", address_full + [("ind", i)]) for i, x in enumerate(item)] 383 | ) 384 | else: 385 | next_stack.extend( 386 | [ 387 | (x, f"{address}.{i}", address_full + [("ind", i)]) 388 | for i, x in enumerate(item) 389 | ] 390 | ) 391 | 392 | if issubclass(type(item), dict): 393 | if address == "": 394 | next_stack.extend( 395 | [(val, key, address_full + [("ind", key)]) for key, val in item.items()] 396 | ) 397 | else: 398 | next_stack.extend( 399 | [ 400 | (val, f"{address}.{key}", address_full + [("ind", key)]) 401 | for key, val in item.items() 402 | ] 403 | ) 404 | 405 | for attr_name in dir(item): 406 | if ((attr_name.startswith("__")) or 407 | (attr_name in ['T', 'mT', 'real', 'imag', 'H']) or 408 | ('grad' in attr_name)): 409 | continue 410 | try: 411 | with warnings.catch_warnings(): 412 | warnings.simplefilter("ignore") 413 | attr = getattr(item, attr_name) 414 | except: 415 | continue 416 | attr_cls = type(attr) 417 | if attr_cls in [str, int, float, bool, np.ndarray]: 418 | continue 419 | if callable(attr) and not issubclass(attr_cls, nn.Module): 420 | continue 421 | if address == "": 422 | next_stack.append( 423 | (attr, attr_name.strip("_"), address_full + [("attr", attr_name)]) 424 | ) 425 | else: 426 | next_stack.append( 427 | ( 428 | attr, 429 | f'{address}.{attr_name.strip("_")}', 430 | address_full + [("attr", attr_name)], 431 | ) 432 | ) 433 | 434 | 435 | def get_attr_values_from_tensor_list( 436 | tensor_list: List[torch.Tensor], field_name: str 437 | ) -> List[Any]: 438 | """For a list of tensors, gets the value of a given attribute from each tensor that has that attribute. 439 | 440 | Args: 441 | tensor_list: List of tensors to search. 442 | field_name: Name of the field to check in the tensor. 443 | 444 | Returns: 445 | List of marks from the tensors. 446 | """ 447 | marks = [] 448 | for tensor in tensor_list: 449 | mark = getattr(tensor, field_name, None) 450 | if mark is not None: 451 | marks.append(mark) 452 | return marks 453 | 454 | 455 | def nested_getattr(obj: Any, attr: str) -> Any: 456 | """Helper function that takes in an object, and a string of attributes separated by '.' and recursively 457 | returns the attribute. 458 | 459 | Args: 460 | obj: Any object, e.g. "torch" 461 | attr: String specifying the nested attribute, e.g. "nn.functional" 462 | 463 | Returns: 464 | The attribute specified by the string. 465 | """ 466 | if attr == "": 467 | return obj 468 | 469 | attributes = attr.split(".") 470 | for i, a in enumerate(attributes): 471 | if a in [ 472 | "volatile", 473 | "T", 474 | 'H', 475 | 'mH', 476 | 'mT' 477 | ]: # avoid annoying warning; if there's more, make a list 478 | with warnings.catch_warnings(): 479 | warnings.simplefilter("ignore") 480 | obj = getattr(obj, a) 481 | else: 482 | obj = getattr(obj, a) 483 | return obj 484 | 485 | 486 | def nested_assign(obj, addr, val): 487 | """Given object and an address in that object, assign value to that address.""" 488 | for i, (entry_type, entry_val) in enumerate(addr): 489 | if i == len(addr) - 1: 490 | if entry_type == "ind": 491 | obj[entry_val] = val 492 | elif entry_type == "attr": 493 | setattr(obj, entry_val, val) 494 | else: 495 | if entry_type == "ind": 496 | obj = obj[entry_val] 497 | elif entry_type == "attr": 498 | with warnings.catch_warnings(): 499 | warnings.simplefilter("ignore") 500 | obj = getattr(obj, entry_val) 501 | 502 | 503 | def iter_accessible_attributes(obj: Any, *, short_circuit: Optional[Callable[[Any, str], bool]] = None): 504 | for attr_name in dir(obj): 505 | if short_circuit and short_circuit(obj, attr_name): 506 | continue 507 | 508 | # Attribute access can fail for any number of reasons, especially when 509 | # working with objects that we don't know anything about. This 510 | # function makes a best-effort attempt to access every attribute, but 511 | # gracefully skips any that cause problems. 512 | 513 | with warnings.catch_warnings(): 514 | warnings.simplefilter("ignore") 515 | try: 516 | attr = getattr(obj, attr_name) 517 | except Exception: 518 | continue 519 | 520 | yield attr_name, attr 521 | 522 | 523 | def remove_attributes_starting_with_str(obj: Any, s: str): 524 | """Given an object removes, any attributes for that object beginning with a given 525 | substring. 526 | 527 | Args: 528 | obj: object 529 | s: string that marks fields to remove 530 | """ 531 | for field in dir(obj): 532 | if field.startswith(s): 533 | delattr(obj, field) 534 | 535 | 536 | def tensor_all_nan(t: torch.Tensor) -> bool: 537 | """Returns True if tensor is all nans, False otherwise.""" 538 | if torch.isnan(t).int().sum() == t.numel(): 539 | return True 540 | else: 541 | return False 542 | 543 | 544 | def tensor_nanequal(t1: torch.Tensor, 545 | t2: torch.Tensor, 546 | allow_tolerance=False) -> bool: 547 | """Returns True if the two tensors are equal, allowing for nans.""" 548 | if t1.shape != t2.shape: 549 | return False 550 | 551 | if t1.dtype != t2.dtype: 552 | return False 553 | 554 | if not torch.equal(t1.isinf(), t2.isinf()): 555 | return False 556 | 557 | t1_nonan = torch.nan_to_num(t1, 0.7234691827346) 558 | t2_nonan = torch.nan_to_num(t2, 0.7234691827346) 559 | 560 | if torch.equal(t1_nonan, t2_nonan): 561 | return True 562 | 563 | if ( 564 | allow_tolerance 565 | and (t1_nonan.dtype != torch.bool) 566 | and (t2_nonan.dtype != torch.bool) 567 | and ((t1_nonan - t2_nonan).abs().max() <= MAX_FLOATING_POINT_TOLERANCE) 568 | ): 569 | return True 570 | 571 | return False 572 | 573 | 574 | def safe_to(x: Any, device: str): 575 | """Moves object to device if it's a tensor, does nothing otherwise. 576 | 577 | Args: 578 | x: The object. 579 | device: which device to move to 580 | 581 | Returns: 582 | Object either moved to device if a tensor, same object if otherwise. 583 | """ 584 | if type(x) == torch.Tensor: 585 | return clean_to(x, device) 586 | else: 587 | return x 588 | 589 | 590 | def get_tensor_memory_amount(t: torch.Tensor) -> int: 591 | """Returns the size of a tensor in bytes. 592 | 593 | Args: 594 | t: Tensor. 595 | 596 | Returns: 597 | Size of tensor in bytes. 598 | """ 599 | cpu_data = clean_cpu(t.data) 600 | if cpu_data.dtype == torch.bfloat16: 601 | cpu_data = clean_to(cpu_data, torch.float16) 602 | return getsizeof(np.array(clean_dense(cpu_data))) 603 | 604 | 605 | def human_readable_size(size: int, decimal_places: int = 1) -> str: 606 | """Utility function to convert a size in bytes to a human-readable format. 607 | 608 | Args: 609 | size: Number of bytes. 610 | decimal_places: Number of decimal places to use. 611 | 612 | Returns: 613 | String with human-readable size. 614 | """ 615 | for unit in ["B", "KB", "MB", "GB", "TB", "PB"]: 616 | if size < 1024.0 or unit == "PB": 617 | break 618 | size /= 1024.0 619 | if unit == "B": 620 | size = int(size) 621 | else: 622 | size = np.round(size, decimals=decimal_places) 623 | return f"{size} {unit}" 624 | 625 | 626 | clean_from_numpy = copy.deepcopy(torch.from_numpy) 627 | clean_new_param = copy.deepcopy(torch.nn.Parameter) 628 | clean_clone = copy.deepcopy(torch.clone) 629 | clean_cpu = copy.deepcopy(torch.Tensor.cpu) 630 | clean_cuda = copy.deepcopy(torch.Tensor.cuda) 631 | clean_to = copy.deepcopy(torch.Tensor.to) 632 | clean_dense = copy.deepcopy(torch.Tensor.to_dense) 633 | 634 | 635 | def print_override(t: torch.Tensor, func_name: str): 636 | """Overrides the __str__ and __repr__ methods of Tensor so as not to lead to any infinite recursion. 637 | 638 | Args: 639 | t: Tensor 640 | func_name: Either "__str__" or "__repr__" 641 | 642 | Returns: 643 | The string representation of the tensor. 644 | """ 645 | cpu_data = clean_cpu(t.data) 646 | if cpu_data.dtype == torch.bfloat16: 647 | cpu_data = clean_to(cpu_data, torch.float16) 648 | n = np.array(cpu_data) 649 | np_str = getattr(n, func_name)() 650 | np_str = np_str.replace("array", "tensor") 651 | np_str = np_str.replace("\n", "\n ") 652 | if t.grad_fn is not None: 653 | grad_fn_str = f", grad_fn={type(t.grad_fn).__name__})" 654 | np_str = np_str[0:-1] + grad_fn_str 655 | elif t.requires_grad: 656 | np_str = np_str[0:-1] + ", requires_grad=True)" 657 | return np_str 658 | 659 | 660 | def safe_copy(x, detach_tensor: bool = False): 661 | """Utility function to make a copy of a tensor or parameter when torch is in mutated mode, or just copy 662 | the thing if it's not a tensor. 663 | 664 | Args: 665 | x: Input 666 | detach_tensor: Whether to detach the cloned tensor from the computational graph or not. 667 | 668 | Returns: 669 | Safely copied variant of the input with same values and same class, but different memory 670 | """ 671 | if issubclass(type(x), (torch.Tensor, torch.nn.Parameter)): 672 | if not detach_tensor: 673 | return clean_clone(x) 674 | vals_cpu = clean_cpu(x.data) 675 | if vals_cpu.dtype == torch.bfloat16: 676 | vals_cpu = clean_to(vals_cpu, torch.float16) 677 | vals_np = vals_cpu.numpy() 678 | vals_tensor = clean_from_numpy(vals_np) 679 | if hasattr(x, "tl_tensor_label_raw"): 680 | vals_tensor.tl_tensor_label_raw = x.tl_tensor_label_raw 681 | if type(x) == torch.Tensor: 682 | return vals_tensor 683 | elif type(x) == torch.nn.Parameter: 684 | return clean_new_param(vals_tensor) 685 | else: 686 | return copy.copy(x) 687 | 688 | 689 | def in_notebook(): 690 | try: 691 | if "IPKernelApp" not in get_ipython().config: 692 | return False 693 | except ImportError: 694 | return False 695 | except AttributeError: 696 | return False 697 | return True 698 | 699 | 700 | def warn_parallel(): 701 | """ 702 | Utility function to give raise error if it's being run in parallel processing. 703 | """ 704 | if mp.current_process().name != "MainProcess": 705 | raise RuntimeError( 706 | "WARNING: It looks like you are using parallel execution; only run " 707 | "pytorch-xray in the main process, since certain operations " 708 | "depend on execution order." 709 | ) 710 | -------------------------------------------------------------------------------- /torchlens/tensor_log.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import defaultdict 3 | from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Tuple, Union 4 | 5 | import torch 6 | 7 | from .constants import TENSOR_LOG_ENTRY_FIELD_ORDER 8 | from .helper_funcs import clean_to, get_tensor_memory_amount, human_readable_size, print_override, safe_copy 9 | 10 | if TYPE_CHECKING: 11 | from .model_history import ModelHistory 12 | 13 | 14 | class TensorLogEntry: 15 | def __init__(self, fields_dict: Dict): 16 | """Object that stores information about a single tensor operation in the forward pass, 17 | including metadata and the tensor itself (if specified). Initialized by passing in a dictionary with 18 | values for all fields. 19 | Args: 20 | fields_dict: Dict with values for all fields in TensorLogEntry. 21 | """ 22 | # Note: this all has to be tediously initialized instead of a for-loop in order for 23 | # autocomplete features to work well. But, this also serves as a reference for all attributes 24 | # of a tensor log entry. 25 | 26 | # Check that fields_dict contains all fields for TensorLogEntry: 27 | field_order_set = set(TENSOR_LOG_ENTRY_FIELD_ORDER) 28 | fields_dict_key_set = set(fields_dict.keys()) 29 | if fields_dict_key_set != field_order_set: 30 | error_str = "Error initializing TensorLogEntry:" 31 | missing_fields = field_order_set - fields_dict_key_set 32 | extra_fields = fields_dict_key_set - field_order_set 33 | if len(missing_fields) > 0: 34 | error_str += f"\n\t- Missing fields {', '.join(missing_fields)}" 35 | if len(extra_fields) > 0: 36 | error_str += f"\n\t- Extra fields {', '.join(extra_fields)}" 37 | raise ValueError(error_str) 38 | 39 | # General info: 40 | self.tensor_label_raw = fields_dict["tensor_label_raw"] 41 | self.layer_label_raw = fields_dict["layer_label_raw"] 42 | self.operation_num = fields_dict["operation_num"] 43 | self.realtime_tensor_num = fields_dict["realtime_tensor_num"] 44 | self.source_model_history: "ModelHistory" = fields_dict["source_model_history"] 45 | self._pass_finished = fields_dict["_pass_finished"] 46 | 47 | # Label info: 48 | self.layer_label = fields_dict["layer_label"] 49 | self.layer_label_short = fields_dict["layer_label_short"] 50 | self.layer_label_w_pass = fields_dict["layer_label_w_pass"] 51 | self.layer_label_w_pass_short = fields_dict["layer_label_w_pass_short"] 52 | self.layer_label_no_pass = fields_dict["layer_label_no_pass"] 53 | self.layer_label_no_pass_short = fields_dict["layer_label_no_pass_short"] 54 | self.layer_type = fields_dict["layer_type"] 55 | self.layer_type_num = fields_dict["layer_type_num"] 56 | self.layer_total_num = fields_dict["layer_total_num"] 57 | self.pass_num = fields_dict["pass_num"] 58 | self.layer_passes_total = fields_dict["layer_passes_total"] 59 | self.lookup_keys = fields_dict["lookup_keys"] 60 | 61 | # Saved tensor info: 62 | self.tensor_contents = fields_dict["tensor_contents"] 63 | self.has_saved_activations = fields_dict["has_saved_activations"] 64 | self.output_device = fields_dict["output_device"] 65 | self.activation_postfunc = fields_dict["activation_postfunc"] 66 | self.detach_saved_tensor = fields_dict["detach_saved_tensor"] 67 | self.function_args_saved = fields_dict["function_args_saved"] 68 | self.creation_args = fields_dict["creation_args"] 69 | self.creation_kwargs = fields_dict["creation_kwargs"] 70 | self.tensor_shape = fields_dict["tensor_shape"] 71 | self.tensor_dtype = fields_dict["tensor_dtype"] 72 | self.tensor_fsize = fields_dict["tensor_fsize"] 73 | self.tensor_fsize_nice = fields_dict["tensor_fsize_nice"] 74 | 75 | # Dealing with getitem complexities 76 | self.was_getitem_applied = fields_dict["was_getitem_applied"] 77 | self.children_tensor_versions = fields_dict["children_tensor_versions"] 78 | 79 | # Saved gradient info 80 | self.grad_contents = fields_dict["grad_contents"] 81 | self.save_gradients = fields_dict["save_gradients"] 82 | self.has_saved_grad = fields_dict["has_saved_grad"] 83 | self.grad_shape = fields_dict["grad_shape"] 84 | self.grad_dtype = fields_dict["grad_dtype"] 85 | self.grad_fsize = fields_dict["grad_fsize"] 86 | self.grad_fsize_nice = fields_dict["grad_fsize_nice"] 87 | 88 | # Function call info: 89 | self.func_applied = fields_dict["func_applied"] 90 | self.func_applied_name = fields_dict["func_applied_name"] 91 | self.func_call_stack = fields_dict["func_call_stack"] 92 | self.func_time_elapsed = fields_dict["func_time_elapsed"] 93 | self.func_rng_states = fields_dict["func_rng_states"] 94 | self.func_argnames = fields_dict["func_argnames"] 95 | self.num_func_args_total = fields_dict["num_func_args_total"] 96 | self.num_position_args = fields_dict["num_position_args"] 97 | self.num_keyword_args = fields_dict["num_keyword_args"] 98 | self.func_position_args_non_tensor = fields_dict[ 99 | "func_position_args_non_tensor" 100 | ] 101 | self.func_keyword_args_non_tensor = fields_dict["func_keyword_args_non_tensor"] 102 | self.func_all_args_non_tensor = fields_dict["func_all_args_non_tensor"] 103 | self.function_is_inplace = fields_dict["function_is_inplace"] 104 | self.gradfunc = fields_dict["gradfunc"] 105 | self.is_part_of_iterable_output = fields_dict["is_part_of_iterable_output"] 106 | self.iterable_output_index = fields_dict["iterable_output_index"] 107 | 108 | # Param info: 109 | self.computed_with_params = fields_dict["computed_with_params"] 110 | self.parent_params = fields_dict["parent_params"] 111 | self.parent_param_barcodes = fields_dict["parent_param_barcodes"] 112 | self.parent_param_passes = fields_dict["parent_param_passes"] 113 | self.num_param_tensors = fields_dict["num_param_tensors"] 114 | self.parent_param_shapes = fields_dict["parent_param_shapes"] 115 | self.num_params_total = fields_dict["num_params_total"] 116 | self.parent_params_fsize = fields_dict["parent_params_fsize"] 117 | self.parent_params_fsize_nice = fields_dict["parent_params_fsize_nice"] 118 | 119 | # Corresponding layer info: 120 | self.operation_equivalence_type = fields_dict["operation_equivalence_type"] 121 | self.equivalent_operations = fields_dict["equivalent_operations"] 122 | self.same_layer_operations = fields_dict["same_layer_operations"] 123 | 124 | # Graph info: 125 | self.parent_layers = fields_dict["parent_layers"] 126 | self.has_parents = fields_dict["has_parents"] 127 | self.parent_layer_arg_locs = fields_dict["parent_layer_arg_locs"] 128 | self.orig_ancestors = fields_dict["orig_ancestors"] 129 | self.child_layers = fields_dict["child_layers"] 130 | self.has_children = fields_dict["has_children"] 131 | self.sibling_layers = fields_dict["sibling_layers"] 132 | self.has_siblings = fields_dict["has_siblings"] 133 | self.spouse_layers = fields_dict["spouse_layers"] 134 | self.has_spouses = fields_dict["has_spouses"] 135 | self.is_input_layer = fields_dict["is_input_layer"] 136 | self.has_input_ancestor = fields_dict["has_input_ancestor"] 137 | self.input_ancestors = fields_dict["input_ancestors"] 138 | self.min_distance_from_input = fields_dict["min_distance_from_input"] 139 | self.max_distance_from_input = fields_dict["max_distance_from_input"] 140 | self.is_output_layer = fields_dict["is_output_layer"] 141 | self.is_output_parent = fields_dict["is_output_parent"] 142 | self.is_last_output_layer = fields_dict["is_last_output_layer"] 143 | self.is_output_ancestor = fields_dict["is_output_ancestor"] 144 | self.output_descendents = fields_dict["output_descendents"] 145 | self.min_distance_from_output = fields_dict["min_distance_from_output"] 146 | self.max_distance_from_output = fields_dict["max_distance_from_output"] 147 | self.input_output_address = fields_dict["input_output_address"] 148 | self.is_buffer_layer = fields_dict["is_buffer_layer"] 149 | self.buffer_address = fields_dict["buffer_address"] 150 | self.buffer_pass = fields_dict["buffer_pass"] 151 | self.buffer_parent = fields_dict["buffer_parent"] 152 | self.initialized_inside_model = fields_dict["initialized_inside_model"] 153 | self.has_internally_initialized_ancestor = fields_dict[ 154 | "has_internally_initialized_ancestor" 155 | ] 156 | self.internally_initialized_parents = fields_dict[ 157 | "internally_initialized_parents" 158 | ] 159 | self.internally_initialized_ancestors = fields_dict[ 160 | "internally_initialized_ancestors" 161 | ] 162 | self.terminated_inside_model = fields_dict["terminated_inside_model"] 163 | 164 | # Conditional info 165 | self.is_terminal_bool_layer = fields_dict["is_terminal_bool_layer"] 166 | self.is_atomic_bool_layer = fields_dict["is_atomic_bool_layer"] 167 | self.atomic_bool_val = fields_dict["atomic_bool_val"] 168 | self.in_cond_branch = fields_dict["in_cond_branch"] 169 | self.cond_branch_start_children = fields_dict["cond_branch_start_children"] 170 | 171 | # Module info 172 | self.is_computed_inside_submodule = fields_dict["is_computed_inside_submodule"] 173 | self.containing_module_origin = fields_dict["containing_module_origin"] 174 | self.containing_modules_origin_nested = fields_dict[ 175 | "containing_modules_origin_nested" 176 | ] 177 | self.module_nesting_depth = fields_dict["module_nesting_depth"] 178 | self.modules_entered = fields_dict["modules_entered"] 179 | self.modules_entered_argnames = fields_dict["modules_entered_argnames"] 180 | self.module_passes_entered = fields_dict["module_passes_entered"] 181 | self.is_submodule_input = fields_dict["is_submodule_input"] 182 | self.modules_exited = fields_dict["modules_exited"] 183 | self.module_passes_exited = fields_dict["module_passes_exited"] 184 | self.is_submodule_output = fields_dict["is_submodule_output"] 185 | self.is_bottom_level_submodule_output = fields_dict[ 186 | "is_bottom_level_submodule_output" 187 | ] 188 | self.bottom_level_submodule_pass_exited = fields_dict[ 189 | "bottom_level_submodule_pass_exited" 190 | ] 191 | self.module_entry_exit_threads_inputs = fields_dict[ 192 | "module_entry_exit_threads_inputs" 193 | ] 194 | self.module_entry_exit_thread_output = fields_dict[ 195 | "module_entry_exit_thread_output" 196 | ] 197 | 198 | # ******************************************** 199 | # *********** User-Facing Functions ********** 200 | # ******************************************** 201 | 202 | def print_all_fields(self): 203 | """Print all data fields in the layer.""" 204 | fields_to_exclude = ["source_model_history", "func_rng_states"] 205 | 206 | for field in dir(self): 207 | attr = getattr(self, field) 208 | if not any( 209 | [field.startswith("_"), field in fields_to_exclude, callable(attr)] 210 | ): 211 | print(f"{field}: {attr}") 212 | 213 | # ******************************************** 214 | # ************* Logging Functions ************ 215 | # ******************************************** 216 | 217 | def copy(self): 218 | """Return a copy of itself. 219 | 220 | Returns: 221 | Copy of itself. 222 | """ 223 | fields_dict = {} 224 | fields_not_to_deepcopy = [ 225 | "func_applied", 226 | "gradfunc", 227 | "source_model_history", 228 | "func_rng_states", 229 | "creation_args", 230 | "creation_kwargs", 231 | "parent_params", 232 | "tensor_contents", 233 | "children_tensor_versions", 234 | ] 235 | for field in TENSOR_LOG_ENTRY_FIELD_ORDER: 236 | if field not in fields_not_to_deepcopy: 237 | fields_dict[field] = copy.deepcopy(getattr(self, field)) 238 | else: 239 | fields_dict[field] = getattr(self, field) 240 | copied_entry = TensorLogEntry(fields_dict) 241 | return copied_entry 242 | 243 | def save_tensor_data( 244 | self, 245 | t: torch.Tensor, 246 | t_args: Union[List, Tuple], 247 | t_kwargs: Dict, 248 | save_function_args: bool, 249 | activation_postfunc: Optional[Callable] = None, 250 | ): 251 | """Saves the tensor data for a given tensor operation. 252 | 253 | Args: 254 | t: the tensor. 255 | t_args: tensor positional arguments for the operation 256 | t_kwargs: tensor keyword arguments for the operation 257 | save_function_args: whether to save the arguments to the function 258 | activation_postfunc: function to apply to activations before saving them 259 | """ 260 | # The tensor itself: 261 | self.tensor_contents = safe_copy(t, self.detach_saved_tensor) 262 | if self.output_device not in [str(self.tensor_contents.device), "same"]: 263 | self.tensor_contents = clean_to(self.tensor_contents, self.output_device) 264 | if activation_postfunc is not None: 265 | self.source_model_history._pause_logging = True 266 | self.tensor_contents = activation_postfunc(self.tensor_contents) 267 | self.source_model_history._pause_logging = False 268 | 269 | self.has_saved_activations = True 270 | 271 | # Tensor args and kwargs: 272 | if save_function_args: 273 | self.function_args_saved = True 274 | creation_args = [] 275 | for arg in t_args: 276 | if type(arg) == list: 277 | creation_args.append([safe_copy(a) for a in arg]) 278 | else: 279 | creation_args.append(safe_copy(arg)) 280 | 281 | creation_kwargs = {} 282 | for key, value in t_kwargs.items(): 283 | if type(value) == list: 284 | creation_kwargs[key] = [safe_copy(v) for v in value] 285 | else: 286 | creation_kwargs[key] = safe_copy(value) 287 | self.creation_args = creation_args 288 | self.creation_kwargs = creation_kwargs 289 | else: 290 | self.creation_args = None 291 | self.creation_kwargs = None 292 | 293 | def log_tensor_grad(self, grad: torch.Tensor): 294 | """Logs the gradient for a tensor to the log entry 295 | 296 | Args: 297 | grad: The gradient to save. 298 | """ 299 | self.grad_contents = grad 300 | self.has_saved_grad = True 301 | self.grad_shape = grad.shape 302 | self.grad_dtype = grad.dtype 303 | self.grad_fsize = get_tensor_memory_amount(grad) 304 | self.grad_fsize_nice = human_readable_size(get_tensor_memory_amount(grad)) 305 | 306 | # ******************************************** 307 | # ************* Fetcher Functions ************ 308 | # ******************************************** 309 | 310 | def get_child_layers(self): 311 | return [ 312 | self.source_model_history[child_label] for child_label in self.child_layers 313 | ] 314 | 315 | def get_parent_layers(self): 316 | return [ 317 | self.source_model_history[parent_label] 318 | for parent_label in self.parent_layers 319 | ] 320 | 321 | # ******************************************** 322 | # ************* Built-in Methods ************* 323 | # ******************************************** 324 | 325 | def __str__(self): 326 | if self._pass_finished: 327 | return self._str_after_pass() 328 | else: 329 | return self._str_during_pass() 330 | 331 | def _str_during_pass(self): 332 | s = f"Tensor {self.tensor_label_raw} (layer {self.layer_label_raw}) (PASS NOT FINISHED):" 333 | s += f"\n\tPass: {self.pass_num}" 334 | s += f"\n\tTensor info: shape {self.tensor_shape}, dtype {self.tensor_dtype}" 335 | s += f"\n\tComputed from params: {self.computed_with_params}" 336 | s += f"\n\tComputed in modules: {self.containing_modules_origin_nested}" 337 | s += f"\n\tOutput of modules: {self.module_passes_exited}" 338 | if self.is_bottom_level_submodule_output: 339 | s += " (bottom-level submodule output)" 340 | else: 341 | s += " (not bottom-level submodule output)" 342 | s += "\n\tFamily info:" 343 | s += f"\n\t\tParents: {self.parent_layers}" 344 | s += f"\n\t\tChildren: {self.child_layers}" 345 | s += f"\n\t\tSpouses: {self.spouse_layers}" 346 | s += f"\n\t\tSiblings: {self.sibling_layers}" 347 | s += ( 348 | f"\n\t\tOriginal Ancestors: {self.orig_ancestors} " 349 | f"(min dist {self.min_distance_from_input} nodes, max dist {self.max_distance_from_input} nodes)" 350 | ) 351 | s += f"\n\t\tInput Ancestors: {self.input_ancestors}" 352 | s += f"\n\t\tInternal Ancestors: {self.internally_initialized_ancestors}" 353 | s += ( 354 | f"\n\t\tOutput Descendents: {self.output_descendents} " 355 | f"(min dist {self.min_distance_from_output} nodes, max dist {self.max_distance_from_output} nodes)" 356 | ) 357 | if self.tensor_contents is not None: 358 | s += f"\n\tTensor contents: \n{print_override(self.tensor_contents, '__str__')}" 359 | return s 360 | 361 | def _str_after_pass(self): 362 | if self.layer_passes_total > 1: 363 | pass_str = f" (pass {self.pass_num}/{self.layer_passes_total}), " 364 | else: 365 | pass_str = ", " 366 | s = ( 367 | f"Layer {self.layer_label_no_pass}" 368 | f"{pass_str}operation {self.operation_num}/" 369 | f"{self.source_model_history.num_operations}:" 370 | ) 371 | s += f"\n\tOutput tensor: shape={self.tensor_shape}, dype={self.tensor_dtype}, size={self.tensor_fsize_nice}" 372 | if not self.has_saved_activations: 373 | s += " (not saved)" 374 | s += self._tensor_contents_str_helper() 375 | s += self._tensor_family_str_helper() 376 | if len(self.parent_param_shapes) > 0: 377 | params_shapes_str = ", ".join( 378 | str(param_shape) for param_shape in self.parent_param_shapes 379 | ) 380 | s += ( 381 | f"\n\tParams: Computed from params with shape {params_shapes_str}; " 382 | f"{self.num_params_total} params total ({self.parent_params_fsize_nice})" 383 | ) 384 | else: 385 | s += "\n\tParams: no params used" 386 | if self.containing_module_origin is None: 387 | module_str = "\n\tComputed inside module: not computed inside a module" 388 | else: 389 | module_str = f"\n\tComputed inside module: {self.containing_module_origin}" 390 | if not self.is_input_layer: 391 | s += ( 392 | f"\n\tFunction: {self.func_applied_name} (grad_fn: {self.gradfunc}) " 393 | f"{module_str}" 394 | ) 395 | s += f"\n\tTime elapsed: {self.func_time_elapsed: .3E}s" 396 | if len(self.modules_exited) > 0: 397 | modules_exited_str = ", ".join(self.modules_exited) 398 | s += f"\n\tOutput of modules: {modules_exited_str}" 399 | else: 400 | s += "\n\tOutput of modules: none" 401 | if self.is_bottom_level_submodule_output: 402 | s += f"\n\tOutput of bottom-level module: {self.bottom_level_submodule_pass_exited}" 403 | lookup_keys_str = ", ".join([str(key) for key in self.lookup_keys]) 404 | s += f"\n\tLookup keys: {lookup_keys_str}" 405 | 406 | return s 407 | 408 | def _tensor_contents_str_helper(self) -> str: 409 | """Returns short, readable string for the tensor contents.""" 410 | if self.tensor_contents is None: 411 | return "" 412 | else: 413 | s = "" 414 | tensor_size_shown = 8 415 | saved_shape = self.tensor_contents.shape 416 | if len(saved_shape) == 0: 417 | tensor_slice = self.tensor_contents 418 | elif len(saved_shape) == 1: 419 | num_dims = min(tensor_size_shown, saved_shape[0]) 420 | tensor_slice = self.tensor_contents[0:num_dims] 421 | elif len(saved_shape) == 2: 422 | num_dims = min([tensor_size_shown, saved_shape[-2], saved_shape[-1]]) 423 | tensor_slice = self.tensor_contents[0:num_dims, 0:num_dims] 424 | else: 425 | num_dims = min( 426 | [tensor_size_shown, self.tensor_shape[-2], self.tensor_shape[-1]] 427 | ) 428 | tensor_slice = self.tensor_contents.data.clone() 429 | for _ in range(len(saved_shape) - 2): 430 | tensor_slice = tensor_slice[0] 431 | tensor_slice = tensor_slice[0:num_dims, 0:num_dims] 432 | tensor_slice = tensor_slice.detach() 433 | tensor_slice.requires_grad = False 434 | s += f"\n\t\t{str(tensor_slice)}" 435 | if (len(saved_shape) > 0) and (max(saved_shape) > tensor_size_shown): 436 | s += "..." 437 | return s 438 | 439 | def _tensor_family_str_helper(self) -> str: 440 | s = "\n\tRelated Layers:" 441 | if len(self.parent_layers) > 0: 442 | s += "\n\t\t- parent layers: " + ", ".join(self.parent_layers) 443 | else: 444 | s += "\n\t\t- no parent layers" 445 | 446 | if len(self.child_layers) > 0: 447 | s += "\n\t\t- child layers: " + ", ".join(self.child_layers) 448 | else: 449 | s += "\n\t\t- no child layers" 450 | 451 | if len(self.sibling_layers) > 0: 452 | s += "\n\t\t- shares parents with layers: " + ", ".join(self.sibling_layers) 453 | else: 454 | s += "\n\t\t- shares parents with no other layers" 455 | 456 | if len(self.spouse_layers) > 0: 457 | s += "\n\t\t- shares children with layers: " + ", ".join(self.spouse_layers) 458 | else: 459 | s += "\n\t\t- shares children with no other layers" 460 | 461 | if self.has_input_ancestor: 462 | s += "\n\t\t- descendent of input layers: " + ", ".join( 463 | self.input_ancestors 464 | ) 465 | else: 466 | s += "\n\t\t- tensor was created de novo inside the model (not computed from input)" 467 | 468 | if self.is_output_ancestor: 469 | s += "\n\t\t- ancestor of output layers: " + ", ".join( 470 | self.output_descendents 471 | ) 472 | else: 473 | s += "\n\t\t- tensor is not an ancestor of the model output; it terminates within the model" 474 | 475 | return s 476 | 477 | def __repr__(self): 478 | return self.__str__() 479 | 480 | 481 | class RolledTensorLogEntry: 482 | def __init__(self, source_entry: TensorLogEntry): 483 | """Stripped-down version TensorLogEntry that only encodes the information needed to plot the model 484 | in its rolled-up form. 485 | 486 | Args: 487 | source_entry: The source TensorLogEntry from which the rolled node is constructed 488 | """ 489 | # Label & general info 490 | self.layer_label = source_entry.layer_label_no_pass 491 | self.layer_type = source_entry.layer_type 492 | self.layer_type_num = source_entry.layer_type_num 493 | self.layer_total_num = source_entry.layer_total_num 494 | self.layer_passes_total = source_entry.layer_passes_total 495 | self.source_model_history = source_entry.source_model_history 496 | 497 | # Saved tensor info 498 | self.tensor_shape = source_entry.tensor_shape 499 | self.tensor_fsize_nice = source_entry.tensor_fsize_nice 500 | 501 | # Param info: 502 | self.computed_with_params = source_entry.computed_with_params 503 | self.parent_param_shapes = source_entry.parent_param_shapes 504 | self.num_param_tensors = source_entry.num_param_tensors 505 | 506 | # Graph info 507 | self.is_input_layer = source_entry.is_input_layer 508 | self.has_input_ancestor = source_entry.has_input_ancestor 509 | self.is_output_layer = source_entry.is_output_layer 510 | self.is_last_output_layer = source_entry.is_last_output_layer 511 | self.is_buffer_layer = source_entry.is_buffer_layer 512 | self.buffer_address = source_entry.buffer_address 513 | self.buffer_pass = source_entry.buffer_pass 514 | self.input_output_address = source_entry.input_output_address 515 | self.cond_branch_start_children = source_entry.cond_branch_start_children 516 | self.is_terminal_bool_layer = source_entry.is_terminal_bool_layer 517 | self.atomic_bool_val = source_entry.atomic_bool_val 518 | self.child_layers = [] 519 | self.parent_layers = [] 520 | self.orphan_layers = [] 521 | 522 | # Module info: 523 | self.containing_modules_origin_nested = ( 524 | source_entry.containing_modules_origin_nested 525 | ) 526 | self.modules_exited = source_entry.modules_exited 527 | self.module_passes_exited = source_entry.module_passes_exited 528 | self.is_bottom_level_submodule_output = False 529 | self.bottom_level_submodule_passes_exited = set() 530 | 531 | # Fields specific to rolled node to fill in: 532 | self.edges_vary_across_passes = False 533 | self.child_layers_per_pass = defaultdict(list) 534 | self.child_passes_per_layer = defaultdict(list) 535 | self.parent_layers_per_pass = defaultdict(list) 536 | self.parent_passes_per_layer = defaultdict(list) 537 | 538 | # Each one will now be a list of layers, since they can vary across passes. 539 | self.parent_layer_arg_locs = { 540 | "args": defaultdict(set), 541 | "kwargs": defaultdict(set), 542 | } 543 | 544 | def update_data(self, source_node: TensorLogEntry): 545 | """Updates the data as need be. 546 | Args: 547 | source_node: the source node 548 | """ 549 | if source_node.has_input_ancestor: 550 | self.has_input_ancestor = True 551 | if not any( 552 | [ 553 | self.input_output_address is None, 554 | source_node.input_output_address is None, 555 | ] 556 | ): 557 | self.input_output_address = "".join( 558 | [ 559 | char if (source_node.input_output_address[c] == char) else "*" 560 | for c, char in enumerate(self.input_output_address) 561 | ] 562 | ) 563 | if self.input_output_address[-1] == ".": 564 | self.input_output_address = self.input_output_address[:-1] 565 | if self.input_output_address[-1] == "*": 566 | self.input_output_address = self.input_output_address.strip("*") + "*" 567 | 568 | def add_pass_info(self, source_node: TensorLogEntry): 569 | """Adds information about another pass of the same layer: namely, mark information about what the 570 | child and parent layers are for each pass. 571 | 572 | Args: 573 | source_node: Information for the source pass 574 | """ 575 | # Label the layers for each pass 576 | child_layer_labels = [ 577 | self.source_model_history[child].layer_label_no_pass 578 | for child in source_node.child_layers 579 | ] 580 | for child_layer in child_layer_labels: 581 | if child_layer not in self.child_layers: 582 | self.child_layers.append(child_layer) 583 | if child_layer not in self.child_layers_per_pass[source_node.pass_num]: 584 | self.child_layers_per_pass[source_node.pass_num].append(child_layer) 585 | 586 | parent_layer_labels = [ 587 | self.source_model_history[parent].layer_label_no_pass 588 | for parent in source_node.parent_layers 589 | ] 590 | for parent_layer in parent_layer_labels: 591 | if parent_layer not in self.parent_layers: 592 | self.parent_layers.append(parent_layer) 593 | if parent_layer not in self.parent_layers_per_pass[source_node.pass_num]: 594 | self.parent_layers_per_pass[source_node.pass_num].append(parent_layer) 595 | 596 | # Label the passes for each layer, and indicate if any layers vary based on the pass. 597 | for child_layer in source_node.child_layers: 598 | child_layer_label = self.source_model_history[ 599 | child_layer 600 | ].layer_label_no_pass 601 | if ( 602 | source_node.pass_num 603 | not in self.child_passes_per_layer[child_layer_label] 604 | ): 605 | self.child_passes_per_layer[child_layer_label].append( 606 | source_node.pass_num 607 | ) 608 | 609 | for parent_layer in source_node.parent_layers: 610 | parent_layer_label = self.source_model_history[ 611 | parent_layer 612 | ].layer_label_no_pass 613 | if ( 614 | source_node.pass_num 615 | not in self.parent_passes_per_layer[parent_layer_label] 616 | ): 617 | self.parent_passes_per_layer[parent_layer_label].append( 618 | source_node.pass_num 619 | ) 620 | 621 | # Check if any edges vary across passes. 622 | if source_node.pass_num == source_node.layer_passes_total: 623 | pass_lists = list(self.parent_passes_per_layer.values()) + list( 624 | self.child_passes_per_layer.values() 625 | ) 626 | pass_lens = [len(passes) for passes in pass_lists] 627 | if any( 628 | [pass_len < source_node.layer_passes_total for pass_len in pass_lens] 629 | ): 630 | self.edges_vary_across_passes = True 631 | else: 632 | self.edges_vary_across_passes = False 633 | 634 | # Add submodule info: 635 | if source_node.is_bottom_level_submodule_output: 636 | self.is_bottom_level_submodule_output = True 637 | self.bottom_level_submodule_passes_exited.add( 638 | source_node.bottom_level_submodule_pass_exited 639 | ) 640 | 641 | # For the parent arg locations, have a list of layers rather than single layer, since they can 642 | # vary across passes. 643 | 644 | for arg_type in ["args", "kwargs"]: 645 | for arg_key, layer_label in source_node.parent_layer_arg_locs[ 646 | arg_type 647 | ].items(): 648 | layer_label_no_pass = self.source_model_history[ 649 | layer_label 650 | ].layer_label_no_pass 651 | self.parent_layer_arg_locs[arg_type][arg_key].add(layer_label_no_pass) 652 | 653 | def __str__(self) -> str: 654 | fields_not_to_print = ["source_model_history"] 655 | s = "" 656 | for field in dir(self): 657 | attr = getattr(self, field) 658 | if ( 659 | not field.startswith("_") 660 | and field not in fields_not_to_print 661 | and not (callable(attr)) 662 | ): 663 | s += f"{field}: {attr}\n" 664 | return s 665 | 666 | def __repr__(self): 667 | return self.__str__() 668 | -------------------------------------------------------------------------------- /torchlens/validation.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Any, Dict, List, Set, TYPE_CHECKING, Union 3 | 4 | import torch 5 | 6 | from .tensor_log import TensorLogEntry 7 | 8 | if TYPE_CHECKING: 9 | pass 10 | 11 | from .helper_funcs import ( 12 | log_current_rng_states, 13 | set_rng_from_saved_states, 14 | tuple_tolerant_assign, 15 | tensor_nanequal, 16 | tensor_all_nan, 17 | ) 18 | 19 | FUNCS_NOT_TO_PERTURB_IN_VALIDATION = [ 20 | "expand_as", 21 | "new_zeros", 22 | "new_ones", 23 | "zero_", 24 | "copy_", 25 | "clamp", 26 | "fill_", 27 | "zeros_like", 28 | "ones_like", 29 | ] 30 | 31 | 32 | def validate_saved_activations( 33 | self, ground_truth_output_tensors: List[torch.Tensor], verbose: bool = False 34 | ) -> bool: 35 | """Starting from outputs and internally terminated tensors, checks whether computing their values from the saved 36 | values of their input tensors yields their actually saved values, and whether computing their values from 37 | their parent tensors yields their saved values. 38 | 39 | Returns: 40 | True if it passes the tests, False otherwise. 41 | """ 42 | # First check that the ground truth output tensors are accurate: 43 | for i, output_layer_label in enumerate(self.output_layers): 44 | output_layer = self[output_layer_label] 45 | if not tensor_nanequal( 46 | output_layer.tensor_contents, 47 | ground_truth_output_tensors[i], 48 | allow_tolerance=False, 49 | ): 50 | print( 51 | f"The {i}th output layer, {output_layer_label}, does not match the ground truth output tensor." 52 | ) 53 | return False 54 | 55 | # Validate the parents of each validated layer. 56 | validated_child_edges_for_each_layer = defaultdict(set) 57 | validated_layers = set(self.output_layers + self.internally_terminated_layers) 58 | layers_to_validate_parents_for = list(validated_layers) 59 | 60 | while len(layers_to_validate_parents_for) > 0: 61 | layer_to_validate_parents_for = layers_to_validate_parents_for.pop(0) 62 | parent_layers_valid = validate_parents_of_saved_layer( 63 | self, 64 | layer_to_validate_parents_for, 65 | validated_layers, 66 | validated_child_edges_for_each_layer, 67 | layers_to_validate_parents_for, 68 | verbose, 69 | ) 70 | if not parent_layers_valid: 71 | return False 72 | 73 | if len(validated_layers) < len(self.layer_labels): 74 | print( 75 | f"All saved activations were accurate, but some layers were not reached (check that " 76 | f"child args logged accurately): {set(self.layer_labels) - validated_layers}" 77 | ) 78 | return False 79 | 80 | return True 81 | 82 | 83 | def validate_parents_of_saved_layer( 84 | self, 85 | layer_to_validate_parents_for_label: str, 86 | validated_layers: Set[str], 87 | validated_child_edges_for_each_layer: Dict[str, Set[str]], 88 | layers_to_validate_parents_for: List[str], 89 | verbose: bool = False, 90 | ) -> bool: 91 | """Given a layer, checks that 1) all parent tensors appear properly in the saved arguments for that layer, 92 | 2) that executing the function for that layer with the saved parent layer activations yields the 93 | ground truth activation values for that layer, and 3) that plugging in "perturbed" values for each 94 | child layer yields values different from the saved activations for that layer. 95 | 96 | Args: 97 | layer_to_validate_parents_for_label: 98 | validated_layers: 99 | validated_child_edges_for_each_layer: 100 | layers_to_validate_parents_for: 101 | verbose: whether to print warning messages 102 | """ 103 | layer_to_validate_parents_for = self[layer_to_validate_parents_for_label] 104 | 105 | # Check that the arguments are logged correctly: 106 | if not _check_layer_arguments_logged_correctly( 107 | self, layer_to_validate_parents_for_label 108 | ): 109 | print( 110 | f"Parent arguments for layer {layer_to_validate_parents_for_label} are not logged properly; " 111 | f"either a parent wasn't logged as an argument, or was logged an extra time" 112 | ) 113 | return False 114 | 115 | # Check that executing the function based on the actual saved values of the parents yields the saved 116 | # values of the layer itself: 117 | 118 | if not _check_whether_func_on_saved_parents_yields_saved_tensor( 119 | self, layer_to_validate_parents_for_label, perturb=False 120 | ): 121 | return False 122 | 123 | # Check that executing the layer's function on the wrong version of the saved parent tensors 124 | # yields the wrong tensors, when each saved tensor is perturbed in turn: 125 | 126 | for perturb_layer in layer_to_validate_parents_for.parent_layers: 127 | if ( 128 | layer_to_validate_parents_for.func_applied_name 129 | in FUNCS_NOT_TO_PERTURB_IN_VALIDATION 130 | ): 131 | continue 132 | if not _check_whether_func_on_saved_parents_yields_saved_tensor( 133 | self, 134 | layer_to_validate_parents_for_label, 135 | perturb=True, 136 | layers_to_perturb=[perturb_layer], 137 | verbose=verbose, 138 | ): 139 | return False 140 | 141 | # Log that each parent layer has been validated for this source layer. 142 | 143 | for parent_layer_label in layer_to_validate_parents_for.parent_layers: 144 | parent_layer = self[parent_layer_label] 145 | validated_child_edges_for_each_layer[parent_layer_label].add( 146 | layer_to_validate_parents_for_label 147 | ) 148 | if validated_child_edges_for_each_layer[parent_layer_label] == set( 149 | parent_layer.child_layers 150 | ): 151 | validated_layers.add(parent_layer_label) 152 | if ((not parent_layer.is_input_layer) and 153 | not (parent_layer.is_buffer_layer and (parent_layer.buffer_parent is None))): 154 | layers_to_validate_parents_for.append(parent_layer_label) 155 | 156 | return True 157 | 158 | 159 | def _check_layer_arguments_logged_correctly(self, target_layer_label: str) -> bool: 160 | """Check whether the activations of the parent layers match the saved arguments of 161 | the target layer, and that the argument locations have been logged correctly. 162 | 163 | Args: 164 | target_layer_label: Layer to check 165 | 166 | Returns: 167 | True if arguments logged accurately, False otherwise 168 | """ 169 | target_layer = self[target_layer_label] 170 | 171 | # Make sure that all parent layers appear in at least one argument and that no extra layers appear: 172 | parent_layers_in_args = set() 173 | for arg_type in ["args", "kwargs"]: 174 | parent_layers_in_args.update( 175 | list(target_layer.parent_layer_arg_locs[arg_type].values()) 176 | ) 177 | if parent_layers_in_args != set(target_layer.parent_layers): 178 | return False 179 | 180 | argtype_dict = { 181 | "args": (enumerate, "creation_args"), 182 | "kwargs": (lambda x: x.items(), "creation_kwargs"), 183 | } 184 | 185 | # Check for each parent layer that it is logged as a saved argument when it matches an argument, and 186 | # is not logged when it does not match a saved argument. 187 | 188 | for parent_layer_label in target_layer.parent_layers: 189 | parent_layer = self[parent_layer_label] 190 | for arg_type in ["args", "kwargs"]: 191 | iterfunc, argtype_field = argtype_dict[arg_type] 192 | for key, val in iterfunc(getattr(target_layer, argtype_field)): 193 | validation_correct_for_arg_and_layer = ( 194 | _validate_layer_against_arg( 195 | self, target_layer, parent_layer, arg_type, key, val 196 | ) 197 | ) 198 | if not validation_correct_for_arg_and_layer: 199 | return False 200 | return True 201 | 202 | 203 | def _validate_layer_against_arg( 204 | self, target_layer, parent_layer, arg_type, key, val 205 | ): 206 | if type(val) in [list, tuple]: 207 | for v, subval in enumerate(val): 208 | argloc_key = (key, v) 209 | validation_correct_for_arg_and_layer = ( 210 | _check_arglocs_correct_for_arg( 211 | self, target_layer, parent_layer, arg_type, argloc_key, subval 212 | ) 213 | ) 214 | if not validation_correct_for_arg_and_layer: 215 | return False 216 | 217 | elif type(val) == dict: 218 | for subkey, subval in val.items(): 219 | argloc_key = (key, subkey) 220 | validation_correct_for_arg_and_layer = ( 221 | _check_arglocs_correct_for_arg( 222 | self, target_layer, parent_layer, arg_type, argloc_key, subval 223 | ) 224 | ) 225 | if not validation_correct_for_arg_and_layer: 226 | return False 227 | else: 228 | argloc_key = key 229 | validation_correct_for_arg_and_layer = _check_arglocs_correct_for_arg( 230 | self, target_layer, parent_layer, arg_type, argloc_key, val 231 | ) 232 | if not validation_correct_for_arg_and_layer: 233 | return False 234 | 235 | return True 236 | 237 | 238 | def _check_arglocs_correct_for_arg( 239 | self, 240 | target_layer: TensorLogEntry, 241 | parent_layer: TensorLogEntry, 242 | arg_type: str, 243 | argloc_key: Union[str, tuple], 244 | saved_arg_val: Any, 245 | ): 246 | """For a given layer and an argument to its child layer, checks that it is logged correctly: 247 | that is, that it's logged as an argument if it matches, and is not logged as an argument if it doesn't match. 248 | """ 249 | target_layer_label = target_layer.layer_label 250 | parent_layer_label = parent_layer.layer_label 251 | if target_layer_label in parent_layer.children_tensor_versions: 252 | parent_activations = parent_layer.children_tensor_versions[ 253 | target_layer_label 254 | ] 255 | else: 256 | parent_activations = parent_layer.tensor_contents 257 | 258 | if type(saved_arg_val) == torch.Tensor: 259 | parent_layer_matches_arg = tensor_nanequal( 260 | saved_arg_val, parent_activations, allow_tolerance=False 261 | ) 262 | else: 263 | parent_layer_matches_arg = False 264 | parent_layer_logged_as_arg = ( 265 | argloc_key in target_layer.parent_layer_arg_locs[arg_type] 266 | ) and ( 267 | target_layer.parent_layer_arg_locs[arg_type][argloc_key] 268 | == parent_layer_label 269 | ) 270 | 271 | if ( 272 | parent_layer_matches_arg 273 | and (not parent_layer_logged_as_arg) 274 | and (parent_activations.numel() != 0) 275 | and (parent_activations.dtype != torch.bool) 276 | and (not tensor_all_nan(parent_activations)) 277 | and (parent_activations.abs().float().mean() != 0) 278 | and (parent_activations.abs().float().mean() != 1) 279 | and not any( 280 | [ 281 | torch.equal(parent_activations, self[other_parent].tensor_contents) 282 | for other_parent in target_layer.parent_layers 283 | if other_parent != parent_layer_label 284 | ] 285 | ) 286 | ): 287 | print( 288 | f"Parent {parent_layer_label} of {target_layer_label} has activations that match " 289 | f"{arg_type} {argloc_key} for {target_layer_label}, but is not logged as " 290 | f"such in parent_layer_arg_locs." 291 | ) 292 | return False 293 | 294 | if (not parent_layer_matches_arg) and parent_layer_logged_as_arg: 295 | print( 296 | f"Parent {parent_layer_label} of {target_layer_label} is logged as {arg_type} {argloc_key} to " 297 | f"{target_layer_label}, but its saved activations don't match the saved argument." 298 | ) 299 | return False 300 | 301 | return True 302 | 303 | 304 | def _check_whether_func_on_saved_parents_yields_saved_tensor( 305 | self, 306 | layer_to_validate_parents_for_label: str, 307 | perturb: bool = False, 308 | layers_to_perturb: List[str] = None, 309 | verbose: bool = False, 310 | ) -> bool: 311 | """Checks whether executing the saved function for a layer on the saved value of its parent layers 312 | in fact yields the saved activations for that layer. 313 | 314 | Args: 315 | layer_to_validate_parents_for_label: label of the layer to check the saved activations 316 | perturb: whether to perturb the saved activations 317 | layers_to_perturb: layers for which to perturb the saved activations 318 | 319 | Returns: 320 | True if the activations match, False otherwise 321 | """ 322 | if layers_to_perturb is None: 323 | layers_to_perturb = [] 324 | 325 | layer_to_validate_parents_for = self[layer_to_validate_parents_for_label] 326 | 327 | if ( 328 | perturb 329 | and (layer_to_validate_parents_for.func_applied_name == "__getitem__") 330 | and (type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor) 331 | and torch.equal( 332 | self[layers_to_perturb[0]].tensor_contents, 333 | layer_to_validate_parents_for.creation_args[1], 334 | ) 335 | ): 336 | return True 337 | elif ( 338 | perturb 339 | and (layer_to_validate_parents_for.func_applied_name == "__getitem__") 340 | and not torch.equal( 341 | self[layers_to_perturb[0]].tensor_contents, 342 | layer_to_validate_parents_for.creation_args[0], 343 | ) 344 | ): 345 | return True 346 | elif layer_to_validate_parents_for.func_applied_name == 'empty_like': 347 | return True 348 | elif ( 349 | perturb 350 | and (layer_to_validate_parents_for.func_applied_name == "__setitem__") 351 | and (type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor) 352 | and (layer_to_validate_parents_for.creation_args[1].dtype == torch.bool) 353 | and torch.equal( 354 | self[layers_to_perturb[0]].tensor_contents, 355 | layer_to_validate_parents_for.creation_args[1], 356 | ) 357 | ): 358 | return True 359 | elif ( 360 | perturb 361 | and (layer_to_validate_parents_for.func_applied_name == "cross_entropy") 362 | and torch.equal( 363 | self[layers_to_perturb[0]].tensor_contents, 364 | layer_to_validate_parents_for.creation_args[1], 365 | ) 366 | ): 367 | return True 368 | elif ( 369 | perturb 370 | and (layer_to_validate_parents_for.func_applied_name == "__setitem__") 371 | and (type(layer_to_validate_parents_for.creation_args[1]) == tuple) 372 | and ( 373 | type(layer_to_validate_parents_for.creation_args[1][0]) == torch.Tensor 374 | ) 375 | and (layer_to_validate_parents_for.creation_args[1][0].dtype == torch.bool) 376 | and torch.equal( 377 | self[layers_to_perturb[0]].tensor_contents, 378 | layer_to_validate_parents_for.creation_args[1][0], 379 | ) 380 | ): 381 | return True 382 | elif ( 383 | perturb 384 | and (layer_to_validate_parents_for.func_applied_name == "index_select") 385 | and torch.equal( 386 | self[layers_to_perturb[0]].tensor_contents, 387 | layer_to_validate_parents_for.creation_args[2], 388 | ) 389 | ): 390 | return True 391 | elif ( 392 | perturb 393 | and (layer_to_validate_parents_for.func_applied_name == "lstm") 394 | and (torch.equal( 395 | self[layers_to_perturb[0]].tensor_contents, 396 | layer_to_validate_parents_for.creation_args[1][0]) or 397 | torch.equal( 398 | self[layers_to_perturb[0]].tensor_contents, 399 | layer_to_validate_parents_for.creation_args[1][1]) or 400 | torch.equal( 401 | self[layers_to_perturb[0]].tensor_contents, 402 | layer_to_validate_parents_for.creation_args[2][0]) or 403 | torch.equal( 404 | self[layers_to_perturb[0]].tensor_contents, 405 | layer_to_validate_parents_for.creation_args[2][1]) or 406 | ((type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor) and 407 | torch.equal( 408 | self[layers_to_perturb[0]].tensor_contents, 409 | layer_to_validate_parents_for.creation_args[1]) 410 | ))): 411 | return True 412 | elif ( 413 | perturb 414 | and (layer_to_validate_parents_for.func_applied_name == "_pad_packed_sequence") 415 | and torch.equal( 416 | self[layers_to_perturb[0]].tensor_contents, 417 | layer_to_validate_parents_for.creation_args[1] 418 | )): 419 | return True 420 | elif ( 421 | perturb 422 | and (layer_to_validate_parents_for.func_applied_name == "masked_fill_") 423 | and torch.equal( 424 | self[layers_to_perturb[0]].tensor_contents, 425 | layer_to_validate_parents_for.creation_args[1] 426 | )): 427 | return True 428 | elif ( 429 | perturb 430 | and (layer_to_validate_parents_for.func_applied_name == "scatter_") 431 | and torch.equal( 432 | self[layers_to_perturb[0]].tensor_contents, 433 | layer_to_validate_parents_for.creation_args[2] 434 | )): 435 | return True 436 | elif ( 437 | perturb 438 | and (layer_to_validate_parents_for.func_applied_name == "interpolate") 439 | and ((('scale_factor' in layer_to_validate_parents_for.creation_kwargs) 440 | and (layer_to_validate_parents_for.creation_kwargs['scale_factor'] is not None) 441 | and torch.equal( 442 | self[layers_to_perturb[0]].tensor_contents, 443 | torch.tensor(layer_to_validate_parents_for.creation_kwargs['scale_factor']))) 444 | or ((len(layer_to_validate_parents_for.creation_args) >= 3) 445 | and torch.equal( 446 | self[layers_to_perturb[0]].tensor_contents, 447 | layer_to_validate_parents_for.creation_args[2]))) 448 | ): 449 | return True 450 | 451 | # Prepare input arguments: keep the ones that should just be kept, perturb those that should be perturbed 452 | 453 | input_args = _prepare_input_args_for_validating_layer( 454 | self, layer_to_validate_parents_for, layers_to_perturb 455 | ) 456 | 457 | # set the saved rng value: 458 | layer_func = layer_to_validate_parents_for.func_applied 459 | current_rng_states = log_current_rng_states() 460 | set_rng_from_saved_states(layer_to_validate_parents_for.func_rng_states) 461 | try: 462 | recomputed_output = layer_func(*input_args["args"], **input_args["kwargs"]) 463 | except: 464 | raise Exception(f"Invalid perturbed arguments for layer {layer_to_validate_parents_for_label}") 465 | set_rng_from_saved_states(current_rng_states) 466 | 467 | if layer_func.__name__ in [ 468 | "__setitem__", 469 | "zero_", 470 | "__delitem__", 471 | ]: # TODO: fix this 472 | recomputed_output = input_args["args"][0] 473 | 474 | if any([issubclass(type(recomputed_output), which_type) for which_type in [list, tuple]]): 475 | recomputed_output = recomputed_output[ 476 | layer_to_validate_parents_for.iterable_output_index 477 | ] 478 | 479 | if ( 480 | not ( 481 | tensor_nanequal( 482 | recomputed_output, 483 | layer_to_validate_parents_for.tensor_contents, 484 | allow_tolerance=True, 485 | ) 486 | ) 487 | and not perturb 488 | ): 489 | print( 490 | f"Saved activations for layer {layer_to_validate_parents_for_label} do not match the " 491 | f"values computed based on the parent layers {layer_to_validate_parents_for.parent_layers}." 492 | ) 493 | return False 494 | 495 | if ( 496 | tensor_nanequal( 497 | recomputed_output, 498 | layer_to_validate_parents_for.tensor_contents, 499 | allow_tolerance=False, 500 | ) 501 | and perturb 502 | ): 503 | return _posthoc_perturb_check( 504 | self, layer_to_validate_parents_for, layers_to_perturb, verbose 505 | ) 506 | 507 | return True 508 | 509 | 510 | def _prepare_input_args_for_validating_layer( 511 | self, 512 | layer_to_validate_parents_for: TensorLogEntry, 513 | layers_to_perturb: List[str], 514 | ) -> Dict: 515 | """Prepares the input arguments for validating the saved activations of a layer. 516 | 517 | Args: 518 | layer_to_validate_parents_for: Layer being checked. 519 | layers_to_perturb: Layers for which to perturb the saved activations. 520 | 521 | Returns: 522 | Dict of input arguments. 523 | """ 524 | input_args = { 525 | "args": list(layer_to_validate_parents_for.creation_args[:]), 526 | "kwargs": layer_to_validate_parents_for.creation_kwargs.copy(), 527 | } 528 | input_args = _copy_validation_args(input_args) 529 | 530 | # Swap in saved parent activations: 531 | 532 | for arg_type in ["args", "kwargs"]: 533 | for ( 534 | key, 535 | parent_layer_arg, 536 | ) in layer_to_validate_parents_for.parent_layer_arg_locs[arg_type].items(): 537 | parent_layer = self[parent_layer_arg] 538 | if ( 539 | layer_to_validate_parents_for.layer_label 540 | in parent_layer.children_tensor_versions 541 | ): 542 | parent_values = parent_layer.children_tensor_versions[ 543 | layer_to_validate_parents_for.layer_label 544 | ] 545 | else: 546 | parent_values = parent_layer.tensor_contents 547 | parent_values = parent_values.detach().clone() 548 | 549 | if parent_layer_arg in layers_to_perturb: 550 | parent_layer_func_values = _perturb_layer_activations( 551 | parent_values, layer_to_validate_parents_for.tensor_contents 552 | ) 553 | else: 554 | parent_layer_func_values = parent_values 555 | 556 | if type(key) != tuple: 557 | input_args[arg_type][key] = parent_layer_func_values 558 | else: 559 | input_args[arg_type][key[0]] = tuple_tolerant_assign( 560 | input_args[arg_type][key[0]], key[1], parent_layer_func_values 561 | ) 562 | 563 | return input_args 564 | 565 | 566 | def _copy_validation_args(input_args: Dict): 567 | new_args = [] 568 | for i, val in enumerate(input_args["args"]): 569 | if type(val) == torch.Tensor: 570 | new_args.append(val.detach().clone()) 571 | elif type(val) in [list, tuple, set]: 572 | new_iter = [] 573 | for i2, val2 in enumerate(val): 574 | if type(val2) == torch.Tensor: 575 | new_iter.append(val2.detach().clone()) 576 | else: 577 | new_iter.append(val2) 578 | new_args.append(type(val)(new_iter)) 579 | else: 580 | new_args.append(val) 581 | input_args["args"] = new_args 582 | 583 | new_kwargs = {} 584 | for key, val in input_args["kwargs"].items(): 585 | if type(val) == torch.Tensor: 586 | new_kwargs[key] = val.detach().clone() 587 | elif type(val) in [list, tuple, set]: 588 | new_iter = [] 589 | for i2, val2 in enumerate(val): 590 | if type(val2) == torch.Tensor: 591 | new_iter.append(val2.detach().clone()) 592 | else: 593 | new_iter.append(val2) 594 | new_kwargs[key] = type(val)(new_iter) 595 | else: 596 | new_kwargs[key] = val 597 | input_args["kwargs"] = new_kwargs 598 | return input_args 599 | 600 | 601 | def _perturb_layer_activations( 602 | parent_activations: torch.Tensor, output_activations: torch.Tensor 603 | ) -> torch.Tensor: 604 | """ 605 | Perturbs the values of a saved tensor. 606 | 607 | Args: 608 | parent_activations: Tensor of activation values for the parent tensor 609 | output_activations: Tensor of activation values for the tensor whose parents are being tested (the output) 610 | 611 | Returns: 612 | Perturbed version of saved tensor 613 | """ 614 | device = parent_activations.device 615 | if parent_activations.numel() == 0: 616 | return parent_activations.detach().clone() 617 | 618 | if parent_activations.dtype in [ 619 | torch.int, 620 | torch.long, 621 | torch.short, 622 | torch.uint8, 623 | torch.int8, 624 | torch.int16, 625 | torch.int32, 626 | torch.int64, 627 | ]: 628 | tensor_unique_vals = torch.unique(parent_activations) 629 | if len(tensor_unique_vals) > 1: 630 | perturbed_activations = parent_activations.detach().clone() 631 | while torch.equal(perturbed_activations, parent_activations): 632 | perturbed_activations = torch.randint( 633 | parent_activations.min(), 634 | parent_activations.max() + 1, 635 | size=parent_activations.shape, 636 | device=device, 637 | ).type(parent_activations.dtype) 638 | else: 639 | perturbed_activations = parent_activations.detach().clone() 640 | while torch.equal(perturbed_activations, parent_activations): 641 | if torch.min(parent_activations) < 0: 642 | perturbed_activations = torch.randint( 643 | -10, 11, size=parent_activations.shape, device=device 644 | ).type(parent_activations.dtype) 645 | else: 646 | perturbed_activations = torch.randint( 647 | 0, 11, size=parent_activations.shape, device=device 648 | ).type(parent_activations.dtype) 649 | 650 | elif parent_activations.dtype == torch.bool: 651 | perturbed_activations = parent_activations.detach().clone() 652 | while torch.equal(perturbed_activations, parent_activations): 653 | perturbed_activations = torch.randint( 654 | 0, 2, size=parent_activations.shape, device=device 655 | ).bool() 656 | else: 657 | mean_output_sqrt = output_activations.detach().float().abs().mean() 658 | mean_output_sqrt += torch.rand(mean_output_sqrt.shape) * 100 659 | mean_output_sqrt *= torch.rand(mean_output_sqrt.shape) 660 | mean_output_sqrt.requires_grad = False 661 | perturbed_activations = torch.randn_like( 662 | parent_activations.float(), device=device 663 | ) * mean_output_sqrt.to(device) 664 | perturbed_activations = perturbed_activations.type(parent_activations.dtype) 665 | 666 | return perturbed_activations 667 | 668 | 669 | def _posthoc_perturb_check( 670 | self, 671 | layer_to_validate_parents_for: TensorLogEntry, 672 | layers_to_perturb: List[str], 673 | verbose: bool = False, 674 | ) -> bool: 675 | """If a layer fails the "perturbation check"--that is, if perturbing the values of parent 676 | layers doesn't change the values relative to the layer's saved values--checks whether one of the 677 | remaining arguments is a "special" tensor, such as all-ones or all-zeros, such that perturbing a tensor 678 | wouldn't necessarily change the output of the layer. 679 | 680 | Args: 681 | layer_to_validate_parents_for: layer being checked. 682 | layers_to_perturb: parent layers being perturbed 683 | 684 | Returns: 685 | True if there's an "excuse" for the perturbation failing, False otherwise. 686 | """ 687 | # Check if the tensor is all nans or all infinite: 688 | if layer_to_validate_parents_for.tensor_dtype == torch.bool: 689 | return True 690 | elif ( 691 | (layer_to_validate_parents_for.func_applied_name == "to") 692 | and (len(layer_to_validate_parents_for.creation_args) > 1) 693 | and (type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor) 694 | ): 695 | return True 696 | elif ( 697 | (layer_to_validate_parents_for.func_applied_name == "__setitem__") 698 | and (type(layer_to_validate_parents_for.creation_args[2]) == torch.Tensor) 699 | and ( 700 | layer_to_validate_parents_for.creation_args[0].shape 701 | == layer_to_validate_parents_for.creation_args[2].shape 702 | ) 703 | ): 704 | return True 705 | elif ( 706 | layer_to_validate_parents_for.func_applied_name in ["__getitem__", "unbind"] 707 | ) and ( 708 | layer_to_validate_parents_for.tensor_contents.numel() < 20 709 | ): # some elements can be the same by chance 710 | return True 711 | elif ( 712 | (layer_to_validate_parents_for.func_applied_name == "__getitem__") 713 | and (type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor) 714 | and (layer_to_validate_parents_for.creation_args[1].unique() < 20) 715 | ): 716 | return True 717 | elif (layer_to_validate_parents_for.func_applied_name == "max") and len( 718 | layer_to_validate_parents_for.creation_args 719 | ) > 1: 720 | return True 721 | elif ( 722 | layer_to_validate_parents_for.func_applied_name == "max" 723 | ) and not torch.is_floating_point( 724 | layer_to_validate_parents_for.creation_args[0] 725 | ): 726 | return True 727 | else: 728 | num_inf = ( 729 | torch.isinf(layer_to_validate_parents_for.tensor_contents.abs()) 730 | .int() 731 | .sum() 732 | ) 733 | num_nan = ( 734 | torch.isnan(layer_to_validate_parents_for.tensor_contents.abs()) 735 | .int() 736 | .sum() 737 | ) 738 | if (num_inf == layer_to_validate_parents_for.tensor_contents.numel()) or ( 739 | num_nan == layer_to_validate_parents_for.tensor_contents.numel() 740 | ): 741 | return True 742 | 743 | arg_type_dict = { 744 | "args": (enumerate, "creation_args"), 745 | "kwargs": (lambda x: x.items(), "creation_kwargs"), 746 | } 747 | 748 | layer_to_validate_parents_for_label = layer_to_validate_parents_for.layer_label 749 | for arg_type in ["args", "kwargs"]: 750 | iterfunc, fieldname = arg_type_dict[arg_type] 751 | for key, val in iterfunc(getattr(layer_to_validate_parents_for, fieldname)): 752 | # Skip if it's the argument itself: 753 | if ( 754 | key in layer_to_validate_parents_for.parent_layer_arg_locs[arg_type] 755 | ) and ( 756 | layer_to_validate_parents_for.parent_layer_arg_locs[arg_type][key] 757 | ) in layers_to_perturb: 758 | continue 759 | arg_is_special = _check_if_arg_is_special_val(val) 760 | if arg_is_special: 761 | if verbose: 762 | print( 763 | f"Activations for layer {layer_to_validate_parents_for_label} do not change when " 764 | f"values for {layers_to_perturb} are changed (out of parent " 765 | f"layers {layer_to_validate_parents_for.parent_layers}), but {arg_type[:-1]} {key} is " 766 | f"all zeros or all-ones, so validation still succeeds..." 767 | ) 768 | return True 769 | 770 | print( 771 | f"Activations for layer {layer_to_validate_parents_for_label} do not change when " 772 | f"values for {layers_to_perturb} are changed (out of parent " 773 | f"layers {layer_to_validate_parents_for.parent_layers}), and the other " 774 | f'arguments are not "special" (all-ones or all-zeros) tensors.' 775 | ) 776 | return False 777 | 778 | 779 | def _check_if_arg_is_special_val(val: Union[torch.Tensor, Any]): 780 | # If it's one of the other arguments, check if it's all zeros or all ones: 781 | if type(val) != torch.Tensor: 782 | try: 783 | val = torch.Tensor(val) 784 | except: 785 | return True 786 | if torch.all(torch.eq(val, 0)) or torch.all(torch.eq(val, 1)) or (val.numel() == 0): 787 | return True 788 | else: 789 | return False 790 | --------------------------------------------------------------------------------