├── .gitignore ├── Makefile ├── README.md ├── lcc ├── __init__.py ├── aio.py ├── common.py ├── threads.py └── unified.py └── tests ├── conftest.py └── test_basic.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyo 3 | __pycache__ 4 | .DS_Store 5 | build 6 | .cache 7 | *.egg-info 8 | venv 9 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | PYTHONPATH=`pwd` pytest tests --tb=short -vv 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # python-logical-call-context 2 | 3 | A repository for experiments with logical call contexts in Python. 4 | The idea is that this stuff moves into contextlib which is why the 5 | examples assume it's there and the code monkeypatches on import. 6 | 7 | Examples: 8 | 9 | ```python 10 | from contextlib import get_call_context 11 | 12 | # Returns the current call context 13 | ctx = get_call_context() 14 | 15 | # Isolate some calls 16 | from contextlib import isolated_call_context 17 | with isolated_call_context(): 18 | ... 19 | 20 | # Register a new context provider 21 | from contextlib import register_context_provider, new_call_context 22 | @register_context_provider 23 | def get_my_context(create=False): 24 | rv = get_current_context_object() 25 | if rv is not None: 26 | return rv 27 | if not create: 28 | return 29 | rv = new_call_context(parent=...) 30 | bind_the_new_context(rv) 31 | return rv 32 | ``` 33 | 34 | What you can do with the call context: 35 | 36 | ```python 37 | # Sets some data 38 | ctx.set_data('key', 'value') 39 | 40 | # Sets some data but mark it so that it cannot cross a thread 41 | # or something similar which would require external synchronization. 42 | ctx.set_data('key', 'value', sync=False) 43 | 44 | # Set some data so that it does not pass over to isolated contexts 45 | # (these contexts are created with `isolated_call_context` and set up 46 | # a new logical call context. 47 | ctx.set_data('werkzeug.request', ..., local=True) 48 | 49 | # Looks up stored data (or raise a LookupError) 50 | ctx.get_data('key') 51 | 52 | # Looks up stored data or return a default 53 | ctx.get_data('key', default=None) 54 | 55 | # Deletes some data 56 | ctx.del_data('key') 57 | 58 | # Return the current logical key (a hashable object) 59 | ctx.logical_key 60 | 61 | # Return the current concurrency key (a hashable object) 62 | ctx.key 63 | 64 | # Nest the context (throws away local modifications later) 65 | with ctx.nested(): 66 | ... 67 | ``` 68 | 69 | Other things patched: 70 | 71 | ```python 72 | from threading import get_thread_call_context 73 | from asyncio import get_task_call_context 74 | ``` 75 | -------------------------------------------------------------------------------- /lcc/__init__.py: -------------------------------------------------------------------------------- 1 | def __patch(): 2 | from .common import patch_contextlib 3 | from .threads import patch_threads 4 | from .aio import patch_asyncio 5 | patch_contextlib() 6 | patch_threads() 7 | patch_asyncio() 8 | 9 | 10 | __patch() 11 | del __patch 12 | -------------------------------------------------------------------------------- /lcc/aio.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | import contextlib 4 | 5 | from weakref import WeakKeyDictionary 6 | 7 | from .threads import get_thread_call_context 8 | 9 | 10 | _local = threading.local() 11 | 12 | 13 | def set_task_call_context(task, ctx): 14 | """Binds the given call context to the current asyncio task.""" 15 | try: 16 | contexts = _local.asyncio_contexts 17 | except AttributeError: 18 | _local.asyncio_contexts = contexts = WeakKeyDictionary() 19 | contexts[task] = ctx 20 | return ctx 21 | 22 | 23 | def get_task_call_context(create=False): 24 | """Returns the call context associated with the current task.""" 25 | try: 26 | loop = asyncio.get_event_loop() 27 | except (AssertionError, RuntimeError): 28 | return 29 | 30 | task = asyncio.Task.current_task(loop=loop) 31 | if task is None: 32 | return 33 | 34 | try: 35 | return _local.asyncio_contexts[task] 36 | except (AttributeError, LookupError): 37 | ctx = None 38 | 39 | if not create: 40 | return 41 | 42 | ctx = contextlib.new_call_context(parent=get_thread_call_context()) 43 | return set_task_call_context(task, ctx) 44 | 45 | 46 | def patch_asyncio(): 47 | # asyncio support 48 | ensure_future = asyncio.ensure_future 49 | 50 | def better_ensure_future(coro_or_future, *, loop=None): 51 | ctx = contextlib.get_call_context() 52 | task = ensure_future(coro_or_future, loop=loop) 53 | new_ctx = contextlib.new_call_context(name='Task-0x%x' % id(task), parent=ctx) 54 | set_task_call_context(task, new_ctx) 55 | return task 56 | 57 | asyncio.ensure_future = better_ensure_future 58 | asyncio.tasks.ensure_future = better_ensure_future 59 | asyncio.get_task_call_context = get_task_call_context 60 | -------------------------------------------------------------------------------- /lcc/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import asyncio 4 | import _thread 5 | import threading 6 | import contextlib 7 | 8 | from contextlib import contextmanager 9 | from collections import MutableMapping 10 | from weakref import ref as weakref 11 | 12 | 13 | _missing = object() 14 | 15 | 16 | class CallContextKey(object): 17 | """An immutable, hashable and comparable object to uniquely identify 18 | the call context. This can be used by code to map data to a specific 19 | call context. 20 | """ 21 | 22 | def __init__(self, name=None): 23 | self.name = name 24 | self.pid = os.getpid() 25 | self.tid = _thread.get_ident() 26 | 27 | def __repr__(self): 28 | return '' % ( 29 | self.name, 30 | self.pid, 31 | self.tid, 32 | ) 33 | 34 | 35 | class LogicalCallContextKey(object): 36 | """Uniquely identifies a logical call context.""" 37 | 38 | 39 | class _ContextData(object): 40 | 41 | def __init__(self, value, key, logical_key, sync=True, local=False): 42 | self.value = value 43 | self.key = key 44 | self.logical_key = logical_key 45 | self.sync = sync 46 | self.local = local 47 | 48 | def unsafe_context_crossing(self, call_context): 49 | # Synchronous objects always cross safely 50 | if self.sync: 51 | return False 52 | # If we cross over to another process the crossing is safe 53 | if self.key.pid != call_context.key.pid: 54 | return False 55 | # Otherwise the crossing is only safe if we are on the same 56 | # thread 57 | return self.key.tid != call_context.key.tid 58 | 59 | 60 | class CallContext(object): 61 | """Represents the call context.""" 62 | 63 | def __init__(self, name, parent=None): 64 | logical_key = None 65 | if parent is not None: 66 | data = parent._data.copy() 67 | if not parent.isolates: 68 | logical_key = parent.logical_key 69 | else: 70 | data = {} 71 | 72 | if logical_key is None: 73 | logical_key = LogicalCallContextKey() 74 | 75 | self.key = CallContextKey(name) 76 | self.logical_key = logical_key 77 | self.isolates = False 78 | self._data = data 79 | self._backup = None 80 | 81 | def __repr__(self): 82 | return '' % ( 83 | self.key.name, 84 | id(self), 85 | ) 86 | 87 | def __eq__(self, other): 88 | return self.__class__ is other.__class__ and \ 89 | self.key == other.key 90 | 91 | def __ne__(self, other): 92 | return not self.__eq__(other) 93 | 94 | def get_data(self, name, *, default=_missing): 95 | """Returns the data for the given key. By default if the key cannot 96 | be found a `LookupError` is raised. If a default is provided it's 97 | returned instead. 98 | """ 99 | try: 100 | cd = self._data.get(name) 101 | if cd is None: 102 | raise KeyError(name) 103 | 104 | # If the key is local pretend it never exists in this context 105 | if cd.local and cd.logical_key != self.logical_key: 106 | raise KeyError(name) 107 | 108 | # Do not let non sync values cross contexts 109 | if cd.unsafe_context_crossing(self): 110 | raise LookupError('The stored context data was created for ' 111 | 'a different context and cannot be shared ' 112 | 'because it was not marked as synchronous.') 113 | 114 | return cd.value 115 | except LookupError: 116 | if default is not _missing: 117 | return default 118 | raise 119 | 120 | def set_data(self, name, value, *, sync=True, local=False): 121 | """Sets a key to a given value. By default the value is set nonlocal 122 | and sync which means that it shows up in any derived context. If the 123 | value is set to ``sync=False`` the value will not be travelling to a 124 | context that would require external synchronization (eg: a different 125 | thread). If the value is set to local with ``local=True`` the value 126 | will not travel to a context belonging to a different logical call 127 | context. 128 | """ 129 | if self._backup is not None and name not in self._backup: 130 | self._backup[name] = self._data.get(name) 131 | self._data[name] = _ContextData(value, self.key, self.logical_key, 132 | sync=sync, local=local) 133 | 134 | def del_data(self, name): 135 | """Deletes a key""" 136 | self._data[name] = None 137 | 138 | @contextmanager 139 | def nested(self): 140 | """Helper context manager to """ 141 | backup = self._backup 142 | self._backup = {} 143 | try: 144 | yield 145 | finally: 146 | self._data.update(self._backup) 147 | self._backup = backup 148 | 149 | 150 | def new_call_context(name=None, parent=None): 151 | """Creates a new call context which optionally is created from a given 152 | parent. 153 | """ 154 | if name is None: 155 | name = threading.current_thread().name 156 | return CallContext(name, parent) 157 | 158 | 159 | @contextmanager 160 | def isolated_call_context(isolate=True): 161 | """Context manager that temporarily isolates the call context. This means 162 | that new contexts created out of the current context until the end of the 163 | context manager will be created isolated from the current one. All values 164 | that are marked as "local" will be unavailable in the newly created call 165 | context. 166 | 167 | When a context is created while the parent is isolated a new logical call 168 | context will be created. 169 | 170 | Example:: 171 | 172 | import contextlib 173 | 174 | with contextlib.isolated_call_context(): 175 | ... 176 | """ 177 | ctx = contextlib.get_call_context() 178 | old = ctx.isolates 179 | ctx.isolates = isolate 180 | try: 181 | yield 182 | finally: 183 | ctx.isolates = old 184 | 185 | 186 | def patch_contextlib(): 187 | """Injects us to where we expect to live.""" 188 | from . import unified 189 | contextlib.get_call_context = unified.get_call_context 190 | contextlib.new_call_context = new_call_context 191 | contextlib.isolated_call_context = isolated_call_context 192 | contextlib.register_context_provider = unified.register 193 | -------------------------------------------------------------------------------- /lcc/threads.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import contextlib 3 | 4 | 5 | _local = threading.local() 6 | 7 | 8 | def set_thread_call_context(ctx): 9 | """Binds the given call context to the current thread.""" 10 | _local.context = ctx 11 | return ctx 12 | 13 | 14 | def get_thread_call_context(create=False): 15 | """Returns the current thread's call context.""" 16 | rv = getattr(_local, 'context', None) 17 | if rv is not None: 18 | return rv 19 | if not create: 20 | return 21 | return set_thread_call_context(contextlib.new_call_context()) 22 | 23 | 24 | def patch_threads(): 25 | thread_init = threading.Thread.__init__ 26 | thread_bootstrap = threading.Thread._bootstrap 27 | 28 | def better_thread_init(self, *args, **kwargs): 29 | self.__outer_call_ctx = contextlib.get_call_context() 30 | return thread_init(self, *args, **kwargs) 31 | 32 | def better_thread_bootstrap(self): 33 | set_thread_call_context(contextlib.new_call_context( 34 | name=self.name, parent=self.__outer_call_ctx)) 35 | return thread_bootstrap(self) 36 | 37 | threading.Thread.__init__ = better_thread_init 38 | threading.Thread._bootstrap = better_thread_bootstrap 39 | threading.get_thread_call_context = get_thread_call_context 40 | -------------------------------------------------------------------------------- /lcc/unified.py: -------------------------------------------------------------------------------- 1 | from .aio import get_task_call_context 2 | from .threads import get_thread_call_context 3 | 4 | 5 | providers = [ 6 | get_task_call_context, 7 | ] 8 | 9 | 10 | def register(func): 11 | """Registers an unified provider.""" 12 | providers.append(func) 13 | return func 14 | 15 | 16 | def get_call_context(): 17 | """Returns the current call context.""" 18 | for provider in providers: 19 | rv = provider(create=True) 20 | if rv is not None: 21 | return rv 22 | return get_thread_call_context(create=True) 23 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import lcc 2 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import asyncio 3 | import threading 4 | import contextlib 5 | 6 | 7 | def test_basic_context(): 8 | ctx = contextlib.get_call_context() 9 | ctx.set_data('foo', 42) 10 | assert ctx.get_data('foo') == 42 11 | ctx.del_data('foo') 12 | 13 | with pytest.raises(LookupError): 14 | ctx.get_data('foo') 15 | 16 | 17 | def test_context_nesting(): 18 | ctx = contextlib.get_call_context() 19 | with ctx.nested(): 20 | ctx.set_data('bar', 23) 21 | 22 | with pytest.raises(LookupError): 23 | ctx.get_data('bar') 24 | 25 | 26 | def test_local_context_behavior(): 27 | ctx = contextlib.get_call_context() 28 | ctx.set_data('foo', 23, local=True) 29 | 30 | with ctx.nested(): 31 | assert ctx.get_data('foo') == 23 32 | 33 | rv = [] 34 | def test_thread(): 35 | rv.append(contextlib.get_call_context().get_data('foo', default=None)) 36 | 37 | t = threading.Thread(target=test_thread) 38 | t.start() 39 | t.join() 40 | 41 | with contextlib.isolated_call_context(): 42 | t = threading.Thread(target=test_thread) 43 | t.start() 44 | t.join() 45 | 46 | assert rv == [23, None] 47 | 48 | 49 | def test_sync_data(): 50 | ctx = contextlib.get_call_context() 51 | ctx.set_data('foo', 23, sync=False) 52 | 53 | rv = [] 54 | rv.append(contextlib.get_call_context().get_data('foo', default=None)) 55 | 56 | def test_thread(): 57 | rv.append(contextlib.get_call_context().get_data('foo', default=None)) 58 | 59 | t = threading.Thread(target=test_thread) 60 | t.start() 61 | t.join() 62 | 63 | assert rv == [23, None] 64 | 65 | 66 | def test_async(): 67 | ctx = contextlib.get_call_context() 68 | ctx.set_data('__locale__', 'en_US') 69 | 70 | rv = [] 71 | 72 | async def x(val): 73 | ctx = contextlib.get_call_context() 74 | rv.append(ctx.get_data('__locale__')) 75 | ctx.set_data('__locale__', val) 76 | rv.append(ctx.get_data('__locale__')) 77 | 78 | asyncio.get_event_loop().run_until_complete(x('de_DE')) 79 | asyncio.get_event_loop().run_until_complete(x('fr_FR')) 80 | rv.append(ctx.get_data('__locale__')) 81 | 82 | assert rv == ['en_US', 'de_DE', 'en_US', 'fr_FR', 'en_US'] 83 | --------------------------------------------------------------------------------