├── pyccolo ├── py.typed ├── _fast │ ├── __init__.py │ ├── fast_ast.py │ └── misc_ast_utils.py ├── handler.py ├── examples │ ├── __init__.py │ ├── quasiquote.py │ ├── coverage.py │ ├── optional_chaining.py │ ├── quick_lambda.py │ ├── future_tracer.py │ └── lazy_imports.py ├── version.py ├── extra_builtins.py ├── utils.py ├── fast │ ├── __init__.py │ └── __init__.pyi ├── __main__.py ├── stmt_mapper.py ├── predicate.py ├── ast_bookkeeping.py ├── trace_stack.py ├── emit_event.py ├── trace_events.py ├── syntax_augmentation.py ├── __init__.py └── ast_rewriter.py ├── test ├── __init__.py ├── foo.py ├── uses_optional_chaining.py ├── lazy_import_test_module.py ├── _test_lazy_imports.py ├── test_future_tracer.py ├── test_import_hooks.py ├── test_script_entrypoint.py ├── test_no_prints.py ├── test_instrumented_functions.py ├── test_predicate.py ├── test_local_guards.py ├── test_optional_chaining.py ├── test_syntax_augmentation.py ├── test_stack.py └── test_pipeline_tracer.py ├── .gitattributes ├── .git-blame-ignore-revs ├── MANIFEST.in ├── setup.py ├── scripts ├── blacken.sh ├── bump-version.py └── deploy.sh ├── .gitignore ├── Makefile ├── pyproject.toml ├── docs ├── LICENSE.txt ├── CODE_OF_CONDUCT.md └── HISTORY.rst ├── setup.cfg ├── .github └── workflows │ └── ci.yml └── README.md /pyccolo/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | pyccolo/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | 1c98f15212d131b46dec2d6d7777c91fe99f594d 2 | -------------------------------------------------------------------------------- /test/foo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | x = 41 4 | assert x == 42 5 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md docs/HISTORY.rst 2 | recursive-exclude test * 3 | include versioneer.py 4 | include pyccolo/_version.py 5 | -------------------------------------------------------------------------------- /test/uses_optional_chaining.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # ruff: noqa 3 | # nopycln: file 4 | 5 | foo = None 6 | assert foo?.bar is None 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import setuptools 4 | 5 | 6 | if __name__ == "__main__": 7 | setuptools.setup() 8 | -------------------------------------------------------------------------------- /scripts/blacken.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # ref: https://vaneyckt.io/posts/safer_bash_scripts_with_set_euxo_pipefail/ 4 | set -euxo pipefail 5 | 6 | DIRS="./pyccolo ./test" 7 | black $DIRS $@ 8 | -------------------------------------------------------------------------------- /test/lazy_import_test_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | 4 | import numpy as np 5 | 6 | assert list(np.arange(5)) == list(range(5)) 7 | print(len([mod for mod in sys.modules if "numpy" in mod])) 8 | -------------------------------------------------------------------------------- /pyccolo/_fast/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pyccolo._fast.fast_ast import FastAst 3 | from pyccolo._fast.misc_ast_utils import ( 4 | EmitterMixin, 5 | copy_ast, 6 | make_composite_condition, 7 | make_test, 8 | subscript_to_slice, 9 | ) 10 | 11 | __all__ = [ 12 | "copy_ast", 13 | "EmitterMixin", 14 | "FastAst", 15 | "make_composite_condition", 16 | "make_test", 17 | "subscript_to_slice", 18 | ] 19 | -------------------------------------------------------------------------------- /test/_test_lazy_imports.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pyccolo.examples import LazyImportTracer 3 | 4 | 5 | class TestTracer(LazyImportTracer): 6 | def should_instrument_file(self, filename: str) -> bool: 7 | return not filename.endswith("lazy_imports.py") 8 | 9 | 10 | def test_simple(): 11 | with TestTracer.instance(): 12 | import lazy_import_test_module # noqa: F401 13 | 14 | 15 | if __name__ == "__main__": 16 | test_simple() 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | **/.vscode 3 | **/__pycache__ 4 | **/.mypy_cache 5 | **/.ipynb_checkpoints 6 | *.bundle.* 7 | *.tsbuildinfo 8 | frontend/**/lib/ 9 | frontend/**/node_modules/ 10 | frontend/labextension/package-lock.json 11 | build/ 12 | dist/ 13 | *.egg-info 14 | *.pyc 15 | *.ipynb 16 | !**/notebooks/*.ipynb 17 | **/[uU]ntitled* 18 | MANIFEST 19 | nbsafety/resources/nbextension/index.js 20 | nbsafety/resources/nbextension/index.js.map 21 | .coverage 22 | coverage.xml 23 | htmlcov/** 24 | .hypothesis 25 | -------------------------------------------------------------------------------- /pyccolo/handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | from typing import Any, Callable, NamedTuple, Optional 4 | 5 | from pyccolo.predicate import Predicate 6 | 7 | 8 | class HandlerSpec(NamedTuple): 9 | handler: Callable[..., Any] 10 | use_raw_node_id: bool 11 | reentrant: bool 12 | predicate: Predicate 13 | guard: Optional[Callable[[ast.AST], str]] 14 | exempt_from_guards: bool 15 | 16 | @classmethod 17 | def empty(cls): 18 | return cls(None, False, False, Predicate(lambda *_: True), None, False) # type: ignore 19 | -------------------------------------------------------------------------------- /pyccolo/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .coverage import CoverageTracer 3 | from .future_tracer import FutureTracer 4 | from .lazy_imports import LazyImportTracer 5 | from .optional_chaining import OptionalChainer 6 | from .pipeline_tracer import PipelineTracer 7 | from .quasiquote import Quasiquoter 8 | from .quick_lambda import QuickLambdaTracer 9 | 10 | __all__ = [ 11 | "CoverageTracer", 12 | "FutureTracer", 13 | "LazyImportTracer", 14 | "OptionalChainer", 15 | "PipelineTracer", 16 | "Quasiquoter", 17 | "QuickLambdaTracer", 18 | ] 19 | -------------------------------------------------------------------------------- /scripts/bump-version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import argparse 4 | import subprocess 5 | import sys 6 | 7 | from pyccolo.version import make_version_tuple 8 | 9 | 10 | def main(*_): 11 | components = list(make_version_tuple()) 12 | components[-1] += 1 13 | version = ".".join(str(c) for c in components) 14 | subprocess.check_output(["git", "tag", version]) 15 | return 0 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser(description="Bump version and create git tag.") 20 | args = parser.parse_args() 21 | sys.exit(main(args)) 22 | -------------------------------------------------------------------------------- /pyccolo/version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pyccolo._version import get_versions 3 | 4 | __version__ = get_versions()["version"] 5 | del get_versions 6 | 7 | 8 | def make_version_tuple(vstr=None): 9 | if vstr is None: 10 | vstr = __version__ 11 | if vstr[0] == "v": 12 | vstr = vstr[1:] 13 | components = [] 14 | for component in vstr.split("+")[0].split("."): 15 | try: 16 | components.append(int(component)) 17 | except ValueError: 18 | break 19 | return tuple(components) 20 | 21 | 22 | version = ".".join(str(d) for d in make_version_tuple()) 23 | -------------------------------------------------------------------------------- /pyccolo/extra_builtins.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | from typing import Union 4 | 5 | PYCCOLO_BUILTIN_PREFIX = "_X5ix" 6 | EMIT_EVENT = f"{PYCCOLO_BUILTIN_PREFIX}_PYCCOLO_EVT_EMIT" 7 | TRACE_LAMBDA = f"{PYCCOLO_BUILTIN_PREFIX}_PYCCOLO_TRACE_LAM" 8 | EXEC_SAVED_THUNK = f"{PYCCOLO_BUILTIN_PREFIX}_PYCCOLO_EXEC_SAVED_THUNK" 9 | TRACING_ENABLED = f"{PYCCOLO_BUILTIN_PREFIX}_PYCCOLO_TRACING_ENABLED" 10 | FUNCTION_TRACING_ENABLED = f"{PYCCOLO_BUILTIN_PREFIX}_PYCCOLO_FUNCTION_TRACING_ENABLED" 11 | 12 | 13 | def make_guard_name(node: Union[int, ast.AST]): 14 | node_id = node if isinstance(node, int) else id(node) 15 | return f"{PYCCOLO_BUILTIN_PREFIX}_PYCCOLO_GUARD_{node_id}" 16 | -------------------------------------------------------------------------------- /test/test_future_tracer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pyccolo as pyc 3 | from pyccolo.examples import FutureTracer 4 | 5 | 6 | def test_simple(): 7 | with FutureTracer.instance(): 8 | pyc.exec( 9 | """ 10 | def foo(): 11 | return 0 12 | x = foo() 13 | y = x + 1 14 | z = y + 2 15 | assert y == 1, "got %s" % y 16 | assert z == 3, "got %s" % z 17 | """ 18 | ) 19 | FutureTracer.clear_instance() 20 | assert not any(isinstance(tracer, FutureTracer) for tracer in pyc._TRACER_STACK), ( 21 | "got %s" % pyc._TRACER_STACK 22 | ) 23 | 24 | 25 | if __name__ == "__main__": 26 | test_simple() 27 | -------------------------------------------------------------------------------- /test/test_import_hooks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pyccolo as pyc 3 | 4 | 5 | def test_basic_instrumented_import(): 6 | class IncrementsAssignValue(pyc.BaseTracer): 7 | def should_instrument_file(self, filename: str) -> bool: 8 | return filename.endswith("foo.py") 9 | 10 | @pyc.register_handler(pyc.after_assign_rhs) 11 | def handle_assign(self, ret, node, *_, **__): 12 | node_id = id(node) 13 | assert self.ast_node_by_id[node_id] is node 14 | assert node_id in self.containing_ast_by_id 15 | assert node_id in self.containing_stmt_by_id 16 | return ret + 1 17 | 18 | with IncrementsAssignValue.instance().tracing_enabled(): 19 | import test.foo # noqa 20 | -------------------------------------------------------------------------------- /test/test_script_entrypoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | 4 | from pyccolo.__main__ import make_parser, run 5 | 6 | if sys.version_info >= (3, 8): # noqa 7 | 8 | def test_entrypoint_with_script(): 9 | # just make sure it doesn't raise 10 | run( 11 | make_parser().parse_args( 12 | "./test/uses_optional_chaining.py -t pyccolo.examples.optional_chaining.ScriptOptionalChainer".split() 13 | ) 14 | ) 15 | 16 | def test_entrypoint_with_module(): 17 | # just make sure it doesn't raise 18 | run( 19 | make_parser().parse_args( 20 | "-m test.uses_optional_chaining -t pyccolo.examples.optional_chaining.ScriptOptionalChainer".split() 21 | ) 22 | ) 23 | -------------------------------------------------------------------------------- /scripts/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # ref: https://vaneyckt.io/posts/safer_bash_scripts_with_set_euxo_pipefail/ 4 | set -euxo pipefail 5 | 6 | if ! git diff-index --quiet HEAD --; then 7 | echo "dirty working tree; please clean or commit changes" 8 | exit 1 9 | fi 10 | 11 | if ! git describe --exact-match --tags HEAD > /dev/null; then 12 | echo "current revision not tagged; please deploy from a tagged revision" 13 | exit 1 14 | fi 15 | 16 | current="$(python -c 'import versioneer; print(versioneer.get_version())')" 17 | [[ $? -eq 1 ]] && exit 1 18 | 19 | latest="$(git describe --tags $(git rev-list --tags --max-count=1))" 20 | [[ $? -eq 1 ]] && exit 1 21 | 22 | if [[ "$current" != "$latest" ]]; then 23 | echo "current revision is not the latest version; please deploy from latest version" 24 | exit 1 25 | fi 26 | 27 | expect <= 48", 4 | "wheel >= 0.30.0", 5 | "setuptools-git-versioning", 6 | ] 7 | build-backend = 'setuptools.build_meta' 8 | 9 | [tool.setuptools-git-versioning] 10 | enabled = true 11 | 12 | [tool.black] 13 | line-length = 88 14 | target-version = ['py39'] 15 | extend-exclude = '(^/pyccolo/__init__|^/versioneer|_version|.*uses_optional_chaining)\.py' 16 | 17 | [tool.isort] 18 | profile = 'black' 19 | extend_skip_glob = [ 20 | '**/pyccolo/__init__.py', 21 | '**/versioneer.py', 22 | '**/_version.py', 23 | '**/setup.py', 24 | ] 25 | 26 | [tool.pytest.ini_options] 27 | markers = ['integration: mark a test as an integration test.'] 28 | filterwarnings = [ 29 | 'ignore::DeprecationWarning', 30 | 'ignore::pytest.PytestAssertRewriteWarning', 31 | ] 32 | 33 | [tool.coverage.run] 34 | source = ['pyccolo'] 35 | omit = ['pyccolo/_version.py', 'pyccolo/version.py', 'pyccolo/examples/**'] 36 | 37 | [tool.coverage.report] 38 | exclude_lines = [ 39 | 'pragma: no cover *$', 40 | '^ *if TYPE_CHECKING:', 41 | '^ *except Exception', 42 | '^ *raise', 43 | '^ *\.\.\.', 44 | ] 45 | -------------------------------------------------------------------------------- /pyccolo/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import importlib 3 | from contextlib import ExitStack, contextmanager 4 | from types import FunctionType 5 | from typing import TYPE_CHECKING, Any, Dict, Iterable, Set, Type, TypeVar, Union 6 | 7 | if TYPE_CHECKING: 8 | from pyccolo.tracer import BaseTracer 9 | 10 | 11 | def resolve_tracer(ref: str) -> Type["BaseTracer"]: 12 | module, attr = ref.rsplit(".", 1) 13 | return getattr(importlib.import_module(module), attr) 14 | 15 | 16 | @contextmanager 17 | def multi_context(cms): 18 | with ExitStack() as stack: 19 | yield [stack.enter_context(mgr) for mgr in cms] 20 | 21 | 22 | def clone_function(func: FunctionType) -> FunctionType: 23 | local_env: Dict[str, Any] = {} 24 | exec( 25 | f"def {func.__name__}(*args, **kwargs): pass", 26 | func.__globals__, 27 | local_env, 28 | ) 29 | cloned_func = local_env[func.__name__] 30 | cloned_func.__code__ = func.__code__ 31 | return cloned_func 32 | 33 | 34 | K = TypeVar("K") 35 | 36 | 37 | def clear_keys(d: Union[Dict[K, Any], Set[K]], keys: Iterable[K]) -> None: 38 | if isinstance(d, dict): 39 | for key in keys: 40 | d.pop(key, None) 41 | elif isinstance(d, set): 42 | d.difference_update(keys) 43 | else: 44 | raise TypeError(f"Unsupported type: {type(d)}") 45 | -------------------------------------------------------------------------------- /pyccolo/fast/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | import warnings 4 | 5 | from pyccolo._fast import ( 6 | EmitterMixin, 7 | FastAst, 8 | copy_ast, 9 | make_composite_condition, 10 | make_test, 11 | subscript_to_slice, 12 | ) 13 | 14 | backcompat_helpers = ( 15 | FastAst.Str.__name__, 16 | FastAst.Num.__name__, 17 | FastAst.Bytes.__name__, 18 | FastAst.NameConstant.__name__, 19 | FastAst.Ellipsis.__name__, 20 | ) 21 | location_of = FastAst.location_of 22 | location_of_arg = FastAst.location_of_arg 23 | kw = FastAst.kw 24 | kwargs = FastAst.kwargs 25 | iter_arguments = FastAst.iter_arguments 26 | with warnings.catch_warnings(): 27 | warnings.simplefilter("ignore", DeprecationWarning) 28 | for name in dir(FastAst): 29 | if hasattr(ast, name) or name in backcompat_helpers: 30 | globals()[name] = getattr(FastAst, name) 31 | 32 | 33 | __all__ = [ 34 | "copy_ast", 35 | "EmitterMixin", 36 | "FastAst", 37 | "make_composite_condition", 38 | "make_test", 39 | "subscript_to_slice", 40 | # now all the ast helper functions 41 | "location_of", 42 | "location_of_arg", 43 | "kw", 44 | "kwargs", 45 | "iter_arguments", 46 | ] 47 | 48 | 49 | __all__.extend( 50 | {name for name in dir(FastAst) if hasattr(ast, name)} | set(backcompat_helpers) 51 | ) 52 | -------------------------------------------------------------------------------- /docs/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2021 Stephen Macke 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /test/test_no_prints.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | The idea of this file is to make sure that no debugging "print" statements 4 | make it into production. 5 | """ 6 | import ast 7 | import os 8 | 9 | import pyccolo 10 | 11 | join = os.path.join 12 | root = join(os.curdir, pyccolo.__name__) 13 | 14 | 15 | _EXCEPTED_FILES = { 16 | join(root, "_version.py"), 17 | } 18 | 19 | 20 | class ContainsPrintVisitor(ast.NodeVisitor): 21 | def __init__(self): 22 | self._found_print_call = False 23 | 24 | def __call__(self, filename: str) -> bool: 25 | with open(filename, "r") as f: 26 | self.visit(ast.parse(f.read())) 27 | ret = self._found_print_call 28 | self._found_print_call = False 29 | return ret 30 | 31 | def visit_Call(self, node: ast.Call): 32 | self.generic_visit(node) 33 | if isinstance(node.func, ast.Name) and node.func.id == "print": 34 | self._found_print_call = True 35 | 36 | 37 | def test_no_prints(): 38 | contains_print = ContainsPrintVisitor() 39 | for path, _, files in os.walk(root): 40 | for filename in files: 41 | if not filename.endswith(".py") or filename in _EXCEPTED_FILES: 42 | continue 43 | filename = os.path.join(path, filename) 44 | if filename in _EXCEPTED_FILES: 45 | continue 46 | assert not contains_print( 47 | filename 48 | ), f"file {filename} had a print statement!" 49 | -------------------------------------------------------------------------------- /test/test_instrumented_functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | 4 | import pyccolo as pyc 5 | 6 | 7 | def test_basic_decorator(): 8 | class IncrementsAssignValue(pyc.BaseTracer): 9 | @pyc.register_handler(ast.Assign) 10 | def handle_assign(self, ret, *_, **__): 11 | return ret + 1 12 | 13 | tracer = IncrementsAssignValue.instance() 14 | 15 | @tracer 16 | def f(): 17 | x = 41 18 | return x 19 | 20 | assert f() == 42 21 | 22 | 23 | def test_decorated_tracing_decorator(): 24 | class IncrementsAssignValue(pyc.BaseTracer): 25 | @pyc.register_handler(ast.Assign) 26 | def handle_assign(self, ret, *_, **__): 27 | return ret + 1 28 | 29 | tracer = IncrementsAssignValue.instance() 30 | 31 | def twice(f): 32 | def new_f(): 33 | return f() * 2 34 | 35 | return new_f 36 | 37 | @twice 38 | @tracer 39 | def f(): 40 | x = 41 41 | return x 42 | 43 | assert f() == 84 44 | 45 | 46 | def test_multiple_tracing_decorators(): 47 | class IncrementsAssignValue1(pyc.BaseTracer): 48 | @pyc.register_handler(ast.Assign) 49 | def handle_assign(self, ret, *_, **__): 50 | return ret + 1 51 | 52 | class IncrementsAssignValue2(pyc.BaseTracer): 53 | @pyc.register_handler(ast.Assign) 54 | def handle_assign(self, ret, *_, **__): 55 | return ret + 2 56 | 57 | tracer1 = IncrementsAssignValue1.instance() 58 | tracer2 = IncrementsAssignValue2.instance() 59 | 60 | @pyc.instrumented((tracer1, tracer2)) 61 | def f(): 62 | x = 41 63 | return x 64 | 65 | assert f() == 44 66 | -------------------------------------------------------------------------------- /test/test_predicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | 4 | from pyccolo.predicate import CompositePredicate, Predicate 5 | 6 | 7 | def _rand(threshold=0.5, **kwargs): 8 | return Predicate(lambda *_: random.random() < threshold, **kwargs) 9 | 10 | 11 | def test_static_coalescing(): 12 | assert CompositePredicate.any([Predicate.TRUE, _rand(), _rand()]) is Predicate.TRUE 13 | assert ( 14 | CompositePredicate.any([Predicate.FALSE, _rand(), _rand()]) 15 | is not Predicate.FALSE 16 | ) 17 | assert ( 18 | CompositePredicate.any([Predicate.FALSE, Predicate.FALSE, Predicate.FALSE]) 19 | is Predicate.FALSE 20 | ) 21 | assert CompositePredicate.any([]) is Predicate.TRUE 22 | assert ( 23 | CompositePredicate.all([Predicate.FALSE, _rand(), _rand()]) is Predicate.FALSE 24 | ) 25 | assert ( 26 | CompositePredicate.all([Predicate.TRUE, _rand(), _rand()]) is not Predicate.TRUE 27 | ) 28 | assert ( 29 | CompositePredicate.any([Predicate.TRUE, Predicate.TRUE, Predicate.TRUE]) 30 | is Predicate.TRUE 31 | ) 32 | assert CompositePredicate.all([]) is Predicate.TRUE 33 | 34 | 35 | def test_dynamic_behavior(): 36 | assert Predicate.TRUE(None) 37 | assert not Predicate.FALSE(None) 38 | assert Predicate.TRUE.dynamic_call(None) 39 | assert not Predicate.FALSE.dynamic_call(None) 40 | 41 | static_false = CompositePredicate.any( 42 | [_rand(0, static=True), _rand(0, static=True), _rand(0, static=True)] 43 | ) 44 | assert not static_false(None) 45 | assert static_false.dynamic_call(None) # none of the filters kick in dynamically 46 | 47 | dynamic_false = CompositePredicate.any( 48 | [_rand(0, static=True), _rand(0, static=True), _rand(0)] 49 | ) 50 | assert not dynamic_false(None) 51 | assert not dynamic_false.dynamic_call(None) 52 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # See the docstring in versioneer.py for instructions. Note that you must 2 | # re-run 'python versioneer.py setup' after changing this section, and 3 | # commit the resulting files. 4 | 5 | [versioneer] 6 | VCS = git 7 | style = pep440 8 | versionfile_source = pyccolo/_version.py 9 | versionfile_build = pyccolo/_version.py 10 | tag_prefix = 11 | parentdir_prefix = pyccolo- 12 | 13 | [metadata] 14 | name = pyccolo 15 | history = file: docs/HISTORY.rst 16 | description = Declarative instrumentation for Python 17 | long_description = file: README.md 18 | long_description_content_type = text/markdown; charset=UTF-8 19 | url = https://github.com/smacke/pyccolo 20 | author = Stephen Macke 21 | author_email = stephen.macke@gmail.com 22 | license = BSD-3-Clause 23 | license_files = docs/LICENSE.txt 24 | classifiers = 25 | Development Status :: 3 - Alpha 26 | Intended Audience :: Developers 27 | License :: OSI Approved :: BSD License 28 | Natural Language :: English 29 | Programming Language :: Python :: 3.6 30 | Programming Language :: Python :: 3.7 31 | Programming Language :: Python :: 3.8 32 | Programming Language :: Python :: 3.9 33 | Programming Language :: Python :: 3.10 34 | Programming Language :: Python :: 3.11 35 | Programming Language :: Python :: 3.12 36 | Programming Language :: Python :: 3.13 37 | Programming Language :: Python :: 3.14 38 | 39 | [options] 40 | zip_safe = False 41 | packages = find: 42 | platforms = any 43 | python_requires = >= 3.6 44 | install_requires = 45 | traitlets 46 | typing_extensions 47 | 48 | [bdist_wheel] 49 | universal = 1 50 | 51 | [options.entry_points] 52 | console_scripts = 53 | pyc = pyccolo.__main__:main 54 | pyccolo = pyccolo.__main__:main 55 | 56 | [options.extras_require] 57 | test = 58 | black 59 | hypothesis 60 | isort 61 | mypy 62 | pytest 63 | pytest-cov 64 | ruff==0.1.9 65 | dev = 66 | build 67 | pycln 68 | twine 69 | versioneer 70 | %(test)s 71 | 72 | [tool:pytest] 73 | filterwarnings = ignore::DeprecationWarning 74 | 75 | [mypy] 76 | ignore_missing_imports = True 77 | 78 | [mypy-pyccolo._version] 79 | ignore_errors = True 80 | 81 | -------------------------------------------------------------------------------- /pyccolo/__main__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Allows running scripts and modules with Pyccolo instrumentation enabled. 4 | """ 5 | import argparse 6 | import os 7 | import sys 8 | from pathlib import Path 9 | from runpy import run_module 10 | from typing import List, Type 11 | 12 | import pyccolo as pyc 13 | 14 | 15 | def get_script_as_module(script: str) -> str: 16 | # ref: https://nvbn.github.io/2016/08/17/ast-import/ 17 | script_path = Path(script) 18 | script_dir = script_path.parent.as_posix() 19 | module_name = os.path.splitext(script_path.name)[0] 20 | sys.path.insert(0, script_dir) 21 | return module_name 22 | 23 | 24 | def make_parser() -> argparse.ArgumentParser: 25 | parser = argparse.ArgumentParser(description="Pyccolo command line tool.") 26 | parser.add_argument("script", nargs="?", help="Script to run with instrumentation.") 27 | parser.add_argument( 28 | "-m", "--module", help="The module to run, if `script` not specified." 29 | ) 30 | parser.add_argument( 31 | "-t", 32 | "--tracer", 33 | nargs="+", 34 | help="Tracers to use for instrumentation.", 35 | required=True, 36 | ) 37 | return parser 38 | 39 | 40 | def validate_args(args: argparse.Namespace) -> None: 41 | if args.script is None and args.module is None: 42 | raise ValueError("must specify script, either as file or module") 43 | if args.script is not None and args.module is not None: 44 | raise ValueError("only one of `script` or `module` may be specified") 45 | 46 | 47 | def run(args: argparse.Namespace) -> None: 48 | validate_args(args) 49 | tracers: List[Type[pyc.BaseTracer]] = [] 50 | for tracer_ref in args.tracer: 51 | tracers.append(pyc.resolve_tracer(tracer_ref)) 52 | if args.module is not None: 53 | module_to_run = args.module 54 | else: 55 | module_to_run = get_script_as_module(args.script) 56 | with pyc.multi_context([tracer.instance() for tracer in tracers]): 57 | run_module(module_to_run) 58 | 59 | 60 | def main() -> int: 61 | run(make_parser().parse_args()) 62 | return 0 63 | 64 | 65 | if __name__ == "__main__": 66 | sys.exit(main()) 67 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: pyccolo 2 | 3 | on: [push, pull_request, workflow_dispatch] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ${{ matrix.os }} 9 | 10 | strategy: 11 | matrix: 12 | os: [ 'ubuntu-22.04', 'windows-latest' ] 13 | python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13', '3.14' ] 14 | include: 15 | - python-version: '3.7' 16 | os: 'macos-15-intel' 17 | - python-version: '3.8' 18 | os: 'macos-15-intel' 19 | - python-version: '3.9' 20 | os: 'macos-15-intel' 21 | - python-version: '3.10' 22 | os: 'macos-latest' 23 | - python-version: '3.11' 24 | os: 'macos-latest' 25 | - python-version: '3.12' 26 | os: 'macos-latest' 27 | - python-version: '3.13' 28 | os: 'macos-latest' 29 | - python-version: '3.14' 30 | os: 'macos-latest' 31 | 32 | steps: 33 | - uses: actions/checkout@v4 34 | with: 35 | fetch-depth: 1 36 | - name: Set up Python 37 | uses: actions/setup-python@v5 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | - name: Install dependencies 41 | run: | 42 | python -m pip install --upgrade pip 43 | pip install -e .[test] 44 | - name: Lint with ruff 45 | run: make lint 46 | - name: Run unit tests with pytest (skip typechecking, coverage) 47 | if: matrix.os == 'windows-latest' 48 | run: pytest 49 | - name: Run typechecking with mypy and unit tests with pytest (skip coverage) 50 | if: ${{ matrix.os == 'macos-latest' || matrix.os == 'macos-15-intel' }} 51 | run: | 52 | make typecheck 53 | pytest 54 | - name: Run typechecking with mypy and unit tests with pytest (including coverage) 55 | if: matrix.os == 'ubuntu-22.04' 56 | run: make check_ci 57 | - name: Upload coverage report 58 | if: matrix.os == 'ubuntu-22.04' 59 | uses: codecov/codecov-action@v1 60 | with: 61 | token: '${{ secrets.CODECOV_TOKEN }}' 62 | files: ./coverage.xml 63 | env_vars: PYTHON 64 | name: codecov-umbrella 65 | fail_ci_if_error: true 66 | verbose: true 67 | -------------------------------------------------------------------------------- /pyccolo/stmt_mapper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | import logging 4 | from typing import TYPE_CHECKING, Dict, List, Optional, TypeVar 5 | 6 | from pyccolo import fast 7 | from pyccolo.emit_event import _TRACER_STACK 8 | 9 | if TYPE_CHECKING: 10 | from pyccolo.tracer import BaseTracer 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.WARNING) 15 | 16 | 17 | _T = TypeVar("_T", bound=ast.AST) 18 | 19 | 20 | class StatementMapper(ast.NodeVisitor): 21 | def __init__(self, tracers: Optional[List["BaseTracer"]] = None): 22 | self._tracers: List["BaseTracer"] = ( 23 | _TRACER_STACK if tracers is None else tracers 24 | ) 25 | self.traversal: List[ast.AST] = [] 26 | 27 | @classmethod 28 | def augmentation_propagating_copy(cls, node: _T) -> _T: 29 | return cls()(node)[id(node)] # type: ignore[return-value] 30 | 31 | def _handle_augmentations(self, no: ast.AST, nc: ast.AST) -> None: 32 | for tracer in self._tracers: 33 | for spec in tracer.get_augmentations(id(no)): 34 | tracer.augmented_node_ids_by_spec[spec].add(id(nc)) 35 | 36 | def __call__( 37 | self, 38 | node: ast.AST, 39 | copy_node: Optional[ast.AST] = None, 40 | ) -> Dict[int, ast.AST]: 41 | # for some bizarre reason we need to visit once to clear empty nodes apparently 42 | self.traversal.clear() 43 | self.visit(node) 44 | self.traversal.clear() 45 | 46 | self.visit(node) 47 | orig_traversal = self.traversal 48 | self.traversal = [] 49 | self.visit(copy_node or fast.copy_ast(node)) 50 | copy_traversal = self.traversal 51 | orig_to_copy_mapping = {} 52 | for no, nc in zip(orig_traversal, copy_traversal): 53 | orig_to_copy_mapping[id(no)] = nc 54 | if hasattr(nc, "lineno"): 55 | self._handle_augmentations(no, nc) 56 | return orig_to_copy_mapping 57 | 58 | def visit(self, node: ast.AST) -> None: 59 | self.traversal.append(node) 60 | for name, field in ast.iter_fields(node): 61 | if isinstance(field, ast.AST): 62 | self.visit(field) 63 | elif isinstance(field, list): 64 | for inner_node in field: 65 | if isinstance(inner_node, ast.AST): 66 | self.visit(inner_node) 67 | -------------------------------------------------------------------------------- /test/test_local_guards.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | from collections import Counter 4 | from types import FrameType 5 | 6 | import pyccolo as pyc 7 | 8 | 9 | def test_local_guard_activation_prevents_future_handlers(): 10 | class LoadNameCounter(pyc.BaseTracer): 11 | counter_by_name = Counter() 12 | 13 | @pyc.init_module 14 | def init_module(self, _ret, node, frame: FrameType, *_, **__): 15 | assert node is not None 16 | for guard in self.local_guards_by_module_id.get(id(node), []): 17 | frame.f_globals[guard] = False 18 | 19 | @pyc.load_name(guard=lambda node: f"{pyc.PYCCOLO_BUILTIN_PREFIX}_{node.id}") 20 | def load_name( 21 | self, _ret, node: ast.Name, frame: FrameType, _evt, guard, *_, **__ 22 | ): 23 | self.counter_by_name[node.id] += 1 24 | assert guard is not None 25 | frame.f_globals[guard] = True 26 | 27 | with LoadNameCounter.instance(): 28 | pyc.exec( 29 | """ 30 | w = 0 31 | x = w + 1 32 | y = w + x + 1 33 | z = w + x + y + 1 34 | """ 35 | ) 36 | for var in ("w", "x", "y"): 37 | assert LoadNameCounter.counter_by_name[var] == 1, var 38 | 39 | 40 | def test_subscript_local_guard_activation(): 41 | class SubscriptCounter(pyc.BaseTracer): 42 | counter_by_subscript = Counter() 43 | 44 | @pyc.init_module 45 | def init_module(self, _ret, node: ast.Module, frame: FrameType, *_, **__): 46 | assert node is not None 47 | for guard in self.local_guards_by_module_id.get(id(node), []): 48 | frame.f_globals[guard] = False 49 | 50 | @pyc.before_subscript_load( 51 | guard=lambda node: f"{pyc.PYCCOLO_BUILTIN_PREFIX}_{node.value.id}" 52 | ) 53 | def before_subscript_load(self, _ret, node, *_, **__): 54 | self.counter_by_subscript[node.value.id] += 1 55 | 56 | @pyc.after_subscript_load( 57 | guard=lambda node: f"{pyc.PYCCOLO_BUILTIN_PREFIX}_{node.value.id}" 58 | ) 59 | def after_subscript_load( 60 | self, _ret, node, frame: FrameType, _evt, guard, *_, **__ 61 | ): 62 | self.counter_by_subscript[node.value.id] += 1 63 | assert guard is not None 64 | frame.f_globals[guard] = True 65 | 66 | with SubscriptCounter.instance(): 67 | pyc.exec( 68 | """ 69 | lst = [0] 70 | x = lst[0] + 1 71 | y = lst[0] + x + 1 72 | z = lst[0] + x + y + 1 73 | assert z == 4 74 | """ 75 | ) 76 | for var in ("lst",): 77 | assert SubscriptCounter.counter_by_subscript[var] == 2, var 78 | -------------------------------------------------------------------------------- /test/test_optional_chaining.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import textwrap 4 | 5 | import pyccolo as pyc 6 | from pyccolo.examples.optional_chaining import ScriptOptionalChainer as OptionalChainer 7 | 8 | if sys.version_info >= (3, 8): # noqa 9 | 10 | def test_optional_chaining_simple(): 11 | OptionalChainer.instance().exec( 12 | """ 13 | class Foo: 14 | def __init__(self, x): 15 | self.x = x 16 | foo = Foo(Foo(Foo(None))) 17 | try: 18 | bar = foo.x.x.x.x 19 | except: 20 | pass 21 | else: 22 | assert False 23 | assert foo.x.x.x?.x is None 24 | assert foo.x.x.x?.x() is None 25 | assert foo.x.x.x?.x?.whatever is None 26 | assert isinstance(foo?.x?.x, Foo) 27 | assert isinstance(foo.x?.x, Foo) 28 | assert isinstance(foo?.x.x, Foo) 29 | """ 30 | ) 31 | 32 | def test_permissive_attr_vs_optional_attr_qualifier(): 33 | with OptionalChainer: 34 | try: 35 | pyc.exec("foo = object(); assert foo?.bar is None") 36 | except AttributeError: 37 | pass 38 | else: 39 | assert False 40 | 41 | with OptionalChainer: 42 | try: 43 | pyc.exec("foo = object(); assert foo.?bar.baz is None") 44 | except AttributeError: 45 | pass 46 | else: 47 | assert False 48 | 49 | with OptionalChainer: 50 | pyc.exec("foo = object(); assert foo.?bar is None") 51 | 52 | with OptionalChainer: 53 | pyc.exec("foo = object(); assert foo.?bar?.baz is None") 54 | 55 | with OptionalChainer: 56 | pyc.exec("foo = object(); assert foo.?bar?.baz.bam() is None") 57 | 58 | def test_call_on_optional(): 59 | with OptionalChainer: 60 | pyc.exec("foo = None; assert foo?.() is None") 61 | 62 | def test_nullish_coalescing(): 63 | with OptionalChainer: 64 | pyc.exec("None ?? None") 65 | pyc.exec("foo = ''; assert (foo ?? None) == ''") 66 | pyc.exec("foo = ''; assert (foo ?? None) == ''") 67 | pyc.exec("foo = ''; assert (foo ?? None) == ''") 68 | pyc.exec("foo = ''; assert (foo ?? None) == ''") 69 | assert pyc.eval("'' ?? None") == "" 70 | assert pyc.eval("''??None") == "" 71 | assert pyc.eval("0 ?? None") == 0 72 | assert pyc.eval("None ?? 0 ?? None") == 0 73 | assert pyc.eval("None or 0 ?? None") == 0 74 | assert pyc.eval("None and 0 ?? None") is None 75 | assert pyc.eval("0 or None ?? False") is False 76 | 77 | def test_multiline_nullish_coalescing(): 78 | with OptionalChainer: 79 | assert ( 80 | pyc.eval( 81 | textwrap.dedent( 82 | """ 83 | ( 84 | "" 85 | ?? 86 | None 87 | ) 88 | """.strip( 89 | "\n" 90 | ) 91 | ) 92 | ) 93 | == "" 94 | ) 95 | -------------------------------------------------------------------------------- /docs/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at stephen.macke@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /pyccolo/predicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | from typing import Callable, Optional, Sequence, Union, overload 4 | 5 | from typing_extensions import Literal 6 | 7 | 8 | class Predicate: 9 | TRUE: "Predicate" 10 | FALSE: "Predicate" 11 | 12 | @overload 13 | def __init__( 14 | self, 15 | condition: Callable[[ast.AST], bool], 16 | use_raw_node_id: Literal[False], 17 | static: bool = True, 18 | ) -> None: ... 19 | 20 | @overload 21 | def __init__( 22 | self, 23 | condition: Callable[[int], bool], 24 | use_raw_node_id: Literal[True], 25 | static: bool = False, 26 | ) -> None: ... 27 | 28 | @overload 29 | def __init__( 30 | self, 31 | condition: Callable[..., bool], 32 | use_raw_node_id: bool = False, 33 | static: bool = False, 34 | ) -> None: ... 35 | 36 | def __init__( 37 | self, 38 | condition: Callable[..., bool], 39 | use_raw_node_id: bool = False, 40 | static: bool = False, 41 | ) -> None: 42 | self.condition = condition 43 | self.use_raw_node_id = use_raw_node_id 44 | self.static = static 45 | 46 | def __call__(self, node: Union[ast.AST, int]) -> bool: 47 | node_or_id = ( 48 | id(node) if self.use_raw_node_id and isinstance(node, ast.AST) else node 49 | ) 50 | return self.condition(node_or_id) # type: ignore 51 | 52 | def dynamic_call(self, node: Union[ast.AST, int]) -> bool: 53 | return True if self.static else self(node) 54 | 55 | 56 | Predicate.TRUE = Predicate(lambda *_: True) 57 | Predicate.FALSE = Predicate(lambda *_: False) 58 | 59 | 60 | class CompositePredicate(Predicate): 61 | def __init__(self, base_predicates: Sequence[Predicate], reducer=any) -> None: 62 | self.base_predicates = list(base_predicates) 63 | self.dynamic_base_predicates = [ 64 | pred for pred in base_predicates if not pred.static 65 | ] 66 | self.static = len(self.dynamic_base_predicates) == 0 67 | self.use_raw_node_id = all(pred.use_raw_node_id for pred in base_predicates) 68 | self.reducer = reducer 69 | 70 | def __call__( 71 | self, 72 | node: Union[ast.AST, int], 73 | predicates: Optional[Sequence[Predicate]] = None, 74 | ) -> bool: 75 | predicates = self.base_predicates if predicates is None else predicates 76 | assert len(predicates) > 0 77 | return self.reducer(pred(node) for pred in predicates) 78 | 79 | def dynamic_call(self, node: Union[ast.AST, int]) -> bool: 80 | return ( 81 | True if self.static else self(node, predicates=self.dynamic_base_predicates) 82 | ) 83 | 84 | @classmethod 85 | def _create(cls, base_predicates: Sequence[Predicate], reducer) -> Predicate: 86 | assert len(base_predicates) > 0 87 | return cls(base_predicates, reducer=reducer) 88 | 89 | @classmethod 90 | def any(cls, base_predicates: Sequence[Predicate]) -> Predicate: 91 | if len(base_predicates) == 0 or any( 92 | pred is Predicate.TRUE for pred in base_predicates 93 | ): 94 | return Predicate.TRUE 95 | if all(pred is Predicate.FALSE for pred in base_predicates): 96 | return Predicate.FALSE 97 | return cls._create(base_predicates, reducer=any) 98 | 99 | @classmethod 100 | def all(cls, base_predicates: Sequence[Predicate]) -> Predicate: 101 | if any(pred is Predicate.FALSE for pred in base_predicates): 102 | return Predicate.FALSE 103 | if all(pred is Predicate.TRUE for pred in base_predicates): 104 | return Predicate.TRUE 105 | return cls._create(base_predicates, reducer=all) 106 | -------------------------------------------------------------------------------- /pyccolo/_fast/fast_ast.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | import functools 4 | import sys 5 | import textwrap 6 | import warnings 7 | from contextlib import contextmanager 8 | from typing import TYPE_CHECKING, Callable, Generator, List, Optional 9 | 10 | 11 | class FastAst: 12 | _LOCATION_OF_NODE: Optional[ast.AST] = None 13 | 14 | if TYPE_CHECKING: 15 | 16 | @staticmethod 17 | def keyword(arg: str, value: ast.AST) -> ast.keyword: ... 18 | @staticmethod 19 | def Str(*args, **kwargs) -> ast.Constant: ... 20 | @staticmethod 21 | def Num(*args, **kwargs) -> ast.Constant: ... 22 | @staticmethod 23 | def Bytes(*args, **kwargs) -> ast.Constant: ... 24 | @staticmethod 25 | def NameConstant(*args, **kwargs) -> ast.Constant: ... 26 | @staticmethod 27 | def Ellipsis(*args, **kwargs) -> ast.Constant: ... 28 | 29 | @staticmethod 30 | @contextmanager 31 | def location_of(node: ast.AST) -> Generator[None, None, None]: 32 | """ 33 | All nodes created like `fast.AST(...)` instead of 34 | `ast.AST(...)` will inherit location info from `node`. 35 | """ 36 | old_location_of_node = FastAst._LOCATION_OF_NODE 37 | FastAst._LOCATION_OF_NODE = node 38 | try: 39 | yield 40 | finally: 41 | FastAst._LOCATION_OF_NODE = old_location_of_node 42 | 43 | @classmethod 44 | def location_of_arg(cls, func: Callable[..., ast.AST]) -> Callable[..., ast.AST]: 45 | @functools.wraps(func) 46 | def wrapped_node_transform(*args) -> ast.AST: 47 | with cls.location_of(args[-1]): 48 | return func(*args) 49 | 50 | return wrapped_node_transform 51 | 52 | @classmethod 53 | def kw(cls, arg, value) -> ast.keyword: 54 | return cls.keyword(arg=arg, value=value) 55 | 56 | @classmethod 57 | def kwargs(cls, **kwargs) -> List[ast.keyword]: 58 | return [cls.keyword(arg=arg, value=value) for arg, value in kwargs.items()] 59 | 60 | @classmethod 61 | def parse(cls, code: str, *args, **kwargs) -> ast.AST: 62 | ret = ast.parse(textwrap.dedent(code), *args, **kwargs) 63 | if cls._LOCATION_OF_NODE is not None: 64 | ast.copy_location(ret, cls._LOCATION_OF_NODE) 65 | return ret 66 | 67 | @classmethod 68 | def Call(cls, func, args=None, keywords=None, **kwargs) -> ast.Call: 69 | args = args or [] 70 | keywords = keywords or [] 71 | ret = ast.Call(func, args, keywords, **kwargs) 72 | if cls._LOCATION_OF_NODE is not None: 73 | ast.copy_location(ret, cls._LOCATION_OF_NODE) 74 | return ret 75 | 76 | @staticmethod 77 | def iter_arguments(args: ast.arguments) -> Generator[ast.arg, None, None]: 78 | yield from getattr(args, "posonlyargs", []) 79 | yield from args.args 80 | if args.vararg is not None: 81 | yield args.vararg 82 | yield from getattr(args, "kwonlyargs", []) 83 | if args.kwarg is not None: 84 | yield args.kwarg 85 | 86 | 87 | def _make_func(new_name, old_name=None): 88 | def ctor(*args, **kwargs): 89 | ret = getattr(ast, new_name)(*args, **kwargs) 90 | if FastAst._LOCATION_OF_NODE is not None: 91 | ast.copy_location(ret, FastAst._LOCATION_OF_NODE) 92 | return ret 93 | 94 | ctor.__name__ = old_name or new_name 95 | return ctor 96 | 97 | 98 | with warnings.catch_warnings(): 99 | warnings.simplefilter("ignore", DeprecationWarning) 100 | for ctor_name in ast.__dict__: 101 | if ctor_name.startswith("_") or hasattr(FastAst, ctor_name): 102 | continue 103 | setattr(FastAst, ctor_name, staticmethod(_make_func(ctor_name))) 104 | 105 | 106 | if sys.version_info >= (3, 8): 107 | for old_name in ("Str", "Num", "Bytes", "NameConstant", "Ellipsis"): 108 | setattr(FastAst, old_name, staticmethod(_make_func("Constant", old_name))) 109 | -------------------------------------------------------------------------------- /pyccolo/examples/quasiquote.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Example of an quasiquoter for Pyccolo, similar to MacroPy's. 4 | Ref: https://macropy3.readthedocs.io/en/latest/reference.html#quasiquote 5 | 6 | Example: 7 | ``` 8 | a = 10 9 | b = 2 10 | with Quasiquoter: 11 | pyc.eval("q[1 + u[a + b]]") 12 | >>> BinOp(Add, left=Num(1), right=Num(12)) 13 | ``` 14 | """ 15 | import ast 16 | import copy 17 | from typing import Callable, Tuple, Union 18 | 19 | import pyccolo as pyc 20 | 21 | 22 | class _QuasiquoteTransformer(ast.NodeTransformer): 23 | def __init__(self, global_env, local_env): 24 | self._global_env = global_env 25 | self._local_env = local_env 26 | self._quoter = Quasiquoter.instance() 27 | 28 | def visit_Subscript(self, node: ast.Subscript): 29 | if isinstance(node.value, ast.Name) and node.value.id in self._quoter.macros - { 30 | "q" 31 | }: 32 | return self._quoter.eval(node, self._global_env, self._local_env) 33 | else: 34 | return node 35 | 36 | 37 | def is_macro(name_or_names: Union[str, Tuple[str, ...]]) -> Callable[[ast.AST], bool]: 38 | if isinstance(name_or_names, tuple): 39 | names = name_or_names 40 | else: 41 | names = (name_or_names,) 42 | return ( 43 | lambda node: isinstance(node, ast.Subscript) 44 | and isinstance(node.value, ast.Name) 45 | and node.value.id in names 46 | ) 47 | 48 | 49 | class _IdentitySubscript: 50 | def __getitem__(self, item): 51 | return item 52 | 53 | 54 | _identity_subscript = _IdentitySubscript() 55 | 56 | 57 | class Quasiquoter(pyc.BaseTracer): 58 | allow_reentrant_events = True 59 | 60 | def __init__(self, *args, **kwargs): 61 | super().__init__(*args, **kwargs) 62 | self.macros = {"q", "u", "name", "ast_literal", "ast_list"} 63 | self._extra_builtins = set() 64 | 65 | def enter_tracing_hook(self) -> None: 66 | import builtins 67 | 68 | # need to create dummy reference to avoid NameError 69 | for macro in self.macros: 70 | if not hasattr(builtins, macro): 71 | self._extra_builtins.add(macro) 72 | setattr(builtins, macro, None) 73 | 74 | def exit_tracing_hook(self) -> None: 75 | import builtins 76 | 77 | for macro in self._extra_builtins: 78 | if hasattr(builtins, macro): 79 | delattr(builtins, macro) 80 | self._extra_builtins.clear() 81 | 82 | @pyc.before_subscript_slice(when=is_macro("q"), reentrant=True) 83 | def quote_handler(self, _ret, node, frame, *_, **__): 84 | to_visit = node.slice 85 | if isinstance(node.slice, ast.Index): 86 | to_visit = to_visit.value 87 | return lambda: _QuasiquoteTransformer(frame.f_globals, frame.f_locals).visit( 88 | copy.deepcopy(to_visit) 89 | ) 90 | 91 | @pyc.after_subscript_slice(when=is_macro("u"), reentrant=True) 92 | def unquote_handler(self, ret, *_, **__): 93 | return self.eval(f"q[{repr(ret)}]") 94 | 95 | @pyc.after_subscript_slice(when=is_macro("name"), reentrant=True) 96 | def name_handler(self, ret, *_, **__): 97 | assert isinstance(ret, str) 98 | return self.eval(f"q[{ret}]") 99 | 100 | @pyc.after_subscript_slice(when=is_macro("ast_literal"), reentrant=True) 101 | def ast_literal_handler(self, ret, *_, **__): 102 | # technically we could get away without even having this handler 103 | assert isinstance(ret, ast.AST) 104 | return ret 105 | 106 | @pyc.after_subscript_slice(when=is_macro("ast_list"), reentrant=True) 107 | def ast_list_handler(self, ret, *_, **__): 108 | return ast.List(elts=list(ret)) 109 | 110 | def is_any_macro(self, node): 111 | return is_macro(tuple(self.macros))(node) 112 | 113 | @pyc.before_subscript_load(when=is_any_macro, reentrant=True) 114 | def load_macro_result(self, _ret, *_, **__): 115 | return _identity_subscript 116 | -------------------------------------------------------------------------------- /pyccolo/ast_bookkeeping.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | from typing import Dict, NamedTuple, Optional, Tuple 4 | 5 | 6 | class AstBookkeeper(NamedTuple): 7 | path: str 8 | module_id: int 9 | ast_node_by_id: Dict[int, ast.AST] 10 | containing_ast_by_id: Dict[int, ast.AST] 11 | containing_stmt_by_id: Dict[int, ast.stmt] 12 | parent_stmt_by_id: Dict[int, ast.stmt] 13 | stmt_by_lineno: Dict[int, ast.stmt] 14 | 15 | def remap(self, new_module_id: int) -> Tuple["AstBookkeeper", Dict[int, int]]: 16 | """ 17 | After unpickling, the ast nodes will have different ids than before. 18 | This method will compuate a new bookkeeper to reflect the new ids, as well 19 | as return a mapping from the old ids to the new ids. 20 | """ 21 | remapping: Dict[int, int] = {self.module_id: new_module_id} 22 | new_ast_node_by_id: Dict[int, ast.AST] = {} 23 | new_containing_ast_by_id: Dict[int, ast.AST] = {} 24 | new_containing_stmt_by_id: Dict[int, ast.stmt] = {} 25 | new_parent_stmt_by_id: Dict[int, ast.stmt] = {} 26 | for old_id, ast_node in self.ast_node_by_id.items(): 27 | new_id = id(ast_node) 28 | remapping[old_id] = new_id 29 | new_ast_node_by_id[new_id] = ast_node 30 | if old_id in self.containing_ast_by_id: 31 | new_containing_ast_by_id[new_id] = self.containing_ast_by_id[old_id] 32 | if old_id in self.containing_stmt_by_id: 33 | new_containing_stmt_by_id[new_id] = self.containing_stmt_by_id[old_id] 34 | if old_id in self.parent_stmt_by_id: 35 | new_parent_stmt_by_id[new_id] = self.parent_stmt_by_id[old_id] 36 | return ( 37 | AstBookkeeper( 38 | self.path, 39 | new_module_id, 40 | new_ast_node_by_id, 41 | new_containing_ast_by_id, 42 | new_containing_stmt_by_id, 43 | new_parent_stmt_by_id, 44 | self.stmt_by_lineno, 45 | ), 46 | remapping, 47 | ) 48 | 49 | @classmethod 50 | def create(cls, path: str, module_id: int) -> "AstBookkeeper": 51 | return cls(path, module_id, {}, {}, {}, {}, {}) 52 | 53 | 54 | class BookkeepingVisitor(ast.NodeVisitor): 55 | def __init__(self, bookkeeper: AstBookkeeper) -> None: 56 | self.ast_node_by_id = bookkeeper.ast_node_by_id 57 | self.containing_ast_by_id = bookkeeper.containing_ast_by_id 58 | self.containing_stmt_by_id = bookkeeper.containing_stmt_by_id 59 | self.parent_stmt_by_id = bookkeeper.parent_stmt_by_id 60 | self.stmt_by_lineno = bookkeeper.stmt_by_lineno 61 | self._current_containing_stmt: Optional[ast.stmt] = None 62 | 63 | def generic_visit(self, node: ast.AST): 64 | if isinstance(node, ast.stmt): 65 | self._current_containing_stmt = node 66 | if self._current_containing_stmt is not None: 67 | self.containing_stmt_by_id.setdefault( 68 | id(node), self._current_containing_stmt 69 | ) 70 | self.ast_node_by_id[id(node)] = node 71 | if isinstance(node, ast.stmt): 72 | self.stmt_by_lineno[node.lineno] = node 73 | # workaround for python >= 3.8 wherein function calls seem 74 | # to yield trace frames that use the lineno of the first decorator 75 | for decorator in getattr(node, "decorator_list", []): 76 | self.stmt_by_lineno[decorator.lineno] = node 77 | for name, field in ast.iter_fields(node): 78 | if isinstance(field, ast.AST): 79 | self.containing_ast_by_id[id(field)] = node 80 | if self._current_containing_stmt is not None: 81 | self.containing_stmt_by_id[id(field)] = ( 82 | self._current_containing_stmt 83 | ) 84 | elif isinstance(field, list): 85 | for subfield in field: 86 | if isinstance(subfield, ast.AST): 87 | self.containing_ast_by_id[id(subfield)] = node 88 | if isinstance(node, ast.stmt): 89 | self.parent_stmt_by_id[id(subfield)] = node 90 | super().generic_visit(node) 91 | -------------------------------------------------------------------------------- /test/test_syntax_augmentation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | import sys 4 | 5 | import pyccolo as pyc 6 | 7 | if sys.version_info >= (3, 8): # noqa 8 | add_42_spec = pyc.AugmentationSpec( 9 | aug_type=pyc.AugmentationType.binop, token="++", replacement="+" 10 | ) 11 | 12 | def test_augmented_plus(): 13 | class Add42(pyc.BaseTracer): 14 | @classmethod 15 | def syntax_augmentation_specs(cls): 16 | return [add_42_spec] 17 | 18 | @pyc.after_binop(when=lambda node: isinstance(node.op, ast.Add)) 19 | def handle_add(self, ret, node, *_, **__): 20 | if add_42_spec in self.get_augmentations(id(node)): 21 | return ret + 42 22 | else: 23 | return ret 24 | 25 | tracer = Add42.instance() 26 | env = tracer.exec("x = 21 ++ 21") 27 | assert env["x"] == 84, "got %s" % env["x"] 28 | 29 | env = tracer.exec("x = x + 21", local_env=env) 30 | assert env["x"] == 105 31 | 32 | coalesce_dot_spec = pyc.AugmentationSpec( 33 | aug_type=pyc.AugmentationType.dot_suffix, token="?.", replacement="." 34 | ) 35 | 36 | prefix_spec = pyc.AugmentationSpec( 37 | aug_type=pyc.AugmentationType.prefix, token="$", replacement="" 38 | ) 39 | 40 | suffix_spec = pyc.AugmentationSpec( 41 | aug_type=pyc.AugmentationType.suffix, token="$$", replacement="" 42 | ) 43 | 44 | def test_prefix_suffix(): 45 | class IncrementAugmentedTracer(pyc.BaseTracer): 46 | def __init__(self, *args, **kwargs): 47 | super().__init__(*args, **kwargs) 48 | self._delayed_increment = 0 49 | 50 | @classmethod 51 | def syntax_augmentation_specs(cls): 52 | return [suffix_spec, prefix_spec] 53 | 54 | @pyc.load_name 55 | def handle_name(self, ret, node, *_, **__): 56 | self._delayed_increment = 0 57 | node_id = id(node) 58 | augs = self.get_augmentations(node_id) 59 | assert not (prefix_spec in augs and suffix_spec in augs) 60 | if prefix_spec in augs: 61 | offset = 1 62 | elif suffix_spec in augs: 63 | offset = 2 64 | else: 65 | offset = 0 66 | if isinstance(ret, int): 67 | return ret + offset 68 | else: 69 | self._delayed_increment = offset 70 | return ret 71 | 72 | @pyc.after_attribute_load 73 | def handle_attr(self, ret, node, *_, **__): 74 | if isinstance(node.value, ast.Name): 75 | ret += self._delayed_increment 76 | self._delayed_increment = 0 77 | augs = self.get_augmentations(id(node)) 78 | assert not (prefix_spec in augs and suffix_spec in augs) 79 | if prefix_spec in augs: 80 | return ret + 1 81 | elif suffix_spec in augs: 82 | return ret + 2 83 | else: 84 | return ret 85 | 86 | with IncrementAugmentedTracer.instance(): 87 | assert ( 88 | pyc.exec( 89 | """ 90 | class Foo: 91 | y = 4 92 | x = 3 93 | foo = Foo() 94 | z = $x + foo.y$$ 95 | """ 96 | )["z"] 97 | == 10 98 | ) 99 | 100 | assert ( 101 | pyc.exec( 102 | """ 103 | class Foo: 104 | y = 4 105 | x = 3 106 | foo = Foo() 107 | z = $x + $foo.y 108 | """ 109 | )["z"] 110 | == 9 111 | ) 112 | 113 | assert ( 114 | pyc.exec( 115 | """ 116 | class Foo: 117 | y = 4 118 | x = 3 119 | foo = Foo() 120 | z = $x + $foo.y$$ 121 | """ 122 | )["z"] 123 | == 11 124 | ) 125 | -------------------------------------------------------------------------------- /pyccolo/trace_stack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import itertools 3 | from contextlib import contextmanager 4 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple 5 | 6 | if TYPE_CHECKING: 7 | from pyccolo.tracer import _InternalBaseTracer 8 | 9 | 10 | class TraceStack: 11 | def __init__(self, manager: "_InternalBaseTracer"): 12 | self._manager = manager 13 | self._stack: List[Tuple[Any, ...]] = [] 14 | self._stack_item_initializers: Dict[str, Callable[[], Any]] = {} 15 | self._stack_items_with_manual_initialization: Set[str] = set() 16 | self._registering_stack_state_context = False 17 | self._field_mapping: Dict[str, int] = {} 18 | 19 | def _stack_item_names(self): 20 | return itertools.chain( 21 | self._stack_item_initializers.keys(), 22 | self._stack_items_with_manual_initialization, 23 | ) 24 | 25 | def get_field( 26 | self, field: str, depth: Optional[int] = None, height: Optional[int] = None 27 | ) -> Any: 28 | height = -(1 if depth is None else depth) if height is None else height 29 | return self._stack[height][self._field_mapping[field]] 30 | 31 | @staticmethod 32 | def _make_initer_from_val(init_val: Any) -> Callable[[], Any]: 33 | return lambda: init_val 34 | 35 | @contextmanager 36 | def register_stack_state(self): 37 | self._registering_stack_state_context = True 38 | original_state = set(self._manager.__dict__.keys()) 39 | yield 40 | self._registering_stack_state_context = False 41 | stack_item_names = set(self._manager.__dict__.keys() - original_state) 42 | for stack_item_name in ( 43 | stack_item_names - self._stack_items_with_manual_initialization 44 | ): 45 | stack_item = self._manager.__dict__[stack_item_name] 46 | if isinstance(stack_item, TraceStack): 47 | self._stack_item_initializers[stack_item_name] = stack_item._clone 48 | elif stack_item is None: 49 | self._stack_item_initializers[stack_item_name] = lambda: None 50 | elif isinstance(stack_item, (int, bool, str, float)): 51 | init_val = type(stack_item)(stack_item) 52 | self._stack_item_initializers[stack_item_name] = ( 53 | self._make_initer_from_val(init_val) 54 | ) 55 | else: 56 | self._stack_item_initializers[stack_item_name] = type(stack_item) 57 | for i, stack_item_name in enumerate(self._stack_item_names()): 58 | self._field_mapping[stack_item_name] = i 59 | 60 | @contextmanager 61 | def needing_manual_initialization(self): 62 | assert self._registering_stack_state_context 63 | original_state = set(self._manager.__dict__.keys()) 64 | yield 65 | self._stack_items_with_manual_initialization = set( 66 | self._manager.__dict__.keys() - original_state 67 | ) 68 | 69 | @contextmanager 70 | def push(self): 71 | """ 72 | Checks at the end of the context that everything requiring manual init was manually inited. 73 | """ 74 | self._stack.append( 75 | tuple( 76 | self._manager.__dict__[stack_item] 77 | for stack_item in self._stack_item_names() 78 | ) 79 | ) 80 | for stack_item, initializer in self._stack_item_initializers.items(): 81 | self._manager.__dict__[stack_item] = initializer() 82 | for stack_item in self._stack_items_with_manual_initialization: 83 | del self._manager.__dict__[stack_item] 84 | yield 85 | uninitialized_items = [] 86 | for stack_item in self._stack_items_with_manual_initialization: 87 | if stack_item not in self._manager.__dict__: 88 | uninitialized_items.append(stack_item) 89 | if len(uninitialized_items) > 0: 90 | raise ValueError( 91 | "Stack item(s) %s requiring manual initialization were not initialized" 92 | % uninitialized_items 93 | ) 94 | 95 | def _clone(self): 96 | new_tracing_stack = TraceStack(self._manager) 97 | new_tracing_stack.__dict__ = dict(self.__dict__) 98 | return new_tracing_stack 99 | 100 | def pop(self) -> "TraceStack": 101 | for stack_item_name, stack_item in zip( 102 | self._stack_item_names(), self._stack.pop() 103 | ): 104 | self._manager.__dict__[stack_item_name] = stack_item 105 | return self 106 | 107 | def clear(self): 108 | self._stack = self._stack[:1] 109 | if len(self._stack) > 0: 110 | self.pop() 111 | 112 | def __len__(self): 113 | return len(self._stack) 114 | -------------------------------------------------------------------------------- /pyccolo/fast/__init__.pyi: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import typing 4 | from contextlib import contextmanager 5 | from typing import TYPE_CHECKING, Callable, Generator 6 | 7 | if TYPE_CHECKING: 8 | import ast 9 | 10 | from pyccolo._fast import * # noqa: F403 11 | 12 | @contextmanager 13 | def location_of(node: ast.AST) -> "Generator[None, None, None]": ... 14 | def location_of_arg(func: Callable[..., ast.AST]) -> "Callable[..., ast.AST]": ... 15 | def kw(arg: str, value: "ast.expr") -> "ast.keyword": ... 16 | def kwargs(**kwargs: "ast.expr") -> "typing.List[ast.keyword]": ... 17 | def iter_arguments(args: "ast.arguments") -> "Generator[ast.arg, None, None]": ... 18 | def parse(*args, **kwargs) -> "ast.Module": ... 19 | def AnnAssign(*args, **kwargs) -> "ast.AnnAssign": ... 20 | def Assert(*args, **kwargs) -> "ast.Assert": ... 21 | def Assign(*args, **kwargs) -> "ast.Assign": ... 22 | def AsyncFor(*args, **kwargs) -> "ast.AsyncFor": ... 23 | def AsyncFunctionDef(*args, **kwargs) -> "ast.AsyncFunctionDef": ... 24 | def AsyncWith(*args, **kwargs) -> "ast.AsyncWith": ... 25 | def Attribute(*args, **kwargs) -> "ast.Attribute": ... 26 | def AugAssign(*args, **kwargs) -> "ast.AugAssign": ... 27 | def Await(*args, **kwargs) -> "ast.Await": ... 28 | def BinOp(*args, **kwargs) -> "ast.BinOp": ... 29 | def BoolOp(*args, **kwargs) -> "ast.BoolOp": ... 30 | def Call(*args, **kwargs) -> "ast.Call": ... 31 | def ClassDef(*args, **kwargs) -> "ast.ClassDef": ... 32 | def Compare(*args, **kwargs) -> "ast.Compare": ... 33 | def Delete(*args, **kwargs) -> "ast.Delete": ... 34 | def Dict(*args, **kwargs) -> "ast.Dict": ... 35 | def DictComp(*args, **kwargs) -> "ast.DictComp": ... 36 | def ExceptHandler(*args, **kwargs) -> "ast.ExceptHandler": ... 37 | def Expr(*args, **kwargs) -> "ast.Expr": ... 38 | def For(*args, **kwargs) -> "ast.For": ... 39 | def FormattedValue(*args, **kwargs) -> "ast.FormattedValue": ... 40 | def FunctionDef(*args, **kwargs) -> "ast.FunctionDef": ... 41 | def GeneratorExpr(*args, **kwargs) -> "ast.GeneratorExp": ... 42 | def Global(*args, **kwargs) -> "ast.Global": ... 43 | def If(*args, **kwargs) -> "ast.If": ... 44 | def IfExp(*args, **kwargs) -> "ast.IfExp": ... 45 | def Import(*args, **kwargs) -> "ast.Import": ... 46 | def ImportFrom(*args, **kwargs) -> "ast.ImportFrom": ... 47 | def JoinedStr(*args, **kwargs) -> "ast.JoinedStr": ... 48 | def Lambda(*args, **kwargs) -> "ast.Lambda": ... 49 | def List(*args, **kwargs) -> "ast.List": ... 50 | def ListComp(*args, **kwargs) -> "ast.ListComp": ... 51 | def Name(*args, **kwargs) -> "ast.Name": ... 52 | def Nonlocal(*args, **kwargs) -> "ast.Nonlocal": ... 53 | def Pass(*args, **kwargs) -> "ast.Pass": ... 54 | def Raise(*args, **kwargs) -> "ast.Raise": ... 55 | def Return(*args, **kwargs) -> "ast.Return": ... 56 | def Set(*args, **kwargs) -> "ast.Set": ... 57 | def SetComp(*args, **kwargs) -> "ast.SetComp": ... 58 | def Slice(*args, **kwargs) -> "ast.Slice": ... 59 | def Starred(*args, **kwargs) -> "ast.Starred": ... 60 | def Subscript(*args, **kwargs) -> "ast.Subscript": ... 61 | def Try(*args, **kwargs) -> "ast.Try": ... 62 | def Tuple(*args, **kwargs) -> "ast.Tuple": ... 63 | def UnaryOp(*args, **kwargs) -> "ast.UnaryOp": ... 64 | def While(*args, **kwargs) -> "ast.While": ... 65 | def With(*args, **kwargs) -> "ast.With": ... 66 | def Yield(*args, **kwargs) -> "ast.Yield": ... 67 | def YieldFrom(*args, **kwargs) -> "ast.YieldFrom": ... 68 | def alias(*args, **kwargs) -> "ast.alias": ... 69 | def arg(*args, **kwargs) -> "ast.arg": ... 70 | def comprehension(*args, **kwargs) -> "ast.comprehension": ... 71 | def excepthandler(*args, **kwargs) -> "ast.excepthandler": ... 72 | def keyword(*args, **kwargs) -> "ast.keyword": ... 73 | def withitem(*args, **kwargs) -> "ast.withitem": ... 74 | 75 | if sys.version_info < (3, 9): 76 | def ExtSlice(*args, **kwargs) -> "ast.ExtSlice": ... 77 | def Index(*args, **kwargs) -> "ast.Index": ... 78 | 79 | if sys.version_info < (3, 8): 80 | def Num(*args, **kwargs) -> "ast.Num": ... 81 | def Str(*args, **kwargs) -> "ast.Str": ... 82 | def Bytes(*args, **kwargs) -> "ast.Bytes": ... 83 | def NameConstant(*args, **kwargs) -> "ast.NameConstant": ... 84 | def Ellipsis(*args, **kwargs) -> "ast.Ellipsis": ... 85 | 86 | else: 87 | def Constant(*args, **kwargs) -> "ast.Constant": ... 88 | def Num(*args, **kwargs) -> "ast.Constant": ... 89 | def Str(*args, **kwargs) -> "ast.Constant": ... 90 | def Bytes(*args, **kwargs) -> "ast.Constant": ... 91 | def NameConstant(*args, **kwargs) -> "ast.Constant": ... 92 | def Ellipsis(*args, **kwargs) -> "ast.Constant": ... 93 | 94 | if sys.version_info >= (3, 10): 95 | def Match(*args, **kwargs) -> "ast.Match": ... 96 | def MatchAs(*args, **kwargs) -> "ast.MatchAs": ... 97 | def MatchClass(*args, **kwargs) -> "ast.MatchClass": ... 98 | def MatchMapping(*args, **kwargs) -> "ast.MatchMapping": ... 99 | def MatchOr(*args, **kwargs) -> "ast.MatchOr": ... 100 | def MatchSequence(*args, **kwargs) -> "ast.MatchSequence": ... 101 | def MatchSingleton(*args, **kwargs) -> "ast.MatchSingleton": ... 102 | def MatchStar(*args, **kwargs) -> "ast.MatchStar": ... 103 | def MatchValue(*args, **kwargs) -> "ast.MatchValue": ... 104 | 105 | if sys.version_info >= (3, 11): 106 | def TryStar(*args, **kwargs) -> "ast.TryStar": ... 107 | -------------------------------------------------------------------------------- /pyccolo/examples/coverage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Example of simple code coverage implemented using Pyccolo. 4 | 5 | Run as `python pyccolo/examples/coverage.py` from the repository root. 6 | """ 7 | import ast 8 | import logging 9 | import os 10 | import sys 11 | from collections import Counter 12 | 13 | import pyccolo as pyc 14 | from pyccolo.import_hooks import patch_meta_path 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | join = os.path.join 20 | 21 | 22 | EXCEPTED_FILES = { 23 | pyc.SANDBOX_FNAME, 24 | "version.py", 25 | "_version.py", 26 | # weird shit happens if we instrument _emit_event and import_hooks, so exclude them. 27 | # can be removed for coverage of non-pyccolo projects. 28 | "emit_event.py", 29 | "import_hooks.py", 30 | } 31 | 32 | 33 | class CountStatementsVisitor(ast.NodeVisitor): 34 | def __init__(self): 35 | self.num_stmts = 0 36 | 37 | def generic_visit(self, node): 38 | if isinstance(node, ast.stmt): 39 | if not isinstance(node, ast.Raise): 40 | self.num_stmts += 1 41 | if ( 42 | isinstance(node, ast.If) 43 | and isinstance(node.test, ast.Name) 44 | and node.test.id == "TYPE_CHECKING" 45 | ): 46 | return 47 | super().generic_visit(node) 48 | 49 | 50 | class CoverageTracer(pyc.BaseTracer): 51 | bytecode_caching_allowed = False 52 | 53 | def __init__(self, *args, **kwargs): 54 | super().__init__(*args, **kwargs) 55 | self.seen_stmts = set() 56 | self.stmt_count_by_fname = Counter() 57 | self.count_static_statements_visitor = CountStatementsVisitor() 58 | 59 | def count_statements(self, path: str) -> int: 60 | with open(path, "r") as f: 61 | contents = f.read() 62 | try: 63 | self.count_static_statements_visitor.visit(ast.parse(contents)) 64 | except SyntaxError: 65 | # this means that we must have some other tracer in there, 66 | # that should be capable of parsing some augmented syntax 67 | self.count_static_statements_visitor.visit(self.parse(contents)) 68 | ret = self.count_static_statements_visitor.num_stmts 69 | self.count_static_statements_visitor.num_stmts = 0 70 | return ret 71 | 72 | def should_instrument_file(self, filename: str) -> bool: 73 | if "test/" in filename or "examples" in filename: 74 | # filter out tests and self 75 | return False 76 | 77 | return "pyccolo" in filename and not any( 78 | filename.endswith(excepted) for excepted in EXCEPTED_FILES 79 | ) 80 | 81 | @pyc.register_raw_handler(pyc.before_stmt) 82 | def handle_stmt(self, _ret, stmt_id, frame, *_, **__): 83 | fname = frame.f_code.co_filename 84 | if fname.startswith(pyc.SANDBOX_FNAME_PREFIX): 85 | # filter these out. not necessary for non-pyccolo coverage 86 | return 87 | if stmt_id not in self.seen_stmts: 88 | self.stmt_count_by_fname[fname] += 1 89 | self.seen_stmts.add(stmt_id) 90 | 91 | def exit_tracing_hook(self) -> None: 92 | total_stmts = 0 93 | for fname in sorted(self.stmt_count_by_fname.keys()): 94 | if fname.startswith(pyc.SANDBOX_FNAME_PREFIX): 95 | continue 96 | shortened = "." + fname.split(".", 1)[-1] 97 | seen = self.stmt_count_by_fname[fname] 98 | total_in_file = self.count_statements(fname) 99 | total_stmts += total_in_file 100 | logger.warning( 101 | "[%-40s]: seen=%4d, total=%4d, ratio=%.3f", 102 | shortened, 103 | seen, 104 | total_in_file, 105 | float(seen) / total_in_file, 106 | ) 107 | num_seen_stmts = len(self.seen_stmts) 108 | logger.warning("num stmts seen: %s", num_seen_stmts) 109 | logger.warning("num stmts total: %s", total_stmts) 110 | if total_stmts == 0: 111 | logger.error("Counted 0 total statements; saw %d", num_seen_stmts) 112 | else: 113 | logger.warning("ratio: %.3f", float(num_seen_stmts) / total_stmts) 114 | 115 | 116 | def remove_pyccolo_modules(): 117 | to_delete = [] 118 | for mod in sys.modules: 119 | if mod.startswith("pyccolo"): 120 | to_delete.append(mod) 121 | for mod in to_delete: 122 | del sys.modules[mod] 123 | 124 | 125 | if __name__ == "__main__": 126 | import pytest 127 | 128 | sys.path.insert(0, ".") 129 | # now clear pyccolo modules so that they get reimported, and instrumented 130 | # can be omitted for non-pyccolo projects 131 | orig_pyc = pyc 132 | remove_pyccolo_modules() 133 | tracer = CoverageTracer.instance() 134 | with tracer: 135 | import pyccolo as pyc 136 | 137 | # we just cleared the original tracer stack when we deleted all the imports, so 138 | # we need to put it back 139 | # (can be omitted for non-pyccolo projects) 140 | pyc._TRACER_STACK.append(tracer) 141 | with patch_meta_path(pyc._TRACER_STACK): 142 | exit_code = pytest.console_main() 143 | sys.exit(exit_code) 144 | -------------------------------------------------------------------------------- /pyccolo/emit_event.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import sys 4 | import threading 5 | from contextlib import contextmanager 6 | from typing import TYPE_CHECKING, List 7 | 8 | from pyccolo.trace_events import BEFORE_EXPR_EVENTS 9 | 10 | if TYPE_CHECKING: 11 | from pyccolo.tracer import BaseTracer, _InternalBaseTracer 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | _BEFORE_EXPR_EVENT_NAMES = {evt.value for evt in BEFORE_EXPR_EVENTS} 18 | _TRACER_STACK: "List[BaseTracer]" = [] 19 | _allow_event_handling = True 20 | _allow_reentrant_event_handling = False 21 | 22 | 23 | @contextmanager 24 | def allow_reentrant_event_handling(): 25 | global _allow_reentrant_event_handling 26 | orig_allow_reentrant_handling = _allow_reentrant_event_handling 27 | _allow_reentrant_event_handling = True 28 | try: 29 | yield 30 | finally: 31 | _allow_reentrant_event_handling = orig_allow_reentrant_handling 32 | 33 | 34 | def _make_ret(event, ret): 35 | if event in _BEFORE_EXPR_EVENT_NAMES and not callable(ret): 36 | return lambda *_: ret 37 | else: 38 | return ret 39 | 40 | 41 | SkipAll = object() 42 | _main_thread_id = threading.main_thread().ident 43 | SANDBOX_FNAME = "" 44 | SANDBOX_FNAME_PREFIX = " bool: 48 | return False 49 | 50 | 51 | def _should_instrument_file_impl(tracer, filename: str) -> bool: 52 | if ( 53 | tracer.instrument_all_files 54 | or filename in tracer._tracing_enabled_files 55 | or filename.startswith(SANDBOX_FNAME_PREFIX) 56 | ): 57 | return True 58 | for clazz in tracer.__class__.mro(): 59 | if clazz.__name__ == "BaseTracer": 60 | break 61 | should_instrument_file = clazz.__dict__.get("should_instrument_file") 62 | if should_instrument_file is not None: 63 | return should_instrument_file(tracer, filename) 64 | return _should_instrument_file(tracer, filename) 65 | 66 | 67 | def _file_passes_filter_for_event( 68 | tracer: "_InternalBaseTracer", evt: str, filename: str 69 | ) -> bool: 70 | return True 71 | 72 | 73 | def _file_passes_filter_impl( 74 | tracer: "_InternalBaseTracer", evt: str, filename: str, is_reentrant: bool = False 75 | ) -> bool: 76 | if filename == tracer._current_sandbox_fname and tracer.has_sys_trace_events: 77 | ret = tracer._num_sandbox_calls_seen >= 2 78 | tracer._num_sandbox_calls_seen += evt == "call" 79 | return ret 80 | if not ( 81 | evt 82 | in ( 83 | "before_import", 84 | "init_module", 85 | "after_import", 86 | ) 87 | or _should_instrument_file_impl(tracer, filename) 88 | ): 89 | return False 90 | for clazz in tracer.__class__.mro(): 91 | if clazz.__name__ == "BaseTracer": 92 | break 93 | file_passes_filter_for_event = clazz.__dict__.get( 94 | "file_passes_filter_for_event" 95 | ) 96 | if file_passes_filter_for_event is not None: 97 | return file_passes_filter_for_event(tracer, evt, filename) 98 | return _file_passes_filter_for_event(tracer, evt, filename) 99 | 100 | 101 | def _emit_tracer_loop( 102 | event, 103 | node_id, 104 | frame, 105 | kwargs, 106 | ): 107 | global _allow_reentrant_event_handling 108 | global _allow_event_handling 109 | current_thread_id = threading.current_thread().ident 110 | is_reentrant = not _allow_event_handling 111 | reentrant_handlers_only = is_reentrant and not _allow_reentrant_event_handling 112 | _allow_event_handling = False 113 | for tracer in _TRACER_STACK: 114 | if current_thread_id != _main_thread_id and not tracer.multiple_threads_allowed: 115 | continue 116 | if ( 117 | is_reentrant 118 | and not tracer.allow_reentrant_events 119 | and not _allow_reentrant_event_handling 120 | ): 121 | continue 122 | if not _file_passes_filter_impl( 123 | tracer, event, frame.f_code.co_filename, is_reentrant=is_reentrant 124 | ): 125 | continue 126 | new_ret = tracer._emit_event( 127 | event, 128 | node_id, 129 | frame, 130 | reentrant_handlers_only=reentrant_handlers_only, 131 | **kwargs, 132 | ) 133 | if isinstance(new_ret, tuple) and len(new_ret) > 1 and new_ret[0] is SkipAll: 134 | kwargs["ret"] = new_ret[1] 135 | break 136 | else: 137 | kwargs["ret"] = new_ret 138 | 139 | 140 | def _emit_event(event, node_id, **kwargs): 141 | global _allow_event_handling 142 | global _allow_reentrant_event_handling 143 | __debuggerskip__ = True # noqa: F841 144 | frame = sys._getframe().f_back 145 | if frame.f_code.co_filename == __file__: 146 | # weird shit happens if we instrument this file, so exclude it. 147 | return _make_ret(event, kwargs.get("ret", None)) 148 | orig_allow_event_handling = _allow_event_handling 149 | orig_allow_reentrant_event_handling = _allow_reentrant_event_handling 150 | if len(_TRACER_STACK) > 0: 151 | remapping = _TRACER_STACK[-1].node_id_remapping_by_fname.get( 152 | frame.f_code.co_filename 153 | ) 154 | if remapping is not None: 155 | node_id = remapping.get(node_id, node_id) 156 | try: 157 | _emit_tracer_loop( 158 | event, 159 | node_id, 160 | frame, 161 | kwargs, 162 | ) 163 | finally: 164 | _allow_event_handling = orig_allow_event_handling 165 | _allow_reentrant_event_handling = orig_allow_reentrant_event_handling 166 | return _make_ret(event, kwargs.get("ret")) 167 | -------------------------------------------------------------------------------- /pyccolo/examples/optional_chaining.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Example of optional chaining and nullish coalescing implementing with Pyccolo; 4 | e.g., foo?.bar resolves to `None` when `foo` is `None`. 5 | """ 6 | import ast 7 | from typing import Any, Optional 8 | 9 | import pyccolo as pyc 10 | 11 | 12 | def parent_is_or_boolop(node_id: int) -> bool: 13 | parent = pyc.BaseTracer.containing_ast_by_id.get(node_id) 14 | return isinstance(parent, ast.BoolOp) and isinstance(parent.op, ast.Or) 15 | 16 | 17 | class OptionalChainer(pyc.BaseTracer): 18 | class ResolvesToNone: 19 | def __init__(self, eventually: bool) -> None: 20 | self.__eventually = eventually 21 | 22 | def __getattr__(self, _item: str): 23 | if self.__eventually: 24 | return self 25 | else: 26 | return None 27 | 28 | def __call__(self, *_, **__): 29 | return self 30 | 31 | resolves_to_none_eventually = ResolvesToNone(eventually=True) 32 | resolves_to_none_immediately = ResolvesToNone(eventually=False) 33 | 34 | call_optional_chaining_spec = pyc.AugmentationSpec( 35 | aug_type=pyc.AugmentationType.dot_suffix, token="?.(", replacement="(" 36 | ) 37 | 38 | optional_chaining_spec = pyc.AugmentationSpec( 39 | aug_type=pyc.AugmentationType.dot_suffix, token="?.", replacement="." 40 | ) 41 | 42 | permissive_attr_dereference_spec = pyc.AugmentationSpec( 43 | aug_type=pyc.AugmentationType.dot_suffix, token=".?", replacement="." 44 | ) 45 | 46 | nullish_coalescing_spec = pyc.AugmentationSpec( 47 | aug_type=pyc.AugmentationType.boolop, token="??", replacement=" or " 48 | ) 49 | 50 | def __init__(self, *args, **kwargs) -> None: 51 | super().__init__(*args, **kwargs) 52 | self._saved_ret_expr = None 53 | self.lexical_call_stack: pyc.TraceStack = self.make_stack() 54 | with self.lexical_call_stack.register_stack_state(): 55 | # TODO: pop this the right number of times if an exception occurs 56 | self.cur_call_is_none_resolver: bool = False 57 | self.lexical_nullish_stack: pyc.TraceStack = self.make_stack() 58 | with self.lexical_nullish_stack.register_stack_state(): 59 | self.cur_boolop_has_nullish_coalescer = False 60 | self.coalesced_value: Optional[Any] = None 61 | 62 | @pyc.register_raw_handler(pyc.after_stmt) 63 | def handle_after_stmt(self, ret, *_, **__): 64 | self._saved_ret_expr = ret 65 | 66 | @pyc.register_raw_handler(pyc.after_module_stmt) 67 | def handle_after_module_stmt(self, *_, **__): 68 | while len(self.lexical_call_stack) > 0: 69 | self.lexical_call_stack.pop() 70 | ret = self._saved_ret_expr 71 | self._saved_ret_expr = None 72 | return ret 73 | 74 | @pyc.register_handler(pyc.before_attribute_load) 75 | def handle_before_attr(self, obj, node: ast.Attribute, *_, **__): 76 | if ( 77 | self.optional_chaining_spec in self.get_augmentations(id(node)) 78 | and obj is None 79 | ): 80 | return self.resolves_to_none_eventually 81 | elif self.permissive_attr_dereference_spec in self.get_augmentations( 82 | id(node) 83 | ) and not hasattr(obj, node.attr): 84 | return self.resolves_to_none_immediately 85 | else: 86 | return obj 87 | 88 | @pyc.register_handler(pyc.before_call) 89 | def handle_before_call(self, func, node: ast.Call, *_, **__): 90 | if func is None and self.call_optional_chaining_spec in self.get_augmentations( 91 | id(node.func) 92 | ): 93 | func = self.resolves_to_none_eventually 94 | with self.lexical_call_stack.push(): 95 | self.cur_call_is_none_resolver = func is self.resolves_to_none_eventually 96 | return func 97 | 98 | @pyc.register_raw_handler(pyc.before_argument) 99 | def handle_before_arg(self, arg_lambda, *_, **__): 100 | if self.cur_call_is_none_resolver: 101 | return lambda: None 102 | else: 103 | return arg_lambda 104 | 105 | @pyc.register_raw_handler(pyc.after_call) 106 | def handle_after_call(self, *_, **__): 107 | self.lexical_call_stack.pop() 108 | 109 | @pyc.register_raw_handler(pyc.after_load_complex_symbol) 110 | def handle_after_load_complex_symbol(self, ret, *_, **__): 111 | if isinstance(ret, self.ResolvesToNone): 112 | return pyc.Null 113 | else: 114 | return ret 115 | 116 | @pyc.register_handler( 117 | pyc.before_boolop, when=lambda node: isinstance(node.op, ast.Or) 118 | ) 119 | def before_or_boolop(self, ret, node: ast.BoolOp, *_, **__): 120 | with self.lexical_nullish_stack.push(): 121 | self.cur_boolop_has_nullish_coalescer = any( 122 | self.nullish_coalescing_spec in self.get_augmentations(id(val)) 123 | for val in node.values 124 | ) 125 | self.coalesced_value = None 126 | return ret 127 | 128 | @pyc.register_handler( 129 | pyc.after_boolop, when=lambda node: isinstance(node.op, ast.Or) 130 | ) 131 | def after_or_boolop(self, *_, **__): 132 | self.lexical_nullish_stack.pop() 133 | 134 | def _maybe_compute_nullish_coalesced_value(self, ret, node_id: int) -> None: 135 | if self.coalesced_value is not None: 136 | return 137 | val = ret() 138 | if self.nullish_coalescing_spec in self.get_augmentations(node_id): 139 | self.coalesced_value = None if val is None else val 140 | else: 141 | self.coalesced_value = val or None 142 | 143 | @pyc.register_raw_handler(pyc.before_boolop_arg, when=parent_is_or_boolop) 144 | def before_or_boolup(self, ret, node_id: int, *_, is_last: bool, **__): 145 | if self.cur_boolop_has_nullish_coalescer: 146 | if is_last: 147 | if self.coalesced_value is None: 148 | return ret 149 | else: 150 | return lambda: self.coalesced_value 151 | else: 152 | self._maybe_compute_nullish_coalesced_value(ret, node_id) 153 | return lambda: None 154 | else: 155 | return ret 156 | 157 | 158 | class ScriptOptionalChainer(OptionalChainer): 159 | def should_instrument_file(self, filename: str) -> bool: 160 | return True 161 | -------------------------------------------------------------------------------- /test/test_stack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pyccolo as pyc 3 | 4 | 5 | def test_basic_stack(): 6 | class FunctionTracer(pyc.BaseTracer): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self.stack = self.make_stack() 10 | with self.stack.register_stack_state(): 11 | with self.stack.needing_manual_initialization(): 12 | self.name = "" 13 | self.dummy = None 14 | 15 | def should_propagate_handler_exception(self, _evt, exc: Exception) -> bool: 16 | return True 17 | 18 | @pyc.register_handler(pyc.before_call) 19 | def before_call(self, fun, *_, **__): 20 | with self.stack.push(): 21 | self.name = fun.__name__ 22 | 23 | @pyc.register_handler(pyc.after_call) 24 | def after_call(self, *_, **__): 25 | self.stack.pop() 26 | 27 | tracer = FunctionTracer.instance() 28 | 29 | with tracer.tracing_enabled(): 30 | # note: everything below is off by 1 because the 31 | # asserts themselves will call fns that push / pop 32 | pyc.exec( 33 | """ 34 | assert tracer.name == "" 35 | assert len(tracer.stack) == 1 36 | assert tracer.stack.get_field("name") == "" 37 | def f(): 38 | assert tracer.name == f.__name__ 39 | assert len(tracer.stack) == 2 40 | assert tracer.stack.get_field("name") == f.__name__ 41 | assert tracer.stack.get_field("name", depth=2) == "" 42 | def ggg(): 43 | assert tracer.name == ggg.__name__ 44 | assert len(tracer.stack) == 3 45 | assert tracer.stack.get_field("name") == ggg.__name__ 46 | assert tracer.stack.get_field("name", depth=2) == f.__name__ 47 | assert tracer.stack.get_field("name", depth=3) == "" 48 | def hhhhh(): 49 | assert tracer.name == hhhhh.__name__ 50 | assert len(tracer.stack) == 4 51 | assert tracer.stack.get_field("name") == hhhhh.__name__ 52 | assert tracer.stack.get_field("name", depth=2) == ggg.__name__ 53 | assert tracer.stack.get_field("name", depth=3) == f.__name__ 54 | assert tracer.stack.get_field("name", depth=4) == "" 55 | hhhhh() 56 | assert tracer.name == ggg.__name__ 57 | assert len(tracer.stack) == 3 58 | assert tracer.stack.get_field("name") == ggg.__name__ 59 | assert tracer.stack.get_field("name", depth=2) == f.__name__ 60 | assert tracer.stack.get_field("name", depth=3) == "" 61 | ggg() 62 | assert tracer.name == f.__name__ 63 | assert len(tracer.stack) == 2 64 | assert tracer.stack.get_field("name") == f.__name__ 65 | assert tracer.stack.get_field("name", depth=2) == "" 66 | f() 67 | assert tracer.name == "" 68 | assert len(tracer.stack) == 1 69 | assert tracer.stack.get_field("name") == "" 70 | """ 71 | ) 72 | 73 | 74 | class NestedTracer(pyc.BaseTracer): 75 | def __init__(self, *args, **kwargs): 76 | super().__init__(*args, **kwargs) 77 | self.stack = self.make_stack() 78 | with self.stack.register_stack_state(): 79 | self.list_stack = self.make_stack() 80 | with self.list_stack.register_stack_state(): 81 | self.running_length = 0 82 | 83 | def should_propagate_handler_exception(self, _evt, exc: Exception) -> bool: 84 | return True 85 | 86 | @pyc.register_handler(pyc.before_call) 87 | def before_call(self, *_, **__): 88 | with self.stack.push(): 89 | pass 90 | 91 | @pyc.register_handler(pyc.after_call) 92 | def after_call(self, *_, **__): 93 | self.stack.pop() 94 | 95 | @pyc.register_handler(pyc.before_list_literal) 96 | def before_list_literal(self, *_, **__): 97 | with self.list_stack.push(): 98 | pass 99 | 100 | @pyc.register_handler(pyc.after_list_literal) 101 | def after_list_literal(self, *_, **__): 102 | self.list_stack.pop() 103 | 104 | @pyc.register_handler(pyc.list_elt) 105 | def list_elt(self, *_, **__): 106 | self.running_length += 1 107 | 108 | 109 | def test_nested_stack(): 110 | tracer = NestedTracer.instance() 111 | with tracer.tracing_enabled(): 112 | assert ( 113 | pyc.exec( 114 | """ 115 | lst = [ 116 | tracer.running_length, 117 | tracer.running_length, 118 | [ 119 | tracer.running_length, 120 | tracer.running_length, 121 | tracer.running_length, 122 | ], 123 | tracer.running_length, 124 | tracer.running_length, 125 | ] 126 | """ 127 | )["lst"] 128 | == [0, 1, [0, 1, 2], 3, 4] 129 | ) 130 | 131 | assert ( 132 | pyc.exec( 133 | """ 134 | assert len(tracer.stack) == 1 135 | def f(): 136 | assert len(tracer.stack) == 2 137 | return [ 138 | tracer.running_length, 139 | tracer.running_length, 140 | [ 141 | tracer.running_length, 142 | tracer.running_length, 143 | tracer.running_length, 144 | ], 145 | tracer.running_length, 146 | tracer.running_length, 147 | ] 148 | lst = [tracer.running_length, f(), tracer.running_length] 149 | """ 150 | )["lst"] 151 | == [0, [0, 1, [0, 1, 2], 3, 4], 2] 152 | ) 153 | 154 | 155 | def test_clear(): 156 | tracer = NestedTracer.instance() 157 | 158 | def clear_one_level_up(): 159 | tracer.stack.get_field("list_stack").clear() 160 | return -1 161 | 162 | with tracer.tracing_enabled(): 163 | pyc.exec( 164 | """ 165 | try: 166 | lst = [clear_one_level_up()] 167 | except IndexError: 168 | pass 169 | else: 170 | assert False 171 | """ 172 | ) 173 | -------------------------------------------------------------------------------- /pyccolo/examples/quick_lambda.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Implementation of quick lambdas in Pyccolo, similar to MacroPy's. 4 | Ref: https://macropy3.readthedocs.io/en/latest/quick_lambda.html#quicklambda 5 | 6 | Example: 7 | ``` 8 | with QuickLambdaTracer: 9 | pyc.eval("f[_ + _](3, 4)") 10 | >>> 7 11 | ``` 12 | """ 13 | import ast 14 | import builtins 15 | from functools import reduce 16 | from types import FrameType 17 | from typing import Any, Dict, List, Tuple, cast 18 | 19 | import pyccolo as pyc 20 | from pyccolo import fast 21 | from pyccolo.examples.pipeline_tracer import PipelineTracer, SingletonArgCounterMixin 22 | from pyccolo.examples.quasiquote import Quasiquoter, is_macro 23 | from pyccolo.stmt_mapper import StatementMapper 24 | from pyccolo.trace_events import TraceEvent 25 | 26 | 27 | class _ArgReplacer(ast.NodeVisitor, SingletonArgCounterMixin): 28 | def __init__(self) -> None: 29 | super().__init__() 30 | self.placeholder_names: Dict[str, None] = {} 31 | 32 | def visit_Subscript(self, node: ast.Subscript) -> None: 33 | if ( 34 | isinstance(node.value, ast.Name) 35 | and node.value.id in QuickLambdaTracer.lambda_macros 36 | ): 37 | # defer visiting nested quick lambdas 38 | return 39 | self.generic_visit(node) 40 | 41 | def visit_Name(self, node: ast.Name) -> None: 42 | if ( 43 | node.id != "_" 44 | and id(node) 45 | not in PipelineTracer.augmented_node_ids_by_spec[ 46 | PipelineTracer.arg_placeholder_spec 47 | ] 48 | ): 49 | return 50 | # quick lambda will interpret this node as placeholder without any aug spec necessary 51 | PipelineTracer.augmented_node_ids_by_spec[ 52 | PipelineTracer.arg_placeholder_spec 53 | ].discard(id(node)) 54 | assert node.id.startswith("_") 55 | if node.id == "_": 56 | node.id = f"_{self.arg_ctr}" 57 | self.arg_ctr += 1 58 | else: 59 | if node.id[1].isalpha(): 60 | node.id = node.id[1:] 61 | self.placeholder_names[node.id] = None 62 | 63 | def visit_BinOp(self, node: ast.BinOp) -> None: 64 | if isinstance(node.op, ast.BitOr) and PipelineTracer.get_augmentations( 65 | id(node) 66 | ): 67 | return 68 | self.generic_visit(node) 69 | 70 | def get_placeholder_names(self, node: ast.AST) -> List[str]: 71 | self.placeholder_names.clear() 72 | self.visit(node) 73 | return list(self.placeholder_names) 74 | 75 | 76 | class QuickLambdaTracer(Quasiquoter): 77 | lambda_macros = ("f", "filter", "ifilter", "map", "imap", "reduce") 78 | 79 | def __init__(self, *args, **kwargs) -> None: 80 | super().__init__(*args, **kwargs) 81 | for macro in self.lambda_macros: 82 | self.macros.add(macro) 83 | self._arg_replacer = _ArgReplacer() 84 | builtins.reduce = reduce # type: ignore[attr-defined] 85 | builtins.imap = map # type: ignore[attr-defined] 86 | self.lambda_cache: Dict[Tuple[int, int, TraceEvent], Any] = {} 87 | 88 | _not_found = object() 89 | 90 | @pyc.before_subscript_slice(when=is_macro(lambda_macros), reentrant=True) 91 | def handle_quick_lambda( 92 | self, _ret, node: ast.Subscript, frame: FrameType, evt: TraceEvent, *_, **__ 93 | ): 94 | lambda_cache_key = (id(node), id(frame), evt) 95 | cached_lambda = self.lambda_cache.get(lambda_cache_key, self._not_found) 96 | if cached_lambda is not self._not_found: 97 | return cached_lambda 98 | __hide_pyccolo_frame__ = True 99 | orig_ctr = self._arg_replacer.arg_ctr 100 | orig_lambda_body: ast.expr = node.slice # type: ignore[assignment] 101 | if isinstance(orig_lambda_body, ast.Index): 102 | orig_lambda_body = orig_lambda_body.value # type: ignore[attr-defined] 103 | lambda_body = StatementMapper.augmentation_propagating_copy(orig_lambda_body) 104 | placeholder_names = self._arg_replacer.get_placeholder_names(lambda_body) 105 | if self._arg_replacer.arg_ctr == orig_ctr and len(placeholder_names) == 0: 106 | ast_lambda = lambda_body 107 | else: 108 | ast_lambda = SingletonArgCounterMixin.create_placeholder_lambda( 109 | placeholder_names, orig_ctr, lambda_body, frame.f_globals 110 | ) 111 | ast_lambda.body = lambda_body 112 | func = cast(ast.Name, node.value).id 113 | if func in ("filter", "ifilter", "map", "imap", "reduce"): 114 | with fast.location_of(ast_lambda): 115 | arg = f"_{self._arg_replacer.arg_ctr}" 116 | self._arg_replacer.arg_ctr += 1 117 | inner_func = func 118 | if func == "ifilter": 119 | inner_func = "filter" 120 | elif func == "imap": 121 | inner_func = "map" 122 | lambda_body_str = f"{inner_func}(None, {arg})" 123 | functor_lambda_body = cast( 124 | ast.Call, 125 | cast( 126 | ast.Expr, 127 | fast.parse(lambda_body_str).body[0], 128 | ).value, 129 | ) 130 | functor_lambda_body.args[0] = ast_lambda 131 | if func in ("filter", "map"): 132 | id_arg = f"_{self._arg_replacer.arg_ctr}" 133 | self._arg_replacer.arg_ctr += 1 134 | lambda_body_str = f"(list if type({arg}) is list else lambda {id_arg}: {id_arg})(None)" 135 | functor_lambda_outer_body = cast( 136 | ast.Call, 137 | cast( 138 | ast.Expr, 139 | fast.parse(lambda_body_str).body[0], 140 | ).value, 141 | ) 142 | functor_lambda_outer_body.args[0] = functor_lambda_body 143 | functor_lambda_body = functor_lambda_outer_body 144 | functor_lambda = cast( 145 | ast.Lambda, 146 | cast(ast.Expr, fast.parse(f"lambda {arg}: None").body[0]).value, 147 | ) 148 | functor_lambda.body = functor_lambda_body 149 | ast_lambda = functor_lambda 150 | evaluated_lambda = pyc.eval(ast_lambda, frame.f_globals, frame.f_locals) 151 | ret = lambda: __hide_pyccolo_frame__ and evaluated_lambda # noqa: E731 152 | self.lambda_cache[lambda_cache_key] = ret 153 | return ret 154 | -------------------------------------------------------------------------------- /pyccolo/trace_events.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | import sys 4 | import warnings 5 | from enum import Enum 6 | 7 | from pyccolo import fast 8 | 9 | 10 | class TraceEvent(Enum): 11 | before_import = "before_import" 12 | init_module = "init_module" 13 | exit_module = "exit_module" 14 | after_import = "after_import" 15 | 16 | before_stmt = "before_stmt" 17 | after_stmt = "after_stmt" 18 | after_module_stmt = "after_module_stmt" 19 | after_expr_stmt = "after_expr_stmt" 20 | _load_saved_expr_stmt_ret = "_load_saved_expr_stmt_ret" 21 | 22 | load_name = "load_name" 23 | 24 | after_bool = "after_bool" 25 | after_bytes = "after_bytes" 26 | after_complex = "after_complex" 27 | after_float = "after_float" 28 | after_int = "after_int" 29 | after_none = "after_none" 30 | after_string = "after_string" 31 | 32 | before_fstring = "before_fstring" 33 | after_fstring = "after_fstring" 34 | 35 | before_for_loop_body = "before_for_loop_body" 36 | after_for_loop_iter = "after_for_loop_iter" 37 | before_while_loop_body = "before_while_loop_body" 38 | after_while_loop_iter = "after_while_loop_iter" 39 | 40 | before_for_iter = "before_for_iter" 41 | after_for_iter = "after_for_iter" 42 | 43 | before_attribute_load = "before_attribute_load" 44 | before_attribute_store = "before_attribute_store" 45 | before_attribute_del = "before_attribute_del" 46 | after_attribute_load = "after_attribute_load" 47 | before_subscript_load = "before_subscript_load" 48 | before_subscript_store = "before_subscript_store" 49 | before_subscript_del = "before_subscript_del" 50 | after_subscript_load = "after_subscript_load" 51 | 52 | before_subscript_slice = "before_subscript_slice" 53 | after_subscript_slice = "after_subscript_slice" 54 | _load_saved_slice = "_load_saved_slice" 55 | 56 | before_load_complex_symbol = "before_load_complex_symbol" 57 | after_load_complex_symbol = "after_load_complex_symbol" 58 | 59 | after_if_test = "after_if_test" 60 | after_while_test = "after_while_test" 61 | 62 | before_lambda = "before_lambda" 63 | after_lambda = "after_lambda" 64 | 65 | decorator = "decorator" 66 | before_call = "before_call" 67 | after_call = "after_call" 68 | before_argument = "before_argument" 69 | after_argument = "after_argument" 70 | before_return = "before_return" 71 | after_return = "after_return" 72 | 73 | before_dict_literal = "before_dict_literal" 74 | after_dict_literal = "after_dict_literal" 75 | before_list_literal = "before_list_literal" 76 | after_list_literal = "after_list_literal" 77 | before_set_literal = "before_set_literal" 78 | after_set_literal = "after_set_literal" 79 | before_tuple_literal = "before_tuple_literal" 80 | after_tuple_literal = "after_tuple_literal" 81 | 82 | dict_key = "dict_key" 83 | dict_value = "dict_value" 84 | list_elt = "list_elt" 85 | set_elt = "set_elt" 86 | tuple_elt = "tuple_elt" 87 | 88 | before_assign_rhs = "before_assign_rhs" 89 | after_assign_rhs = "after_assign_rhs" 90 | before_augassign_rhs = "before_augassign_rhs" 91 | after_augassign_rhs = "after_augassign_rhs" 92 | 93 | before_assert = "before_assert" 94 | after_assert = "after_assert" 95 | 96 | before_function_body = "before_function_body" 97 | after_function_execution = "after_function_execution" 98 | 99 | before_lambda_body = "before_lambda_body" 100 | after_lambda_body = "after_lambda_body" 101 | 102 | before_left_binop_arg = "before_left_binop_arg" 103 | after_left_binop_arg = "after_left_binop_arg" 104 | before_right_binop_arg = "before_right_binop_arg" 105 | after_right_binop_arg = "after_right_binop_arg" 106 | before_binop = "before_binop" 107 | after_binop = "after_binop" 108 | 109 | before_boolop_arg = "before_boolop_arg" 110 | after_boolop_arg = "after_boolop_arg" 111 | before_boolop = "before_boolop" 112 | after_boolop = "after_boolop" 113 | 114 | left_compare_arg = "left_compare_arg" 115 | compare_arg = "compare_arg" 116 | before_compare = "before_compare" 117 | after_compare = "after_compare" 118 | 119 | after_comprehension_if = "after_comprehension_if" 120 | after_comprehension_elt = "after_comprehension_elt" 121 | after_dict_comprehension_key = "after_dict_comprehension_key" 122 | after_dict_comprehension_value = "after_dict_comprehension_value" 123 | 124 | exception_handler_type = "exception_handler_type" 125 | 126 | ellipsis = "ellipsis" 127 | 128 | line = "line" 129 | call = "call" 130 | return_ = "return" 131 | exception = "exception" 132 | opcode = "opcode" 133 | 134 | # these are included for completeness but will probably not be used 135 | c_call = "c_call" 136 | c_return = "c_return" 137 | c_exception = "c_exception" 138 | 139 | def __str__(self): 140 | return self.value 141 | 142 | def __repr__(self): 143 | return "<" + str(self) + ">" 144 | 145 | def __call__(self, handler=None, **kwargs): 146 | # this will be filled by tracer.py 147 | ... 148 | 149 | if sys.version_info < (3, 8): 150 | 151 | def to_ast(self): 152 | return fast.Str(self.name) 153 | 154 | else: 155 | 156 | def to_ast(self): 157 | return fast.Constant(self.name) 158 | 159 | 160 | SYS_TRACE_EVENTS = { 161 | TraceEvent.line, 162 | TraceEvent.call, 163 | TraceEvent.return_, 164 | TraceEvent.exception, 165 | TraceEvent.opcode, 166 | } 167 | 168 | 169 | BEFORE_EXPR_EVENTS = { 170 | TraceEvent.before_argument, 171 | TraceEvent.before_assign_rhs, 172 | TraceEvent.before_augassign_rhs, 173 | TraceEvent.before_binop, 174 | TraceEvent.before_boolop, 175 | TraceEvent.before_boolop_arg, 176 | TraceEvent.before_compare, 177 | TraceEvent.before_dict_literal, 178 | TraceEvent.before_for_iter, 179 | TraceEvent.before_fstring, 180 | TraceEvent.before_lambda, 181 | TraceEvent.before_left_binop_arg, 182 | TraceEvent.before_list_literal, 183 | TraceEvent.before_load_complex_symbol, 184 | TraceEvent.before_return, 185 | TraceEvent.before_right_binop_arg, 186 | TraceEvent.before_set_literal, 187 | TraceEvent.before_subscript_slice, 188 | TraceEvent.before_tuple_literal, 189 | } 190 | 191 | 192 | AST_TO_EVENT_MAPPING = { 193 | ast.arg: TraceEvent.after_argument, 194 | ast.stmt: TraceEvent.after_stmt, 195 | ast.Assign: TraceEvent.after_assign_rhs, 196 | ast.Module: TraceEvent.init_module, 197 | ast.Name: TraceEvent.load_name, 198 | ast.Attribute: TraceEvent.after_attribute_load, 199 | ast.Subscript: TraceEvent.after_subscript_load, 200 | ast.Call: TraceEvent.after_call, 201 | ast.Dict: TraceEvent.after_dict_literal, 202 | ast.List: TraceEvent.after_list_literal, 203 | ast.Tuple: TraceEvent.after_tuple_literal, 204 | ast.Set: TraceEvent.after_set_literal, 205 | ast.Return: TraceEvent.after_return, 206 | ast.BinOp: TraceEvent.after_binop, 207 | ast.Compare: TraceEvent.after_compare, 208 | } 209 | 210 | EVT_TO_EVENT_MAPPING = { 211 | TraceEvent.before_assert: TraceEvent.before_stmt, 212 | TraceEvent.after_assert: TraceEvent.after_stmt, 213 | } 214 | 215 | 216 | with warnings.catch_warnings(): 217 | warnings.simplefilter("ignore", DeprecationWarning) 218 | ast_Ellipsis = getattr(ast, "Ellipsis", None) 219 | if ast_Ellipsis is not None: 220 | AST_TO_EVENT_MAPPING[ast_Ellipsis] = TraceEvent.ellipsis 221 | -------------------------------------------------------------------------------- /docs/HISTORY.rst: -------------------------------------------------------------------------------- 1 | History 2 | ======= 3 | 4 | 0.0.73 (2025-11-21) 5 | ------------------- 6 | * convenience events for before / after assert; 7 | * make compatible with Python 3.14; 8 | 9 | 0.0.72 (2025-04-10) 10 | ------------------- 11 | * improve default tracer options; 12 | * bugfixes for coverage example; 13 | * for guard exempt handlers, only instrument exprs; 14 | 15 | 0.0.71 (2025-03-16) 16 | ------------------- 17 | * suppress false positive deprecation warnings; 18 | 19 | 0.0.70 (2025-01-04) 20 | ------------------- 21 | * add before / after for iter ast events 22 | 23 | 0.0.69 (2024-12-19) 24 | ------------------- 25 | * Automatic bookkeeping for return expressions from `after_module_stmt` events; 26 | 27 | 0.0.68 (2024-12-12) 28 | ------------------- 29 | * Always track containing statement during bookkeeping; 30 | * Improve recursion efficiency; 31 | 32 | 0.0.67 (2024-10-13) 33 | ------------------- 34 | * Make AST bookkeeping garbage collection optional; 35 | 36 | 0.0.66 (2024-10-06) 37 | ------------------- 38 | * Small configuration simplification; 39 | 40 | 0.0.65 (2024-08-25) 41 | ------------------- 42 | * Implement garbage collection for invalidated AST bookkeeping; 43 | 44 | 0.0.64 (2024-07-21) 45 | ------------------- 46 | * Implement guard exemption for event handlers; 47 | * Bugfix to ensure only applicable tracers transform ASTs of files instrumented in import hooks; 48 | 49 | 0.0.63 (2024-07-19) 50 | ------------------- 51 | * Add py.typed; 52 | 53 | 0.0.58 (2024-07-12) 54 | ------------------- 55 | * Support exception type instrumentation; 56 | * Add support for bytecode caching; 57 | 58 | 0.0.55 (2024-06-24) 59 | ------------------- 60 | * Remove bare except; 61 | * Be robust to undefined builtins; 62 | 63 | 0.0.54 (2024-02-29) 64 | ------------------- 65 | * Hoist global / nonlocal declarations above instrumentation branch; 66 | 67 | 0.0.53 (2024-02-24) 68 | ------------------- 69 | * Add decorator trace event; 70 | 71 | 0.0.52 (2023-12-21) 72 | ------------------- 73 | * Skip pyccolo frames in pdb; 74 | 75 | 0.0.51 (2023-12-19) 76 | ------------------- 77 | * Improve ergonomics of trace stack get_field() with height param; 78 | * Add option to disable tracing for non-main threads; 79 | * Add is_initial_frame_stmt utility; 80 | 81 | 0.0.50 (2023-12-07) 82 | ------------------- 83 | * Fix a bug in the optional chainer example; 84 | 85 | 0.0.49 (2023-08-19) 86 | ------------------- 87 | * Fix a bug around trace stack field initialization; 88 | 89 | 0.0.48 (2023-07-15) 90 | ------------------- 91 | * Fixes for 3.12 compat; 92 | * Improve composability of sys tracing; 93 | 94 | 0.0.45 (2023-01-02) 95 | ------------------- 96 | * Environment var for developer mode with more verbose logging; 97 | 98 | 0.0.44 (2022-12-14) 99 | ------------------- 100 | * Bugfix for circular deps; 101 | * Add events for various literal comprehensions; 102 | 103 | 0.0.43 (2022-12-06) 104 | ------------------- 105 | * Bugfix to ensure that patched sys.settrace does no override behavior on new threads; 106 | * Bugfix to ensure that custom finder does not override behavior on new threads; 107 | 108 | 0.0.41 (2022-12-04) 109 | ------------------- 110 | * Fix degraded performance of deepcopy when ast nodes have parent pointers; 111 | 112 | 0.0.39 (2022-11-19) 113 | ------------------- 114 | * Use Python tokenizer for syntax extensions; 115 | 116 | 0.0.38 (2022-10-30) 117 | ------------------- 118 | * Better meta_path fallback behavior; 119 | 120 | 0.0.37 (2022-10-26) 121 | ------------------- 122 | * Bugfix for 3.11; 123 | 124 | 0.0.36 (2022-10-24) 125 | ------------------- 126 | * Preserve global / nonlocal declarations in copy source; 127 | 128 | 0.0.35 (2022-08-06) 129 | ------------------- 130 | * Record when passed argument is the last one 131 | 132 | 0.0.34 (2022-07-18) 133 | ------------------- 134 | * Add enable / disable non-context convenience classmethods; 135 | 136 | 0.0.33 (2022-07-12) 137 | ------------------- 138 | * Allow before_import handler to overwrite source_path; 139 | * Allow tracer classes to themselves be used as context managers; 140 | 141 | 0.0.32 (2022-07-02) 142 | ------------------- 143 | * Disable tracing in import_hooks where applicable; 144 | 145 | 0.0.31 (2022-07-02) 146 | ------------------- 147 | * Support for before / after import events; 148 | 149 | 0.0.30 (2022-07-02) 150 | ------------------- 151 | * Better version handling; 152 | * Add NoopTracer just for use with exec / eval; 153 | 154 | 0.0.28 (2022-05-30) 155 | ------------------- 156 | * Add 'before_argument' event; 157 | * OptionalChainer improvements; 158 | 159 | 0.0.27 (2022-05-30) 160 | ------------------- 161 | * NullCoalescer -> OptionalChainer; 162 | 163 | 0.0.26 (2022-05-21) 164 | ------------------- 165 | * Get rid of phantom dependency on pytest; 166 | 167 | 0.0.25 (2022-04-18) 168 | ------------------- 169 | * Allow prefix / suffix augmentations for importfrom statements; 170 | 171 | 0.0.24 (2022-04-18) 172 | ------------------- 173 | * Allow prefix / suffix augmentations for import statements; 174 | 175 | 0.0.23 (2022-03-18) 176 | ------------------- 177 | * Support configuring whether global guards enabled; 178 | * Lazy importer: support unwrapping lazy symbols that result from subscripts; 179 | 180 | 0.0.22 (2022-03-17) 181 | ------------------- 182 | * Preserve docstring in function definitions; 183 | * Perform __future__ imports first; 184 | * Add local guard functionality; 185 | * Add lazy import example; 186 | 187 | 0.0.21 (2022-03-02) 188 | ------------------- 189 | * Bugfixes and improvements to FutureTracer example; 190 | 191 | 0.0.20 (2022-02-14) 192 | ------------------- 193 | * Provide non-context manager variants of tracing-related contexts; 194 | 195 | 0.0.19 (2022-02-14) 196 | ------------------- 197 | * Add 'exit_module' event; 198 | * Use deferred evaluation variants for all 'before expr' events; 199 | * Improve AST bookkeeping; 200 | * Add FutureTracer under pyccolo.examples; 201 | * Fix bug where starred expressions weren't traced if used as literal elements; 202 | 203 | 0.0.17 (2022-02-03) 204 | ------------------- 205 | * Fix packaging issue after new configuration; 206 | 207 | 0.0.14 (2022-02-02) 208 | ------------------- 209 | * Move configuration out of setup.py; 210 | 211 | 0.0.13 (2022-01-31) 212 | ------------------- 213 | * Default to all tracers in stack for package-level tracing enabled / disabled context managers; 214 | * Omit instrumentating the AST of statements underneath "with pyc.tracing_disabled()" bocks; 215 | * Add SkipAll return value; 216 | * Improve reentrancy for sys events; 217 | 218 | 0.0.12 (2022-01-30) 219 | ------------------- 220 | * Expose logic for resolving tracer class based on module path; 221 | 222 | 0.0.11 (2022-01-30) 223 | ------------------- 224 | * Expanded predicate functionality; 225 | * New events for after if / while test, after expr stmts, after lambda body, before / after augassign rhs; 226 | * Disambiguate between user and generated lambdas (e.g. used for before expr events); 227 | 228 | 0.0.10 (2022-01-26) 229 | ------------------- 230 | * Simplify binop events; 231 | * Add compare events; 232 | 233 | 0.0.9 (2022-01-24) 234 | ------------------ 235 | * Allow per-handler reentrancy; 236 | 237 | 0.0.8 (2022-01-23) 238 | ------------------ 239 | * Add eval helper; 240 | * Add syntactic macro examples (quasiquotes and quick lambdas); 241 | * Add support for conditional handlers; 242 | 243 | 0.0.7 (2022-01-06) 244 | ------------------ 245 | * Add cli; 246 | * Add basic readme documentation; 247 | * Allow returning pyc.Skip for skipping subsequent handlers for same event; 248 | * Misc improvements to file filter hooks; 249 | * Allow returning lambdas for before_expr events; 250 | 251 | 0.0.6 (2022-01-06) 252 | ------------------ 253 | * Misc ergonomics improvements; 254 | * Enable for Python 3.10; 255 | * Enable linting and fix package-level imports; 256 | 257 | 0.0.5 (2021-12-29) 258 | ------------------ 259 | * Get rid of future-annotations dependency; 260 | * Fix memory leak in sandbox exec; 261 | 262 | 0.0.4 (2021-12-26) 263 | ------------------ 264 | * Misc composability improvements and fixes; 265 | * Improve file filter handling; 266 | 267 | 0.0.3 (2021-12-23) 268 | ------------------ 269 | * Misc ergonomics improvements; 270 | * Misc composability improvements and fixes; 271 | 272 | 0.0.2 (2021-12-22) 273 | ------------------ 274 | * Initial internal release; 275 | 276 | 0.0.1 (2020-10-25) 277 | ------------------ 278 | * Initial placeholder release; 279 | -------------------------------------------------------------------------------- /pyccolo/_fast/misc_ast_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | import builtins 4 | import pickle 5 | import sys 6 | from contextlib import contextmanager 7 | from typing import ( 8 | TYPE_CHECKING, 9 | Callable, 10 | DefaultDict, 11 | Dict, 12 | Generator, 13 | Iterable, 14 | List, 15 | Optional, 16 | Set, 17 | Union, 18 | ) 19 | 20 | from pyccolo.extra_builtins import EMIT_EVENT, PYCCOLO_BUILTIN_PREFIX, TRACE_LAMBDA 21 | from pyccolo.stmt_mapper import StatementMapper 22 | from pyccolo.trace_events import BEFORE_EXPR_EVENTS, TraceEvent 23 | 24 | if TYPE_CHECKING: 25 | from pyccolo import fast 26 | from pyccolo.ast_rewriter import GUARD_DATA_T 27 | from pyccolo.tracer import BaseTracer 28 | else: 29 | from pyccolo._fast.fast_ast import FastAst as fast 30 | 31 | if sys.version_info < (3, 8): 32 | NumConst = ast.Num 33 | else: 34 | NumConst = ast.Constant 35 | 36 | 37 | class _SaveParentsVisitor(ast.NodeVisitor): 38 | has_parent = f"{PYCCOLO_BUILTIN_PREFIX}_has_parent_Xix54321" 39 | 40 | def generic_visit(self, node: ast.AST) -> None: 41 | if hasattr(node, "parent"): 42 | node.parent = self.has_parent # type: ignore 43 | super().generic_visit(node) 44 | 45 | def reinject(self, tree: ast.AST) -> None: 46 | for node in ast.walk(tree): 47 | for child in ast.iter_child_nodes(node): 48 | if getattr(child, "parent", None) == self.has_parent: 49 | child.parent = node # type: ignore 50 | 51 | 52 | @contextmanager 53 | def _save_parents(node: ast.AST) -> Generator[_SaveParentsVisitor, None, None]: 54 | visitor = _SaveParentsVisitor() 55 | visitor.visit(node) 56 | yield visitor 57 | 58 | 59 | def copy_ast(node: ast.AST) -> ast.AST: 60 | with _save_parents(node) as parents: 61 | node_copy = pickle.loads(pickle.dumps(node)) 62 | parents.reinject(node_copy) 63 | return node_copy 64 | 65 | 66 | def make_test(var_name: str, negate: bool = False) -> ast.expr: 67 | ret: ast.expr = fast.Name(var_name, ast.Load()) 68 | if negate: 69 | ret = fast.UnaryOp(operand=ret, op=ast.Not()) 70 | return ret 71 | 72 | 73 | def make_composite_condition( 74 | nullable_conditions: Iterable[Optional[Union[str, ast.expr]]], 75 | op: Optional[ast.boolop] = None, 76 | ) -> ast.expr: 77 | conditions = [ 78 | fast.Name(cond, ast.Load()) if isinstance(cond, str) else cond 79 | for cond in nullable_conditions 80 | if cond is not None 81 | ] 82 | if len(conditions) == 1: 83 | return conditions[0] 84 | op = op or fast.And() # type: ignore 85 | return fast.BoolOp(op=op, values=conditions) 86 | 87 | 88 | def subscript_to_slice(node: ast.Subscript) -> ast.expr: 89 | if isinstance(node.slice, ast.Index): 90 | return node.slice.value # type: ignore 91 | else: 92 | return node.slice # type: ignore 93 | 94 | 95 | class EmitterMixin: 96 | def __init__( 97 | self, 98 | tracers: "List[BaseTracer]", 99 | mapper: StatementMapper, 100 | orig_to_copy_mapping: Dict[int, ast.AST], 101 | handler_predicate_by_event: DefaultDict[TraceEvent, Callable[..., bool]], 102 | guard_exempt_handler_predicate_by_event: DefaultDict[ 103 | TraceEvent, Callable[..., bool] 104 | ], 105 | handler_guards_by_event: DefaultDict[TraceEvent, List["GUARD_DATA_T"]], 106 | ): 107 | self.tracers = tracers 108 | self.mapper = mapper 109 | self.orig_to_copy_mapping = orig_to_copy_mapping 110 | self._handler_predicate_by_event = handler_predicate_by_event 111 | self._guard_exempt_handler_predicate_by_event = ( 112 | guard_exempt_handler_predicate_by_event 113 | ) 114 | self.handler_guards_by_event = handler_guards_by_event 115 | self.guards: Set[str] = tracers[-1].guards 116 | self.global_guards_enabled = any( 117 | tracer.global_guards_enabled for tracer in tracers 118 | ) 119 | self._is_guard_exempt_context = False 120 | 121 | @contextmanager 122 | def guard_exempt_context( 123 | self, node: ast.AST, guard_exempt_node: ast.AST 124 | ) -> Generator[None, None, None]: 125 | new_orig_to_copy_mapping = self.mapper( 126 | guard_exempt_node, self.orig_to_copy_mapping[id(node)] 127 | ) 128 | orig_is_guard_exempt_context = self._is_guard_exempt_context 129 | orig_orig_to_copy_mapping = self.orig_to_copy_mapping 130 | try: 131 | self._is_guard_exempt_context = True 132 | self.orig_to_copy_mapping = new_orig_to_copy_mapping 133 | yield 134 | finally: 135 | self._is_guard_exempt_context = orig_is_guard_exempt_context 136 | self.orig_to_copy_mapping = orig_orig_to_copy_mapping 137 | 138 | @property 139 | def handler_predicate_by_event( 140 | self, 141 | ) -> DefaultDict[TraceEvent, Callable[..., bool]]: 142 | if self._is_guard_exempt_context: 143 | return self._guard_exempt_handler_predicate_by_event 144 | else: 145 | return self._handler_predicate_by_event 146 | 147 | @staticmethod 148 | def is_tracing_disabled_context(node: ast.AST): 149 | if not isinstance(node, ast.With): 150 | return False 151 | if len(node.items) != 1: 152 | return False 153 | expr = node.items[0].context_expr 154 | if not isinstance(expr, ast.Call): 155 | return False 156 | func = expr.func 157 | if not isinstance(func, ast.Attribute): 158 | return False 159 | return ( 160 | isinstance(func.value, ast.Name) 161 | and func.value.id == "pyc" 162 | and func.attr == "tracing_disabled" 163 | ) 164 | 165 | def register_guard(self, guard: str) -> None: 166 | self.guards.add(guard) 167 | setattr(builtins, guard, True) 168 | 169 | @staticmethod 170 | def make_func_name(name=EMIT_EVENT) -> ast.Name: 171 | return fast.Name(name, ast.Load()) 172 | 173 | def get_copy_node(self, orig_node_id: Union[int, ast.AST]) -> ast.AST: 174 | if not isinstance(orig_node_id, int): 175 | orig_node_id = id(orig_node_id) 176 | return self.orig_to_copy_mapping[orig_node_id] 177 | 178 | def get_copy_id_ast(self, orig_node_id: Union[int, ast.AST]) -> NumConst: 179 | return fast.Num(id(self.get_copy_node(orig_node_id))) 180 | 181 | def make_lambda( 182 | self, body: ast.expr, args: Optional[List[ast.arg]] = None 183 | ) -> ast.Call: 184 | return fast.Call( 185 | func=self.make_func_name(TRACE_LAMBDA), 186 | args=[ 187 | fast.Lambda( 188 | body=body, 189 | args=ast.arguments( 190 | args=[] if args is None else args, 191 | defaults=[], 192 | kwonlyargs=[], 193 | kw_defaults=[], 194 | posonlyargs=[], 195 | ), 196 | ) 197 | ], 198 | ) 199 | 200 | def emit( 201 | self, 202 | evt: TraceEvent, 203 | node_or_id: Union[int, ast.AST], 204 | args=None, 205 | before_expr_args=None, 206 | **kwargs, 207 | ) -> Union[ast.Call, ast.IfExp]: 208 | args = args or [] 209 | before_expr_args = before_expr_args or [] 210 | if evt in BEFORE_EXPR_EVENTS and "ret" in kwargs: 211 | kwargs_ret = kwargs["ret"] 212 | if ( 213 | not isinstance(kwargs_ret, ast.Call) 214 | or not isinstance(kwargs_ret.func, ast.Name) 215 | or kwargs_ret.func.id != TRACE_LAMBDA 216 | ): 217 | kwargs["ret"] = self.make_lambda(kwargs_ret) 218 | local_guard_makers = self.handler_guards_by_event.get(evt, None) 219 | local_guards = {} 220 | if local_guard_makers is not None: 221 | for spec, maker in local_guard_makers: 222 | guardval = maker(node_or_id) 223 | if guardval is not None: 224 | local_guards[id(spec)] = guardval 225 | if len(local_guards) == 0: 226 | kwargs["guards_by_handler_spec_id"] = fast.NameConstant(None) 227 | else: 228 | kwargs["guards_by_handler_spec_id"] = fast.Dict( 229 | keys=[fast.Num(k) for k in local_guards.keys()], 230 | values=[fast.Str(v) for v in local_guards.values()], 231 | ) 232 | ret: Union[ast.Call, ast.IfExp] = fast.Call( 233 | func=self.make_func_name(), 234 | args=[evt.to_ast(), self.get_copy_id_ast(node_or_id)] + args, 235 | keywords=fast.kwargs(**kwargs), 236 | ) 237 | if evt in BEFORE_EXPR_EVENTS: 238 | ret = fast.Call(func=ret, args=before_expr_args) 239 | if len(local_guards) > 0: 240 | for guard in local_guards.values(): 241 | self.tracers[-1].register_local_guard(guard) 242 | ret = fast.IfExp( 243 | test=make_composite_condition(local_guards.values(), op=ast.Or()), 244 | body=self.get_copy_node(node_or_id), 245 | orelse=ret, 246 | ) 247 | return ret 248 | -------------------------------------------------------------------------------- /pyccolo/syntax_augmentation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import itertools 3 | import tokenize 4 | import warnings 5 | from collections import Counter, defaultdict 6 | from enum import Enum 7 | from io import StringIO 8 | from typing import TYPE_CHECKING, Callable, Dict, List, NamedTuple, Set, Tuple, Union 9 | 10 | if TYPE_CHECKING: 11 | from pyccolo.ast_rewriter import AstRewriter 12 | 13 | CodeType = Union[str, List[str]] 14 | 15 | 16 | class AugmentationType(Enum): 17 | prefix = "prefix" 18 | suffix = "suffix" 19 | dot_prefix = "dot_prefix" 20 | dot_suffix = "dot_suffix" 21 | binop = "binop" 22 | boolop = "boolop" 23 | 24 | 25 | class AugmentationSpec(NamedTuple): 26 | aug_type: AugmentationType 27 | token: str 28 | replacement: str 29 | 30 | 31 | def fix_positions( 32 | pos_by_spec: Dict[AugmentationSpec, Set[Tuple[int, int]]], 33 | spec_order: Tuple[AugmentationSpec, ...], 34 | ) -> Dict[AugmentationSpec, Set[Tuple[int, int]]]: 35 | grouped_by_line: Dict[int, List[Tuple[int, AugmentationSpec]]] = defaultdict(list) 36 | fixed_pos_by_spec: Dict[AugmentationSpec, Set[Tuple[int, int]]] = {} 37 | for spec, positions in pos_by_spec.items(): 38 | fixed_pos_by_spec[spec] = set() 39 | for line, col in positions: 40 | grouped_by_line[line].append((col, spec)) 41 | 42 | for line, cols_with_spec in grouped_by_line.items(): 43 | total_offset_by_spec: Dict[AugmentationSpec, int] = Counter() 44 | offset_by_spec: Dict[AugmentationSpec, int] = Counter() 45 | cols_with_spec.sort(key=lambda x: x[0]) 46 | for col, spec in cols_with_spec: 47 | offset = len(spec.token) - len(spec.replacement) 48 | for prev_applied in spec_order: 49 | # the offsets will only be messed up for specs that 50 | # were applied earlier 51 | total_offset_by_spec[prev_applied] += offset 52 | if prev_applied == spec: 53 | break 54 | offset_by_spec[spec] += offset 55 | new_col = col - (total_offset_by_spec[spec] - offset_by_spec[spec]) 56 | fixed_pos_by_spec[spec].add((line, new_col)) 57 | 58 | return fixed_pos_by_spec 59 | 60 | 61 | def replace_tokens_and_get_augmented_positions( 62 | tokenizable: Union[str, List[tokenize.TokenInfo]], spec: AugmentationSpec 63 | ) -> Tuple[str, List[Tuple[int, int]]]: 64 | if isinstance(tokenizable, str): 65 | tokens = list(make_tokens_by_line([tokenizable]))[0] 66 | else: 67 | tokens = tokenizable 68 | transformed = StringIO() 69 | match = StringIO() 70 | cur_match_start = (-1, -1) 71 | col_offset = 0 72 | token_before_match_start = None 73 | token_before_match_start_col_offset = None 74 | 75 | def _flush_match(force: bool = False) -> None: 76 | nonlocal cur_match_start 77 | num_to_increment = 0 78 | while True: 79 | # TODO: this is super inefficient 80 | cur = match.getvalue() 81 | if cur == "" or (not force and spec.token.startswith(cur)): 82 | break 83 | match.seek(0) 84 | transformed.write(match.read(1)) 85 | num_to_increment += 1 86 | remaining = match.read() 87 | match.seek(0) 88 | match.truncate() 89 | match.write(remaining) 90 | cur_match_start = (cur_match_start[0], cur_match_start[1] + num_to_increment) 91 | 92 | def _write_match(tok: Union[str, tokenize.TokenInfo]) -> None: 93 | nonlocal cur_match_start 94 | nonlocal col_offset 95 | nonlocal token_before_match_start 96 | nonlocal token_before_match_start_col_offset 97 | if isinstance(tok, tokenize.TokenInfo): 98 | if match.getvalue() == "": 99 | cur_match_start = tok.start 100 | token_before_match_start = prev_non_whitespace_token 101 | token_before_match_start_col_offset = ( 102 | prev_non_whitespace_token_col_offset 103 | ) 104 | to_write = tok.string 105 | else: 106 | to_write = tok 107 | match.write(to_write) 108 | _flush_match() 109 | if spec.token != match.getvalue(): 110 | return 111 | if spec.aug_type in (AugmentationType.binop, AugmentationType.boolop): 112 | # for binop / boolop, we use left operand's end_col_offset to locate the position of the op 113 | if ( 114 | token_before_match_start is not None 115 | and token_before_match_start_col_offset is not None 116 | ): 117 | positions.append( 118 | ( 119 | token_before_match_start.end[0], 120 | token_before_match_start.end[1] 121 | + token_before_match_start_col_offset, 122 | ) 123 | ) 124 | else: 125 | match_pos_col_offset = cur_match_start[1] + col_offset 126 | match_pos_col_offset += len(spec.token) - len(spec.token.strip()) 127 | match_pos_col_offset += len(spec.token) - len(spec.token.lstrip()) 128 | positions.append((cur_match_start[0], match_pos_col_offset)) 129 | col_offset += len(spec.replacement) - len(spec.token) 130 | transformed.write(spec.replacement) 131 | cur_match_start = ( 132 | cur_match_start[0], 133 | cur_match_start[1] + len(match.getvalue()), 134 | ) 135 | match.seek(0) 136 | match.truncate() 137 | 138 | positions: List[Tuple[int, int]] = [] 139 | prev = None 140 | prev_non_whitespace_token = None 141 | prev_non_whitespace_token_col_offset = None 142 | for cur in tokens: 143 | if prev is not None and prev.end[0] == cur.start[0]: 144 | if match.getvalue() == "": 145 | cur_match_start = (prev.end[0], prev.end[1]) 146 | for _ in range(cur.start[1] - prev.end[1]): 147 | _write_match(" ") 148 | else: 149 | col_offset = 0 150 | _flush_match(force=True) 151 | cur_match_start = (cur.start[0], 0) 152 | for _ in range(cur.start[1]): 153 | _write_match(" ") 154 | _write_match(cur) 155 | prev = cur 156 | if cur.string.strip() != "": 157 | prev_non_whitespace_token = cur 158 | prev_non_whitespace_token_col_offset = col_offset 159 | 160 | _flush_match(force=True) 161 | return transformed.getvalue(), positions 162 | 163 | 164 | # copied from IPython to avoid bringing it in as a dependency 165 | # fine since it's BSD licensed 166 | def make_tokens_by_line(lines: List[str]) -> List[List[tokenize.TokenInfo]]: 167 | """Tokenize a series of lines and group tokens by line. 168 | 169 | The tokens for a multiline Python string or expression are grouped as one 170 | line. All lines except the last lines should keep their line ending ('\\n', 171 | '\\r\\n') for this to properly work. Use `.splitlines(keepends=True)` 172 | for example when passing block of text to this function. 173 | 174 | """ 175 | # NL tokens are used inside multiline expressions, but also after blank 176 | # lines or comments. This is intentional - see https://bugs.python.org/issue17061 177 | # We want to group the former case together but split the latter, so we 178 | # track parentheses level, similar to the internals of tokenize. 179 | NEWLINE, NL = tokenize.NEWLINE, tokenize.NL 180 | tokens_by_line: List[List[tokenize.TokenInfo]] = [[]] 181 | if len(lines) > 1 and not lines[0].endswith(("\n", "\r", "\r\n", "\x0b", "\x0c")): 182 | warnings.warn( 183 | "`make_tokens_by_line` received a list of lines which do not have " 184 | + "lineending markers ('\\n', '\\r', '\\r\\n', '\\x0b', '\\x0c'), " 185 | + "behavior will be unspecified" 186 | ) 187 | parenlev = 0 188 | try: 189 | for token in tokenize.generate_tokens(iter(lines).__next__): 190 | tokens_by_line[-1].append(token) 191 | if (token.type == NEWLINE) or ((token.type == NL) and (parenlev <= 0)): 192 | tokens_by_line.append([]) 193 | elif token.string in {"(", "[", "{"}: 194 | parenlev += 1 195 | elif token.string in {")", "]", "}"}: 196 | if parenlev > 0: 197 | parenlev -= 1 198 | except tokenize.TokenError: 199 | # Input ended in a multiline string or expression. That's OK for us. 200 | pass 201 | 202 | if not tokens_by_line[-1]: 203 | tokens_by_line.pop() 204 | 205 | return tokens_by_line 206 | 207 | 208 | def make_syntax_augmenter( 209 | rewriter: "AstRewriter", aug_spec: AugmentationSpec 210 | ) -> "Callable[[CodeType], CodeType]": 211 | def _input_transformer(lines: "CodeType") -> "CodeType": 212 | if isinstance(lines, list): 213 | code_lines: List[str] = lines 214 | else: 215 | code_lines = lines.splitlines(keepends=True) 216 | tokens = list(itertools.chain(*make_tokens_by_line(code_lines))) 217 | transformed, positions = replace_tokens_and_get_augmented_positions( 218 | tokens, aug_spec 219 | ) 220 | for pos in positions: 221 | rewriter.register_augmented_position(aug_spec, *pos) 222 | if isinstance(lines, list): 223 | return transformed.splitlines(keepends=True) 224 | else: 225 | return transformed 226 | 227 | return _input_transformer 228 | -------------------------------------------------------------------------------- /pyccolo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Pyccolo: declarative, composable, portable instrumentation embedded directly in Python source. 4 | 5 | Pyccolo brings metaprogramming to everybody via general event-emitting AST transformations. 6 | """ 7 | import ast 8 | import functools 9 | import inspect 10 | import textwrap 11 | import types 12 | from contextlib import contextmanager 13 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union 14 | 15 | from pyccolo.ast_rewriter import AstRewriter 16 | from pyccolo.emit_event import _TRACER_STACK, SANDBOX_FNAME, SANDBOX_FNAME_PREFIX, allow_reentrant_event_handling 17 | from pyccolo.extra_builtins import PYCCOLO_BUILTIN_PREFIX, make_guard_name 18 | from pyccolo.predicate import Predicate 19 | from pyccolo.syntax_augmentation import AugmentationSpec, AugmentationType 20 | from pyccolo.trace_events import TraceEvent 21 | from pyccolo.trace_stack import TraceStack 22 | from pyccolo.tracer import ( 23 | BaseTracer, 24 | NoopTracer, 25 | register_handler, 26 | register_raw_handler, 27 | skip_when_tracing_disabled, 28 | ) 29 | from pyccolo.utils import multi_context, resolve_tracer 30 | 31 | 32 | if TYPE_CHECKING: 33 | class Null: ... 34 | class Pass: ... 35 | class Skip: ... 36 | class SkipAll: ... 37 | else: 38 | from pyccolo.emit_event import SkipAll 39 | from pyccolo.tracer import Null, Pass, Skip 40 | 41 | event = TraceEvent 42 | 43 | 44 | call = TraceEvent.call 45 | return_ = TraceEvent.return_ 46 | exception = TraceEvent.exception 47 | before_import = TraceEvent.before_import 48 | init_module = TraceEvent.init_module 49 | exit_module = TraceEvent.exit_module 50 | after_import = TraceEvent.after_import 51 | before_stmt = TraceEvent.before_stmt 52 | after_stmt = TraceEvent.after_stmt 53 | after_module_stmt = TraceEvent.after_module_stmt 54 | after_expr_stmt = TraceEvent.after_expr_stmt 55 | before_assert = TraceEvent.before_assert 56 | after_assert = TraceEvent.after_assert 57 | load_name = TraceEvent.load_name 58 | after_bool = TraceEvent.after_bool 59 | after_bytes = TraceEvent.after_bytes 60 | after_complex = TraceEvent.after_complex 61 | after_float = TraceEvent.after_float 62 | after_int = TraceEvent.after_int 63 | after_none = TraceEvent.after_none 64 | after_string = TraceEvent.after_string 65 | before_fstring = TraceEvent.before_fstring 66 | after_fstring = TraceEvent.after_fstring 67 | before_for_loop_body = TraceEvent.before_for_loop_body 68 | after_for_loop_iter = TraceEvent.after_for_loop_iter 69 | before_while_loop_body = TraceEvent.before_while_loop_body 70 | after_while_loop_iter = TraceEvent.after_while_loop_iter 71 | before_for_iter = TraceEvent.before_for_iter 72 | after_for_iter = TraceEvent.after_for_iter 73 | before_attribute_load = TraceEvent.before_attribute_load 74 | before_attribute_store = TraceEvent.before_attribute_store 75 | before_attribute_del = TraceEvent.before_attribute_del 76 | after_attribute_load = TraceEvent.after_attribute_load 77 | before_subscript_load = TraceEvent.before_subscript_load 78 | before_subscript_store = TraceEvent.before_subscript_store 79 | before_subscript_del = TraceEvent.before_subscript_del 80 | after_subscript_load = TraceEvent.after_subscript_load 81 | before_subscript_slice = TraceEvent.before_subscript_slice 82 | after_subscript_slice = TraceEvent.after_subscript_slice 83 | before_load_complex_symbol = TraceEvent.before_load_complex_symbol 84 | after_load_complex_symbol = TraceEvent.after_load_complex_symbol 85 | after_if_test = TraceEvent.after_if_test 86 | after_while_test = TraceEvent.after_while_test 87 | before_lambda = TraceEvent.before_lambda 88 | after_lambda = TraceEvent.after_lambda 89 | decorator = TraceEvent.decorator 90 | before_call = TraceEvent.before_call 91 | after_call = TraceEvent.after_call 92 | before_argument = TraceEvent.before_argument 93 | after_argument = TraceEvent.after_argument 94 | before_return = TraceEvent.before_return 95 | after_return = TraceEvent.after_return 96 | before_dict_literal = TraceEvent.before_dict_literal 97 | after_dict_literal = TraceEvent.after_dict_literal 98 | before_list_literal = TraceEvent.before_list_literal 99 | after_list_literal = TraceEvent.after_list_literal 100 | before_set_literal = TraceEvent.before_set_literal 101 | after_set_literal = TraceEvent.after_set_literal 102 | before_tuple_literal = TraceEvent.before_tuple_literal 103 | after_tuple_literal = TraceEvent.after_tuple_literal 104 | dict_key = TraceEvent.dict_key 105 | dict_value = TraceEvent.dict_value 106 | list_elt = TraceEvent.list_elt 107 | set_elt = TraceEvent.set_elt 108 | tuple_elt = TraceEvent.tuple_elt 109 | before_assign_rhs = TraceEvent.before_assign_rhs 110 | after_assign_rhs = TraceEvent.after_assign_rhs 111 | before_augassign_rhs = TraceEvent.before_augassign_rhs 112 | after_augassign_rhs = TraceEvent.after_augassign_rhs 113 | before_function_body = TraceEvent.before_function_body 114 | after_function_execution = TraceEvent.after_function_execution 115 | before_lambda_body = TraceEvent.before_lambda_body 116 | after_lambda_body = TraceEvent.after_lambda_body 117 | before_left_binop_arg = TraceEvent.before_left_binop_arg 118 | after_left_binop_arg = TraceEvent.after_left_binop_arg 119 | before_right_binop_arg = TraceEvent.before_right_binop_arg 120 | after_right_binop_arg = TraceEvent.after_right_binop_arg 121 | before_binop = TraceEvent.before_binop 122 | after_binop = TraceEvent.after_binop 123 | before_boolop_arg = TraceEvent.before_boolop_arg 124 | after_boolop_arg = TraceEvent.after_boolop_arg 125 | before_boolop = TraceEvent.before_boolop 126 | after_boolop = TraceEvent.after_boolop 127 | left_compare_arg = TraceEvent.left_compare_arg 128 | compare_arg = TraceEvent.compare_arg 129 | before_compare = TraceEvent.before_compare 130 | after_compare = TraceEvent.after_compare 131 | after_comprehension_if = TraceEvent.after_comprehension_if 132 | after_comprehension_elt = TraceEvent.after_comprehension_elt 133 | after_dict_comprehension_key = TraceEvent.after_dict_comprehension_key 134 | after_dict_comprehension_value = TraceEvent.after_dict_comprehension_value 135 | exception_handler_type = TraceEvent.exception_handler_type 136 | ellipses = TraceEvent.ellipsis 137 | 138 | 139 | # redundant; do this just in case we forgot to add stubs in trace_events.py 140 | for evt in TraceEvent: 141 | globals()[evt.name] = evt 142 | 143 | 144 | # convenience functions for managing tracer singleton 145 | def tracer() -> BaseTracer: 146 | if len(_TRACER_STACK) > 0: 147 | return _TRACER_STACK[-1] 148 | else: 149 | return NoopTracer.instance() 150 | 151 | 152 | def instance() -> BaseTracer: 153 | return tracer() 154 | 155 | 156 | def parse(code: str, mode: str = "exec") -> Union[ast.Module, ast.Expression]: 157 | return tracer().parse(code, mode=mode) 158 | 159 | 160 | def eval(code: Union[str, ast.expr, ast.Expression], *args, **kwargs) -> Any: 161 | return tracer().eval( 162 | code, 163 | *args, 164 | num_extra_lookback_frames=kwargs.pop("num_extra_lookback_frames", 0) + 1, 165 | **kwargs, 166 | ) 167 | 168 | 169 | def exec(code: Union[str, ast.Module, ast.stmt], *args, **kwargs) -> Dict[str, Any]: 170 | return tracer().exec( 171 | code, 172 | *args, 173 | num_extra_lookback_frames=kwargs.pop("num_extra_lookback_frames", 0) + 1, 174 | **kwargs, 175 | ) 176 | 177 | 178 | def execute(*args, **kwargs) -> Dict[str, Any]: 179 | return exec(*args, **kwargs) 180 | 181 | 182 | def instrumented(tracers: List[BaseTracer]) -> Callable[[Callable[..., Any]], Callable[..., Any]]: 183 | def decorator(f: Callable[..., Any]) -> Callable[..., Any]: 184 | f_defined_file = f.__code__.co_filename 185 | with multi_context([tracer.tracing_disabled() for tracer in tracers]): 186 | code = ast.parse(textwrap.dedent(inspect.getsource(f))) 187 | code.body[0] = tracers[-1].make_ast_rewriter(path=f.__code__.co_filename).visit(code.body[0]) 188 | compiled: types.CodeType = compile(code, f.__code__.co_filename, "exec") 189 | for const in compiled.co_consts: 190 | if ( 191 | isinstance(const, types.CodeType) 192 | and const.co_name == f.__code__.co_name 193 | ): 194 | f.__code__ = const 195 | break 196 | 197 | @functools.wraps(f) 198 | def instrumented_f(*args, **kwargs) -> Any: 199 | with multi_context( 200 | [ 201 | tracer.tracing_enabled(tracing_enabled_file=f_defined_file) 202 | for tracer in tracers 203 | ] 204 | ): 205 | return f(*args, **kwargs) 206 | 207 | return instrumented_f 208 | 209 | return decorator 210 | 211 | 212 | @contextmanager 213 | def tracing_context(tracers=None, *args, **kwargs): 214 | tracers = _TRACER_STACK if tracers is None else tracers 215 | with multi_context([tracer.tracing_context(*args, **kwargs) for tracer in tracers]): 216 | yield 217 | 218 | 219 | @contextmanager 220 | def tracing_enabled(tracers=None, **kwargs): 221 | tracers = _TRACER_STACK if tracers is None else tracers 222 | if len(tracers) == 0: 223 | raise ValueError("Expected at least one tracer to enable") 224 | with multi_context([tracer.tracing_enabled(**kwargs) for tracer in tracers]): 225 | yield 226 | 227 | 228 | @contextmanager 229 | def tracing_disabled(tracers=None, **kwargs): 230 | if tracers is None: 231 | tracers = _TRACER_STACK 232 | if len(tracers) == 0: 233 | tracers = [NoopTracer.instance()] 234 | with multi_context([tracer.tracing_disabled(**kwargs) for tracer in tracers]): 235 | yield 236 | 237 | 238 | is_outer_stmt = BaseTracer.is_outer_stmt 239 | 240 | 241 | __all__ = [ 242 | "__version__", 243 | "AstRewriter", 244 | "AugmentationSpec", 245 | "AugmentationType", 246 | "BaseTracer", 247 | "NoopTracer", 248 | "Null", 249 | "PYCCOLO_BUILTIN_PREFIX", 250 | "Pass", 251 | "Predicate", 252 | "SANDBOX_FNAME", 253 | "SANDBOX_FNAME_PREFIX", 254 | "Skip", 255 | "SkipAll", 256 | "TraceStack", 257 | "allow_reentrant_event_handling", 258 | "event", 259 | "exec", 260 | "execute", 261 | "instance", 262 | "instrumented", 263 | "is_outer_stmt", 264 | "make_guard_name", 265 | "multi_context", 266 | "parse", 267 | "register_handler", 268 | "register_raw_handler", 269 | "resolve_tracer", 270 | "skip_when_tracing_disabled", 271 | "tracer", 272 | "tracing_context", 273 | "tracing_disabled", 274 | "tracing_enabled", 275 | ] 276 | 277 | 278 | # all the public events now 279 | __all__.extend(evt.name for evt in TraceEvent if evt not in (TraceEvent._load_saved_expr_stmt_ret, TraceEvent._load_saved_slice)) 280 | 281 | from pyccolo import _version # noqa: E402 282 | __version__ = _version.get_versions()['version'] 283 | -------------------------------------------------------------------------------- /pyccolo/examples/future_tracer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | import builtins 4 | import copy 5 | import logging 6 | import sys 7 | import threading 8 | import traceback 9 | from collections import Counter, defaultdict 10 | from concurrent.futures import Future, ThreadPoolExecutor 11 | from typing import Dict, List, Optional, Set, Tuple, Union 12 | 13 | import pyccolo as pyc 14 | import pyccolo.fast as fast 15 | from pyccolo.extra_builtins import PYCCOLO_BUILTIN_PREFIX 16 | 17 | try: 18 | from IPython import get_ipython 19 | except Exception: 20 | 21 | def get_ipython(): 22 | return None 23 | 24 | 25 | try: 26 | from ipyflow.singletons import flow 27 | except Exception: 28 | 29 | def flow(): 30 | return None 31 | 32 | 33 | logger = logging.getLogger(__name__) 34 | logger.setLevel(logging.WARNING) 35 | 36 | 37 | _UNWRAP_FUTURE_EXTRA_BUILTIN = f"{PYCCOLO_BUILTIN_PREFIX}_PYCCOLO_UNWRAP_FUTURE" 38 | _FUT_TAB_EXTRA_BUILTIN = f"{PYCCOLO_BUILTIN_PREFIX}_PYCCOLO_FUTURE_TABLE" 39 | 40 | 41 | class FutureUnwrapper(ast.NodeTransformer): 42 | def __init__( 43 | self, 44 | async_vars: Dict[str, int], 45 | future_by_name_and_version: Dict[Tuple[str, int], Future], 46 | ) -> None: 47 | self._async_vars = async_vars 48 | self._future_by_name_and_version = future_by_name_and_version 49 | self._deps: List[Future] = [] 50 | self._async_var: Optional[str] = None 51 | 52 | def __call__( 53 | self, node: ast.AST, async_var: str 54 | ) -> Tuple[ast.Expression, List[Future]]: 55 | self._async_var = async_var 56 | transformed_node = self.visit(copy.deepcopy(node)) 57 | deps, self._deps = self._deps, [] 58 | return ast.Expression(transformed_node), deps 59 | 60 | def visit_Name(self, node: ast.Name): 61 | assert isinstance(node.ctx, ast.Load) 62 | current_version = self._async_vars.get(node.id) 63 | if current_version is None: 64 | return node 65 | else: 66 | if node.id != self._async_var: 67 | # exclude usages of same var since we don't want to wait on self 68 | self._deps.append( 69 | self._future_by_name_and_version[node.id, current_version] 70 | ) 71 | with fast.location_of(node): 72 | slc: Union[ast.Tuple, ast.Index] = fast.Tuple( 73 | elts=[fast.Str(node.id), fast.Num(current_version)], 74 | ctx=ast.Load(), 75 | ) 76 | if sys.version_info < (3, 9): 77 | slc = fast.Index(slc, ctx=ast.Load()) 78 | return fast.Call( 79 | func=fast.Name(_UNWRAP_FUTURE_EXTRA_BUILTIN, ast.Load()), 80 | args=[ 81 | fast.Subscript( 82 | value=fast.Name(_FUT_TAB_EXTRA_BUILTIN, ast.Load()), 83 | slice=slc, 84 | ctx=ast.Load(), 85 | ) 86 | ], 87 | ) 88 | 89 | 90 | def _jump_to_non_internal_frame(tb): 91 | while tb is not None: 92 | pyccolo_seen = "pyccolo" in tb.tb_frame.f_code.co_filename 93 | concurrent_seen = "concurrent" in tb.tb_frame.f_code.co_filename 94 | if pyccolo_seen or concurrent_seen: 95 | tb = tb.tb_next 96 | else: 97 | break 98 | return tb 99 | 100 | 101 | def _unwrap_exception(ex: Exception) -> Exception: 102 | tb = _jump_to_non_internal_frame(ex.__traceback__) 103 | if tb is None: 104 | return ex 105 | prev_tb_next = None 106 | while tb is not None and tb.tb_next is not None and tb.tb_next is not prev_tb_next: 107 | prev_tb_next = tb.tb_next 108 | tb.tb_next = _jump_to_non_internal_frame(prev_tb_next) 109 | return ex.with_traceback(tb) 110 | 111 | 112 | class FutureTracer(pyc.BaseTracer): 113 | _MAX_WORKERS = 10 114 | 115 | def should_propagate_handler_exception(self, *_) -> bool: 116 | return True 117 | 118 | def __init__(self, *args, **kwargs) -> None: 119 | super().__init__(*args, **kwargs) 120 | with self.persistent_fields(): 121 | self._executor = ThreadPoolExecutor(max_workers=self._MAX_WORKERS) 122 | self._async_variable_version_by_name: Dict[str, int] = Counter() 123 | self._future_by_name_and_version: Dict[Tuple[str, int], Future] = {} 124 | self._waiters_by_future_id: Dict[int, Set[Future]] = defaultdict(set) 125 | self._exec_counter_by_future_id: Dict[int, int] = {} 126 | self._version_lock = threading.Lock() 127 | self._future_unwrapper = FutureUnwrapper( 128 | self._async_variable_version_by_name, self._future_by_name_and_version 129 | ) 130 | self._current_job_timestamp: int = 0 131 | self._timestamp_by_future_id: Dict[int, int] = {} 132 | self._threadlocal_state = threading.local() 133 | self._threadlocal_state.current_fut = None 134 | setattr(builtins, _UNWRAP_FUTURE_EXTRA_BUILTIN, self._unwrap_future) 135 | setattr(builtins, _FUT_TAB_EXTRA_BUILTIN, self._future_by_name_and_version) 136 | 137 | def __del__(self): 138 | self._executor.shutdown() 139 | 140 | def _unwrap_future(self, fut): 141 | if isinstance(fut, Future): 142 | if not fut.done(): 143 | current_ts = self._timestamp_by_future_id.get( 144 | id(self._threadlocal_state.current_fut) 145 | ) 146 | for waiter in self._waiters_by_future_id.get(id(fut), []): 147 | if ( 148 | current_ts is not None 149 | and current_ts >= self._timestamp_by_future_id[id(waiter)] 150 | ): 151 | continue 152 | else: 153 | self._unwrap_future(waiter) 154 | return fut.result() 155 | else: 156 | return fut 157 | 158 | @pyc.load_name(reentrant=True) 159 | def handle_load_name(self, ret, node, *_, **__): 160 | if node.id in self._async_variable_version_by_name: 161 | try: 162 | return self._unwrap_future(ret) 163 | except Exception as ex: 164 | ex = _unwrap_exception(ex) 165 | relevant_cell = self._exec_counter_by_future_id.get(id(ret)) 166 | if relevant_cell is not None: 167 | logger.error("Exception occurred in cell %d:", relevant_cell) 168 | logger.error( 169 | "".join(traceback.format_exception(type(ex), ex, ex.__traceback__)) 170 | ) 171 | return ret 172 | else: 173 | return ret 174 | 175 | @pyc.before_assign_rhs( 176 | when=lambda node: pyc.is_outer_stmt(node, exclude_outer_stmt_types={ast.Try}) 177 | ) 178 | def handle_assign_rhs(self, ret, node, frame, *_, **__): 179 | stmt = self.containing_stmt_by_id[id(node)] 180 | if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name): 181 | return ret 182 | async_var = stmt.targets[0].id 183 | unwrap_futures_expr, deps = self._future_unwrapper(node, async_var) 184 | unwrap_futures_code = compile( 185 | unwrap_futures_expr, "", mode="eval" 186 | ) 187 | with self._version_lock: 188 | self._async_variable_version_by_name[async_var] += 1 189 | current_version = self._async_variable_version_by_name[async_var] 190 | fut_cv = threading.Condition() 191 | fut = None 192 | 193 | def assign_rhs_job(): 194 | with fut_cv: 195 | # wait until we have a reference to the future indicated by 196 | # this very job 197 | while fut is None: 198 | fut_cv.wait() 199 | old_fut = self._future_by_name_and_version.get( 200 | (async_var, current_version - 1) 201 | ) 202 | for waiter in self._waiters_by_future_id.get(id(old_fut), []): 203 | # first, wait on everything that depends on the previous value to finish 204 | try: 205 | self._unwrap_future(waiter) 206 | except Exception: 207 | pass 208 | self._threadlocal_state.current_fut = fut 209 | retval = eval(unwrap_futures_code, frame.f_globals, frame.f_locals) 210 | self._threadlocal_state.current_fut = None 211 | # next, garbage collect the previous value 212 | self._waiters_by_future_id.pop(id(old_fut), None) 213 | self._timestamp_by_future_id.pop(id(old_fut), None) 214 | self._exec_counter_by_future_id.pop(id(old_fut), None) 215 | self._future_by_name_and_version.pop((async_var, current_version - 1), None) 216 | try: 217 | flow_ = flow() 218 | except Exception: 219 | flow_ = None 220 | with self._version_lock: 221 | if self._async_variable_version_by_name[async_var] == current_version: 222 | # by using 'is_outer_stmt', we can be sure 223 | # that setting the global is the right thing 224 | frame.f_globals[async_var] = retval 225 | if flow_ is not None: 226 | aliases = list(flow_.aliases.get(id(fut), [])) 227 | for alias in aliases: 228 | alias.update_obj_ref(retval) 229 | return retval 230 | 231 | ipy = get_ipython() 232 | current_cell = None if ipy is None else ipy.execution_count 233 | del ipy 234 | 235 | with fut_cv: 236 | fut = self._executor.submit(assign_rhs_job) 237 | self._current_job_timestamp += 1 238 | self._timestamp_by_future_id[id(fut)] = self._current_job_timestamp 239 | self._future_by_name_and_version[async_var, current_version] = fut 240 | if current_cell is not None: 241 | self._exec_counter_by_future_id[id(fut)] = current_cell 242 | for dep in deps: 243 | self._waiters_by_future_id[id(dep)].add(fut) 244 | fut_cv.notify() 245 | return lambda: fut 246 | 247 | @pyc.before_stmt(when=lambda node: isinstance(node, ast.AugAssign)) 248 | def handle_augassign(self, _ret, node, frame, *_, **__): 249 | async_var = node.target.id 250 | with self._version_lock: 251 | version = self._async_variable_version_by_name.get(async_var) 252 | if version is None: 253 | return 254 | else: 255 | fut = self._future_by_name_and_version[async_var, version] 256 | try: 257 | frame.f_globals[async_var] = self._unwrap_future(fut) 258 | except Exception: 259 | pass 260 | -------------------------------------------------------------------------------- /pyccolo/examples/lazy_imports.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | import copy 4 | import importlib 5 | import logging 6 | import sys 7 | from types import FrameType 8 | from typing import Any, List, Optional, Set, Union 9 | 10 | import pyccolo as pyc 11 | from pyccolo._fast.misc_ast_utils import subscript_to_slice 12 | from pyccolo.extra_builtins import PYCCOLO_BUILTIN_PREFIX 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | _unresolved = object() 18 | 19 | 20 | ast_Constant = getattr(ast, "Constant", type(None)) 21 | ast_Num = getattr(ast, "Num", type(None)) 22 | ast_Str = getattr(ast, "Str", type(None)) 23 | 24 | 25 | class _LazySymbol: 26 | non_modules: Set[str] = set() 27 | blocklist_packages: Set[str] = set() 28 | 29 | def __init__(self, spec: Union[ast.Import, ast.ImportFrom]): 30 | self.spec = spec 31 | self.value = _unresolved 32 | 33 | @property 34 | def qualified_module(self) -> str: 35 | node = self.spec 36 | name = node.names[0].name 37 | if isinstance(node, ast.Import): 38 | return name 39 | else: 40 | return f"{node.module}.{name}" 41 | 42 | @staticmethod 43 | def top_level_package(module: str) -> str: 44 | return module.split(".", 1)[0] 45 | 46 | @classmethod 47 | def _unwrap_module(cls, module: str) -> Any: 48 | if module in sys.modules: 49 | return sys.modules[module] 50 | exc = None 51 | if module not in cls.non_modules: 52 | try: 53 | with pyc.allow_reentrant_event_handling(): 54 | return importlib.import_module(module) 55 | except ImportError as e: 56 | cls.non_modules.add(module) 57 | exc = e 58 | except Exception: 59 | logger.error("fatal error trying to import module %s", module) 60 | raise 61 | module_symbol = module.rsplit(".", 1) 62 | if len(module_symbol) != 2: 63 | raise ValueError("invalid module %s" % module) from exc 64 | else: 65 | module, symbol = module_symbol 66 | ret = getattr(cls._unwrap_module(module), symbol) 67 | if isinstance(ret, _LazySymbol): 68 | ret = ret.unwrap() 69 | return ret 70 | 71 | def _unwrap_helper(self) -> Any: 72 | return self._unwrap_module(self.qualified_module) 73 | 74 | def unwrap(self) -> Any: 75 | if self.value is not _unresolved: 76 | return self.value 77 | ret = self._unwrap_helper() 78 | self.value = ret 79 | return ret 80 | 81 | def __call__(self, *args, **kwargs): 82 | raise TypeError("cant call _LazyName for spec %s" % ast.unparse(self.spec)) 83 | 84 | def __getattr__(self, item): 85 | raise TypeError( 86 | "cant __getattr__ on _LazyName for spec %s" % ast.unparse(self.spec) 87 | ) 88 | 89 | 90 | class _GetLazyNames(ast.NodeVisitor): 91 | def __init__(self) -> None: 92 | self.lazy_names: Optional[Set[str]] = set() 93 | 94 | def visit_Import(self, node: ast.Import) -> None: 95 | if self.lazy_names is None: 96 | return 97 | for alias in node.names: 98 | if alias.asname is None: 99 | return 100 | for alias in node.names: 101 | assert alias.asname is not None 102 | self.lazy_names.add(alias.asname) 103 | 104 | def visit_ImportFrom(self, node: ast.ImportFrom) -> None: 105 | if self.lazy_names is None: 106 | return 107 | for alias in node.names: 108 | if alias.name == "*": 109 | self.lazy_names = None 110 | return 111 | for alias in node.names: 112 | self.lazy_names.add(alias.asname or alias.name) 113 | 114 | @classmethod 115 | def compute(cls, node: ast.Module) -> Set[str]: 116 | inst = cls() 117 | inst.visit(node) 118 | return inst.lazy_names or set() 119 | 120 | 121 | def _make_attr_guard_helper(node: ast.Attribute) -> Optional[str]: 122 | if isinstance(node.value, ast.Name): 123 | return f"{node.value.id}_O_{node.attr}" 124 | elif isinstance(node.value, ast.Attribute): 125 | prefix = _make_attr_guard_helper(node.value) 126 | if prefix is None: 127 | return None 128 | else: 129 | return f"{prefix}_O_{node.attr}" 130 | else: 131 | return None 132 | 133 | 134 | def _make_attr_guard(node: ast.Attribute) -> Optional[str]: 135 | suffix = _make_attr_guard_helper(node) 136 | if suffix is None: 137 | return None 138 | else: 139 | return f"{PYCCOLO_BUILTIN_PREFIX}_{suffix}" 140 | 141 | 142 | def _make_subscript_guard_helper(node: ast.Subscript) -> Optional[str]: 143 | slice_val = subscript_to_slice(node) 144 | if isinstance(slice_val, (ast_Constant, ast_Str, ast_Num, ast.Name)): 145 | if isinstance(slice_val, ast.Name): 146 | subscript = slice_val.id 147 | elif hasattr(slice_val, "s"): 148 | subscript = f"_{slice_val.s}_" # type: ignore 149 | elif hasattr(slice_val, "n"): 150 | subscript = f"_{slice_val.n}_" # type: ignore 151 | else: 152 | return None 153 | else: 154 | return None 155 | if isinstance(node.value, ast.Name): 156 | return f"{node.value.id}_S_{subscript}" 157 | elif isinstance(node.value, ast.Subscript): 158 | prefix = _make_subscript_guard_helper(node.value) 159 | if prefix is None: 160 | return None 161 | else: 162 | return f"{prefix}_S_{subscript}" 163 | else: 164 | return None 165 | 166 | 167 | def _make_subscript_guard(node: ast.Subscript) -> Optional[str]: 168 | suffix = _make_subscript_guard_helper(node) 169 | if suffix is None: 170 | return None 171 | else: 172 | return f"{PYCCOLO_BUILTIN_PREFIX}_{suffix}" 173 | 174 | 175 | class LazyImportTracer(pyc.BaseTracer): 176 | def __init__(self, *args, **kwargs) -> None: 177 | super().__init__(*args, **kwargs) 178 | self.cur_module_lazy_names: Set[str] = set() 179 | self.saved_attributes: List[Any] = [] 180 | self.saved_subscripts: List[Any] = [] 181 | self.saved_slices: List[Any] = [] 182 | 183 | def _is_name_lazy_load(self, node: Union[ast.Attribute, ast.Name]) -> bool: 184 | if self.cur_module_lazy_names is None: 185 | return True 186 | elif isinstance(node, ast.Name): 187 | return node.id in self.cur_module_lazy_names 188 | elif isinstance(node, (ast.Attribute, ast.Subscript)): 189 | return self._is_name_lazy_load(node.value) # type: ignore 190 | elif isinstance(node, ast.Call): 191 | return self._is_name_lazy_load(node.func) 192 | else: 193 | return False 194 | 195 | def static_init_module(self, node: ast.Module) -> None: 196 | self.cur_module_lazy_names = _GetLazyNames.compute(node) 197 | 198 | @staticmethod 199 | def _convert_relative_to_absolute( 200 | package: str, module: Optional[str], level: int 201 | ) -> str: 202 | prefix = package.rsplit(".", level - 1)[0] 203 | if not module: 204 | return prefix 205 | else: 206 | return f"{prefix}.{module}" 207 | 208 | @pyc.init_module 209 | def init_module( 210 | self, _ret: None, node: ast.Module, frame: FrameType, *_, **__ 211 | ) -> None: 212 | assert node is not None 213 | for guard in self.local_guards_by_module_id.get(id(node), []): 214 | frame.f_globals[guard] = False 215 | 216 | @pyc.before_stmt( 217 | when=pyc.Predicate( 218 | lambda node: isinstance(node, (ast.Import, ast.ImportFrom)) 219 | and pyc.is_outer_stmt(node), 220 | static=True, 221 | ) 222 | ) 223 | def before_stmt( 224 | self, 225 | _ret: None, 226 | node: Union[ast.Import, ast.ImportFrom], 227 | frame: FrameType, 228 | *_, 229 | **__, 230 | ) -> Any: 231 | is_import = isinstance(node, ast.Import) 232 | for alias in node.names: 233 | if alias.name == "*": 234 | return None 235 | elif is_import and alias.asname is None: 236 | return None 237 | package = frame.f_globals["__package__"] 238 | level = getattr(node, "level", 0) 239 | if is_import: 240 | module = None 241 | else: 242 | module = node.module # type: ignore 243 | if level > 0: 244 | module = self._convert_relative_to_absolute(package, module, level) 245 | for alias in node.names: 246 | node_cpy = copy.deepcopy(node) 247 | node_cpy.names = [alias] 248 | if module is not None: 249 | node_cpy.module = module # type: ignore 250 | node_cpy.level = 0 # type: ignore 251 | frame.f_globals[alias.asname or alias.name] = _LazySymbol(spec=node_cpy) 252 | return pyc.Pass 253 | 254 | @pyc.before_attribute_load( 255 | when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_attr_guard 256 | ) 257 | def before_attr_load(self, ret: Any, *_, **__) -> Any: 258 | self.saved_attributes.append(ret) 259 | return ret 260 | 261 | @pyc.after_attribute_load( 262 | when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_attr_guard 263 | ) 264 | def after_attr_load( 265 | self, ret: Any, node: ast.Attribute, frame: FrameType, _evt, guard, *_, **__ 266 | ) -> Any: 267 | if guard is not None: 268 | frame.f_globals[guard] = True 269 | saved_attr_obj = self.saved_attributes.pop() 270 | if isinstance(ret, _LazySymbol): 271 | ret = ret.unwrap() 272 | setattr(saved_attr_obj, node.attr, ret) 273 | return pyc.Null if ret is None else ret 274 | 275 | @pyc.before_subscript_load( 276 | when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_subscript_guard 277 | ) 278 | def before_subscript_load(self, ret: Any, *_, attr_or_subscript: Any, **__) -> Any: 279 | self.saved_subscripts.append(ret) 280 | self.saved_slices.append(attr_or_subscript) 281 | return ret 282 | 283 | @pyc.after_subscript_load( 284 | when=pyc.Predicate(_is_name_lazy_load, static=True), guard=_make_subscript_guard 285 | ) 286 | def after_subscript_load( 287 | self, ret: Any, _node, frame: FrameType, _evt, guard, *_, **__ 288 | ) -> Any: 289 | if guard is not None: 290 | frame.f_globals[guard] = True 291 | saved_subscript_obj = self.saved_subscripts.pop() 292 | saved_slice_obj = self.saved_slices.pop() 293 | if isinstance(ret, _LazySymbol): 294 | ret = ret.unwrap() 295 | saved_subscript_obj[saved_slice_obj] = ret 296 | return pyc.Null if ret is None else ret 297 | 298 | @pyc.load_name( 299 | when=pyc.Predicate(_is_name_lazy_load, static=True), 300 | guard=lambda node: f"{PYCCOLO_BUILTIN_PREFIX}_{node.id}_guard", 301 | ) 302 | def load_name( 303 | self, ret: Any, node: ast.Name, frame: FrameType, _evt, guard, *_, **__ 304 | ) -> Any: 305 | if guard is not None: 306 | frame.f_globals[guard] = True 307 | if isinstance(ret, _LazySymbol): 308 | ret = ret.unwrap() 309 | frame.f_globals[node.id] = ret 310 | return pyc.Null if ret is None else ret 311 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Pyccolo 2 | ======= 3 | 4 | [![CI Status](https://github.com/smacke/pyccolo/workflows/pyccolo/badge.svg)](https://github.com/smacke/pyccolo/actions) 5 | [![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/) 6 | [![codecov](https://codecov.io/gh/smacke/pyccolo/branch/master/graph/badge.svg?token=MGORH1IXLO)](https://codecov.io/gh/smacke/pyccolo) 7 | [![License: BSD3](https://img.shields.io/badge/License-BSD3-maroon.svg)](https://opensource.org/licenses/BSD-3-Clause) 8 | [![Python Versions](https://img.shields.io/pypi/pyversions/pyccolo.svg)](https://pypi.org/project/pyccolo) 9 | [![PyPI Version](https://img.shields.io/pypi/v/pyccolo.svg)](https://pypi.org/project/pyccolo) 10 | 11 | Pyccolo (pronounced like the instrument "piccolo") is a library for declarative 12 | instrumentation in Python; i.e., it lets you specify the *what* of the 13 | instrumentation you wish to perform, and takes care of the *how* for you. It 14 | aims to be *ergonomic*, *composable*, and *portable*, by providing an intuitive 15 | interface, making it easy to layer multiple levels of instrumentation, and 16 | allowing the same code to work across multiple versions of Python (3.6 to 17 | 3.12), with few exceptions. Portability across versions is accomplished by 18 | embedding instrumentation at the level of source code (as opposed to 19 | bytecode-level instrumentation). 20 | 21 | Pyccolo can be used (and has been used) to implement various kinds of dynamic analysis 22 | tools and other instrumentation: 23 | - Code coverage (see [pyccolo/examples/coverage.py](https://github.com/smacke/pyccolo/blob/master/pyccolo/examples/coverage.py)) 24 | - Syntactic macros such as quasiquotes (like [MacroPy's](https://macropy3.readthedocs.io/en/latest/reference.html#quasiquote)) or quick lambdas; see [pyccolo/examples/quasiquote.py](https://github.com/smacke/pyccolo/blob/master/pyccolo/examples/quasiquote.py) and [pyccolo/examples/quick_lambda.py](https://github.com/smacke/pyccolo/blob/master/pyccolo/examples/quick_lambda.py) 25 | - Syntax-augmented Python (3.8 and up, see [pyccolo/examples/optional_chaining.py](https://github.com/smacke/pyccolo/blob/master/pyccolo/examples/optional_chaining.py)) 26 | - Dynamic dataflow analysis performed by [ipyflow](https://github.com/ipyflow/ipyflow) 27 | - Tools to perform (most) imports lazily (see [pyccolo/examples/lazy_imports.py](https://github.com/smacke/pyccolo/blob/master/pyccolo/examples/lazy_imports.py)) 28 | - Tools to uncover [semantic memory leaks](http://ithare.com/java-vs-c-trading-ub-for-semantic-memory-leaks-same-problem-different-punishment-for-failure/) 29 | - \ 30 | 31 | ## Install 32 | 33 | ```bash 34 | pip install pyccolo 35 | ``` 36 | 37 | ## Hello World 38 | 39 | Below is a simple script that uses Pyccolo to print "Hello, world!" before 40 | every statement that executes: 41 | 42 | ```python 43 | import pyccolo as pyc 44 | 45 | 46 | class HelloTracer(pyc.BaseTracer): 47 | @pyc.before_stmt 48 | def handle_stmt(self, *_, **__): 49 | print("Hello, world!") 50 | 51 | 52 | if __name__ == "__main__": 53 | with HelloTracer: 54 | # prints "Hello, world!" 11 times 55 | pyc.exec("for _ in range(10): pass") 56 | ``` 57 | 58 | Instrumentation is provided by a *tracer class* that inherit from 59 | `pyccolo.BaseTracer`. This class rewrites Python source code with 60 | instrumentation that triggers whenever events of interest occur, such as when a 61 | statement is about to execute. By registering a handler with the associated 62 | event (with the `@pyc.before_stmt` decorator, in this case), we can enrich our 63 | programs with additional observability, or even alter their behavior 64 | altogether. 65 | 66 | ### What is up with `pyc.exec(...)`? 67 | 68 | A program's abstract syntax tree is fixed at import / compile time, and when 69 | our script initially started running, the tracer was not active, so unquoted 70 | Python in the same file will lack instrumentation. It is possible to instrument 71 | modules at import time, but only when the imports are performed inside a 72 | tracing context. Thus, we must quote any code appearing in the same module 73 | where the tracer class was defined in order to instrument it. 74 | 75 | ## Composing tracers 76 | 77 | A core feature of Pyccolo is that its instrumentation is *composable*. It's 78 | usually tricky to use two or more `ast.NodeTransformer` classes simultaneously 79 | --- sometimes you can just have one inherit from the other, but if they both 80 | define `visit` methods for the same AST node type, then typically you would 81 | need to define a bespoke node transformer that uses logic from each base 82 | transformer, handling corner cases to resolve incompatibilities. With Pyccolo, 83 | you simply compose the context managers of each tracer class whose 84 | instrumentation you wish to use, and everything usually Just 85 | WorksTM: 86 | 87 | ```python 88 | with tracer1: 89 | with tracer2: 90 | pyc.exec(...) 91 | ``` 92 | 93 | ## Compatibility with sys.settrace(...) 94 | 95 | Pyccolo is designed to support not only AST-level instrumentation, but also 96 | instrumentation involving Python's [built in tracing 97 | utilities](https://docs.python.org/3/library/sys.html#sys.settrace). 98 | To use it, you simply register handlers for one of the corresponding 99 | Pyccolo events (`call`, `line`, `return_`, `exception`, or `opcode`). 100 | Here's a minimal example: 101 | 102 | ```python 103 | import pyccolo as pyc 104 | 105 | 106 | class SysTracer(pyc.BaseTracer): 107 | @pyc.call 108 | def handle_call(self, *_, **__): 109 | print("Pushing a stack frame!") 110 | 111 | @pyc.return_ 112 | def handle_return(self, *_, **__): 113 | print("Popping a stack frame!") 114 | 115 | 116 | if __name__ == "__main__": 117 | with SysTracer: 118 | def f(): 119 | def g(): 120 | return 42 121 | return g() 122 | # push, push, pop, pop 123 | answer_to_life_universe_everything = f() 124 | ``` 125 | 126 | Note that we didn't need to use `pyc.exec(...)` in the above example, because Python's built-in 127 | tracing does not involve any AST-level transformations. If, however, we had registered handlers 128 | for other events, such as `pyc.before_stmt`, we would need to use `pyc.exec(...)` to ensure those 129 | handlers get called, when running code in the same file where our tracer class is defined. 130 | 131 | ### What if I'm already using sys.settrace(...) with my own tracing function? 132 | 133 | Pyccolo is designed to be *composable*, and should execute both your tracing function as well 134 | as any handlers defined in any active Pyccolo tracers. For example Pyccolo's unit tests for 135 | `call` and `return` events work even when [coverage.py](https://coverage.readthedocs.io/) 136 | is active (and without breaking it), which also uses Python's built-in tracing utilities. 137 | 138 | ## Instrumenting Imported Modules 139 | 140 | Instrumentation is opt-in for modules imported within tracing contexts. To determine whether 141 | a module gets instrumented, the method `should_instrument_file(...)` is called with the module's 142 | corresponding filename as input. For example: 143 | 144 | ```python 145 | class MyTracer(pyc.BaseTracer): 146 | def should_instrument_file(self, filename: str) -> bool: 147 | return filename.endswith("foo.py") 148 | 149 | # handlers, etc. defined below 150 | ... 151 | 152 | with MyTracer: 153 | import foo # contents of `foo` module get instrumented 154 | import bar # contents of `bar` module do not get instrumented 155 | ``` 156 | 157 | Imports are instrumented by registering a custom finder / loader with `sys.meta_path`. 158 | This loader ignores cached bytecode (which may possibly be uninstrumented), and avoids 159 | generating *new* cached bytecode (which would be instrumented, possibly causing confusion 160 | later when instrumentation is not desired). 161 | 162 | ## Command Line Interface 163 | 164 | You can execute arbitrary scripts with instrumentation enabled with the `pyc` command line tool. 165 | For example, to use the `OptionalChainer` tracer defined in [pyccolo/examples/optional_chaining.py](https://github.com/smacke/pyccolo/blob/master/pyccolo/examples/optional_chaining.py), 166 | you can call `pyc` as follows, given some example script `bar.py`: 167 | 168 | ```python 169 | # bar.py 170 | bar = None 171 | # prints `None` since bar?.foo coalesces to `None` 172 | print(bar?.foo) 173 | ``` 174 | 175 | ```bash 176 | > pyc bar.py -t pyccolo.examples.OptionalChainer 177 | ``` 178 | 179 | You can also run `bar` as a module (indeed, `pyc` performs this internally when provided a file): 180 | 181 | ```bash 182 | > pyc -m bar -t pyccolo.examples.OptionalChainer 183 | ``` 184 | 185 | Note that you can specify multiple tracer classes after the `-t` argument; 186 | in case you were not already aware, Pyccolo is composable! :) 187 | 188 | The above example demonstrates a tracer class that performs syntax augmentation on its 189 | instrumented Python source to modify the default Python syntax. This feature is available 190 | only on Python >= 3.8 for now and is lacking documentation for the moment, but you can 191 | see some examples in the [test_syntax_augmentation.py](https://github.com/smacke/pyccolo/blob/master/test/test_syntax_augmentation.py) unit tests. 192 | 193 | ## More Events 194 | 195 | Pyccolo handlers can be registered for many kinds of events. Some of the more common ones are: 196 | - `pyc.before_stmt`, emitted before a statement executes; 197 | - `pyc.after_stmt`, emitted after a statement executes; 198 | - `pyc.before_attribute_load`, emitted in [load contexts](https://docs.python.org/3/library/ast.html#ast.Load) before an attribute is accessed; 199 | - `pyc.after_attribute_load`, emitted in load contexts after an attribute is accessed; 200 | - `pyc.load_name`, emitted when a variable is used in a load context (e.g. `foo` in `bar = foo.baz`); 201 | - `pyc.call` and `pyc.return_`, two non-AST trace events built-in to Python. 202 | 203 | There are many different Pyccolo events, and more are always being added. See 204 | [pyccolo/trace_events.py](https://github.com/smacke/pyccolo/blob/master/pyccolo/trace_events.py) 205 | for a full list. 206 | 207 | Note that, for AST events, Python source is only transformed to emit some event 208 | when there is at least one tracer active that has at least one handler 209 | registered for that event. This prevents the transformed source from becoming 210 | extremely bloated when only a few events are needed. 211 | 212 | ## Handler Interface 213 | 214 | Every Pyccolo handler is passed four positional arguments: 215 | 1. The return value, for instrumented expressions; 216 | 2. The AST node (or node id, if using `register_raw_handler(...)`, or `None`, for `sys` events); 217 | 3. The stack frame, at the point where instrumentation kicks in; 218 | 4. The event (useful when the same handler is registered for multiple events). 219 | 220 | Some events pass additional keyword arguments, which I'm still in the process 221 | of documenting, but the above four tend to suffice for most use cases. 222 | 223 | Not every handler receives a return value; for example, this argument is always 224 | `None` for `pyc.after_stmt` handlers. For certain handlers, the return value 225 | can be overridden. For example, by returning a value in a 226 | `pyc.before_attribute_load`, we override the object whose attribute is 227 | accessed. If we return nothing or `None`, then we do not override this object. 228 | (If we actually want to override it as `None` for some reason, then we can 229 | return `pyc.Null`.) For a particular event, handler return values compose with 230 | other handlers defined on the same tracer class as well as with handlers 231 | defined on other tracer classes. 232 | 233 | ## Performance 234 | 235 | Pyccolo instrumentation adds significant overhead to Python. In some 236 | cases, this overhead can be partially mitigated if, for example, you only need 237 | instrumentation the first time a statement runs. In such cases, you can 238 | deactivate instrumentation after, e.g., the first time a function executes, or 239 | after the first iteration in a loop for that respective function or loop, so 240 | that further calls (iterations, respectively) use uninstrumented code with all 241 | the mighty performance of native Python. This is implemented by activating 242 | "guards" associated with the function or loop, as in the below example: 243 | 244 | ```python 245 | class TracesOnce(pyc.BaseTracer): 246 | @pyc.register_raw_handler((pyc.after_for_loop_iter, pyc.after_while_loop_iter)) 247 | def after_loop_iter(self, *_, guard, **__): 248 | self.activate_guard(guard) 249 | 250 | @pyc.register_raw_handler(pyc.after_function_execution) 251 | def after_function_exec(self, *_, guard, **__): 252 | self.activate_guard(guard) 253 | ``` 254 | 255 | Subsequent calls / iterations will be instrumented only after calling 256 | `self.deactivate_guard(...)` on the associated function / loop guard. 257 | 258 | ## License 259 | Code in this project licensed under the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause). 260 | -------------------------------------------------------------------------------- /test/test_pipeline_tracer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import textwrap 4 | 5 | import pyccolo as pyc 6 | from pyccolo.examples import OptionalChainer, PipelineTracer, QuickLambdaTracer 7 | 8 | if sys.version_info >= (3, 8): # noqa 9 | 10 | def test_simple_pipeline(): 11 | with PipelineTracer: 12 | assert pyc.eval("(1, 2, 3) |> list") == [1, 2, 3] 13 | 14 | def test_value_first_partial_apply_then_apply(): 15 | with PipelineTracer: 16 | assert pyc.eval("5 $> isinstance <| int") is True 17 | 18 | def test_fake_infix(): 19 | with PipelineTracer: 20 | assert pyc.eval("5 $>isinstance<| int") is True 21 | 22 | def test_value_first_partial_tuple_apply_then_apply(): 23 | with PipelineTracer: 24 | assert pyc.eval("(1, 2) *$> (lambda a, b, c: a + b + c) <| 3") == 6 25 | 26 | def test_value_first_partial_tuple_apply_then_apply_quick_lambda(): 27 | with PipelineTracer: 28 | with QuickLambdaTracer: 29 | assert pyc.eval("(1, 2) *$> f[_ + _ + _] <| 3") == 6 30 | 31 | def test_function_first_partial_apply_then_apply(): 32 | with PipelineTracer: 33 | assert pyc.eval("isinstance <$ 5 <| int") is True 34 | 35 | def test_function_first_partial_tuple_apply_then_apply(): 36 | with PipelineTracer: 37 | assert pyc.eval("(lambda a, b, c: a + b + c) <$* (1, 2) <| 3") == 6 38 | 39 | def test_function_first_partial_tuple_apply_then_apply_quick_lambda(): 40 | with PipelineTracer: 41 | with QuickLambdaTracer: 42 | assert pyc.eval("f[_ + _ + _] <$* (1, 2) <| 3") == 6 43 | 44 | def test_pipe_into_value_first_partial_apply(): 45 | with PipelineTracer: 46 | assert pyc.eval("int |> (5 $> isinstance)") is True 47 | 48 | def test_pipe_into_function_first_partial_apply(): 49 | with PipelineTracer: 50 | assert pyc.eval("int |> (isinstance <$ 5)") is True 51 | 52 | def test_simple_pipeline_with_quick_lambda_map(): 53 | with PipelineTracer: 54 | with QuickLambdaTracer: 55 | assert pyc.eval("(1, 2, 3) |> f[map(f[_ + 1], _)] |> list") == [2, 3, 4] 56 | 57 | def test_pipeline_assignment(): 58 | with PipelineTracer: 59 | with QuickLambdaTracer: 60 | assert pyc.eval( 61 | "(1, 2, 3) |> list |>> result |> f[map(f[_ + 1], _)] |> list |> f[result + _]" 62 | ) == [1, 2, 3, 2, 3, 4] 63 | 64 | def test_pipeline_methods(): 65 | with PipelineTracer: 66 | assert pyc.eval("(1, 2, 3) |> list |> $.index(2)") == 1 67 | 68 | def test_pipeline_methods_nonstandard_whitespace(): 69 | with PipelineTracer: 70 | assert pyc.eval("(1, 2, 3) |> list |> $.index(2)") == 1 71 | 72 | def test_left_tuple_apply(): 73 | with PipelineTracer: 74 | assert pyc.eval("(5, int) *|> isinstance") is True 75 | 76 | def test_right_tuple_apply(): 77 | with PipelineTracer: 78 | assert pyc.eval("isinstance <|* (5, int)") is True 79 | 80 | def test_compose_op(): 81 | with PipelineTracer: 82 | assert pyc.eval("((lambda x: x * 5) . (lambda x: x + 2))(10)") == 60 83 | 84 | def test_tuple_compose_op(): 85 | with PipelineTracer: 86 | assert ( 87 | pyc.eval("((lambda x, y: x * 5 + y) .* (lambda x: (x, x + 2)))(10)") 88 | == 62 89 | ) 90 | 91 | def test_compose_op_no_space(): 92 | with PipelineTracer: 93 | assert pyc.eval("((lambda x: x * 5). (lambda x: x + 2))(10)") == 60 94 | 95 | def test_compose_op_extra_space(): 96 | with PipelineTracer: 97 | assert pyc.eval("((lambda x: x * 5) . (lambda x: x + 2))(10)") == 60 98 | 99 | def test_compose_op_with_parenthesized_quick_lambdas(): 100 | with PipelineTracer: 101 | with QuickLambdaTracer: 102 | assert pyc.eval("((f[_ * 5]) . (f[_ + 2]))(10)") == 60 103 | 104 | def test_compose_op_with_quick_lambdas(): 105 | with PipelineTracer: 106 | with QuickLambdaTracer: 107 | assert pyc.eval("(f[_ * 5] . f[_ + 2])(10)") == 60 108 | 109 | def test_pipeline_inside_quick_lambda(): 110 | with PipelineTracer: 111 | with QuickLambdaTracer: 112 | assert pyc.eval("2 |> f[$ |> $ + 2]") == 4 113 | assert pyc.eval("2 |> f[$ |> f[_ + 2]]") == 4 114 | 115 | def test_pipeline_dot_op_with_optional_chain(): 116 | with PipelineTracer: 117 | with OptionalChainer: 118 | assert ( 119 | pyc.eval( 120 | "(3, 1, 2) |> (list . reversed . sorted) |> $.index(2).?foo" 121 | ) 122 | is None 123 | ) 124 | 125 | def test_function_placeholder(): 126 | with PipelineTracer: 127 | with QuickLambdaTracer: 128 | # TODO: the commented out ones don't work due to an issue in how NamedExpr values don't get 129 | # bound to lambda closures, which is a weakness in pyccolo BEFORE_EXPR_EVENTS. Technically 130 | # BEFORE_EXPR_EVENTS should all be using the default value binding trick. 131 | # assert pyc.eval("(add := (lambda x, y: x + y)) and (add1 := add($, 1)) and add1(42)") == 43 132 | # assert pyc.eval("(add := (lambda x, y: x + y)) and add(42, 1)") == 43 133 | pyc.exec("(add := (lambda x, y: x + y)); assert add(42, 1) == 43") 134 | pyc.exec( 135 | "(add := (lambda x, y: x + y)); assert (lambda y: add(42, y))(1)" 136 | ) 137 | pyc.exec( 138 | "(add := (lambda x, y: x + y)); assert (lambda y: add(42, y)) <| 1 == 43" 139 | ) 140 | pyc.exec("(add := (lambda x, y: x + y)); assert add(42, $) <| 1 == 43") 141 | pyc.exec("(add := f[$ + $]); assert (add($, 1) <| 1) == 2") 142 | pyc.exec("(add := f[$ + $]); assert 1 |> add($, 1) == 2") 143 | pyc.exec("add = f[$ + $]; add1 = add($, 1); assert add1(42) == 43") 144 | pyc.exec("add = f[$ + $]; assert add($, 42) <| 1 == 43") 145 | assert pyc.eval("(f[$ + $] |>> add) and add($, 1) <| 1") == 2 146 | assert pyc.eval("(f[$ + $] |>> add) and 1 |> add($, 1)") == 2 147 | 148 | def test_tuple_unpack_with_placeholders(): 149 | with PipelineTracer: 150 | with QuickLambdaTracer: 151 | assert pyc.eval("($, $) *|> $ + $ <|* (1, 2)") == 3 152 | assert pyc.eval("($, $) *|> $ + $ <|* (1, 2) |> $.real") == 3 153 | assert pyc.eval("($, $) *|> $ + $ <|* (1, 2) |> $.imag") == 0 154 | assert pyc.eval("($, $) *|> $ + $ <|* (1, 2) |> $ + 1") == 4 155 | assert pyc.eval("(1, 2) *|> ($, $) *|> $ + $") == 3 156 | 157 | def test_placeholder_with_kwarg(): 158 | with PipelineTracer: 159 | pyc.exec("def add(x, y): return x + y; assert 1 |> add($, y=42) == 43") 160 | pyc.exec("42 |> print($, end=' ')") 161 | 162 | def test_keyword_placeholder(): 163 | with PipelineTracer: 164 | pyc.exec( 165 | "func = sorted([1, 3, 2], reverse=$); assert func(False) == [1, 2, 3]; assert func(True) == [3, 2, 1]" 166 | ) 167 | 168 | def test_named_placeholders_simple(): 169 | with PipelineTracer: 170 | with QuickLambdaTracer: 171 | assert pyc.eval("reduce[$x + $y]([1, 2, 3])") == 6 172 | assert pyc.eval("sorted($lst, reverse=True)([1, 2, 3])") == [3, 2, 1] 173 | 174 | def test_named_placeholders_complex(): 175 | with PipelineTracer: 176 | with QuickLambdaTracer: 177 | assert ( 178 | pyc.eval( 179 | "zip(['*', '+', '+'], [[2, 3, 4], [1, 2, 3], [4, 5, 6]]) " 180 | "|> map[$ *|> reduce({'*': f[$ * $], '+': f[$ + $]}[$op], $row)] " 181 | "|> sum" 182 | ) 183 | == 45 184 | ) 185 | assert ( 186 | pyc.eval( 187 | "zip(['*', '+', '+'], [[2, 3, 4], [1, 2, 3], [4, 5, 6]]) " 188 | "|> map[$ *|> reduce({'*': f[$x * $y], '+': f[$x + $y]}[$op], $row)] " 189 | "|> sum" 190 | ) 191 | == 45 192 | ) 193 | assert ( 194 | pyc.eval( 195 | "zip(['*', '+', '+'], [[2, 3, 4], [1, 2, 3], [4, 5, 6]]) " 196 | "|> map[$ *|> ($op, $row) *|> reduce({'*': f[$x * $y], '+': f[$x + $y]}[$op], $row)] " 197 | "|> sum" 198 | ) 199 | == 45 200 | ) 201 | assert ( 202 | pyc.eval( 203 | "zip(['*', '+', '+'], [[2, 3, 4], [1, 2, 3], [4, 5, 6]]) " 204 | "|> map[$ *|> ($op, $row) *|> reduce({'*': f[$ * $], '+': f[$ + $]}[$op], $row)] " 205 | "|> sum" 206 | ) 207 | == 45 208 | ) 209 | 210 | def test_dict_operators(): 211 | with PipelineTracer: 212 | assert pyc.eval("{'a': 1, 'b': 2} **|> dict") == {"a": 1, "b": 2} 213 | assert pyc.eval("{'a': 1, 'b': 2} **$> dict <|** {'c': 3, 'd': 4}") == { 214 | "a": 1, 215 | "b": 2, 216 | "c": 3, 217 | "d": 4, 218 | } 219 | assert pyc.eval("{'a': 1, 'b': 2} **|> (dict <$** {'c': 3, 'd': 4})") == { 220 | "a": 1, 221 | "b": 2, 222 | "c": 3, 223 | "d": 4, 224 | } 225 | assert pyc.eval("[('a',1), ('b', 2)] |> (list . dict .** dict)") == [ 226 | "a", 227 | "b", 228 | ] 229 | 230 | def test_augmentation_spec_order(): 231 | assert PipelineTracer.syntax_augmentation_specs() == [ 232 | PipelineTracer.pipeline_dict_op_spec, 233 | PipelineTracer.pipeline_tuple_op_spec, 234 | PipelineTracer.pipeline_op_assign_spec, 235 | PipelineTracer.pipeline_op_spec, 236 | PipelineTracer.value_first_left_partial_apply_dict_op_spec, 237 | PipelineTracer.value_first_left_partial_apply_tuple_op_spec, 238 | PipelineTracer.value_first_left_partial_apply_op_spec, 239 | PipelineTracer.function_first_left_partial_apply_dict_op_spec, 240 | PipelineTracer.function_first_left_partial_apply_tuple_op_spec, 241 | PipelineTracer.function_first_left_partial_apply_op_spec, 242 | PipelineTracer.apply_dict_op_spec, 243 | PipelineTracer.apply_tuple_op_spec, 244 | PipelineTracer.apply_op_spec, 245 | PipelineTracer.compose_dict_op_spec, 246 | PipelineTracer.compose_tuple_op_spec, 247 | PipelineTracer.compose_op_spec, 248 | PipelineTracer.arg_placeholder_spec, 249 | ] 250 | 251 | def test_multiline_pipeline(): 252 | with PipelineTracer: 253 | pyc.exec( 254 | textwrap.dedent( 255 | """ 256 | add1 = ( 257 | $ 258 | |> $ + 1 259 | ) 260 | assert 1 |> add1 == 2 261 | """.strip( 262 | "\n" 263 | ) 264 | ) 265 | ) 266 | 267 | def test_multistep_multiline_pipeline(): 268 | with PipelineTracer: 269 | pyc.exec( 270 | textwrap.dedent( 271 | """ 272 | add_stuff = $ |> $ + 1 |> $ + 2 |> $ + 3 273 | assert 1 |> add_stuff == 7 274 | """.strip( 275 | "\n" 276 | ) 277 | ) 278 | ) 279 | pyc.exec( 280 | textwrap.dedent( 281 | """ 282 | add_stuff = ( 283 | $ 284 | |> $ + 1 285 | |> $ + 2 286 | |> $ + 3 287 | ) 288 | assert 1 |> add_stuff == 7 289 | """.strip( 290 | "\n" 291 | ) 292 | ) 293 | ) 294 | 295 | def test_comprehension_placeholder(): 296 | with PipelineTracer: 297 | assert pyc.eval( 298 | "'1-2,5-6,3-4'.strip().split(',') " 299 | "|> [v.strip().split('-') for v in $] " 300 | "|> [[int(v1), int(v2)] for v1, v2 in $] " 301 | "|> sorted |> sum($, [])" 302 | ) == [1, 2, 3, 4, 5, 6] 303 | 304 | def test_chain_with_placeholder(): 305 | with PipelineTracer: 306 | assert pyc.eval("[3, 2, 1] |> sorted($).index(1)") == 0 307 | 308 | def test_immediately_evaluated_placeholder(): 309 | with PipelineTracer: 310 | assert pyc.eval("sorted($, reverse=True)([2, 1, 3])") == [3, 2, 1] 311 | 312 | def test_quick_maps(): 313 | with PipelineTracer: 314 | with QuickLambdaTracer: 315 | assert pyc.eval("['1', '2', '3'] |> map[int]") == [1, 2, 3] 316 | assert pyc.eval("['1', '2', '3'] |> map[int($)]") == [1, 2, 3] 317 | assert pyc.eval("['1', '2', '3'] |> map[int] |> map[$ % 2==0]") == [ 318 | False, 319 | True, 320 | False, 321 | ] 322 | assert ( 323 | pyc.eval( 324 | "zip(['*', '+', '+'], [[2, 3, 4], [1, 2, 3], [4, 5, 6]]) " 325 | "|> map[$ *|> reduce({'*': f[$ * $], '+': f[$ + $]}[$], $)] " 326 | "|> sum" 327 | ) 328 | == 45 329 | ) 330 | 331 | def test_pipeline_map_with_quick_lambda_applied(): 332 | with PipelineTracer: 333 | with QuickLambdaTracer: 334 | assert pyc.eval("[[1, 2], [3, 4]] |> map[f[$ + $](*$)]") == [ 335 | 3, 336 | 7, 337 | ] 338 | 339 | def test_quick_reduce(): 340 | with PipelineTracer: 341 | with QuickLambdaTracer: 342 | assert pyc.eval("reduce[$ + $]([1, 2, 3, 4])") == 10 343 | assert pyc.eval("reduce[f[$ + $]]([1, 2, 3, 4])") == 10 344 | assert pyc.eval("reduce[$ + $] <| [1, 2, 3, 4]") == 10 345 | assert pyc.eval("reduce[f[$ + $]] <| [1, 2, 3, 4]") == 10 346 | assert pyc.eval("reduce[$ + $ |> $] <| [1, 2, 3, 4]") == 10 347 | assert pyc.eval("reduce[$ + $ |> 2*$] <| [1, 2, 3, 4]") == 44 348 | 349 | def test_quick_filter(): 350 | with PipelineTracer: 351 | with QuickLambdaTracer: 352 | assert pyc.eval("filter[$ % 2 == 0]([1, 2, 3, 4, 5])") == [2, 4] 353 | assert pyc.eval("filter[$ % 2 == 1]([1, 2, 3, 4, 5])") == [1, 3, 5] 354 | assert pyc.eval("filter[$ % 2 == 0](range(5)) |> list") == [0, 2, 4] 355 | assert pyc.eval("filter[$ % 2 == 1](range(5)) |> list") == [1, 3] 356 | 357 | def test_named_unpack(): 358 | with PipelineTracer: 359 | with QuickLambdaTracer: 360 | assert pyc.eval( 361 | "'a: b c d' |> $.strip().split(': ') *|> ($, $.split())" 362 | ) == ("a", ["b", "c", "d"]) 363 | assert pyc.eval( 364 | "'a: b c d' |> $.strip().split(': ') *|> ($node, $adj.split())" 365 | ) == ("a", ["b", "c", "d"]) 366 | assert pyc.eval( 367 | "'a: b c d' |> $.strip().split(': ') *|> ($node, $adj.split()) *|> ($adj, $node)" 368 | ) == (["b", "c", "d"], "a") 369 | -------------------------------------------------------------------------------- /pyccolo/ast_rewriter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ast 3 | import logging 4 | from collections import defaultdict 5 | from contextlib import contextmanager 6 | from typing import ( 7 | TYPE_CHECKING, 8 | Callable, 9 | DefaultDict, 10 | Dict, 11 | Generator, 12 | List, 13 | Optional, 14 | Set, 15 | Tuple, 16 | TypeVar, 17 | Union, 18 | ) 19 | 20 | from pyccolo.ast_bookkeeping import AstBookkeeper, BookkeepingVisitor 21 | from pyccolo.expr_rewriter import ExprRewriter 22 | from pyccolo.handler import HandlerSpec 23 | from pyccolo.predicate import CompositePredicate, Predicate 24 | from pyccolo.stmt_inserter import StatementInserter 25 | from pyccolo.stmt_mapper import StatementMapper 26 | from pyccolo.syntax_augmentation import ( 27 | AugmentationSpec, 28 | AugmentationType, 29 | fix_positions, 30 | ) 31 | from pyccolo.trace_events import TraceEvent 32 | 33 | if TYPE_CHECKING: 34 | from pyccolo.tracer import BaseTracer 35 | 36 | 37 | logger = logging.getLogger(__name__) 38 | logger.setLevel(logging.WARNING) 39 | 40 | 41 | _T = TypeVar("_T") 42 | GUARD_DATA_T = Tuple[HandlerSpec, Callable[[Union[int, ast.AST]], str]] 43 | 44 | 45 | class AstRewriter(ast.NodeTransformer): 46 | gc_bookkeeping = True 47 | 48 | def __init__( 49 | self, 50 | tracers: "List[BaseTracer]", 51 | path: str, 52 | module_id: Optional[int] = None, 53 | ) -> None: 54 | self._tracers = tracers 55 | self._path = path 56 | self._module_id = module_id 57 | self._augmented_positions_by_spec: Dict[ 58 | AugmentationSpec, Set[Tuple[int, int]] 59 | ] = defaultdict(set) 60 | self.orig_to_copy_mapping: Optional[Dict[int, ast.AST]] = None 61 | 62 | @contextmanager 63 | def tracer_override_context( 64 | self, tracers: List["BaseTracer"], path: str 65 | ) -> Generator[None, None, None]: 66 | orig_tracers = self._tracers 67 | orig_path = self._path 68 | self._tracers = tracers 69 | self._path = path 70 | try: 71 | yield 72 | finally: 73 | self._tracers = orig_tracers 74 | self._path = orig_path 75 | 76 | def _get_order_of_specs_applied(self) -> Tuple[AugmentationSpec, ...]: 77 | specs = [] 78 | for tracer in self._tracers: 79 | for spec in tracer.syntax_augmentation_specs(): 80 | if spec not in specs: 81 | specs.append(spec) 82 | return tuple(specs) 83 | 84 | def register_augmented_position( 85 | self, aug_spec: AugmentationSpec, lineno: int, col_offset: int 86 | ) -> None: 87 | self._augmented_positions_by_spec[aug_spec].add((lineno, col_offset)) 88 | 89 | def _make_node_copy_flyweight( 90 | self, predicate: Callable[..., _T] 91 | ) -> Callable[..., _T]: 92 | return lambda node_or_id: predicate( 93 | (self.orig_to_copy_mapping or {}).get( 94 | node_or_id if isinstance(node_or_id, int) else id(node_or_id), 95 | node_or_id, 96 | ) 97 | ) 98 | 99 | def should_instrument_with_tracer(self, tracer: "BaseTracer") -> bool: 100 | return self._path is None or tracer._should_instrument_file_impl(self._path) 101 | 102 | @staticmethod 103 | def _get_prefix_position_for(node: ast.AST) -> Tuple[Optional[int], Optional[int]]: 104 | if isinstance(node, ast.Name): 105 | return node.lineno, node.col_offset 106 | elif isinstance(node, ast.Attribute): 107 | return node.lineno, getattr(node.value, "end_col_offset", -2) + 1 108 | elif isinstance(node, ast.FunctionDef): 109 | # TODO: can be different if more spaces between 'def' and function name 110 | return node.lineno, node.col_offset + 4 111 | elif isinstance(node, ast.ClassDef): 112 | # TODO: can be different if more spaces between 'class' and class name 113 | return node.lineno, node.col_offset + 6 114 | elif isinstance(node, ast.AsyncFunctionDef): 115 | # TODO: can be different if more spaces between 'async', 'def', and function name 116 | return node.lineno, node.col_offset + 10 117 | elif isinstance(node, (ast.Import, ast.ImportFrom)) and len(node.names) == 1: 118 | # "import " vs "from import " 119 | base_offset = ( 120 | 7 if isinstance(node, ast.Import) else 13 + len(node.module or "") 121 | ) 122 | name = node.names[0] 123 | return node.lineno, ( 124 | node.col_offset 125 | + base_offset 126 | + (0 if name.asname is None else len(name.name) + 1) 127 | ) 128 | else: 129 | return None, None 130 | 131 | @staticmethod 132 | def _get_suffix_position_for(node: ast.AST) -> Tuple[Optional[int], Optional[int]]: 133 | if isinstance(node, ast.Name): 134 | return node.lineno, node.col_offset + len(node.id) 135 | elif isinstance(node, ast.Attribute): 136 | return ( 137 | node.lineno, 138 | getattr(node.value, "end_col_offset", -1) + len(node.attr) + 1, 139 | ) 140 | elif isinstance(node, ast.FunctionDef): 141 | # TODO: can be different if more spaces between 'def' and function name 142 | return node.lineno, node.col_offset + 4 + len(node.name) 143 | elif isinstance(node, ast.ClassDef): 144 | # TODO: can be different if more spaces between 'class' and class name 145 | return node.lineno, node.col_offset + 6 + len(node.name) 146 | elif isinstance(node, ast.AsyncFunctionDef): 147 | # TODO: can be different if more spaces between 'async', 'def', and function name 148 | return node.lineno, node.col_offset + 10 + len(node.name) 149 | elif isinstance(node, (ast.Import, ast.ImportFrom)) and len(node.names) == 1: 150 | name = node.names[0] 151 | # "import " vs "from import " 152 | base_offset = ( 153 | 7 if isinstance(node, ast.Import) else 13 + len(node.module or "") 154 | ) 155 | col_offset = node.col_offset + base_offset 156 | if name.asname is None: 157 | col_offset += len(name.name) 158 | else: 159 | col_offset += len(name.name) + 1 + len(name.asname) 160 | return node.lineno, col_offset 161 | else: 162 | return None, None 163 | 164 | @staticmethod 165 | def _get_dot_suffix_position_for( 166 | node: ast.AST, 167 | ) -> Tuple[Optional[int], Optional[int]]: 168 | if isinstance(node, ast.Name): 169 | return getattr(node, "end_lineno", None), getattr( 170 | node, "end_col_offset", None 171 | ) 172 | elif isinstance(node, ast.Attribute): 173 | return getattr(node.value, "end_lineno", None), getattr( 174 | node.value, "end_col_offset", None 175 | ) 176 | else: 177 | return None, None 178 | 179 | @staticmethod 180 | def _get_dot_prefix_position_for( 181 | node: ast.AST, 182 | ) -> Tuple[Optional[int], Optional[int]]: 183 | if isinstance(node, ast.Name): 184 | return node.lineno, node.col_offset 185 | elif isinstance(node, ast.Attribute): 186 | return node.value.lineno, node.value.col_offset 187 | else: 188 | return None, None 189 | 190 | @staticmethod 191 | def _get_binop_position_for(node: ast.AST) -> Tuple[Optional[int], Optional[int]]: 192 | if isinstance(node, ast.BinOp): 193 | left_end_lineno = getattr(node.left, "end_lineno", None) 194 | left_end_col_offset = getattr(node.left, "end_col_offset", None) 195 | if left_end_col_offset is None: 196 | return None, None 197 | else: 198 | return ( 199 | left_end_lineno, 200 | node.left.col_offset - node.col_offset + left_end_col_offset, 201 | ) 202 | else: 203 | return None, None 204 | 205 | def _get_boolop_position_for( 206 | self, node: ast.AST 207 | ) -> Tuple[Optional[int], Optional[int]]: 208 | if not hasattr(node, "col_offset"): 209 | return None, None 210 | parent = self._tracers[-1].containing_ast_by_id.get(id(node)) 211 | if not isinstance(parent, ast.BoolOp): 212 | return None, None 213 | end_lineno = getattr(node, "end_lineno", None) 214 | end_col_offset = getattr(node, "end_col_offset", None) 215 | return end_lineno, end_col_offset 216 | 217 | def _get_position_for( 218 | self, aug_type: AugmentationType, node: ast.AST 219 | ) -> Tuple[Optional[int], Optional[int]]: 220 | if aug_type == AugmentationType.prefix: 221 | return self._get_prefix_position_for(node) 222 | elif aug_type == AugmentationType.suffix: 223 | return self._get_suffix_position_for(node) 224 | elif aug_type == AugmentationType.dot_suffix: 225 | return self._get_dot_suffix_position_for(node) 226 | elif aug_type == AugmentationType.dot_prefix: 227 | return self._get_dot_prefix_position_for(node) 228 | elif aug_type == AugmentationType.binop: 229 | return self._get_binop_position_for(node) 230 | elif aug_type == AugmentationType.boolop: 231 | return self._get_boolop_position_for(node) 232 | else: 233 | raise NotImplementedError() 234 | 235 | def _handle_augmentations_for_node( 236 | self, 237 | augmented_positions_by_spec: Dict[AugmentationSpec, Set[Tuple[int, int]]], 238 | nc: ast.AST, 239 | ) -> None: 240 | for spec, mod_positions in augmented_positions_by_spec.items(): 241 | lineno, col_offset = self._get_position_for(spec.aug_type, nc) 242 | if lineno is None or col_offset is None: 243 | continue 244 | if (lineno, col_offset) not in mod_positions: # type: ignore[attr-defined] 245 | continue 246 | for tracer in self._tracers: 247 | if spec in tracer.syntax_augmentation_specs(): 248 | tracer.augmented_node_ids_by_spec[spec].add(id(nc)) 249 | 250 | def _handle_all_augmentations( 251 | self, orig_to_copy_mapping: Dict[int, ast.AST] 252 | ) -> None: 253 | augmented_positions_by_spec = fix_positions( 254 | self._augmented_positions_by_spec, 255 | spec_order=self._get_order_of_specs_applied(), 256 | ) 257 | for nc in orig_to_copy_mapping.values(): 258 | self._handle_augmentations_for_node(augmented_positions_by_spec, nc) 259 | 260 | def visit(self, node: ast.AST): 261 | assert isinstance( 262 | node, (ast.Expression, ast.Module, ast.FunctionDef, ast.AsyncFunctionDef) 263 | ) 264 | assert self._path is not None 265 | mapper = StatementMapper(self._tracers) 266 | orig_to_copy_mapping = mapper(node) 267 | last_tracer = self._tracers[-1] 268 | old_bookkeeper = last_tracer.ast_bookkeeper_by_fname.get(self._path) 269 | module_id = id(node) if self._module_id is None else self._module_id 270 | 271 | # garbage collect any stale references to aug specs once they have been propagated 272 | cleanup_bookkeeper = AstBookkeeper.create(self._path, module_id) 273 | BookkeepingVisitor(cleanup_bookkeeper).visit(node) 274 | last_tracer.remove_bookkeeping(cleanup_bookkeeper, module_id) 275 | 276 | new_bookkeeper = last_tracer.ast_bookkeeper_by_fname[self._path] = ( 277 | AstBookkeeper.create(self._path, module_id) 278 | ) 279 | if old_bookkeeper is not None and self.gc_bookkeeping: 280 | last_tracer.remove_bookkeeping(old_bookkeeper, module_id) 281 | BookkeepingVisitor(new_bookkeeper).visit(orig_to_copy_mapping[id(node)]) 282 | last_tracer.add_bookkeeping(new_bookkeeper, module_id) 283 | self.orig_to_copy_mapping = orig_to_copy_mapping 284 | self._handle_all_augmentations(orig_to_copy_mapping) 285 | raw_handler_predicates_by_event: DefaultDict[TraceEvent, List[Predicate]] = ( 286 | defaultdict(list) 287 | ) 288 | raw_guard_exempt_handler_predicates_by_event: DefaultDict[ 289 | TraceEvent, List[Predicate] 290 | ] = defaultdict(list) 291 | 292 | for tracer in self._tracers: 293 | if not self.should_instrument_with_tracer(tracer): 294 | continue 295 | for evt in tracer.events_with_registered_handlers: 296 | # this is to deal with the tests in test_trace_events.py, 297 | # which patch events_with_registered_handlers but not _event_handlers 298 | handler_data = tracer._event_handlers.get( 299 | evt, [HandlerSpec.empty()] # type: ignore 300 | ) 301 | for handler_spec in handler_data: 302 | raw_handler_predicates_by_event[evt].append(handler_spec.predicate) 303 | if handler_spec.exempt_from_guards: 304 | raw_guard_exempt_handler_predicates_by_event[evt].append( 305 | handler_spec.predicate 306 | ) 307 | handler_predicate_by_event: DefaultDict[ 308 | TraceEvent, Callable[..., bool] 309 | ] = defaultdict( 310 | lambda: (lambda *_: False) # type: ignore 311 | ) 312 | guard_exempt_handler_prediate_by_event: DefaultDict[ 313 | TraceEvent, Callable[..., bool] 314 | ] = defaultdict( 315 | lambda: (lambda *_: False) # type: ignore 316 | ) 317 | for evt, raw_predicates in raw_handler_predicates_by_event.items(): 318 | handler_predicate_by_event[evt] = self._make_node_copy_flyweight( 319 | CompositePredicate.any(raw_predicates) 320 | ) 321 | for evt, raw_predicates in raw_guard_exempt_handler_predicates_by_event.items(): 322 | guard_exempt_handler_prediate_by_event[evt] = ( 323 | self._make_node_copy_flyweight(CompositePredicate.any(raw_predicates)) 324 | ) 325 | handler_guards_by_event: DefaultDict[TraceEvent, List[GUARD_DATA_T]] = ( 326 | defaultdict(list) 327 | ) 328 | for tracer in self._tracers: 329 | for evt, handler_specs in tracer._event_handlers.items(): 330 | handler_guards_by_event[evt].extend( 331 | (spec, self._make_node_copy_flyweight(spec.guard)) 332 | for spec in handler_specs 333 | if spec.guard is not None 334 | ) 335 | if isinstance(node, ast.Module): 336 | for tracer in self._tracers: 337 | tracer._static_init_module_impl( 338 | orig_to_copy_mapping.get(id(node), node) # type: ignore 339 | ) 340 | # very important that the eavesdropper does not create new ast nodes for ast.stmt (but just 341 | # modifies existing ones), since StatementInserter relies on being able to map these 342 | expr_rewriter = ExprRewriter( 343 | self._tracers, 344 | mapper, 345 | orig_to_copy_mapping, 346 | handler_predicate_by_event, 347 | guard_exempt_handler_prediate_by_event, 348 | handler_guards_by_event, 349 | ) 350 | if isinstance(node, ast.Expression): 351 | node = expr_rewriter.visit(node) 352 | else: 353 | for i in range(len(node.body)): 354 | node.body[i] = expr_rewriter.visit(node.body[i]) 355 | node = StatementInserter( 356 | self._tracers, 357 | mapper, 358 | orig_to_copy_mapping, 359 | handler_predicate_by_event, 360 | guard_exempt_handler_prediate_by_event, 361 | handler_guards_by_event, 362 | expr_rewriter, 363 | ).visit(node) 364 | if not any(tracer.requires_ast_bookkeeping for tracer in self._tracers): 365 | last_tracer.remove_bookkeeping(new_bookkeeper, module_id) 366 | return node 367 | --------------------------------------------------------------------------------