├── tests ├── __init__.py ├── providers │ ├── __init__.py │ ├── test_object.py │ ├── test_factories.py │ ├── test_collections.py │ ├── test_inject_factories.py │ ├── test_main_providers.py │ ├── test_state.py │ ├── test_selector.py │ ├── test_singleton.py │ ├── test_resources.py │ ├── test_attr_getter.py │ └── test_local_singleton.py ├── experimental │ ├── __init__.py │ ├── test_container_1.py │ ├── test_lazy_provider.py │ └── test_container_2.py ├── integrations │ ├── __init__.py │ ├── fastapi │ │ ├── __init__.py │ │ ├── test_fastapi_di_pass_request.py │ │ ├── test_fastapi_di.py │ │ └── test_fastapi_route_class.py │ ├── litestar │ │ ├── __init__.py │ │ ├── test_litestar_di_simple.py │ │ └── test_litestar_di.py │ └── faststream │ │ ├── __init__.py │ │ ├── test_faststream_di.py │ │ └── test_faststream_di_pass_message.py ├── conftest.py ├── test_utils.py ├── test_meta.py ├── test_dynamic_container.py ├── test_resolver.py ├── creators.py ├── test_multiple_containers.py ├── container.py └── test_container.py ├── that_depends ├── py.typed ├── entities │ ├── __init__.py │ └── resource_context.py ├── integrations │ ├── __init__.py │ ├── fastapi.py │ └── faststream.py ├── experimental │ └── __init__.py ├── exceptions.py ├── utils.py ├── __init__.py ├── providers │ ├── __init__.py │ ├── mixin.py │ ├── object.py │ ├── state.py │ ├── collection.py │ ├── selector.py │ ├── resources.py │ ├── local_singleton.py │ └── factories.py └── container.py ├── examples └── benchmark │ ├── __init__.py │ ├── RESULTS.md │ └── injection.py ├── docs ├── requirements.txt ├── css │ └── code.css ├── overrides │ └── main.html ├── providers │ ├── object.md │ ├── collections.md │ ├── state.md │ ├── selector.md │ ├── resources.md │ ├── factories.md │ └── singleton.md ├── dev │ ├── main-decisions.md │ └── contributing.md ├── testing │ ├── fixture.md │ └── provider-overriding.md ├── introduction │ ├── multiple-containers.md │ ├── ioc-container.md │ ├── type-based-injection.md │ ├── string-injection.md │ ├── tear-down.md │ ├── generator-injection.md │ └── injection.md ├── integrations │ ├── litestar.md │ └── faststream.md ├── experimental │ └── lazy.md ├── index.md └── migration │ └── v3.md ├── .readthedocs.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .github └── workflows │ ├── publish.yml │ └── ci.yml ├── Justfile ├── LICENSE ├── README.md ├── pyproject.toml └── mkdocs.yml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /that_depends/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/providers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | """Benchmarks.""" 2 | -------------------------------------------------------------------------------- /that_depends/entities/__init__.py: -------------------------------------------------------------------------------- 1 | """Entities.""" 2 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-material 3 | mkdocs-llmstxt 4 | -------------------------------------------------------------------------------- /that_depends/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | """Integrations with other frameworks.""" 2 | -------------------------------------------------------------------------------- /tests/integrations/fastapi/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | pytest.importorskip("fastapi") 5 | -------------------------------------------------------------------------------- /tests/integrations/litestar/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | pytest.importorskip("litestar") 5 | -------------------------------------------------------------------------------- /tests/integrations/faststream/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | pytest.importorskip("faststream") 5 | pytest.importorskip("nats") 6 | -------------------------------------------------------------------------------- /that_depends/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """Experimental features.""" 2 | 3 | from that_depends.experimental.providers import LazyProvider 4 | 5 | 6 | __all__ = [ 7 | "LazyProvider", 8 | ] 9 | -------------------------------------------------------------------------------- /docs/css/code.css: -------------------------------------------------------------------------------- 1 | /* Round the corners of code blocks */ 2 | .md-typeset .highlight code { 3 | border-radius: 10px !important; 4 | } 5 | 6 | @media (max-width: 600px) { 7 | .md-typeset code { 8 | border-radius: 4px; 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.10" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt 11 | 12 | 13 | mkdocs: 14 | configuration: mkdocs.yml 15 | -------------------------------------------------------------------------------- /docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block announce %} 4 | that-depends 3.0 just released! If you are upgrading, check out the migration guide. 5 | {% endblock %} 6 | -------------------------------------------------------------------------------- /that_depends/exceptions.py: -------------------------------------------------------------------------------- 1 | class TypeNotBoundError(Exception): 2 | """Exception raised when a type is not bound to a provider.""" 3 | 4 | 5 | class StateNotInitializedError(Exception): 6 | """Exception raised when the state is not initialized.""" 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generic things 2 | *.pyc 3 | *~ 4 | __pycache__/* 5 | *.swp 6 | *.sqlite3 7 | *.map 8 | .vscode 9 | .idea 10 | .DS_Store 11 | .env 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | .coverage* 16 | htmlcov/ 17 | coverage.xml 18 | pytest.xml 19 | dist/ 20 | .python-version 21 | .venv 22 | uv.lock 23 | -------------------------------------------------------------------------------- /docs/providers/object.md: -------------------------------------------------------------------------------- 1 | # Object 2 | 3 | Object provider returns an object “as is”. 4 | 5 | ```python 6 | from that_depends import BaseContainer, providers 7 | 8 | 9 | class DIContainer(BaseContainer): 10 | object_provider = providers.Object(1) 11 | 12 | 13 | assert DIContainer.object_provider() == 1 14 | ``` 15 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | 5 | from tests import container 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | async def _clear_di_container() -> typing.AsyncIterator[None]: 10 | try: 11 | yield 12 | finally: 13 | container.DIContainer.reset_override_sync() 14 | await container.DIContainer.tear_down() 15 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: lint 5 | name: lint 6 | entry: just 7 | args: [lint] 8 | language: system 9 | types: [python] 10 | pass_filenames: false 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v5.0.0 13 | hooks: 14 | - id: end-of-file-fixer 15 | - id: mixed-line-ending 16 | -------------------------------------------------------------------------------- /examples/benchmark/RESULTS.md: -------------------------------------------------------------------------------- 1 | # Injection 2 | Based on this [benchmark](injection.py): 3 | 4 | | version / iterations | 10^4 | 10^5 | 10^6 | 5 | |----------------------|--------|--------|---------| 6 | | 3.2.0 | 0.1039 | 1.0563 | 10.3576 | 7 | | 3.0.0.a2 | 0.0870 | 0.9430 | 8.8815 | 8 | | 3.0.0.a1 | 0.1399 | 1.4136 | 14.1829 | 9 | | 2.2.0 | 0.1465 | 1.4927 | 14.9044 | 10 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from that_depends.utils import UNSET, Unset, is_set 2 | 3 | 4 | def test_is_set_with_unset() -> None: 5 | assert not is_set(UNSET) 6 | 7 | 8 | def test_is_set_with_value() -> None: 9 | assert is_set(42) 10 | assert is_set("hello") 11 | assert is_set([1, 2, 3]) 12 | assert is_set(None) 13 | 14 | 15 | def test_unset_is_singleton() -> None: 16 | assert isinstance(UNSET, Unset) 17 | assert UNSET is Unset() 18 | -------------------------------------------------------------------------------- /tests/providers/test_object.py: -------------------------------------------------------------------------------- 1 | from that_depends import BaseContainer, providers 2 | 3 | 4 | instance = "some string" 5 | 6 | 7 | class DIContainer(BaseContainer): 8 | alias = "object_container" 9 | instance = providers.Object(instance) 10 | 11 | 12 | async def test_object_provider() -> None: 13 | instance1 = await DIContainer.instance() 14 | instance2 = DIContainer.instance.resolve_sync() 15 | 16 | assert instance1 is instance2 is instance 17 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Package 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: extractions/setup-just@v2 14 | - uses: astral-sh/setup-uv@v5 15 | with: 16 | cache-dependency-glob: "**/pyproject.toml" 17 | - run: just publish 18 | env: 19 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 20 | -------------------------------------------------------------------------------- /tests/test_meta.py: -------------------------------------------------------------------------------- 1 | from that_depends import ContextScopes 2 | from that_depends.meta import BaseContainerMeta 3 | 4 | 5 | def test_base_container_meta_has_correct_default_scope() -> None: 6 | class _Test(metaclass=BaseContainerMeta): 7 | pass 8 | 9 | assert _Test.get_scope() == ContextScopes.ANY 10 | 11 | 12 | def test_base_container_meta_has_correct_default_name() -> None: 13 | class _Test(metaclass=BaseContainerMeta): 14 | pass 15 | 16 | assert _Test.name() == "_Test" 17 | -------------------------------------------------------------------------------- /docs/dev/main-decisions.md: -------------------------------------------------------------------------------- 1 | # Main decisions 2 | 1. Dependency resolving is async by default: 3 | - framework was developed mostly for usage with async python applications; 4 | - sync resolving is also possible, but it will fail in runtime in case of unresolved async dependencies; 5 | 2. Container is global: 6 | - it's needed for injections without wiring to work; 7 | - this way most of the logic stays in providers; 8 | 3. Focus on maximum compatibility with mypy: 9 | - no need for `# type: ignore` 10 | - no need for `typing.cast` 11 | -------------------------------------------------------------------------------- /tests/experimental/test_container_1.py: -------------------------------------------------------------------------------- 1 | from tests.experimental.test_container_2 import Container2 2 | from that_depends import BaseContainer, providers 3 | from that_depends.experimental import LazyProvider 4 | 5 | 6 | class Container1(BaseContainer): 7 | """Test Container 1.""" 8 | 9 | alias = "container_1" 10 | obj_1 = providers.Object(1) 11 | obj_2 = LazyProvider(module_string="tests.experimental.test_container_2", provider_string="Container2.obj_2") 12 | 13 | 14 | def test_lazy_provider_resolution_sync() -> None: 15 | assert Container2.obj_2.resolve_sync() == 2 # noqa: PLR2004 16 | -------------------------------------------------------------------------------- /tests/test_dynamic_container.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import pytest 4 | 5 | from tests import container 6 | from that_depends import BaseContainer, providers 7 | 8 | 9 | class DIContainer(BaseContainer): 10 | alias = "dynamic_container" 11 | sync_resource: providers.Resource[datetime.datetime] 12 | 13 | 14 | async def test_dynamic_container_not_supported() -> None: 15 | new_provider = providers.Resource(container.create_sync_resource) 16 | with pytest.raises(AttributeError): 17 | DIContainer.sync_resource = new_provider 18 | 19 | with pytest.raises(AttributeError): 20 | DIContainer.something_new = new_provider 21 | -------------------------------------------------------------------------------- /that_depends/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, ClassVar, TypeGuard, TypeVar 2 | 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | class _Singleton(type): 8 | _instances: ClassVar[dict[type, Any]] = {} 9 | 10 | def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 11 | if cls not in cls._instances: 12 | cls._instances[cls] = super().__call__(*args, **kwargs) 13 | return cls._instances[cls] 14 | 15 | 16 | class Unset(metaclass=_Singleton): 17 | """Represents an unset value.""" 18 | 19 | 20 | UNSET = Unset() 21 | 22 | 23 | def is_set(value: T | Unset) -> TypeGuard[T]: 24 | """Check if a value is set (not UNSET).""" 25 | return value is not UNSET 26 | -------------------------------------------------------------------------------- /docs/dev/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | `that-depends` is an opensource project, and we are opened to new contributors. 3 | 4 | ## Getting started 5 | 1. Make sure that you have [uv](https://docs.astral.sh/uv/) and [just](https://just.systems/) installed. 6 | 2. Clone project: 7 | ``` 8 | git@github.com:modern-python/that-depends.git 9 | cd that-depends 10 | ``` 11 | 3. Install dependencies running `just install` 12 | 13 | ## Running linters 14 | `Ruff` and `mypy` are used for static analysis. 15 | 16 | Run all checks by command `just lint` 17 | 18 | ## Running tests 19 | Run all tests by command `just tests` 20 | 21 | ## Building and running the documentation 22 | Host the documentation locally by running `just docs` 23 | -------------------------------------------------------------------------------- /docs/testing/fixture.md: -------------------------------------------------------------------------------- 1 | # Fixture 2 | ## Dependencies teardown 3 | When using dependency injection in tests, it's important to properly tear down resources after tests complete. 4 | 5 | Without proper teardown, the Python event loop might close before resources have a chance to shut down properly, leading to errors like `RuntimeError: Event loop is closed`. 6 | 7 | You can set up automatic teardown using a pytest fixture: 8 | 9 | ```python 10 | import pytest_asyncio 11 | from typing import AsyncGenerator 12 | 13 | from my_project import DIContainer 14 | 15 | @pytest_asyncio.fixture(autouse=True) 16 | async def di_container_teardown() -> AsyncGenerator[None]: 17 | try: 18 | yield 19 | finally: 20 | await DIContainer.tear_down() 21 | ``` 22 | -------------------------------------------------------------------------------- /Justfile: -------------------------------------------------------------------------------- 1 | default: install lint test 2 | 3 | install: 4 | uv lock --upgrade 5 | uv sync --all-extras --frozen 6 | @just hook 7 | 8 | lint: 9 | uv run ruff format 10 | uv run ruff check --fix 11 | uv run mypy . 12 | 13 | lint-ci: 14 | uv run ruff format --check 15 | uv run ruff check --no-fix 16 | uv run mypy . 17 | 18 | test *args: 19 | uv run --no-sync pytest {{ args }} 20 | 21 | publish: 22 | rm -rf dist 23 | uv version $GITHUB_REF_NAME 24 | uv build 25 | uv publish --token $PYPI_TOKEN 26 | 27 | hook: 28 | uv run pre-commit install --install-hooks --overwrite 29 | 30 | unhook: 31 | uv run pre-commit uninstall 32 | 33 | docs: 34 | uv pip install -r docs/requirements.txt 35 | uv run mkdocs serve 36 | -------------------------------------------------------------------------------- /that_depends/__init__.py: -------------------------------------------------------------------------------- 1 | """Dependency injection framework for Python.""" 2 | 3 | from that_depends import providers 4 | from that_depends.container import BaseContainer 5 | from that_depends.injection import Provide, inject 6 | from that_depends.providers import container_context 7 | from that_depends.providers.context_resources import ( 8 | ContextScope, 9 | ContextScopes, 10 | fetch_context_item, 11 | fetch_context_item_by_type, 12 | get_current_scope, 13 | ) 14 | 15 | 16 | __all__ = [ 17 | "BaseContainer", 18 | "ContextScope", 19 | "ContextScopes", 20 | "Provide", 21 | "container_context", 22 | "fetch_context_item", 23 | "fetch_context_item_by_type", 24 | "get_current_scope", 25 | "inject", 26 | "providers", 27 | ] 28 | -------------------------------------------------------------------------------- /tests/test_resolver.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import pytest 4 | 5 | from tests import container 6 | from tests.container import DIContainer 7 | 8 | 9 | @dataclasses.dataclass(kw_only=True, slots=True) 10 | class WrongFactory: 11 | some_dep: int = 1 12 | not_existing_name: container.DependentFactory 13 | sync_resource: str 14 | 15 | 16 | async def test_dependency_resolver() -> None: 17 | resolver = DIContainer.resolver(container.FreeFactory) 18 | dep1 = await resolver() 19 | dep2 = await DIContainer.resolve(container.FreeFactory) 20 | 21 | assert dep1 22 | assert dep2 23 | assert dep1 is not dep2 24 | 25 | 26 | async def test_dependency_resolver_failed() -> None: 27 | resolver = DIContainer.resolver(WrongFactory) 28 | with pytest.raises(RuntimeError, match=r"Provider is not found, field_name='not_existing_name'"): 29 | await resolver() 30 | -------------------------------------------------------------------------------- /tests/integrations/litestar/test_litestar_di_simple.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from litestar import Litestar, get 4 | from litestar.di import Provide 5 | from litestar.status_codes import HTTP_200_OK 6 | from litestar.testing import TestClient 7 | 8 | from tests import container 9 | 10 | 11 | @get("/") 12 | async def index(injected: datetime.datetime) -> datetime.datetime: 13 | return injected 14 | 15 | 16 | app = Litestar([index], dependencies={"injected": Provide(container.DIContainer.async_resource)}) 17 | 18 | 19 | async def test_litestar_di() -> None: 20 | with TestClient(app=app) as client: 21 | response = client.get("/") 22 | assert response.status_code == HTTP_200_OK, response.text 23 | assert ( 24 | datetime.datetime.fromisoformat(response.json().replace("Z", "+00:00")) 25 | == await container.DIContainer.async_resource() 26 | ) 27 | -------------------------------------------------------------------------------- /docs/introduction/multiple-containers.md: -------------------------------------------------------------------------------- 1 | # Usage with multiple containers 2 | 3 | You can use providers from other containers as following: 4 | ```python 5 | from tests import container 6 | from that_depends import BaseContainer, providers 7 | 8 | 9 | class InnerContainer(BaseContainer): 10 | sync_resource = providers.Resource(container.create_sync_resource) 11 | async_resource = providers.Resource(container.create_async_resource) 12 | 13 | 14 | class OuterContainer(BaseContainer): 15 | sequence = providers.List(InnerContainer.sync_resource, InnerContainer.async_resource) 16 | ``` 17 | 18 | But this way you have to manage `InnerContainer` lifecycle: 19 | 20 | ```python 21 | await InnerContainer.tear_down() 22 | ``` 23 | 24 | Or you can connect sub-containers to the main container: 25 | 26 | ```python 27 | OuterContainer.connect_containers(InnerContainer) 28 | 29 | 30 | # this will init resources for `InnerContainer` also 31 | await OuterContainer.init_resources() 32 | 33 | # and this will tear down resources for `InnerContainer` also 34 | await OuterContainer.tear_down() 35 | ``` 36 | -------------------------------------------------------------------------------- /docs/providers/collections.md: -------------------------------------------------------------------------------- 1 | # Collections 2 | There are several collection providers: `List` and `Dict` 3 | 4 | ## List 5 | - List provider contains other providers. 6 | - Resolves into list of dependencies. 7 | 8 | ```python 9 | import random 10 | from that_depends import BaseContainer, providers 11 | 12 | 13 | class DIContainer(BaseContainer): 14 | random_number = providers.Factory(random.random) 15 | numbers_sequence = providers.List(random_number, random_number) 16 | 17 | 18 | DIContainer.numbers_sequence.resolve_sync() 19 | # [0.3035656170071561, 0.8280498192037787] 20 | ``` 21 | 22 | ## Dict 23 | - Dict provider is a collection of named providers. 24 | - Resolves into dict of dependencies. 25 | 26 | ```python 27 | import random 28 | from that_depends import BaseContainer, providers 29 | 30 | 31 | class DIContainer(BaseContainer): 32 | random_number = providers.Factory(random.random) 33 | numbers_map = providers.Dict(key1=random_number, key2=random_number) 34 | 35 | 36 | DIContainer.numbers_map.resolve_sync() 37 | # {'key1': 0.6851384528299208, 'key2': 0.41044920948045294} 38 | ``` 39 | -------------------------------------------------------------------------------- /docs/integrations/litestar.md: -------------------------------------------------------------------------------- 1 | # Usage with `Litestar` 2 | 3 | ```python 4 | import typing 5 | import fastapi 6 | import contextlib 7 | from litestar import Litestar, get 8 | from litestar.di import Provide 9 | from litestar.status_codes import HTTP_200_OK 10 | from litestar.testing import TestClient 11 | 12 | from tests import container 13 | 14 | 15 | @get("/") 16 | async def index(injected: str) -> str: 17 | return injected 18 | 19 | 20 | @contextlib.asynccontextmanager 21 | async def lifespan_manager(_: fastapi.FastAPI) -> typing.AsyncIterator[None]: 22 | try: 23 | yield 24 | finally: 25 | await container.DIContainer.tear_down() 26 | 27 | 28 | app = Litestar( 29 | route_handlers=[index], 30 | dependencies={"injected": Provide(container.DIContainer.async_resource)}, 31 | lifespan=[lifespan_manager], 32 | ) 33 | 34 | 35 | def test_litestar_di() -> None: 36 | with (TestClient(app=app) as client): 37 | response = client.get("/") 38 | assert response.status_code == HTTP_200_OK, response.text 39 | assert response.text == "async resource" 40 | ``` 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 modern-python 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /that_depends/providers/__init__.py: -------------------------------------------------------------------------------- 1 | """Providers.""" 2 | 3 | from that_depends.providers.base import AbstractProvider, AttrGetter 4 | from that_depends.providers.collection import Dict, List 5 | from that_depends.providers.context_resources import ( 6 | ContextResource, 7 | DIContextMiddleware, 8 | container_context, 9 | ) 10 | from that_depends.providers.factories import AsyncFactory, Factory 11 | from that_depends.providers.local_singleton import ThreadLocalSingleton 12 | from that_depends.providers.object import Object 13 | from that_depends.providers.resources import Resource 14 | from that_depends.providers.selector import Selector 15 | from that_depends.providers.singleton import AsyncSingleton, Singleton 16 | from that_depends.providers.state import State 17 | 18 | 19 | __all__ = [ 20 | "AbstractProvider", 21 | "AsyncFactory", 22 | "AsyncSingleton", 23 | "AttrGetter", 24 | "ContextResource", 25 | "DIContextMiddleware", 26 | "Dict", 27 | "Factory", 28 | "List", 29 | "Object", 30 | "Resource", 31 | "Selector", 32 | "Singleton", 33 | "State", 34 | "ThreadLocalSingleton", 35 | "container_context", 36 | ] 37 | -------------------------------------------------------------------------------- /tests/experimental/test_lazy_provider.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from that_depends.experimental import LazyProvider 4 | 5 | 6 | def test_lazy_provider_incorrect_initialization() -> None: 7 | with pytest.raises( 8 | ValueError, 9 | match=r"You must provide either import_string " 10 | "OR both module_string AND provider_string, but not both or neither.", 11 | ): 12 | LazyProvider(module_string="3213") # type: ignore[call-overload] 13 | 14 | with pytest.raises(ValueError, match=r"Invalid import_string ''"): 15 | LazyProvider("") 16 | 17 | with pytest.raises(ValueError, match=r"Invalid provider_string ''"): 18 | LazyProvider(module_string="some.module", provider_string="") 19 | 20 | with pytest.raises(ValueError, match=r"Invalid module_string '.'"): 21 | LazyProvider(module_string=".", provider_string="SomeProvider") 22 | 23 | with pytest.raises(ValueError, match=r"Invalid import_string 'import.'"): 24 | LazyProvider("import.") 25 | 26 | 27 | def test_lazy_provider_incorrect_import_string() -> None: 28 | p = LazyProvider("some.random.path") 29 | with pytest.raises(ImportError): 30 | p.resolve_sync() 31 | -------------------------------------------------------------------------------- /that_depends/providers/mixin.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class CannotTearDownSyncError(RuntimeError): 5 | """Raised when attempting to tear down an async resource in sync mode.""" 6 | 7 | 8 | class SupportsTeardown(abc.ABC): 9 | """Interface for objects that support teardown.""" 10 | 11 | @abc.abstractmethod 12 | async def tear_down(self, propagate: bool = True) -> None: 13 | """Perform any necessary cleanup operations. 14 | 15 | This method is called when the object is no longer needed. 16 | """ 17 | 18 | @abc.abstractmethod 19 | def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: 20 | """Perform any necessary cleanup operations. 21 | 22 | This method is called when the object is no longer needed. 23 | """ 24 | 25 | 26 | class ProviderWithArguments(abc.ABC): 27 | """Interface for providers that require arguments.""" 28 | 29 | @abc.abstractmethod 30 | def _register_arguments(self) -> None: 31 | """Register arguments for the provider.""" 32 | 33 | @abc.abstractmethod 34 | def _deregister_arguments(self) -> None: 35 | """Deregister arguments for the provider.""" 36 | -------------------------------------------------------------------------------- /that_depends/providers/object.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from typing_extensions import override 4 | 5 | from that_depends.providers.base import AbstractProvider 6 | 7 | 8 | T_co = typing.TypeVar("T_co", covariant=True) 9 | P = typing.ParamSpec("P") 10 | 11 | 12 | class Object(AbstractProvider[T_co]): 13 | """Provides an object "as is" without any modification. 14 | 15 | This provider always returns the same object that was given during 16 | initialization. 17 | 18 | Example: 19 | ```python 20 | provider = Object(1) 21 | result = provider.resolve_sync() 22 | print(result) # 1 23 | ``` 24 | 25 | """ 26 | 27 | __slots__ = ("_obj",) 28 | 29 | def __init__(self, obj: T_co) -> None: 30 | """Initialize the provider with the given object. 31 | 32 | Args: 33 | obj (T_co): The object to be provided. 34 | 35 | """ 36 | super().__init__() 37 | self._obj: typing.Final = obj 38 | 39 | @override 40 | async def resolve(self) -> T_co: 41 | return self.resolve_sync() 42 | 43 | @override 44 | def resolve_sync(self) -> T_co: 45 | if self._override is not None: 46 | return typing.cast(T_co, self._override) 47 | return self._obj 48 | -------------------------------------------------------------------------------- /examples/benchmark/injection.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import time 4 | import typing 5 | 6 | from that_depends import BaseContainer, ContextScopes, Provide, inject, providers 7 | 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def _iterator(x: float) -> typing.Iterator[float]: 14 | yield x 15 | 16 | 17 | class _Container(BaseContainer): 18 | grandparent = providers.Singleton(lambda: random.random()) 19 | parent = providers.Factory(lambda x: x, grandparent.cast) 20 | item = providers.ContextResource(_iterator, parent.cast).with_config(scope=ContextScopes.REQUEST) 21 | 22 | 23 | @inject(scope=ContextScopes.REQUEST) 24 | def _injected( 25 | x_1: float = Provide[_Container.grandparent], 26 | x_2: float = Provide[_Container.parent], 27 | x_3: float = Provide[_Container.item], 28 | ) -> float: 29 | return x_1 + x_2 + x_3 30 | 31 | 32 | def _bench(n_iterations: int) -> float: 33 | start = time.time() 34 | for _ in range(n_iterations): 35 | _injected() 36 | end = time.time() 37 | _Container.tear_down_sync() 38 | return end - start 39 | 40 | 41 | if __name__ == "__main__": 42 | for n in [10000, 100000, 1000000]: 43 | duration = _bench(n) 44 | logger.info(f"Injected {n} times in {duration:.4f} seconds") # noqa: G004 45 | -------------------------------------------------------------------------------- /tests/integrations/fastapi/test_fastapi_di_pass_request.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import fastapi 4 | from starlette import status 5 | from starlette.testclient import TestClient 6 | 7 | from tests import container 8 | from that_depends import BaseContainer, container_context, fetch_context_item, providers 9 | 10 | 11 | async def init_di_context(request: fastapi.Request) -> typing.AsyncIterator[None]: 12 | async with container_context(global_context={"request": request}): 13 | yield 14 | 15 | 16 | app = fastapi.FastAPI(dependencies=[fastapi.Depends(init_di_context)]) 17 | 18 | 19 | class DIContainer(BaseContainer): 20 | alias = "fastapi_container" 21 | context_request = providers.Factory( 22 | lambda: fetch_context_item("request"), 23 | ) 24 | 25 | 26 | @app.get("/") 27 | async def read_root( 28 | context_request: typing.Annotated[ 29 | container.DependentFactory, # incorrect type but available for backwards compatibility 30 | fastapi.Depends(DIContainer.context_request), 31 | ], 32 | request: fastapi.Request, 33 | ) -> None: 34 | assert isinstance(request, fastapi.Request) 35 | assert request == context_request 36 | 37 | 38 | client = TestClient(app) 39 | 40 | 41 | async def test_read_main() -> None: 42 | response = client.get("/") 43 | assert response.status_code == status.HTTP_200_OK 44 | -------------------------------------------------------------------------------- /tests/providers/test_factories.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from that_depends import BaseContainer, Provide, inject, providers 4 | 5 | 6 | async def test_async_factory_with_sync_creator() -> None: 7 | return_value = random.random() 8 | f = providers.AsyncFactory(lambda: return_value) 9 | 10 | assert await f.resolve() == return_value 11 | 12 | 13 | async def test_async_factory_with_sync_creator_multiple_parents() -> None: 14 | """Dependencies of async factory get resolved correctly with a sync creator.""" 15 | _return_value_1 = 32 16 | _return_value_2 = 12 17 | 18 | async def _async_creator() -> int: 19 | return 32 20 | 21 | def _sync_creator() -> int: 22 | return _return_value_2 23 | 24 | class _Adder: 25 | def __init__(self, x: int, y: int) -> None: 26 | self.x = x 27 | self.y = y 28 | 29 | def result(self) -> int: 30 | return self.x + self.y 31 | 32 | class _Container(BaseContainer): 33 | p1 = providers.AsyncFactory(_async_creator) 34 | p2 = providers.Factory(_sync_creator) 35 | p3 = providers.AsyncFactory(_Adder, x=p1.cast, y=p2.cast) 36 | 37 | @inject 38 | async def _injected(adder: _Adder = Provide[_Container.p3]) -> int: 39 | return adder.result() 40 | 41 | assert await _injected() == _return_value_1 + _return_value_2 42 | -------------------------------------------------------------------------------- /docs/providers/state.md: -------------------------------------------------------------------------------- 1 | # State 2 | 3 | The `State` provider stores a value as part of a context. 4 | 5 | It is useful when you want to pass a value into your Container that other providers depend on. 6 | 7 | 8 | ## Creating a state provider 9 | 10 | The `State` provider does not accept any arguments when it created. 11 | ```python 12 | from that_depends import BaseContainer, providers 13 | class Container(BaseContainer): 14 | my_state: providers.State[int] = providers.State() 15 | ``` 16 | 17 | ## Initializing state 18 | 19 | === "async" 20 | ```python 21 | 22 | with Container.my_state.init(42): 23 | print(await Container.my_state.resolve()) # 42 24 | 25 | ``` 26 | 27 | === "sync" 28 | ```python 29 | with Container.my_state.init(42): 30 | print(Container.my_state.resolve_sync()) # 42 31 | 32 | ``` 33 | 34 | > Note: If you try to resolve a `State` provider without initializing it first it will raise an `StateNotInitializedError`. 35 | 36 | 37 | ## Nested state 38 | 39 | The `State` provider will always resolve the last initialize value. 40 | 41 | ```python 42 | with Container.my_state.init(1): 43 | print(Container.my_state.resolve_sync()) # 1 44 | 45 | with Container.my_state.init(2): 46 | print(Container.my_state.resolve_sync()) # 2 47 | 48 | print(Container.my_state.resolve_sync()) # 1 49 | ``` 50 | -------------------------------------------------------------------------------- /docs/introduction/ioc-container.md: -------------------------------------------------------------------------------- 1 | # The Dependency Injection Container 2 | 3 | Containers serve as a central place to store and manage providers. You also define 4 | your dependency graph in the containers. 5 | 6 | While providers can be defined outside of containers with that depends, this is not recommended 7 | if you want to use any [context features](../providers/context-resources.md) 8 | 9 | 10 | ## Quickstart 11 | 12 | Define a container by subclassing `BaseContainer` and define your providers as class attributes. 13 | ```python 14 | from that_depends import BaseContainer 15 | 16 | class Container(BaseContainer): 17 | # define your providers here 18 | ``` 19 | 20 | Then you can build your dependency graph within the container: 21 | 22 | ```python 23 | from that_depends import providers 24 | 25 | class Container(BaseContainer): 26 | config = providers.Singleton(Config) 27 | session = providers.Factory(create_db_session, config=config.db) # (1)! 28 | 29 | user_repository = providers.Factory( 30 | UserRepository, 31 | session=session.cast, # (3)! 32 | config.users 33 | ) # (2)! 34 | 35 | 36 | ``` 37 | 38 | 1. The configuration will be resolved and then the `.db` attribute will be passed to the `create_db_session` creator 39 | as a keyword argument when resolving the `session` provider. 40 | 2. Depends on both the session and configuration providers. 41 | 3. Providers have the `cast` property that will change their type to the return type of their creator, use it to prevent type errors. 42 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: main 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: {} 8 | 9 | concurrency: 10 | group: ${{ github.head_ref || github.run_id }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | lint: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v5 18 | - uses: extractions/setup-just@v3 19 | - uses: astral-sh/setup-uv@v7 20 | with: 21 | cache-dependency-glob: "**/pyproject.toml" 22 | - run: uv python install 3.10 23 | - run: just install lint-ci 24 | 25 | pytest: 26 | runs-on: ubuntu-latest 27 | strategy: 28 | fail-fast: false 29 | matrix: 30 | python-version: 31 | - "3.10" 32 | - "3.11" 33 | - "3.12" 34 | - "3.13" 35 | - "3.14" 36 | faststream-version: 37 | - "<0.6.0" 38 | - ">=0.6.0" 39 | steps: 40 | - uses: actions/checkout@v5 41 | - uses: extractions/setup-just@v3 42 | - uses: astral-sh/setup-uv@v7 43 | with: 44 | cache-dependency-glob: "**/pyproject.toml" 45 | - run: uv python install ${{ matrix.python-version }} 46 | - run: just install 47 | - run: uv run --with "faststream${{ matrix.faststream-version }}" pytest . --cov=. --cov-report xml 48 | - uses: codecov/codecov-action@v5.4.3 49 | env: 50 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 51 | with: 52 | files: ./coverage.xml 53 | flags: unittests 54 | name: codecov-${{ matrix.python-version }} 55 | -------------------------------------------------------------------------------- /docs/experimental/lazy.md: -------------------------------------------------------------------------------- 1 | # Lazy Provider 2 | 3 | The `LazyProvider` enables you to reference other providers without explicitly 4 | importing them into your module. 5 | 6 | This can be helpful if you have a circular dependency between providers in 7 | multiple containers. 8 | 9 | 10 | ## Creating a Lazy Provider 11 | 12 | === "Single import string" 13 | ```python 14 | from that_depends.experimental import LazyProvider 15 | 16 | lazy_p = LazyProvider("my.module.attribute.path") 17 | ``` 18 | === "Separate module and provider" 19 | ```python 20 | from that_depends.experimental import LazyProvider 21 | 22 | lazy_p = LazyProvider(module_string="my.module", provider_string="attribute.path") 23 | ``` 24 | 25 | 26 | ## Usage 27 | 28 | You can use the lazy provider in exactly the same way as you would use the referenced provider. 29 | 30 | ```python 31 | # first_container.py 32 | from that_depends import BaseContainer, providers, ContextScopes 33 | 34 | def my_creator(): 35 | yield 42 36 | 37 | class FirstContainer(BaseContainer): 38 | value_provider = providers.ContextResource(my_creator).with_config(scope=ContextScopes.APP) 39 | ``` 40 | 41 | You can lazily import this provider: 42 | ```python 43 | # second_container.py 44 | from that_depends.experimental import LazyProvider 45 | from that_depends import BaseContainer, providers 46 | class SecondContainer(BaseContainer): 47 | lazy_value = LazyProvider("first_container.FirstContainer.value_provider") 48 | 49 | 50 | with SecondContainer.lazy_value.context_sync(force=True): 51 | SecondContainer.lazy_value.resolve_sync() # 42 52 | ``` 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | "That Depends" 2 | == 3 | [![Test Coverage](https://codecov.io/gh/modern-python/that-depends/branch/main/graph/badge.svg)](https://codecov.io/gh/modern-python/that-depends) 4 | [![MyPy Strict](https://img.shields.io/badge/mypy-strict-blue)](https://mypy.readthedocs.io/en/stable/getting_started.html#strict-mode-and-configuration) 5 | [![Supported versions](https://img.shields.io/pypi/pyversions/that-depends.svg)](https://pypi.python.org/pypi/that-depends) 6 | [![PyPI Downloads](https://static.pepy.tech/badge/that-depends/month)](https://pepy.tech/projects/that-depends) 7 | [![GitHub stars](https://img.shields.io/github/stars/modern-python/that-depends)](https://github.com/modern-python/that-depends/stargazers) 8 | [![libs.tech recommends](https://libs.tech/project/773446541/badge.svg)](https://libs.tech/project/773446541/that-depends) 9 | [![llms.txt](https://img.shields.io/badge/llms.txt-green)](https://that-depends.readthedocs.io/llms.txt) 10 | 11 | Dependency injection framework for Python. 12 | 13 | It is production-ready and gives you the following: 14 | - Simple async-first DI framework with IOC-container. 15 | - Python 3.10+ support. 16 | - Full coverage by types annotations (mypy in strict mode). 17 | - Inbuilt FastAPI, FastStream and LiteStar compatibility. 18 | - Dependency context management with scopes. 19 | - Overriding dependencies for tests. 20 | - Injecting dependencies in functions and coroutines without wiring. 21 | - Package with zero dependencies. 22 | 23 | 24 | ### Installation 25 | ```bash 26 | pip install that-depends 27 | ``` 28 | 29 | ## 📚 [Documentation](https://that-depends.readthedocs.io) 30 | 31 | ## 📦 [PyPi](https://pypi.org/project/that-depends) 32 | 33 | ## 📝 [License](LICENSE) 34 | -------------------------------------------------------------------------------- /tests/creators.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import types 4 | import typing 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | async def create_async_resource() -> typing.AsyncIterator[datetime.datetime]: 11 | logger.debug("Async resource initiated") 12 | try: 13 | yield datetime.datetime.now(tz=datetime.timezone.utc) 14 | finally: 15 | logger.debug("Async resource destructed") 16 | 17 | 18 | def create_sync_resource() -> typing.Iterator[datetime.datetime]: 19 | logger.debug("Resource initiated") 20 | try: 21 | yield datetime.datetime.now(tz=datetime.timezone.utc) 22 | finally: 23 | logger.debug("Resource destructed") 24 | 25 | 26 | class ContextManagerResource(typing.ContextManager[datetime.datetime]): 27 | def __enter__(self) -> datetime.datetime: 28 | logger.debug("Resource initiated") 29 | return datetime.datetime.now(tz=datetime.timezone.utc) 30 | 31 | def __exit__( 32 | self, 33 | exc_type: type[BaseException] | None, 34 | exc_value: BaseException | None, 35 | traceback: types.TracebackType | None, 36 | ) -> None: 37 | logger.debug("Resource destructed") 38 | 39 | 40 | class AsyncContextManagerResource(typing.AsyncContextManager[datetime.datetime]): 41 | async def __aenter__(self) -> datetime.datetime: 42 | logger.debug("Async resource initiated") 43 | return datetime.datetime.now(tz=datetime.timezone.utc) 44 | 45 | async def __aexit__( 46 | self, 47 | exc_type: type[BaseException] | None, 48 | exc_value: BaseException | None, 49 | traceback: types.TracebackType | None, 50 | ) -> None: 51 | logger.debug("Async resource destructed") 52 | -------------------------------------------------------------------------------- /that_depends/providers/state.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import uuid 3 | from contextlib import contextmanager 4 | from contextvars import ContextVar 5 | from typing import TypeVar 6 | 7 | from typing_extensions import override 8 | 9 | from that_depends.exceptions import StateNotInitializedError 10 | from that_depends.providers import AbstractProvider 11 | from that_depends.utils import UNSET, Unset, is_set 12 | 13 | 14 | T = TypeVar("T", covariant=False) 15 | 16 | _STATE_UNSET_ERROR_MESSAGE: typing.Final[str] = ( 17 | "State has not been initialized.\n" 18 | "Please use `with provider.init(state):` to initialize the state before resolving it." 19 | ) 20 | 21 | 22 | class State(AbstractProvider[T]): 23 | """Provides a value that can be passed into the provider at runtime.""" 24 | 25 | def __init__(self) -> None: 26 | """Create a state provider.""" 27 | super().__init__() 28 | 29 | self._state: ContextVar[T | Unset] = ContextVar(f"STATE_{uuid.uuid4()}", default=UNSET) 30 | 31 | @contextmanager 32 | def init(self, state: T) -> typing.Iterator[T]: 33 | """Set the state provider's value. 34 | 35 | Args: 36 | state: value to store. 37 | 38 | Returns: 39 | Iterator[T_co]: A context manager that yields the state value. 40 | 41 | """ 42 | token = self._state.set(state) 43 | yield state 44 | self._state.reset(token) 45 | 46 | @override 47 | async def resolve(self) -> T: 48 | return self.resolve_sync() 49 | 50 | @override 51 | def resolve_sync(self) -> T: 52 | value = self._state.get() 53 | if is_set(value): 54 | return value 55 | raise StateNotInitializedError(_STATE_UNSET_ERROR_MESSAGE) 56 | -------------------------------------------------------------------------------- /docs/introduction/type-based-injection.md: -------------------------------------------------------------------------------- 1 | # Type based injection 2 | 3 | `that-depends` also supports dependency injection without explicitly referencing 4 | the provider of the dependency. 5 | 6 | ## Quick Start 7 | 8 | In order to make use of this, you need to bind providers to the type they will provide: 9 | ```python 10 | class Container(BaseContainer): 11 | my_provider = providers.Factory(lambda: random.random()).bind(float) 12 | ``` 13 | 14 | Then provide inject into your functions or generators: 15 | 16 | === "Option 1" 17 | ```python 18 | @Container.inject 19 | async def foo(v: float = Provide()): 20 | return v 21 | ``` 22 | === "Option 2" 23 | ```python 24 | @inject(container=Container) 25 | async def foo(v: float = Provide()): 26 | return v 27 | ``` 28 | 29 | 30 | ## Default bind 31 | 32 | Per default, providers will **not** be bound to any type, even if your creator 33 | function has type hints. So make sure to always bind your providers. 34 | 35 | You can also bind multiple types to the same provider: 36 | ```python 37 | class Container(BaseContainer): 38 | my_provider = providers.Factory(lambda: random.random()).bind(float, complex) 39 | ``` 40 | 41 | ## Contravariant binding 42 | 43 | Per default injection will be invariant to the bound types. 44 | 45 | If you wish to enable contravariance for your bound types you can do so by setting 46 | `#!python contravariant=True` in the `bind` method: 47 | 48 | ```python hl_lines="6 9" 49 | class A: ... 50 | 51 | class B(A): ... 52 | 53 | class Container(BaseContainer): 54 | my_provider = providers.Factory(lambda: B()).bind(B, contravariant=True) 55 | 56 | @Container.inject 57 | async def foo(v: A = Provide()) -> A: # (1)! 58 | return v 59 | ``` 60 | 61 | 1. `v` will receive an instance of `B` since `A` is a supertype of `B`. 62 | -------------------------------------------------------------------------------- /docs/providers/selector.md: -------------------------------------------------------------------------------- 1 | # Selector 2 | 3 | The Selector provider chooses between provider based on a key. This resolves into a single dependency. 4 | 5 | The selector can be a callable that returns a string, an instance of `AbstractProvider` or a string. 6 | 7 | ## Callable selectors 8 | 9 | ```python 10 | import os 11 | from typing import Protocol 12 | from that_depends import BaseContainer, providers 13 | 14 | class StorageService(Protocol): 15 | ... 16 | 17 | class StorageServiceLocal(StorageService): 18 | ... 19 | 20 | class StorageServiceRemote(StorageService): 21 | ... 22 | 23 | class DIContainer(BaseContainer): 24 | storage_service = providers.Selector( 25 | lambda: os.getenv("STORAGE_BACKEND", "local"), 26 | local=providers.Factory(StorageServiceLocal), 27 | remote=providers.Factory(StorageServiceRemote), 28 | ) 29 | ``` 30 | 31 | ## Provider selectors 32 | 33 | In this example, we have a Pydantic-Settings class that contains the key to select the provider. 34 | 35 | ```python 36 | from typing import Literal 37 | from pydantic_settings import BaseSettings 38 | 39 | class Settings(BaseSettings): 40 | storage_backend: Literal["local", "remote"] = "remote" 41 | 42 | class DIContainer(BaseContainer): 43 | settings = providers.Singleton(Settings) 44 | selector = providers.Selector( 45 | settings.cast.storage_backend, 46 | local=providers.Factory(StorageServiceLocal), 47 | remote=providers.Factory(StorageServiceRemote), 48 | ) 49 | ``` 50 | 51 | ## Fixed string selectors 52 | 53 | This can be useful for quickly testing. 54 | 55 | ```python 56 | class DIContainer(BaseContainer): 57 | selector = providers.Selector( 58 | "local", 59 | local=providers.Factory(StorageServiceLocal), 60 | remote=providers.Factory(StorageServiceRemote), 61 | ) 62 | ``` 63 | -------------------------------------------------------------------------------- /tests/integrations/faststream/test_faststream_di.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import typing 3 | 4 | from faststream import Depends 5 | from faststream.nats import NatsBroker, TestNatsBroker 6 | 7 | from tests import container 8 | from that_depends.integrations.faststream import DIContextMiddleware 9 | 10 | 11 | broker = NatsBroker(middlewares=[DIContextMiddleware(container.DIContainer.context_resource)]) 12 | 13 | TEST_SUBJECT = "test" 14 | 15 | 16 | @broker.subscriber(TEST_SUBJECT) 17 | async def index_subscriber( 18 | dependency: typing.Annotated[ 19 | container.DependentFactory, 20 | Depends(container.DIContainer.dependent_factory), 21 | ], 22 | free_dependency: typing.Annotated[ 23 | container.FreeFactory, 24 | Depends(container.DIContainer.resolver(container.FreeFactory)), 25 | ], 26 | singleton: typing.Annotated[ 27 | container.SingletonFactory, 28 | Depends(container.DIContainer.singleton), 29 | ], 30 | singleton_attribute: typing.Annotated[bool, Depends(container.DIContainer.singleton.dep1)], 31 | context_resource: typing.Annotated[datetime.datetime, Depends(container.DIContainer.context_resource)], 32 | ) -> datetime.datetime: 33 | assert dependency.sync_resource == free_dependency.dependent_factory.sync_resource 34 | assert dependency.async_resource == free_dependency.dependent_factory.async_resource 35 | assert singleton.dep1 is True 36 | assert singleton_attribute is True 37 | assert isinstance(context_resource, datetime.datetime) 38 | return dependency.async_resource 39 | 40 | 41 | async def test_read_main() -> None: 42 | async with TestNatsBroker(broker) as br: 43 | result = await br.request(None, TEST_SUBJECT) 44 | 45 | result_str = typing.cast(str, await result.decode()) 46 | assert ( 47 | datetime.datetime.fromisoformat(result_str.replace("Z", "+00:00")) 48 | == await container.DIContainer.async_resource() 49 | ) 50 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # That Depends 2 | 3 | Welcome to the `that-depends` documentation! 4 | 5 | `that-depends` is a python dependency injection framework which, among other things, 6 | supports the following: 7 | 8 | - Async and sync dependency resolution 9 | - Scopes and granular context management 10 | - Dependency injection anywhere 11 | - Fully typed and tested 12 | - Compatibility with popular frameworks like `FastAPI` and `LiteStar` 13 | - Python 3.10+ support 14 | 15 | --- 16 | 17 | ## Installation 18 | 19 | === "pip" 20 | ```bash 21 | pip install that-depends 22 | ``` 23 | === "uv" 24 | ```bash 25 | uv add that-depends 26 | ``` 27 | 28 | --- 29 | 30 | ## Quickstart 31 | 32 | ### Define a creator 33 | ```python 34 | async def create_async_resource(): 35 | logger.debug("Async resource initiated") 36 | try: 37 | yield "async resource" 38 | finally: 39 | logger.debug("Async resource destructed") 40 | ``` 41 | 42 | ### Setup Dependency Injection Container with Providers 43 | ```python 44 | from that_depends import BaseContainer, providers 45 | 46 | class Container(BaseContainer): 47 | provider = providers.Resource(create_async_resource) 48 | ``` 49 | 50 | See the [containers documentation](introduction/ioc-container.md) for more information on defining the container. 51 | 52 | For a list of providers and their usage, see the [providers section](providers/collections.md). 53 | ### Resolve dependencies in your code 54 | ```python 55 | await Container.provider() 56 | ``` 57 | 58 | ### Inject providers in function arguments 59 | ```python 60 | from that_depends import inject, Provide 61 | 62 | @inject 63 | async def some_foo(value: str = Provide[Container.provider]): 64 | return value 65 | 66 | await some_foo() # "async resource" 67 | ``` 68 | 69 | See the [injection documentation](introduction/injection.md) for more information. 70 | 71 | ## llms.txt 72 | 73 | `that-depends` provides a [llms.txt](https://that-depends.readthedocs.io/llms.txt) file. 74 | 75 | You can find documentation on how to use it [here](https://llmstxt.org/). 76 | -------------------------------------------------------------------------------- /tests/integrations/faststream/test_faststream_di_pass_message.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from faststream import BaseMiddleware, Context, Depends 4 | from faststream.nats import NatsBroker, TestNatsBroker 5 | from faststream.nats.message import NatsMessage 6 | from packaging.version import Version 7 | 8 | from that_depends import BaseContainer, container_context, fetch_context_item, providers 9 | from that_depends.integrations.faststream import _FASTSTREAM_VERSION 10 | 11 | 12 | if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover 13 | from faststream.message import StreamMessage 14 | else: # pragma: no cover 15 | from faststream.broker.message import StreamMessage # type: ignore[import-not-found, no-redef] 16 | 17 | 18 | class ContextMiddleware(BaseMiddleware): 19 | async def consume_scope( 20 | self, 21 | call_next: typing.Callable[..., typing.Awaitable[typing.Any]], 22 | msg: StreamMessage[typing.Any], 23 | ) -> typing.Any: # noqa: ANN401 24 | async with container_context(global_context={"request": msg}): 25 | return await call_next(msg) 26 | 27 | 28 | if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover 29 | broker = NatsBroker(middlewares=(ContextMiddleware,)) 30 | 31 | else: # pragma: no cover 32 | broker = NatsBroker(middlewares=(ContextMiddleware,), validate=False) # type: ignore[call-arg] 33 | 34 | TEST_SUBJECT = "test" 35 | 36 | 37 | class DIContainer(BaseContainer): 38 | context_request = providers.Factory( 39 | lambda: fetch_context_item("request"), 40 | ) 41 | 42 | 43 | @broker.subscriber(TEST_SUBJECT) 44 | async def index_subscriber( 45 | context_request: typing.Annotated[ 46 | NatsMessage, 47 | Depends(DIContainer.context_request, cast=False), 48 | ], 49 | message: typing.Annotated[ 50 | NatsMessage, 51 | Context(), 52 | ], 53 | ) -> bool: 54 | return message is context_request 55 | 56 | 57 | async def test_read_main() -> None: 58 | async with TestNatsBroker(broker) as br: 59 | result = await br.request(None, TEST_SUBJECT) 60 | 61 | assert (await result.decode()) is True 62 | -------------------------------------------------------------------------------- /tests/test_multiple_containers.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import pytest 4 | 5 | from tests import container 6 | from that_depends import BaseContainer, providers 7 | 8 | 9 | class InnerContainer(BaseContainer): 10 | sync_resource = providers.Resource(container.create_sync_resource) 11 | async_resource = providers.Resource(container.create_async_resource) 12 | 13 | 14 | class OuterContainer(BaseContainer): 15 | sequence = providers.List(InnerContainer.sync_resource, InnerContainer.async_resource) 16 | 17 | 18 | OuterContainer.connect_containers(InnerContainer) 19 | 20 | 21 | async def test_included_container() -> None: 22 | sequence = await OuterContainer.sequence() 23 | assert all(isinstance(x, datetime.datetime) for x in sequence) 24 | 25 | await OuterContainer.tear_down() 26 | assert InnerContainer.sync_resource._context.instance is None 27 | assert InnerContainer.async_resource._context.instance is None 28 | 29 | await OuterContainer.init_resources() 30 | sync_resource_context = InnerContainer.sync_resource._context 31 | assert sync_resource_context 32 | assert sync_resource_context.instance is not None 33 | async_resource_context = InnerContainer.async_resource._context 34 | assert async_resource_context 35 | assert async_resource_context.instance is not None 36 | await OuterContainer.tear_down() 37 | 38 | 39 | async def test_overwriting_container_warns(recwarn: None) -> None: # noqa:ARG001 40 | class _A(BaseContainer): 41 | pass 42 | 43 | with pytest.warns(UserWarning, match="Overwriting container '_A'"): 44 | 45 | class _A(BaseContainer): # type: ignore[no-redef] 46 | pass 47 | 48 | class _A(BaseContainer): # type: ignore[no-redef] 49 | pass 50 | 51 | 52 | async def test_overwriting_container_with_alias_warns(recwarn: None) -> None: # noqa:ARG001 53 | class _A(BaseContainer): 54 | alias = "a" 55 | 56 | with pytest.warns(UserWarning, match="Overwriting container 'a'"): 57 | 58 | class _A(BaseContainer): # type: ignore[no-redef] 59 | alias = "a" 60 | 61 | class _A(BaseContainer): # type: ignore[no-redef] 62 | alias = "a" 63 | -------------------------------------------------------------------------------- /tests/providers/test_collections.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | 5 | from tests.container import create_async_resource, create_sync_resource 6 | from that_depends import BaseContainer, providers 7 | from that_depends.providers import AbstractProvider 8 | 9 | 10 | class DIContainer(BaseContainer): 11 | alias = "collection_container" 12 | sync_resource = providers.Resource(create_sync_resource) 13 | async_resource = providers.Resource(create_async_resource) 14 | sequence = providers.List(sync_resource, async_resource) 15 | mapping = providers.Dict(sync_resource=sync_resource, async_resource=async_resource) 16 | 17 | 18 | @pytest.fixture(autouse=True) 19 | async def _clear_di_container() -> typing.AsyncIterator[None]: 20 | try: 21 | yield 22 | finally: 23 | await DIContainer.tear_down() 24 | 25 | 26 | async def test_list_provider() -> None: 27 | sequence = await DIContainer.sequence() 28 | sync_resource = await DIContainer.sync_resource() 29 | async_resource = await DIContainer.async_resource() 30 | 31 | assert sequence == [sync_resource, async_resource] 32 | 33 | 34 | def test_list_failed_sync_resolve() -> None: 35 | with pytest.raises(RuntimeError, match=r"AsyncResource cannot be resolved synchronously"): 36 | DIContainer.sequence.resolve_sync() 37 | 38 | 39 | async def test_list_sync_resolve_after_init() -> None: 40 | await DIContainer.init_resources() 41 | DIContainer.sequence.resolve_sync() 42 | 43 | 44 | async def test_dict_provider() -> None: 45 | mapping = await DIContainer.mapping() 46 | sync_resource = await DIContainer.sync_resource() 47 | async_resource = await DIContainer.async_resource() 48 | 49 | assert mapping == {"sync_resource": sync_resource, "async_resource": async_resource} 50 | assert mapping == DIContainer.mapping.resolve_sync() 51 | 52 | 53 | @pytest.mark.parametrize("provider", [DIContainer.sequence, DIContainer.mapping]) 54 | async def test_attr_getter_in_collections_providers(provider: AbstractProvider[typing.Any]) -> None: 55 | with pytest.raises(AttributeError, match=f"'{type(provider)}' object has no attribute 'some_attribute'"): 56 | await provider.some_attribute 57 | -------------------------------------------------------------------------------- /tests/container.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import datetime 3 | import logging 4 | import typing 5 | 6 | from that_depends import BaseContainer, providers 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def create_sync_resource() -> typing.Iterator[datetime.datetime]: 13 | logger.debug("Resource initiated") 14 | try: 15 | yield datetime.datetime.now(tz=datetime.timezone.utc) 16 | finally: 17 | logger.debug("Resource destructed") 18 | 19 | 20 | async def create_async_resource() -> typing.AsyncIterator[datetime.datetime]: 21 | logger.debug("Async resource initiated") 22 | try: 23 | yield datetime.datetime.now(tz=datetime.timezone.utc) 24 | finally: 25 | logger.debug("Async resource destructed") 26 | 27 | 28 | @dataclasses.dataclass(kw_only=True, slots=True) 29 | class SimpleFactory: 30 | dep1: str 31 | dep2: int 32 | 33 | 34 | async def async_factory(now: datetime.datetime) -> datetime.datetime: 35 | return now + datetime.timedelta(hours=1) 36 | 37 | 38 | @dataclasses.dataclass(kw_only=True, slots=True) 39 | class DependentFactory: 40 | simple_factory: SimpleFactory 41 | sync_resource: datetime.datetime 42 | async_resource: datetime.datetime 43 | 44 | 45 | @dataclasses.dataclass(kw_only=True, slots=True) 46 | class FreeFactory: 47 | dependent_factory: DependentFactory 48 | sync_resource: str 49 | 50 | 51 | @dataclasses.dataclass(kw_only=True, slots=True) 52 | class SingletonFactory: 53 | dep1: bool 54 | 55 | 56 | class DIContainer(BaseContainer): 57 | default_scope = None 58 | alias = "test_container" 59 | sync_resource = providers.Resource(create_sync_resource) 60 | async_resource = providers.Resource(create_async_resource) 61 | 62 | simple_factory = providers.Factory(SimpleFactory, dep1="text", dep2=123) 63 | async_factory = providers.AsyncFactory(async_factory, async_resource.cast) 64 | dependent_factory = providers.Factory( 65 | DependentFactory, 66 | simple_factory=simple_factory.cast, 67 | sync_resource=sync_resource.cast, 68 | async_resource=async_resource.cast, 69 | ) 70 | singleton = providers.Singleton(SingletonFactory, dep1=True) 71 | object = providers.Object(object()) 72 | context_resource = providers.ContextResource(create_async_resource) 73 | -------------------------------------------------------------------------------- /tests/providers/test_inject_factories.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import typing 3 | 4 | import pytest 5 | 6 | from tests import container 7 | from that_depends import BaseContainer, providers 8 | 9 | 10 | @dataclasses.dataclass(kw_only=True, slots=True) 11 | class InjectedFactories: 12 | sync_factory: typing.Callable[..., container.DependentFactory] 13 | async_factory: typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, container.DependentFactory]] 14 | 15 | 16 | class DIContainer(BaseContainer): 17 | alias = "inject_factories_container" 18 | sync_resource = providers.Resource(container.create_sync_resource) 19 | async_resource = providers.Resource(container.create_async_resource) 20 | 21 | simple_factory = providers.Factory(container.SimpleFactory, dep1="text", dep2=123) 22 | dependent_factory = providers.Factory( 23 | container.DependentFactory, 24 | simple_factory=simple_factory.cast, 25 | sync_resource=sync_resource.cast, 26 | async_resource=async_resource.cast, 27 | ) 28 | injected_factories = providers.Factory( 29 | InjectedFactories, 30 | sync_factory=dependent_factory.provider_sync, 31 | async_factory=dependent_factory.provider, 32 | ) 33 | 34 | 35 | async def test_async_provider() -> None: 36 | injected_factories = await DIContainer.injected_factories() 37 | instance1 = await injected_factories.async_factory() 38 | instance2 = await injected_factories.async_factory() 39 | 40 | assert isinstance(instance1, container.DependentFactory) 41 | assert isinstance(instance2, container.DependentFactory) 42 | assert instance1 is not instance2 43 | 44 | await DIContainer.tear_down() 45 | 46 | 47 | async def test_sync_provider() -> None: 48 | injected_factories = await DIContainer.injected_factories() 49 | with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"): 50 | injected_factories.sync_factory() 51 | 52 | await DIContainer.init_resources() 53 | instance1 = injected_factories.sync_factory() 54 | instance2 = injected_factories.sync_factory() 55 | 56 | assert isinstance(instance1, container.DependentFactory) 57 | assert isinstance(instance2, container.DependentFactory) 58 | assert instance1 is not instance2 59 | 60 | await DIContainer.tear_down() 61 | -------------------------------------------------------------------------------- /docs/introduction/string-injection.md: -------------------------------------------------------------------------------- 1 | ## Provider resolution by name 2 | 3 | The `@inject` decorator can be used to inject a provider by its name. This is useful when you want to inject a provider that is not directly imported in the current module. 4 | 5 | 6 | This serves two primary purposes: 7 | 8 | - A higher level of decoupling between container and your code. 9 | - Avoiding circular imports. 10 | 11 | --- 12 | ## Usage 13 | 14 | To inject a provider by name, use the `Provide` marker with a string argument that has the following format: 15 | 16 | ``` 17 | Container.Provider[.attribute.attribute...] 18 | ``` 19 | 20 | The string will be validated when it is passed to `Provide[]`, thus will raise an exception 21 | immediately. 22 | 23 | **For example**: 24 | 25 | ```python 26 | from that_depends import BaseContainer, inject, Provide 27 | 28 | class Config(BaseSettings): 29 | name: str = "Damian" 30 | 31 | class A(BaseContainer): 32 | b = providers.Factory(Config) 33 | 34 | @inject 35 | def read(val = Provide["A.b.name"]): 36 | return val 37 | 38 | assert read() == "Damian" 39 | ``` 40 | 41 | ### Container alias 42 | 43 | Containers support aliases: 44 | 45 | ```python 46 | 47 | class A(BaseContainer): 48 | alias = "C" # replaces the container name. 49 | b = providers.Factory(Config) 50 | 51 | @inject 52 | def read(val = Provide["C.b.name"]): # `A` can no longer be used. 53 | return val 54 | 55 | assert read() == "Damian" 56 | ``` 57 | 58 | --- 59 | ## Considerations 60 | 61 | This feature is primarily intended as a fallback when other options are not optimal or 62 | simply not available, thus is recommended to be used sparingly. 63 | 64 | If you do decide to use injection by name, consider the following: 65 | 66 | - In order for this type of injection to work, your container must be in scope when the injected function is called: 67 | ```python 68 | from that_depends import BaseContainer, inject, Provide 69 | 70 | @inject 71 | def injected(f = Provide["MyContainer.my_provider"]): ... 72 | 73 | injected() # will raise an Exception 74 | 75 | class MyContainer(BaseContainer): 76 | my_provider = providers.Factory(some_creator) 77 | 78 | injected() # will resolve 79 | ``` 80 | - Validation of whether you have provided a correct container name and provider name will only happen when the function is called. 81 | -------------------------------------------------------------------------------- /tests/providers/test_main_providers.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import pytest 4 | 5 | from tests import container 6 | from tests.container import DIContainer 7 | from that_depends import providers 8 | 9 | 10 | async def test_factory_providers() -> None: 11 | simple_factory = await DIContainer.simple_factory() 12 | dependent_factory = await DIContainer.dependent_factory() 13 | async_factory = await DIContainer.async_factory() 14 | sync_resource = await DIContainer.sync_resource() 15 | async_resource = await DIContainer.async_resource() 16 | 17 | assert dependent_factory.simple_factory is not simple_factory 18 | assert DIContainer.simple_factory.resolve_sync() is not simple_factory 19 | assert dependent_factory.sync_resource == sync_resource 20 | assert dependent_factory.async_resource == async_resource 21 | assert isinstance(async_factory, datetime.datetime) 22 | 23 | 24 | async def test_factories_cannot_deregister_arguments() -> None: 25 | with pytest.raises(NotImplementedError): 26 | DIContainer.simple_factory._deregister_arguments() 27 | with pytest.raises(NotImplementedError): 28 | DIContainer.async_factory._deregister_arguments() 29 | 30 | 31 | async def test_async_resource_provider() -> None: 32 | async_resource = await DIContainer.async_resource() 33 | 34 | assert DIContainer.async_resource.resolve_sync() is async_resource 35 | 36 | 37 | def test_failed_sync_resolve() -> None: 38 | with pytest.raises(RuntimeError, match="AsyncFactory cannot be resolved synchronously"): 39 | DIContainer.async_factory.resolve_sync() 40 | 41 | with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"): 42 | DIContainer.async_resource.resolve_sync() 43 | 44 | 45 | def test_wrong_providers_init() -> None: 46 | with pytest.raises(TypeError, match="Unsupported resource type"): 47 | providers.Resource(lambda: None) # type: ignore[arg-type,return-value] 48 | 49 | 50 | def test_container_init_error() -> None: 51 | with pytest.raises(RuntimeError, match="DIContainer should not be instantiated"): 52 | DIContainer() 53 | 54 | 55 | async def test_free_dependency() -> None: 56 | resolver = DIContainer.resolver(container.FreeFactory) 57 | dep1 = await resolver() 58 | dep2 = await DIContainer.resolve(container.FreeFactory) 59 | assert dep1 60 | assert dep2 61 | -------------------------------------------------------------------------------- /tests/integrations/fastapi/test_fastapi_di.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import typing 3 | 4 | import fastapi 5 | import pytest 6 | from starlette import status 7 | from starlette.testclient import TestClient 8 | 9 | from tests import container 10 | from that_depends import fetch_context_item 11 | from that_depends.providers import DIContextMiddleware 12 | 13 | 14 | _GLOBAL_CONTEXT: typing.Final[dict[str, str]] = {"test2": "value2", "test1": "value1"} 15 | 16 | 17 | @pytest.fixture 18 | def fastapi_app() -> fastapi.FastAPI: 19 | app = fastapi.FastAPI() 20 | app.add_middleware(DIContextMiddleware, container.DIContainer, global_context=_GLOBAL_CONTEXT) 21 | 22 | @app.get("/") 23 | async def read_root( 24 | dependency: typing.Annotated[ 25 | container.DependentFactory, 26 | fastapi.Depends(container.DIContainer.dependent_factory), 27 | ], 28 | free_dependency: typing.Annotated[ 29 | container.FreeFactory, 30 | fastapi.Depends(container.DIContainer.resolver(container.FreeFactory)), 31 | ], 32 | singleton: typing.Annotated[ 33 | container.SingletonFactory, 34 | fastapi.Depends(container.DIContainer.singleton), 35 | ], 36 | singleton_attribute: typing.Annotated[bool, fastapi.Depends(container.DIContainer.singleton.dep1)], 37 | context_resource: typing.Annotated[datetime.datetime, fastapi.Depends(container.DIContainer.context_resource)], 38 | ) -> datetime.datetime: 39 | assert dependency.sync_resource == free_dependency.dependent_factory.sync_resource 40 | assert dependency.async_resource == free_dependency.dependent_factory.async_resource 41 | assert singleton.dep1 is True 42 | assert singleton_attribute is True 43 | assert context_resource == await container.DIContainer.context_resource.resolve() 44 | for key, value in _GLOBAL_CONTEXT.items(): 45 | assert fetch_context_item(key) == value 46 | return dependency.async_resource 47 | 48 | return app 49 | 50 | 51 | @pytest.fixture 52 | def fastapi_client(fastapi_app: fastapi.FastAPI) -> TestClient: 53 | return TestClient(fastapi_app) 54 | 55 | 56 | async def test_read_main(fastapi_client: TestClient) -> None: 57 | response = fastapi_client.get("/") 58 | assert response.status_code == status.HTTP_200_OK 59 | assert ( 60 | datetime.datetime.fromisoformat(response.json().replace("Z", "+00:00")) 61 | == await container.DIContainer.async_resource() 62 | ) 63 | -------------------------------------------------------------------------------- /docs/integrations/faststream.md: -------------------------------------------------------------------------------- 1 | # Usage with `FastStream` 2 | 3 | 4 | `that-depends` is out of the box compatible with `faststream.Depends()`: 5 | 6 | ```python hl_lines="14" 7 | from typing import Annotated 8 | from faststream import Depends 9 | from faststream.asgi import AsgiFastStream 10 | from faststream.rabbit import RabbitBroker 11 | 12 | 13 | broker = RabbitBroker() 14 | app = AsgiFastStream(broker) 15 | 16 | @broker.subscriber(queue="queue") 17 | async def process( 18 | text: str, 19 | suffix: Annotated[ 20 | str, Depends(Container.suffix_factory) # (1)! 21 | ], 22 | ) -> None: 23 | return text + suffix 24 | ``` 25 | 26 | 1. This would be the same as `Provide[Container.suffix_factory]` 27 | 28 | 29 | ## Context Middleware 30 | 31 | If you are using [ContextResource](../providers/context-resources.md) provider, you likely will want to 32 | initialize a context before processing message with `faststream.` 33 | 34 | `that-depends` provides integration for these use cases: 35 | 36 | ```shell 37 | pip install that-depends[faststream] 38 | ``` 39 | 40 | Then you can use the `DIContextMiddleware` with your broker: 41 | 42 | ```python 43 | from that_depends.integrations.faststream import DIContextMiddleware 44 | from that_depends import ContextScopes 45 | from faststream.rabbit import RabbitBroker 46 | 47 | broker = RabbitBroker(middlewares=[DIContextMiddleware(Container, scope=ContextScopes.REQUEST)]) 48 | ``` 49 | 50 | ## Example 51 | 52 | Here is an example that includes life-cycle events: 53 | 54 | ```python 55 | import datetime 56 | import contextlib 57 | import typing 58 | 59 | from faststream import FastStream, Depends, Logger 60 | from faststream.rabbit import RabbitBroker 61 | 62 | from tests import container 63 | 64 | 65 | @contextlib.asynccontextmanager 66 | async def lifespan_manager() -> typing.AsyncIterator[None]: 67 | try: 68 | yield 69 | finally: 70 | await container.DIContainer.tear_down() 71 | 72 | 73 | broker = RabbitBroker() 74 | app = FastStream(broker, lifespan=lifespan_manager) 75 | 76 | 77 | @broker.subscriber("in") 78 | async def read_root( 79 | logger: Logger, 80 | some_dependency: typing.Annotated[ 81 | container.DependentFactory, 82 | Depends(container.DIContainer.dependent_factory) 83 | ], 84 | ) -> datetime.datetime: 85 | startup_time = some_dependency.async_resource 86 | logger.info(startup_time) 87 | return startup_time 88 | 89 | 90 | @app.after_startup 91 | async def t() -> None: 92 | await broker.publish(None, "in") 93 | ``` 94 | -------------------------------------------------------------------------------- /tests/providers/test_state.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import random 3 | 4 | import pytest 5 | 6 | from that_depends import BaseContainer 7 | from that_depends.exceptions import StateNotInitializedError 8 | from that_depends.providers import AsyncFactory, Factory, State 9 | 10 | 11 | async def test_state_raises_if_not_set_async() -> None: 12 | class _Container(BaseContainer): 13 | state: State[str] = State() 14 | 15 | with pytest.raises(StateNotInitializedError): 16 | await _Container.state.resolve() 17 | 18 | 19 | def test_state_raises_if_not_set_sync() -> None: 20 | class _Container(BaseContainer): 21 | state: State[str] = State() 22 | 23 | with pytest.raises(StateNotInitializedError): 24 | _Container.state.resolve_sync() 25 | 26 | 27 | async def test_state_set_resolves_correctly_async() -> None: 28 | class _Container(BaseContainer): 29 | state: State[str] = State() 30 | 31 | state_value = "test_value" 32 | 33 | with _Container.state.init(state_value) as value: 34 | assert value == state_value 35 | assert await _Container.state.resolve() == state_value 36 | 37 | 38 | def test_state_set_resolves_correctly_sync() -> None: 39 | class _Container(BaseContainer): 40 | state: State[str] = State() 41 | 42 | state_value = "test_value" 43 | 44 | with _Container.state.init(state_value) as value: 45 | assert value == state_value 46 | assert _Container.state.resolve_sync() == state_value 47 | 48 | 49 | async def test_state_correctly_manages_its_context() -> None: 50 | async def _async_creator(x: int) -> int: 51 | await asyncio.sleep(random.random()) 52 | return x 53 | 54 | class _Container(BaseContainer): 55 | state: State[int] = State() 56 | dependent = Factory(lambda x: x, state.cast) 57 | async_dependent = AsyncFactory(_async_creator, state.cast) 58 | 59 | async def _main(x: int) -> None: 60 | with _Container.state.init(x): 61 | dependent_value = await _Container.dependent.resolve() 62 | assert dependent_value == x 63 | async_dependent_value = await _Container.async_dependent.resolve() 64 | assert async_dependent_value == x 65 | new_x = x + random.randint(1, 1000) 66 | with _Container.state.init(new_x): 67 | dependent_value = await _Container.dependent.resolve() 68 | assert dependent_value == new_x 69 | async_dependent_value = await _Container.async_dependent.resolve() 70 | assert async_dependent_value == new_x 71 | 72 | dependent_value = await _Container.dependent.resolve() 73 | assert dependent_value == x 74 | async_dependent_value = await _Container.async_dependent.resolve() 75 | assert async_dependent_value == x 76 | 77 | tasks = [_main(x) for x in range(10)] 78 | 79 | await asyncio.gather(*tasks) 80 | -------------------------------------------------------------------------------- /that_depends/integrations/fastapi.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from collections.abc import Callable 3 | from inspect import signature 4 | 5 | from fastapi import Depends, Request, Response 6 | from fastapi.routing import APIRoute 7 | 8 | from that_depends.injection import StringProviderDefinition 9 | from that_depends.providers import AbstractProvider 10 | from that_depends.providers.context_resources import ContextScope, SupportsContext, container_context 11 | 12 | 13 | P = typing.ParamSpec("P") 14 | T = typing.TypeVar("T") 15 | 16 | 17 | def _adjust_fastapi_endpoint(endpoint: Callable[P, T]) -> Callable[P, T]: 18 | hints = typing.get_type_hints(endpoint, include_extras=True) 19 | sig = signature(endpoint) 20 | new_params = [] 21 | for name, param in sig.parameters.items(): 22 | if isinstance(param.default, AbstractProvider): 23 | hints[name] = Depends(param.default) 24 | new_params.append(param.replace(default=Depends(param.default))) 25 | elif isinstance(param.default, StringProviderDefinition): 26 | provider = param.default.provider 27 | hints[name] = Depends(provider) 28 | new_params.append(param.replace(default=Depends(provider))) 29 | else: 30 | new_params.append(param) 31 | endpoint.__annotations__ = hints 32 | endpoint.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined] 33 | return endpoint 34 | 35 | 36 | def create_fastapi_route_class( 37 | *context_items: SupportsContext[typing.Any], 38 | global_context: dict[str, typing.Any] | None = None, 39 | scope: ContextScope | None = None, 40 | ) -> type[APIRoute]: 41 | """Create a `that-depends` fastapi route class. 42 | 43 | Args: 44 | *context_items: Items to initialize context for. 45 | global_context: Global context to use. 46 | scope: scope to enter before on request. 47 | 48 | Returns: 49 | type[APIRoute]: A custom fastapi route class. 50 | 51 | """ 52 | 53 | class _Route(APIRoute): 54 | """Custom that-depends router for FastAPI.""" 55 | 56 | def __init__(self, path: str, endpoint: Callable[..., typing.Any], **kwargs: typing.Any) -> None: # noqa: ANN401 57 | endpoint = _adjust_fastapi_endpoint(endpoint) 58 | super().__init__(path=path, endpoint=endpoint, **kwargs) 59 | 60 | def get_route_handler(self) -> Callable[[Request], typing.Coroutine[typing.Any, typing.Any, Response]]: 61 | original_route_handler = super().get_route_handler() 62 | 63 | async def _custom_route_handler(request: Request) -> Response: 64 | async with container_context(*context_items, global_context=global_context, scope=scope): 65 | response: Response = await original_route_handler(request) 66 | return response 67 | 68 | return _custom_route_handler if context_items or global_context or scope else original_route_handler 69 | 70 | return _Route 71 | -------------------------------------------------------------------------------- /that_depends/entities/resource_context.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import contextlib 3 | import threading 4 | import typing 5 | import warnings 6 | 7 | from typing_extensions import override 8 | 9 | from that_depends.providers.mixin import CannotTearDownSyncError, SupportsTeardown 10 | 11 | 12 | T_co = typing.TypeVar("T_co", covariant=True) 13 | 14 | 15 | class ResourceContext(SupportsTeardown, typing.Generic[T_co]): 16 | """Class to manage a resources' context.""" 17 | 18 | __slots__ = "asyncio_lock", "context_stack", "instance", "is_async", "threading_lock" 19 | 20 | def __init__(self, is_async: bool) -> None: 21 | """Create a new ResourceContext instance. 22 | 23 | Args: 24 | is_async (bool): Whether the ResourceContext was created in 25 | an async context. 26 | For example within a ``async with container_context(Container): ...`` statement. 27 | 28 | """ 29 | self.instance: T_co | None = None 30 | self.asyncio_lock: typing.Final = asyncio.Lock() 31 | self.threading_lock: typing.Final = threading.Lock() 32 | self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None 33 | self.is_async = is_async 34 | 35 | @staticmethod 36 | def is_context_stack_async( 37 | context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None, 38 | ) -> typing.TypeGuard[contextlib.AsyncExitStack]: 39 | """Check if the context stack is an async context stack.""" 40 | return isinstance(context_stack, contextlib.AsyncExitStack) 41 | 42 | @staticmethod 43 | def is_context_stack_sync( 44 | context_stack: contextlib.AsyncExitStack | contextlib.ExitStack, 45 | ) -> typing.TypeGuard[contextlib.ExitStack]: 46 | """Check if the context stack is a sync context stack.""" 47 | return isinstance(context_stack, contextlib.ExitStack) 48 | 49 | @override 50 | async def tear_down(self, propagate: bool = True) -> None: 51 | """Tear down the async context stack.""" 52 | if self.context_stack is None: 53 | return 54 | 55 | if self.is_context_stack_async(self.context_stack): 56 | await self.context_stack.aclose() 57 | elif self.is_context_stack_sync(self.context_stack): 58 | self.context_stack.close() 59 | self.context_stack = None 60 | self.instance = None 61 | 62 | @override 63 | def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: 64 | """Tear down the sync context stack.""" 65 | if self.context_stack is None: 66 | return 67 | 68 | if self.is_context_stack_sync(self.context_stack): 69 | self.context_stack.close() 70 | self.context_stack = None 71 | self.instance = None 72 | elif self.is_context_stack_async(self.context_stack): 73 | msg = "Cannot tear down async context in sync mode" 74 | if raise_on_async: 75 | raise CannotTearDownSyncError(msg) 76 | warnings.warn(msg, RuntimeWarning, stacklevel=2) 77 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "that-depends" 3 | description = "Simple Dependency Injection framework" 4 | authors = [ 5 | { name = "Artur Shiriev", email = "me@shiriev.ru" }, 6 | ] 7 | readme = "README.md" 8 | requires-python = ">=3.10,<4" 9 | license = "MIT" 10 | keywords = ["di", "dependency injector", "ioc-container", "mocks", "python"] 11 | classifiers = [ 12 | "Programming Language :: Python :: 3.10", 13 | "Programming Language :: Python :: 3.11", 14 | "Programming Language :: Python :: 3.12", 15 | "Programming Language :: Python :: 3.13", 16 | "Programming Language :: Python :: 3.14", 17 | "Typing :: Typed", 18 | "Topic :: Software Development :: Libraries", 19 | ] 20 | version = "0" 21 | 22 | [project.optional-dependencies] 23 | fastapi = [ 24 | "fastapi", 25 | ] 26 | faststream = [ 27 | "faststream" 28 | ] 29 | 30 | [project.urls] 31 | repository = "https://github.com/modern-python/that-depends" 32 | docs = "https://that-depends.readthedocs.io" 33 | 34 | [dependency-groups] 35 | dev = [ 36 | "httpx", 37 | "pytest", 38 | "pytest-cov", 39 | "pytest-asyncio", 40 | "pytest-repeat", 41 | "ruff", 42 | "mypy", 43 | "typing-extensions", 44 | "pre-commit", 45 | "litestar", 46 | "faststream[nats]", 47 | "mkdocs>=1.6.1", 48 | "pytest-randomly", 49 | "mkdocs-llmstxt>=0.4.0", 50 | "pydantic<2.12.4" # causes tests to fail because of: TypeError: _eval_type() got an unexpected keyword argument 'prefer_fwd_module' 51 | ] 52 | 53 | [build-system] 54 | requires = ["uv_build"] 55 | build-backend = "uv_build" 56 | 57 | [tool.uv.build-backend] 58 | module-name = "that_depends" 59 | module-root = "" 60 | 61 | [tool.mypy] 62 | python_version = "3.10" 63 | strict = true 64 | 65 | [tool.ruff] 66 | fix = true 67 | unsafe-fixes = true 68 | line-length = 120 69 | target-version = "py310" 70 | extend-exclude = [ 71 | "docs", 72 | ] 73 | 74 | [tool.ruff.lint] 75 | select = ["ALL"] 76 | ignore = [ 77 | "B008", # Do not perform function call `Provide` in argument defaults 78 | "D100", # ignore missing module docstrings. 79 | "D105", # ignore missing docstrings in magic methods. 80 | "S101", # allow asserts 81 | "TCH", # ignore flake8-type-checking 82 | "FBT", # allow boolean args 83 | "D203", # "one-blank-line-before-class" conflicting with D211 84 | "D213", # "multi-line-summary-second-line" conflicting with D212 85 | "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes 86 | "COM812", # flake8-commas "Trailing comma missing" 87 | "ISC001", # flake8-implicit-str-concat 88 | "PT028", # Test function has default argument 89 | ] 90 | isort.lines-after-imports = 2 91 | isort.no-lines-before = ["standard-library", "local-folder"] 92 | per-file-ignores = { "tests/*"= ["D1", "SLF001"]} 93 | 94 | [tool.pytest.ini_options] 95 | addopts = "--cov=. --cov-report term-missing" 96 | asyncio_mode = "auto" 97 | asyncio_default_fixture_loop_scope = "function" 98 | filterwarnings = [ 99 | "ignore::UserWarning", 100 | ] 101 | 102 | [tool.coverage.report] 103 | exclude_also = ["if typing.TYPE_CHECKING:"] 104 | -------------------------------------------------------------------------------- /tests/test_container.py: -------------------------------------------------------------------------------- 1 | import random 2 | import typing 3 | 4 | from tests.container import DIContainer 5 | from that_depends import BaseContainer, providers 6 | 7 | 8 | def _sync_resource() -> typing.Iterator[float]: 9 | yield random.random() 10 | 11 | 12 | async def test_container_sync_teardown() -> None: 13 | await DIContainer.init_resources() 14 | DIContainer.tear_down_sync() 15 | for provider in DIContainer.providers.values(): 16 | if isinstance(provider, providers.Resource): 17 | if provider._is_async: 18 | assert provider._context.instance is not None 19 | else: 20 | assert provider._context.instance is None 21 | if isinstance(provider, providers.Singleton): 22 | assert provider._instance is None 23 | 24 | 25 | async def test_container_tear_down() -> None: 26 | await DIContainer.init_resources() 27 | await DIContainer.tear_down() 28 | for provider in DIContainer.providers.values(): 29 | if isinstance(provider, providers.Resource): 30 | assert provider._context.instance is None 31 | if isinstance(provider, providers.Singleton): 32 | assert provider._instance is None 33 | 34 | 35 | async def test_container_sync_tear_down_propagation() -> None: 36 | class _DependentContainer(BaseContainer): 37 | singleton = providers.Singleton(lambda: random.random()) 38 | resource = providers.Resource(_sync_resource) 39 | 40 | DIContainer.connect_containers(_DependentContainer) 41 | 42 | await DIContainer.init_resources() 43 | 44 | assert isinstance(_DependentContainer.singleton._instance, float) 45 | assert isinstance(_DependentContainer.resource._context.instance, float) 46 | 47 | DIContainer.tear_down_sync() 48 | 49 | assert _DependentContainer.singleton._instance is None 50 | assert _DependentContainer.resource._context.instance is None 51 | 52 | 53 | async def test_container_tear_down_propagation() -> None: 54 | async def _async_singleton() -> float: 55 | return random.random() 56 | 57 | async def _async_resource() -> typing.AsyncIterator[float]: 58 | yield random.random() 59 | 60 | class _DependentContainer(BaseContainer): 61 | async_singleton = providers.AsyncSingleton(_async_singleton) 62 | async_resource = providers.Resource(_async_resource) 63 | sync_singleton = providers.Singleton(lambda: random.random()) 64 | sync_resource = providers.Resource(_sync_resource) 65 | 66 | DIContainer.connect_containers(_DependentContainer) 67 | 68 | await DIContainer.init_resources() 69 | 70 | assert isinstance(_DependentContainer.async_singleton._instance, float) 71 | assert isinstance(_DependentContainer.async_resource._context.instance, float) 72 | assert isinstance(_DependentContainer.sync_singleton._instance, float) 73 | assert isinstance(_DependentContainer.sync_resource._context.instance, float) 74 | 75 | await DIContainer.tear_down() 76 | 77 | assert _DependentContainer.async_singleton._instance is None 78 | assert _DependentContainer.async_resource._context.instance is None 79 | assert _DependentContainer.sync_singleton._instance is None 80 | assert _DependentContainer.sync_resource._context.instance is None 81 | -------------------------------------------------------------------------------- /docs/introduction/tear-down.md: -------------------------------------------------------------------------------- 1 | # Tear-down 2 | 3 | Certain providers in `that-depends` require explicit finalization. These providers 4 | implement the `SupportsTeardown` API. Containers also support finalizing their 5 | resources. 6 | 7 | ## Quick-start 8 | 9 | If you are using a provider that supports finalization such as a [Singleton](../providers/singleton.md): 10 | 11 | First, define the container & provider. 12 | 13 | ```python 14 | from that_depends import BaseContainer, providers 15 | import random 16 | 17 | class MyContainer(BaseContainer): 18 | config = providers.Singleton(lambda: random.random()) 19 | ``` 20 | 21 | Resolve the provider somewhere in your code 22 | 23 | ```python 24 | await MyContainer.config() 25 | ``` 26 | 27 | When you want to reset the cached value: 28 | 29 | ```python 30 | await MyContainer.config.tear_down() 31 | ``` 32 | 33 | For [Resources](../providers/resources.md) this will also call any finalization logic in your 34 | context manager. 35 | 36 | --- 37 | 38 | ## Propagation 39 | 40 | Per default `that-depends` will propagate tear-down to dependent providers. 41 | 42 | This means that if you have defined a provider `A` that is dependent on provider `B`, 43 | when calling `await B.tear_down()`, this will also execute `await A.tear_down()`. 44 | 45 | **For example:** 46 | 47 | ```python 48 | class MyContainer(BaseContainer): 49 | B = providers.Singleton(lambda: random.random()) 50 | A = providers.Singleton(lambda x: x, B) 51 | 52 | b = await MyContainer.B() 53 | a = await MyContainer.A() 54 | 55 | assert a == b 56 | 57 | await MyContainer.B.tear_down() 58 | a_new = await MyContainer.A() 59 | assert a_new != a 60 | ``` 61 | 62 | If you do not wish to propagate tear-down simply call `tear_down(propagate=False)` or `tear_down_sync(propagate=False)`. 63 | 64 | --- 65 | 66 | ## Sync tear-down 67 | 68 | If you need to call tear-down from a sync context you can use the `tear_down_sync()` method. However, 69 | keep in mind that because dependent resources might be async, this will fail to correctly finalize these async 70 | resources. 71 | 72 | Per default this will raise a `CannotTearDownSyncError`: 73 | 74 | ```python 75 | async def async_creator(val: float) -> typing.AsyncIterator[float]: 76 | yield val 77 | print("Finalization!") 78 | 79 | 80 | class MyContainer(BaseContainer): 81 | B = providers.Singleton(lambda: random.random()) 82 | A = providers.Resource(async_creator, B.cast) 83 | 84 | 85 | b = await MyContainer.B() 86 | a = await MyContainer.A() 87 | 88 | MyContainer.B.tear_down_sync() # raises 89 | ``` 90 | 91 | If you do not want to see these errors you can reduce this to a `RuntimeWarning`: 92 | 93 | ```python 94 | MyContainer.B.tear_down_sync(raise_on_async=False) 95 | ``` 96 | 97 | --- 98 | 99 | ## Containers and tear-down 100 | 101 | Containers also support the `tear_down` and `tear_down_sync` methods. When calling 102 | `await Container.tear_down()` all providers in the container will be torn down. 103 | 104 | 105 | `Container.tear_down_sync()` is also implemented but not recommended unless you are sure 106 | all providers in your container are sync. 107 | 108 | > Container methods do not support the `propagate` & `raise_on_async` arguments, so if you 109 | > need more granular control try to tear down the providers explicitly. 110 | -------------------------------------------------------------------------------- /that_depends/providers/collection.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from typing_extensions import override 4 | 5 | from that_depends.providers.base import AbstractProvider 6 | 7 | 8 | T_co = typing.TypeVar("T_co", covariant=True) 9 | 10 | 11 | class List(AbstractProvider[list[T_co]]): 12 | """Provides multiple resources as a list. 13 | 14 | The `List` provider resolves multiple dependencies into a list. 15 | 16 | Example: 17 | ```python 18 | from that_depends import providers 19 | 20 | provider1 = providers.Factory(lambda: 1) 21 | provider2 = providers.Factory(lambda: 2) 22 | 23 | list_provider = List(provider1, provider2) 24 | 25 | # Synchronous resolution 26 | resolved_list = list_provider.resolve_sync() 27 | print(resolved_list) # Output: [1, 2] 28 | 29 | # Asynchronous resolution 30 | import asyncio 31 | resolved_list_async = asyncio.run(list_provider.resolve()) 32 | print(resolved_list_async) # Output: [1, 2] 33 | ``` 34 | 35 | """ 36 | 37 | __slots__ = ("_providers",) 38 | 39 | def __init__(self, *providers: AbstractProvider[T_co]) -> None: 40 | """Create a new List provider instance. 41 | 42 | Args: 43 | *providers: List of providers to resolve. 44 | 45 | """ 46 | super().__init__() 47 | self._providers: typing.Final = providers 48 | self._register(self._providers) 49 | 50 | @override 51 | def __getattr__(self, attr_name: str) -> typing.Any: 52 | msg = f"'{type(self)}' object has no attribute '{attr_name}'" 53 | raise AttributeError(msg) 54 | 55 | @override 56 | async def resolve(self) -> list[T_co]: 57 | return [await x.resolve() for x in self._providers] 58 | 59 | @override 60 | def resolve_sync(self) -> list[T_co]: 61 | return [x.resolve_sync() for x in self._providers] 62 | 63 | @override 64 | async def __call__(self) -> list[T_co]: 65 | return await self.resolve() 66 | 67 | 68 | class Dict(AbstractProvider[dict[str, T_co]]): 69 | """Provides multiple resources as a dictionary. 70 | 71 | The `Dict` provider resolves multiple named dependencies into a dictionary. 72 | 73 | Example: 74 | ```python 75 | from that_depends import providers 76 | 77 | provider1 = providers.Factory(lambda: 1) 78 | provider2 = providers.Factory(lambda: 2) 79 | 80 | dict_provider = Dict(key1=provider1, key2=provider2) 81 | 82 | # Synchronous resolution 83 | resolved_dict = dict_provider.resolve_sync() 84 | print(resolved_dict) # Output: {"key1": 1, "key2": 2} 85 | 86 | # Asynchronous resolution 87 | import asyncio 88 | resolved_dict_async = asyncio.run(dict_provider.resolve()) 89 | print(resolved_dict_async) # Output: {"key1": 1, "key2": 2} 90 | ``` 91 | 92 | """ 93 | 94 | __slots__ = ("_providers",) 95 | 96 | def __init__(self, **providers: AbstractProvider[T_co]) -> None: 97 | """Create a new Dict provider instance. 98 | 99 | Args: 100 | **providers: Dictionary of providers to resolve. 101 | 102 | """ 103 | super().__init__() 104 | self._providers: typing.Final = providers 105 | self._register(self._providers.values()) 106 | 107 | @override 108 | def __getattr__(self, attr_name: str) -> typing.Any: 109 | msg = f"'{type(self)}' object has no attribute '{attr_name}'" 110 | raise AttributeError(msg) 111 | 112 | @override 113 | async def resolve(self) -> dict[str, T_co]: 114 | return {key: await provider.resolve() for key, provider in self._providers.items()} 115 | 116 | @override 117 | def resolve_sync(self) -> dict[str, T_co]: 118 | return {key: provider.resolve_sync() for key, provider in self._providers.items()} 119 | -------------------------------------------------------------------------------- /that_depends/providers/selector.py: -------------------------------------------------------------------------------- 1 | """Selection based providers.""" 2 | 3 | import typing 4 | 5 | from typing_extensions import override 6 | 7 | from that_depends.providers.base import AbstractProvider 8 | 9 | 10 | T_co = typing.TypeVar("T_co", covariant=True) 11 | 12 | 13 | class Selector(AbstractProvider[T_co]): 14 | """Chooses a provider based on a key returned by a selector function. 15 | 16 | This class allows you to dynamically select and resolve one of several 17 | named providers at runtime. The provider key is determined by a 18 | user-supplied selector function. 19 | 20 | Examples: 21 | ```python 22 | def environment_selector(): 23 | return "local" 24 | 25 | selector_instance = Selector( 26 | environment_selector, 27 | local=LocalStorageProvider(), 28 | remote=RemoteStorageProvider(), 29 | ) 30 | 31 | # Synchronously resolve the selected provider 32 | service = selector_instance.resolve_sync() 33 | ``` 34 | 35 | """ 36 | 37 | __slots__ = "_override", "_providers", "_selector" 38 | 39 | def __init__( 40 | self, selector: typing.Callable[[], str] | AbstractProvider[str] | str, **providers: AbstractProvider[T_co] 41 | ) -> None: 42 | """Initialize a new Selector instance. 43 | 44 | Args: 45 | selector (Callable[[], str]): A function that returns the key 46 | of the provider to use. 47 | **providers (AbstractProvider[T_co]): The named providers from 48 | which one will be selected based on the `selector`. 49 | 50 | Examples: 51 | ```python 52 | def my_selector(): 53 | return "remote" 54 | 55 | my_selector_instance = Selector( 56 | my_selector, 57 | local=LocalStorageProvider(), 58 | remote=RemoteStorageProvider(), 59 | ) 60 | 61 | # The "remote" provider will be selected 62 | selected_service = my_selector_instance.resolve_sync() 63 | ``` 64 | 65 | """ 66 | super().__init__() 67 | self._selector: typing.Final[typing.Callable[[], str] | AbstractProvider[str] | str] = selector 68 | self._providers: typing.Final = providers 69 | 70 | @override 71 | async def resolve(self) -> T_co: 72 | if self._override: 73 | return typing.cast(T_co, self._override) 74 | 75 | if isinstance(self._selector, AbstractProvider): 76 | selected_key = await self._selector.resolve() 77 | else: 78 | selected_key = self._get_selected_key() 79 | 80 | self._validate_key(selected_key) 81 | 82 | return await self._providers[selected_key].resolve() 83 | 84 | @override 85 | def resolve_sync(self) -> T_co: 86 | if self._override: 87 | return typing.cast(T_co, self._override) 88 | 89 | if isinstance(self._selector, AbstractProvider): 90 | selected_key = self._selector.resolve_sync() 91 | else: 92 | selected_key = self._get_selected_key() 93 | 94 | self._validate_key(selected_key) 95 | 96 | return self._providers[selected_key].resolve_sync() 97 | 98 | def _get_selected_key(self) -> str: 99 | if callable(self._selector): 100 | selected_key = self._selector() 101 | elif isinstance(self._selector, str): 102 | selected_key = self._selector 103 | else: 104 | msg = f"Invalid selector type: {type(self._selector)}, expected str, or a provider/callable returning str" 105 | raise TypeError(msg) 106 | return typing.cast(str, selected_key) 107 | 108 | def _validate_key(self, key: str) -> None: 109 | if key not in self._providers: 110 | msg = f"Key '{key}' not found in providers" 111 | raise KeyError(msg) 112 | -------------------------------------------------------------------------------- /docs/providers/resources.md: -------------------------------------------------------------------------------- 1 | # Resource Provider 2 | 3 | A **Resource** is a special provider that: 4 | 5 | - **Resolves** its dependency only **once** and **caches** the resolved instance for future injections. 6 | - **Includes** teardown (finalization) logic, unlike a plain `Singleton`. 7 | - **Supports** generator or async generator functions for creation (allowing a `yield` plus teardown in `finally`). 8 | - **Also** allows usage of classes that implement standard Python context managers (`typing.ContextManager` or `typing.AsyncContextManager`), but *does not* automatically integrate with `container_context`. 9 | 10 | This makes `Resource` ideal for dependencies that need: 11 | 12 | 1. A **single creation** step, 13 | 2. A **single finalization** step, 14 | 3. **Thread/async safety**—all consumers receive the same resource object, and concurrency is handled. 15 | 16 | --- 17 | 18 | ## How It Works 19 | 20 | ### Defining a Sync or Async Resource 21 | 22 | You can define your creation logic as either a **generator** or a **context manager** class (sync or async). 23 | 24 | **Synchronous generator** example: 25 | 26 | ```python 27 | import typing 28 | 29 | def create_sync_resource() -> typing.Iterator[str]: 30 | print("Creating sync resource") 31 | try: 32 | yield "sync resource" 33 | finally: 34 | print("Tearing down sync resource") 35 | ``` 36 | 37 | **Asynchronous generator** example: 38 | 39 | ```python 40 | import typing 41 | 42 | async def create_async_resource() -> typing.AsyncIterator[str]: 43 | print("Creating async resource") 44 | try: 45 | yield "async resource" 46 | finally: 47 | print("Tearing down async resource") 48 | ``` 49 | 50 | You then attach them to a container: 51 | 52 | ```python 53 | from that_depends import BaseContainer 54 | from that_depends.providers import Resource 55 | 56 | class MyContainer(BaseContainer): 57 | sync_resource = Resource(create_sync_resource) 58 | async_resource = Resource(create_async_resource) 59 | ``` 60 | 61 | --- 62 | 63 | ## Resolving and Teardown 64 | 65 | Once defined, you can explicitly **resolve** the resource and **tear it down**: 66 | 67 | ```python 68 | # Synchronous resource usage 69 | value_sync = MyContainer.sync_resource.resolve_sync() 70 | print(value_sync) # "sync resource" 71 | MyContainer.sync_resource.tear_down_sync() 72 | 73 | # Asynchronous resource usage 74 | import asyncio 75 | 76 | 77 | async def main(): 78 | value_async = await MyContainer.async_resource.resolve() 79 | print(value_async) # "async resource" 80 | await MyContainer.async_resource.tear_down() 81 | 82 | 83 | asyncio.run(main()) 84 | ``` 85 | 86 | - **`resolve_sync()`** or **`resolve()`**: Creates (if needed) and returns the resource instance. 87 | - **`tear_down_sync()`** or **`tear_down()`**: Closes/cleans up the resource (triggering your `finally` block or exiting the context manager) and resets the cached instance to `None`. A subsequent resolve call will then recreate it. 88 | 89 | --- 90 | 91 | ## Concurrency Safety 92 | 93 | `Resource` is **safe** to use under **threading** and **asyncio** concurrency. Internally, a lock ensures only one resource instance is created per container: 94 | 95 | - Multiple threads calling `resolve_sync()` simultaneously will produce a **single** instance for that container. 96 | - Multiple coroutines calling `resolve()` simultaneously will likewise produce **only one** instance for that container in an async environment. 97 | 98 | ```python 99 | # Even if multiple coroutines call resolve in parallel, 100 | # only one instance is created at a time: 101 | await MyContainer.async_resource.resolve() 102 | 103 | # Similarly, multiple threads calling resolve_sync concurrently 104 | # still yield just one instance until teardown: 105 | MyContainer.sync_resource.resolve_sync() 106 | ``` 107 | 108 | --- 109 | 110 | ## Using Context Managers Directly 111 | 112 | If your resource is a standard **context manager** or **async context manager** class, `Resource` will handle entering and exiting it under the hood. For example: 113 | 114 | ```python 115 | import typing 116 | from that_depends.providers import Resource 117 | 118 | 119 | class SyncFileManager: 120 | def __enter__(self) -> str: 121 | print("Opening file") 122 | return "/path/to/file" 123 | 124 | def __exit__(self, exc_type, exc_val, exc_tb): 125 | print("Closing file") 126 | 127 | 128 | sync_file_resource = Resource(SyncFileManager) 129 | 130 | # usage 131 | file_path = sync_file_resource.resolve_sync() 132 | print(file_path) 133 | sync_file_resource.tear_down_sync() 134 | ``` 135 | -------------------------------------------------------------------------------- /tests/integrations/litestar/test_litestar_di.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from functools import partial 3 | from typing import Annotated 4 | from unittest.mock import Mock 5 | 6 | from litestar import Controller, Litestar, Router, get 7 | from litestar.di import Provide 8 | from litestar.params import Dependency 9 | from litestar.status_codes import HTTP_200_OK 10 | from litestar.testing import TestClient 11 | 12 | from that_depends import BaseContainer, providers 13 | 14 | 15 | def bool_fn(value: bool) -> bool: 16 | return value 17 | 18 | 19 | def str_fn() -> str: 20 | return "" 21 | 22 | 23 | def list_fn() -> list[str]: 24 | return ["some"] 25 | 26 | 27 | def int_fn() -> int: 28 | return 1 29 | 30 | 31 | class SomeService: 32 | pass 33 | 34 | 35 | class DIContainer(BaseContainer): 36 | alias = "litestar_container" 37 | bool_fn = providers.Factory(bool_fn, value=False) 38 | str_fn = providers.Factory(str_fn) 39 | list_fn = providers.Factory(list_fn) 40 | int_fn = providers.Factory(int_fn) 41 | some_service = providers.Factory(SomeService) 42 | 43 | 44 | _NoValidationDependency = partial(Dependency, skip_validation=True) 45 | 46 | 47 | class MyController(Controller): 48 | path = "/controller" 49 | dependencies = {"controller_dependency": Provide(DIContainer.list_fn)} # noqa: RUF012 50 | 51 | @get( 52 | path="/handler", 53 | dependencies={ 54 | "local_dependency": Provide(DIContainer.int_fn), 55 | }, 56 | ) 57 | async def my_route_handler( 58 | self, 59 | app_dependency: bool, 60 | router_dependency: str, 61 | controller_dependency: list[str], 62 | local_dependency: int, 63 | ) -> dict[str, typing.Any]: 64 | return { 65 | "app_dependency": app_dependency, 66 | "router_dependency": router_dependency, 67 | "controller_dependency": controller_dependency, 68 | "local_dependency": local_dependency, 69 | } 70 | 71 | @get(path="/mock_overriding", dependencies={"some_service": Provide(DIContainer.some_service)}) 72 | async def mock_overriding_endpoint_handler( 73 | self, some_service: Annotated[SomeService, _NoValidationDependency()] 74 | ) -> None: 75 | assert isinstance(some_service, Mock) 76 | 77 | 78 | my_router = Router( 79 | path="/router", 80 | dependencies={"router_dependency": Provide(DIContainer.str_fn)}, 81 | route_handlers=[MyController], 82 | ) 83 | 84 | # on the app 85 | app = Litestar(route_handlers=[my_router], dependencies={"app_dependency": Provide(DIContainer.bool_fn)}, debug=True) 86 | 87 | 88 | def test_litestar_endpoint_with_mock_overriding() -> None: 89 | some_service_mock = Mock() 90 | 91 | with DIContainer.some_service.override_context_sync(some_service_mock), TestClient(app=app) as client: 92 | response = client.get("/router/controller/mock_overriding") 93 | 94 | assert response.status_code == HTTP_200_OK 95 | 96 | 97 | def test_litestar_di() -> None: 98 | with TestClient(app=app) as client: 99 | response = client.get("/router/controller/handler") 100 | assert response.status_code == HTTP_200_OK, response.text 101 | assert response.json() == { 102 | "app_dependency": False, 103 | "controller_dependency": ["some"], 104 | "local_dependency": 1, 105 | "router_dependency": "", 106 | } 107 | 108 | 109 | def test_litestar_di_override_fail_on_provider_override() -> None: 110 | mock = 12345364758999 111 | with TestClient(app=app) as client, DIContainer.int_fn.override_context_sync(mock): 112 | response = client.get("/router/controller/handler") 113 | 114 | assert response.status_code == HTTP_200_OK, response.text 115 | assert response.json() == { 116 | "app_dependency": False, 117 | "controller_dependency": ["some"], 118 | "local_dependency": mock, 119 | "router_dependency": "", 120 | } 121 | 122 | 123 | def test_litestar_di_override_fail_on_override_providers() -> None: 124 | mock = 12345364758999 125 | overrides = { 126 | "int_fn": mock, 127 | } 128 | with TestClient(app=app) as client, DIContainer.override_providers_sync(overrides): 129 | response = client.get("/router/controller/handler") 130 | 131 | assert response.status_code == HTTP_200_OK, response.text 132 | assert response.json() == { 133 | "app_dependency": False, 134 | "controller_dependency": ["some"], 135 | "local_dependency": mock, 136 | "router_dependency": "", 137 | } 138 | -------------------------------------------------------------------------------- /that_depends/providers/resources.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from typing_extensions import override 4 | 5 | from that_depends.entities.resource_context import ResourceContext 6 | from that_depends.providers.base import AbstractResource, ResourceCreatorType 7 | from that_depends.providers.mixin import SupportsTeardown 8 | 9 | 10 | T_co = typing.TypeVar("T_co", covariant=True) 11 | P = typing.ParamSpec("P") 12 | 13 | 14 | class Resource(SupportsTeardown, AbstractResource[T_co]): 15 | """Provides a resource that is resolved once and cached for future usage. 16 | 17 | Unlike a singleton, this provider includes finalization logic and can be 18 | used with a generator or async generator to manage resource lifecycle. 19 | It also supports usage with `typing.ContextManager` or `typing.AsyncContextManager`. 20 | Threading and asyncio concurrency are supported, ensuring only one instance 21 | is created regardless of concurrent resolves. 22 | 23 | Example: 24 | ```python 25 | async def create_async_resource(): 26 | try: 27 | yield "async resource" 28 | finally: 29 | # Finalize resource 30 | pass 31 | 32 | class MyContainer: 33 | async_resource = Resource(create_async_resource) 34 | 35 | async def main(): 36 | async_resource_instance = await MyContainer.async_resource.resolve() 37 | await MyContainer.async_resource.tear_down() 38 | ``` 39 | 40 | """ 41 | 42 | __slots__ = ( 43 | "_args", 44 | "_context", 45 | "_creator", 46 | "_creator", 47 | "_is_async", 48 | "_kwargs", 49 | "_override", 50 | ) 51 | 52 | def __init__( 53 | self, 54 | creator: ResourceCreatorType[P, T_co], 55 | *args: P.args, 56 | **kwargs: P.kwargs, 57 | ) -> None: 58 | """Initialize the Resource provider with a callable for resource creation. 59 | 60 | The callable can be a generator or async generator that yields the resource 61 | (with optional teardown logic), or a context manager. Only one instance will be 62 | created and cached until explicitly torn down. 63 | 64 | Args: 65 | creator: The callable, generator, or context manager that creates the resource. 66 | *args: Positional arguments passed to the creator. 67 | **kwargs: Keyword arguments passed to the creator. 68 | 69 | Example: 70 | ```python 71 | def custom_creator(name: str): 72 | try: 73 | yield f"Resource created for {name}" 74 | finally: 75 | pass # Teardown 76 | 77 | resource_provider = Resource(custom_creator, "example") 78 | instance = resource_provider.sync_resolve() 79 | resource_provider.tear_down() 80 | ``` 81 | 82 | """ 83 | super().__init__(creator, *args, **kwargs) 84 | self._context: typing.Final[ResourceContext[T_co]] = ResourceContext(is_async=self.is_async) 85 | 86 | def _fetch_context(self) -> ResourceContext[T_co]: 87 | return self._context 88 | 89 | @override 90 | async def tear_down(self, propagate: bool = True) -> None: 91 | """Tear down the resource if it has been created. 92 | 93 | If the resource was never resolved, or was already torn down, 94 | calling this method has no effect. 95 | 96 | Example: 97 | ```python 98 | # Assuming my_provider was previously resolved 99 | await my_provider.tear_down() 100 | ``` 101 | 102 | """ 103 | await self._fetch_context().tear_down() 104 | self._deregister_arguments() 105 | if propagate: 106 | await self._tear_down_children() 107 | 108 | @override 109 | def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: 110 | """Sync tear down the resource if it has been created. 111 | 112 | If the resource was never resolved, or was already torn down, 113 | calling this method has no effect. 114 | 115 | If you try to sync tear down an async resource, this will raise an exception. 116 | 117 | Example: 118 | ```python 119 | # Assuming my_provider was previously resolved 120 | my_provider.sync_tear_down() 121 | ``` 122 | 123 | """ 124 | self._fetch_context().tear_down_sync(propagate=propagate, raise_on_async=raise_on_async) 125 | self._deregister_arguments() 126 | if propagate: 127 | self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async) 128 | -------------------------------------------------------------------------------- /docs/migration/v3.md: -------------------------------------------------------------------------------- 1 | # Migrating from 2.\* to 3.\* 2 | 3 | ## How to Read This Guide 4 | 5 | This guide is intended to help you migrate existing functionality from `that-depends` version `2.*` to `3.*`. 6 | The goal is to enable you to migrate as quickly as possible while making only the minimal necessary changes to your codebase. 7 | 8 | If you want to learn more about the new features introduced in `3.*`, please refer to the [documentation](https://that-depends.readthedocs.io/) and the [release notes](https://github.com/modern-python/that-depends/releases). 9 | 10 | --- 11 | 12 | ## Deprecated or Removed Features 13 | 14 | ### **`container_context()`** can no longer be initialized without arguments. 15 | 16 | Previously the following code would reset the context for all all providers in all containers: 17 | ```python 18 | async with container_context(): 19 | # all ContextResources have been reset 20 | ... 21 | ``` 22 | This was done to enable easy migration and compatibility with `1.*`. 23 | 24 | Now `container_context` must be called with at least 1 argument or keyword-argument. 25 | Thus if you want to reset the context for all containers you need to provide them explicity: 26 | ```python 27 | async with container_context(MyContainer_1, MyContainer_2, ...): 28 | ... 29 | ``` 30 | 31 | --- 32 | 33 | ### **`container_context()`** no longer accepts `reset_all_containers` keyword argument. 34 | 35 | You can no longer reset the context for all containers by using the `container_context` context manager. 36 | Previously you could have done something like this: 37 | ```python 38 | async with container_context(reset_all_containers=True): 39 | # all ContextResources have been reset 40 | ... 41 | ``` 42 | Now you will need to explicitly pass the containers to the container context: 43 | ```python 44 | async with container_context(MyContainer_1, MyContainer_2, ...): 45 | ... 46 | ``` 47 | 48 | --- 49 | 50 | ### **`@inject(scope=...)`** no longer enters the scope. 51 | 52 | The `@inject` decorator no longer enters the scope specified in the `scope` argument. 53 | 54 | In `2.*` the provided scope would be entered before the function was called: 55 | ```python 56 | @inject(scope=ContextScopes.REQUEST) 57 | def injected(...): ... 58 | assert get_current_scope() == ContextScopes.REQUEST 59 | ``` 60 | 61 | In `3.*` the scope is only used to resolve relevant dependencies that match the provided scope. 62 | Thus, to achieve the same behaviour as in `2.*` you need to set `enter_scope=True`: 63 | ```python 64 | @inject(scope=ContextScopes.REQUEST, enter_scope=True) 65 | def injected(...): ... 66 | assert get_current_scope() == ContextScopes.REQUEST 67 | ``` 68 | For further details, please refer to the [scopes documentation](../introduction/scopes.md#named-scopes-with-the-inject-wrapper) 69 | 70 | --- 71 | 72 | ## Changes in the API 73 | 74 | ### Changes to naming of methods. 75 | 76 | You can expect the default implementation of provider and container methods to be async. 77 | This means that methods **not** explicitly ending with `_sync` are normally async. 78 | 79 | This has also introduced new interfaces for operations where part of the implementation was missing (only async or sync implementation was provided in `2.*`). 80 | 81 | 82 | For example `tear_down()` is now async per default and `tear_down_sync()` has been introduced. 83 | 84 | Other examples of similar changes include: 85 | 86 | - `.async_resolve()` -> `.resolve()` 87 | - `.sync_resolve()` -> `.resolve_sync()` 88 | 89 | --- 90 | 91 | ### Tear down propagation enabled per default. 92 | 93 | Tear down is now propagated to all dependencies by default. 94 | 95 | Previously if you called tear down for a provider this only reset the given provider. 96 | ```python 97 | await provider.tear_down() 98 | ``` 99 | In order to maintain this behaviour in `3.*`: 100 | ```python 101 | await provider.tear_down(propagate=False) 102 | ``` 103 | For more details regarding tear-down propagation see the [documentation](../introduction/tear-down.md). 104 | 105 | --- 106 | 107 | ### Overriding is now async per default. 108 | 109 | As mentioned [above](#changes-to-naming-of-methods), `.override()` methods are now async per default. 110 | 111 | Thus any instances of the following in `2.*`: 112 | ```python 113 | provider.override(value) 114 | ``` 115 | Should be changed to the following in `3.*`: 116 | ```python 117 | provider.override_sync(value) 118 | ``` 119 | 120 | This is also the case for the following methods: 121 | 122 | - `.override_context()` -> `.override_context_sync()` 123 | - `.reset_override()` -> `.reset_override_sync()` 124 | 125 | > **Note:** Overrides now support tear-down, read more in the [documentation](../testing/provider-overriding.md) 126 | 127 | --- 128 | 129 | ## Further Help 130 | 131 | If you continue to experience issues during migration, consider creating a [discussion](https://github.com/modern-python/that-depends/discussions) or opening an [issue](https://github.com/modern-python/that-depends/issues). 132 | -------------------------------------------------------------------------------- /that_depends/providers/local_singleton.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | import typing 4 | 5 | from typing_extensions import override 6 | 7 | from that_depends.providers import AbstractProvider 8 | from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown 9 | 10 | 11 | T_co = typing.TypeVar("T_co", covariant=True) 12 | P = typing.ParamSpec("P") 13 | 14 | 15 | class ThreadLocalSingleton(ProviderWithArguments, SupportsTeardown, AbstractProvider[T_co]): 16 | """Creates a new instance for each thread using a thread-local store. 17 | 18 | This provider ensures that each thread gets its own instance, which is 19 | created via the specified factory function. Once created, the instance is 20 | cached for future injections within the same thread. 21 | 22 | Example: 23 | ```python 24 | def factory(): 25 | return random.randint(1, 100) 26 | 27 | singleton = ThreadLocalSingleton(factory) 28 | 29 | # Same thread, same instance 30 | instance1 = singleton.resolve_sync() 31 | instance2 = singleton.resolve_sync() 32 | 33 | def thread_task(): 34 | return singleton.resolve_sync() 35 | 36 | threads = [threading.Thread(target=thread_task) for i in range(10)] 37 | for thread in threads: 38 | thread.start() # Each thread will get a different instance 39 | ``` 40 | 41 | """ 42 | 43 | def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: 44 | """Initialize the ThreadLocalSingleton provider. 45 | 46 | Args: 47 | factory: A callable that returns a new instance of the dependency. 48 | *args: Positional arguments to pass to the factory. 49 | **kwargs: Keyword arguments to pass to the factory 50 | 51 | """ 52 | super().__init__() 53 | self._factory: typing.Final = factory 54 | self._thread_local = threading.local() 55 | self._asyncio_lock = asyncio.Lock() 56 | self._args: typing.Final = args 57 | self._kwargs: typing.Final = kwargs 58 | 59 | def _register_arguments(self) -> None: 60 | self._register(self._args) 61 | self._register(self._kwargs.values()) 62 | 63 | def _deregister_arguments(self) -> None: 64 | self._deregister(self._args) 65 | self._deregister(self._kwargs.values()) 66 | 67 | @property 68 | def _instance(self) -> T_co | None: 69 | return getattr(self._thread_local, "instance", None) 70 | 71 | @_instance.setter 72 | def _instance(self, value: T_co | None) -> None: 73 | self._thread_local.instance = value 74 | 75 | @override 76 | async def resolve(self) -> T_co: 77 | if self._override is not None: 78 | return typing.cast(T_co, self._override) 79 | 80 | async with self._asyncio_lock: 81 | if self._instance is not None: 82 | return self._instance 83 | 84 | self._register_arguments() 85 | 86 | self._instance = self._factory( 87 | *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] 88 | **{ # type: ignore[arg-type] 89 | k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() 90 | }, 91 | ) 92 | return self._instance 93 | 94 | @override 95 | def resolve_sync(self) -> T_co: 96 | if self._override is not None: 97 | return typing.cast(T_co, self._override) 98 | 99 | if self._instance is not None: 100 | return self._instance 101 | 102 | self._register_arguments() 103 | 104 | self._instance = self._factory( 105 | *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type] 106 | **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type] 107 | ) 108 | return self._instance 109 | 110 | @override 111 | def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None: 112 | if self._instance is not None: 113 | self._instance = None 114 | self._deregister_arguments() 115 | if propagate: 116 | self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async) 117 | 118 | @override 119 | async def tear_down(self, propagate: bool = True) -> None: 120 | """Reset the thread-local instance. 121 | 122 | After calling this method, subsequent calls to `resolve_sync()` on the 123 | same thread will produce a new instance. 124 | """ 125 | if self._instance is not None: 126 | self._instance = None 127 | self._deregister_arguments() 128 | if propagate: 129 | await self._tear_down_children() 130 | -------------------------------------------------------------------------------- /docs/introduction/generator-injection.md: -------------------------------------------------------------------------------- 1 | # Injection into Generator Functions 2 | 3 | 4 | `that-depends` supports dependency injections into generator functions. However, this comes 5 | with some minor limitations compared to regular functions. 6 | 7 | 8 | ## Quickstart 9 | 10 | You can use the `@inject` decorator to inject dependencies into generator functions: 11 | 12 | === "async generator" 13 | 14 | ```python 15 | @inject 16 | async def my_generator(value: str = Provide[Container.factory]) -> typing.AsyncGenerator[str, None]: 17 | yield value 18 | ``` 19 | 20 | === "sync generator" 21 | 22 | ```python 23 | @inject 24 | def my_generator(value: str = Provide[Container.factory]) -> typing.Generator[str, None, None]: 25 | yield value 26 | ``` 27 | 28 | === "async context manager" 29 | 30 | ```python 31 | @contextlib.asynccontextmanager 32 | @inject 33 | async def my_generator(value: str = Provide[Container.factory]) -> typing.AsyncIterator[str]: 34 | yield value 35 | ``` 36 | 37 | === "sync context manager" 38 | 39 | ```python 40 | @contextlib.contextmanager 41 | @inject 42 | def my_generator(value: str = Provide[Container.factory]) -> typing.Iterator[str]: 43 | yield value 44 | ``` 45 | 46 | ## Supported Generators 47 | 48 | ### Synchronous Generators 49 | 50 | `that-depends` supports injection into sync generator functions with the following signature: 51 | 52 | ```python 53 | Callable[P, Generator[, , ]] 54 | ``` 55 | 56 | This means that wrapping a sync generator with `@inject` will always preserve all the behaviour of the wrapped generator: 57 | 58 | - It will yield as expected 59 | - It will accept sending values via `send()` 60 | - It will raise `StopIteration` when the generator is exhausted or otherwise returns. 61 | 62 | 63 | ### Asynchronous Generators 64 | 65 | `that-depends` supports injection into async generator functions with the following signature: 66 | 67 | ```python 68 | Callable[P, AsyncGenerator[, None]] 69 | ``` 70 | 71 | This means that wrapping an async generator with `@inject` will have the following effects: 72 | 73 | - The generator will yield as expected 74 | - The generator will **not** accept values via `asend()` 75 | 76 | If you need to send values to an async generator, you can simply resolve dependencies in the generator body: 77 | 78 | ```python 79 | 80 | async def my_generator() -> typing.AsyncGenerator[float, float]: 81 | value = await Container.factory.resolve() 82 | receive = yield value # (1)! 83 | yield receive + value 84 | 85 | ``` 86 | 87 | 1. This receive will always be `None` if you would wrap this generator with @inject. 88 | 89 | 90 | 91 | ## ContextResources 92 | 93 | `that-depends` will **not** allow context initialization for [ContextResource](../providers/context-resources.md) providers 94 | as part of dependency injection into a generator. 95 | 96 | This is the case for both async and sync injection. 97 | 98 | **For example:** 99 | ```python 100 | def sync_resource() -> typing.Iterator[float]: 101 | yield random.random() 102 | 103 | class Container(BaseContainer): 104 | sync_provider = providers.ContextResource(sync_resource).with_config(scope=ContextScopes.INJECT) 105 | dependent_provider = providers.Factory(lambda x: x, sync_provider.cast) 106 | 107 | @inject(scope=ContextScopes.INJECT) # (1)! 108 | def injected(val: float = Provide[Container.dependent_provider]) -> typing.Generator[float, None, None]: 109 | yield val 110 | 111 | # This will raise a `ContextProviderError`! 112 | next(_injected()) 113 | ``` 114 | 115 | 1. Matches context scope of `sync_provider` provider, which is a dependency of the `dependent_provider` provider. 116 | 117 | 118 | When calling `next(injected())`, `that-depends` will try to initialize a new context for the `sync_provider`, 119 | however, this is not permitted for generators, thus it will raise a `ContextProviderError`. 120 | 121 | 122 | Keep in mind that if context does not need to be initialized, the generator injection will work as expected: 123 | 124 | ```python 125 | def sync_resource() -> typing.Iterator[float]: 126 | yield random.random() 127 | 128 | class Container(BaseContainer): 129 | sync_provider = providers.ContextResource(sync_resource).with_config(scope=ContextScopes.REQUEST) 130 | dependent_provider = providers.Factory(lambda x: x, sync_provider.cast) 131 | 132 | @inject(scope=ContextScopes.INJECT) # (1)! 133 | def injected(val: float = Provide[Container.dependent_provider]) -> typing.Generator[float, None, None]: 134 | yield val 135 | 136 | 137 | with container_context(scope=ContextScopes.REQUEST): 138 | # This will resolve as expected 139 | next(_injected()) 140 | ``` 141 | 142 | Since no context initialization was needed, the generator will work as expected. 143 | 144 | 1. Scope provided to `@inject` no longer matches scope of the `sync_provider` 145 | 146 | 147 | ### Container Context 148 | 149 | Similarly to above, the `@container_context` also does **not** support generators: 150 | 151 | ```python 152 | @container_context(Container) 153 | async def my_generator() -> typing.AsyncIterator[None]: 154 | yield 155 | ``` 156 | The above code will raise a `UserWarning`. 157 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: that-depends 2 | repo_url: https://github.com/modern-python/that-depends 3 | site_url: https://that-depends.readthedocs.io/ 4 | docs_dir: docs 5 | edit_uri: edit/main/docs/ 6 | nav: 7 | - Quick-Start: index.md 8 | - Introduction: 9 | - Containers: introduction/ioc-container.md 10 | - Dependency Injection: introduction/injection.md 11 | - Generator Injection: introduction/generator-injection.md 12 | - Type-based Injection: introduction/type-based-injection.md 13 | - String Injection: introduction/string-injection.md 14 | - Scopes: introduction/scopes.md 15 | - Tear-down: introduction/tear-down.md 16 | - Multiple Containers: introduction/multiple-containers.md 17 | 18 | - Providers: 19 | - Collections: providers/collections.md 20 | - Context-Resources: providers/context-resources.md 21 | - Factories: providers/factories.md 22 | - Object: providers/object.md 23 | - Resources: providers/resources.md 24 | - Selector: providers/selector.md 25 | - Singletons: providers/singleton.md 26 | - State: providers/state.md 27 | - Experimental Features: 28 | - Lazy Provider: experimental/lazy.md 29 | 30 | - Integrations: 31 | - FastAPI: integrations/fastapi.md 32 | - FastStream: integrations/faststream.md 33 | - Litestar: integrations/litestar.md 34 | - Testing: 35 | - Fixtures: testing/fixture.md 36 | - Overriding: testing/provider-overriding.md 37 | - Migration: 38 | - 1.* to 2.*: migration/v2.md 39 | - 2.* to 3.*: migration/v3.md 40 | - Development: 41 | - Contributing: dev/contributing.md 42 | - Decisions: dev/main-decisions.md 43 | 44 | theme: 45 | name: material 46 | custom_dir: docs/overrides 47 | features: 48 | - content.code.copy 49 | - content.code.annotate 50 | - content.action.edit 51 | - content.action.view 52 | - navigation.footer 53 | - navigation.sections 54 | - navigation.expand 55 | - navigation.top 56 | - navigation.instant 57 | - header.autohide 58 | - announce.dismiss 59 | icon: 60 | edit: material/pencil 61 | view: material/eye 62 | repo: fontawesome/brands/git-alt 63 | palette: 64 | - media: "(prefers-color-scheme: light)" 65 | scheme: default 66 | primary: black 67 | accent: pink 68 | toggle: 69 | icon: material/brightness-7 70 | name: Switch to dark mode 71 | - media: "(prefers-color-scheme: dark)" 72 | scheme: slate 73 | primary: black 74 | accent: pink 75 | toggle: 76 | icon: material/brightness-4 77 | name: Switch to system preference 78 | 79 | markdown_extensions: 80 | - toc: 81 | permalink: true 82 | - pymdownx.highlight: 83 | anchor_linenums: true 84 | line_spans: __span 85 | pygments_lang_class: true 86 | - pymdownx.inlinehilite 87 | - pymdownx.snippets 88 | - pymdownx.superfences 89 | - def_list 90 | - codehilite: 91 | use_pygments: true 92 | - attr_list 93 | - md_in_html 94 | - pymdownx.superfences 95 | - pymdownx.tabbed: 96 | alternate_style: true 97 | 98 | extra_css: 99 | - css/code.css 100 | plugins: 101 | - llmstxt: 102 | full_output: llms-full.txt 103 | markdown_description: "A simple dependency injection framework for Python." 104 | sections: 105 | Quickstart: 106 | - index.md: How to install and basic usage example 107 | Containers: 108 | - introduction/ioc-container.md: Defining the DI Container 109 | - introduction/multiple-containers.md: Using multiple DI Containers 110 | Injecting Dependencies: 111 | - introduction/injection.md: Default way to inject dependencies 112 | - introduction/generator-injection.md: Injecting dependencies into generators 113 | - introduction/type-based-injection.md: Injecting dependencies by bound type 114 | - introduction/string-injection.md: Injecting dependencies by providing a reference to a Provider by string 115 | Simple Providers: 116 | - providers/singleton.md: The Singleton and AsyncSingleton provider 117 | - providers/factories.md: The Factory and AsyncFactory provider 118 | - providers/object.md: The Object provider 119 | - providers/resources.md: The Resource and AsyncResource provider 120 | - introduction/tear-down.md: Handling tear down for Singletons and Resources 121 | Context Providers: 122 | - providers/context-resources.md: The ContextResource provider 123 | - introduction/scopes.md: Handling scopes with ContextResources 124 | - providers/state.md: The State provider 125 | Providers that interact with other providers: 126 | - providers/collections.md: The Dict and List provider 127 | - providers/selector.md: The Selector provider 128 | Integrations with other Frameworks: 129 | - integrations/fastapi.md: Integration with FastApi 130 | - integrations/faststream.md: Integration with FastStream 131 | - integrations/litestar.md: Integration with LiteStar 132 | Testing: 133 | - testing/fixture.md: Usage with fixtures 134 | - testing/provider-overriding.md: Overriding providers 135 | Experimental features: 136 | - experimental/lazy.md: The Lazy provider 137 | 138 | extra: 139 | social: 140 | - icon: fontawesome/brands/github 141 | link: https://github.com/modern-python/that-depends 142 | name: GitHub 143 | -------------------------------------------------------------------------------- /that_depends/integrations/faststream.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from importlib.metadata import version 3 | from types import TracebackType 4 | from typing import Any, Final, Optional 5 | 6 | from packaging.version import Version 7 | from typing_extensions import deprecated, override 8 | 9 | from that_depends import container_context 10 | from that_depends.providers.context_resources import ContextScope, SupportsContext 11 | from that_depends.utils import UNSET, Unset, is_set 12 | 13 | 14 | _FASTSTREAM_MODULE_NAME: Final[str] = "faststream" 15 | _FASTSTREAM_VERSION: Final[str] = version(_FASTSTREAM_MODULE_NAME) 16 | if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover 17 | from faststream import BaseMiddleware, ContextRepo 18 | from faststream._internal.types import AnyMsg 19 | 20 | class DIContextMiddleware(BaseMiddleware): 21 | """Initializes the container context for faststream brokers.""" 22 | 23 | def __init__( 24 | self, 25 | *context_items: SupportsContext[Any], 26 | msg: AnyMsg | None = None, 27 | context: Optional["ContextRepo"] = None, 28 | global_context: dict[str, Any] | Unset = UNSET, 29 | scope: ContextScope | Unset = UNSET, 30 | ) -> None: 31 | """Initialize the container context middleware. 32 | 33 | Args: 34 | *context_items (SupportsContext[Any]): Context items to initialize. 35 | msg (Any): Message object. 36 | context (ContextRepo): Context repository. 37 | global_context (dict[str, Any] | Unset): Global context to initialize the container. 38 | scope (ContextScope | Unset): Context scope to initialize the container. 39 | 40 | """ 41 | super().__init__(msg, context=context) # type: ignore[arg-type] 42 | self._context: container_context | None = None 43 | self._context_items = set(context_items) 44 | self._global_context = global_context 45 | self._scope = scope 46 | 47 | @override 48 | async def on_receive(self) -> None: 49 | self._context = container_context( 50 | *self._context_items, 51 | scope=self._scope if is_set(self._scope) else None, 52 | global_context=self._global_context if is_set(self._global_context) else None, 53 | ) 54 | await self._context.__aenter__() 55 | 56 | @override 57 | async def after_processed( 58 | self, 59 | exc_type: type[BaseException] | None = None, 60 | exc_val: BaseException | None = None, 61 | exc_tb: Optional["TracebackType"] = None, 62 | ) -> bool | None: 63 | if self._context is not None: 64 | await self._context.__aexit__(exc_type, exc_val, exc_tb) 65 | return None 66 | 67 | def __call__(self, msg: Any = None, **kwargs: Any) -> "DIContextMiddleware": # noqa: ANN401 68 | """Create an instance of DIContextMiddleware. 69 | 70 | Args: 71 | msg (Any): Message object. 72 | **kwargs: Additional keyword arguments. 73 | 74 | Returns: 75 | DIContextMiddleware: A new instance of DIContextMiddleware. 76 | 77 | """ 78 | context = kwargs.get("context") 79 | 80 | return DIContextMiddleware( 81 | *self._context_items, 82 | msg=msg, 83 | context=context, 84 | scope=self._scope, 85 | global_context=self._global_context, 86 | ) 87 | else: # pragma: no cover 88 | from faststream import BaseMiddleware 89 | 90 | @deprecated("Will be removed with faststream v1") 91 | class DIContextMiddleware(BaseMiddleware): # type: ignore[no-redef] 92 | """Initializes the container context for faststream brokers.""" 93 | 94 | def __init__( 95 | self, 96 | *context_items: SupportsContext[Any], 97 | global_context: dict[str, Any] | Unset = UNSET, 98 | scope: ContextScope | Unset = UNSET, 99 | ) -> None: 100 | """Initialize the container context middleware. 101 | 102 | Args: 103 | *context_items (SupportsContext[Any]): Context items to initialize. 104 | global_context (dict[str, Any] | Unset): Global context to initialize the container. 105 | scope (ContextScope | Unset): Context scope to initialize the container. 106 | 107 | """ 108 | super().__init__() # type: ignore[call-arg] 109 | self._context: container_context | None = None 110 | self._context_items = set(context_items) 111 | self._global_context = global_context 112 | self._scope = scope 113 | 114 | @override 115 | async def on_receive(self) -> None: 116 | self._context = container_context( 117 | *self._context_items, 118 | scope=self._scope if is_set(self._scope) else None, 119 | global_context=self._global_context if is_set(self._global_context) else None, 120 | ) 121 | await self._context.__aenter__() 122 | 123 | @override 124 | async def after_processed( 125 | self, 126 | exc_type: type[BaseException] | None = None, 127 | exc_val: BaseException | None = None, 128 | exc_tb: Optional["TracebackType"] = None, 129 | ) -> bool | None: 130 | if self._context is not None: 131 | await self._context.__aexit__(exc_type, exc_val, exc_tb) 132 | return None 133 | 134 | def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> "DIContextMiddleware": # noqa: ARG002, ANN401 135 | """Create an instance of DIContextMiddleware.""" 136 | return DIContextMiddleware(*self._context_items, scope=self._scope, global_context=self._global_context) 137 | -------------------------------------------------------------------------------- /tests/experimental/test_container_2.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections.abc import AsyncIterator, Iterator 3 | 4 | import pytest 5 | import typing_extensions 6 | 7 | from that_depends import BaseContainer, ContextScopes, container_context, providers 8 | from that_depends.experimental import LazyProvider 9 | 10 | 11 | class _RandomWrapper: 12 | def __init__(self) -> None: 13 | self.value = random.random() 14 | 15 | @typing_extensions.override 16 | def __eq__(self, other: object) -> bool: 17 | if isinstance(other, _RandomWrapper): 18 | return self.value == other.value 19 | return False # pragma: nocover 20 | 21 | def __hash__(self) -> int: 22 | return 0 # pragma: nocover 23 | 24 | 25 | async def _async_creator() -> AsyncIterator[float]: 26 | yield random.random() 27 | 28 | 29 | def _sync_creator() -> Iterator[_RandomWrapper]: 30 | yield _RandomWrapper() 31 | 32 | 33 | class Container2(BaseContainer): 34 | """Test Container 2.""" 35 | 36 | alias = "container_2" 37 | default_scope = ContextScopes.APP 38 | obj_1 = LazyProvider("tests.experimental.test_container_1.Container1.obj_1") 39 | obj_2 = providers.Object(2) 40 | async_context_provider = providers.ContextResource(_async_creator) 41 | sync_context_provider = providers.ContextResource(_sync_creator) 42 | singleton_provider = providers.Singleton(lambda: random.random()) 43 | 44 | 45 | async def test_lazy_provider_resolution_async() -> None: 46 | assert await Container2.obj_1.resolve() == 1 47 | 48 | 49 | def test_lazy_provider_override_sync() -> None: 50 | override_value = 42 51 | Container2.obj_1.override_sync(override_value) 52 | assert Container2.obj_1.resolve_sync() == override_value 53 | Container2.obj_1.reset_override_sync() 54 | assert Container2.obj_1.resolve_sync() == 1 55 | 56 | 57 | async def test_lazy_provider_override_async() -> None: 58 | override_value = 42 59 | await Container2.obj_1.override(override_value) 60 | assert await Container2.obj_1.resolve() == override_value 61 | await Container2.obj_1.reset_override() 62 | assert await Container2.obj_1.resolve() == 1 63 | 64 | 65 | def test_lazy_provider_invalid_state() -> None: 66 | lazy_provider = LazyProvider( 67 | module_string="tests.experimental.test_container_2", provider_string="Container2.sync_context_provider" 68 | ) 69 | lazy_provider._module_string = None 70 | with pytest.raises(RuntimeError): 71 | lazy_provider.resolve_sync() 72 | 73 | 74 | async def test_lazy_provider_context_resource_async() -> None: 75 | lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.async_context_provider") 76 | async with lazy_provider.context_async(force=True): 77 | assert await lazy_provider.resolve() == await Container2.async_context_provider.resolve() 78 | async with Container2.async_context_provider.context_async(force=True): 79 | assert await lazy_provider.resolve() == await Container2.async_context_provider.resolve() 80 | 81 | with pytest.raises(RuntimeError): 82 | await lazy_provider.resolve() 83 | 84 | async with container_context(Container2, scope=ContextScopes.APP): 85 | assert await lazy_provider.resolve() == await Container2.async_context_provider.resolve() 86 | 87 | assert lazy_provider.get_scope() == ContextScopes.APP 88 | 89 | assert lazy_provider.supports_context_sync() is False 90 | 91 | 92 | def test_lazy_provider_context_resource_sync() -> None: 93 | lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.sync_context_provider") 94 | with lazy_provider.context_sync(force=True): 95 | assert lazy_provider.resolve_sync() == Container2.sync_context_provider.resolve_sync() 96 | with Container2.sync_context_provider.context_sync(force=True): 97 | assert lazy_provider.resolve_sync() == Container2.sync_context_provider.resolve_sync() 98 | 99 | with pytest.raises(RuntimeError): 100 | lazy_provider.resolve_sync() 101 | 102 | with container_context(Container2, scope=ContextScopes.APP): 103 | assert lazy_provider.resolve_sync() == Container2.sync_context_provider.resolve_sync() 104 | 105 | assert lazy_provider.get_scope() == ContextScopes.APP 106 | 107 | assert lazy_provider.supports_context_sync() is True 108 | 109 | 110 | async def test_lazy_provider_tear_down_async() -> None: 111 | lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.singleton_provider") 112 | assert lazy_provider.resolve_sync() == Container2.singleton_provider.resolve_sync() 113 | 114 | await lazy_provider.tear_down() 115 | 116 | assert await lazy_provider.resolve() == Container2.singleton_provider.resolve_sync() 117 | 118 | 119 | def test_lazy_provider_tear_down_sync() -> None: 120 | lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.singleton_provider") 121 | assert lazy_provider.resolve_sync() == Container2.singleton_provider.resolve_sync() 122 | 123 | lazy_provider.tear_down_sync() 124 | 125 | assert lazy_provider.resolve_sync() == Container2.singleton_provider.resolve_sync() 126 | 127 | 128 | async def test_lazy_provider_not_implemented() -> None: 129 | lazy_provider = Container2.obj_1 130 | with pytest.raises(NotImplementedError): 131 | lazy_provider.get_scope() 132 | with pytest.raises(NotImplementedError): 133 | lazy_provider.context_sync() 134 | with pytest.raises(NotImplementedError): 135 | lazy_provider.context_async() 136 | with pytest.raises(NotImplementedError): 137 | lazy_provider.tear_down_sync() 138 | with pytest.raises(NotImplementedError): 139 | await lazy_provider.tear_down() 140 | 141 | 142 | def test_lazy_provider_attr_getter() -> None: 143 | lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.sync_context_provider") 144 | with lazy_provider.context_sync(force=True): 145 | assert isinstance(lazy_provider.value.resolve_sync(), float) 146 | -------------------------------------------------------------------------------- /docs/introduction/injection.md: -------------------------------------------------------------------------------- 1 | # Injecting Providers in **that-depends** 2 | 3 | `that-depends` uses a decorator-based approach for both synchronous and asynchronous functions. By decorating a function with `@inject` and marking certain parameters as `Provide[...]`, **that-depends** will automatically resolve the specified providers at call time. 4 | 5 | --- 6 | 7 | ## Overview 8 | 9 | In **that-depends**, you define your dependencies as `AbstractProvider` instances—e.g., `Singleton`, `Factory`, `Resource`, or others. These providers typically live inside a subclass of `BaseContainer`, making them globally accessible. 10 | 11 | When you want to use a provider in a function, you can mark a parameter’s **default value** as: 12 | 13 | ```python 14 | my_param = Provide[MyContainer.some_provider] 15 | ``` 16 | 17 | You then decorate the function with `@inject`. This tells `that-depends` to automatically resolve providers when you call your function. 18 | 19 | --- 20 | 21 | ## Quick Start 22 | 23 | Below is a simple example demonstrating how to define a container, declare a provider, and inject that provider into a function. 24 | 25 | ### 1. Define a Container and a Provider 26 | 27 | ```python 28 | from that_depends import BaseContainer 29 | from that_depends.providers import Singleton 30 | 31 | class MyContainer(BaseContainer): 32 | greeting_provider = Singleton(lambda: "Hello from MyContainer") 33 | ``` 34 | For more details on Containers, refer to the [Containers](ioc-container.md) documentation. 35 | 36 | ### 2. Inject the Provider into a Function 37 | 38 | ```python 39 | from that_depends import inject, Provide 40 | 41 | @inject 42 | def greet_user(greeting: str = Provide[MyContainer.greeting_provider]) -> str: 43 | return f"Greeting: {greeting}" 44 | ``` 45 | 46 | Here: 47 | 48 | 1. We used `@inject` above `greet_user`. 49 | 2. We declared a parameter `greeting`, whose default value is `Provide[MyContainer.greeting_provider]`. 50 | 51 | ### 3. Call the Function 52 | 53 | ```python 54 | print(greet_user()) # "Greeting: Hello from MyContainer" 55 | ``` 56 | 57 | --- 58 | 59 | ## The `@inject` Decorator in Detail 60 | 61 | 62 | ### Synchronous vs Asynchronous Functions 63 | 64 | `@inject` works on both sync and async functions. Just note that injecting async providers into sync functions is not supported. 65 | 66 | ```python 67 | @inject 68 | async def async_greet_user(greeting: str = Provide[MyContainer.greeting_provider]) -> str: 69 | # asynchronous operations... 70 | return f"Greeting: {greeting}" 71 | ``` 72 | 73 | --- 74 | 75 | ## Using `Provide[...]` as a Default 76 | 77 | It is recommended to wrap your provider in `Provide[...]` when using it as a default in an injected function since it provides correct type resolution: 78 | 79 | ```python 80 | @inject 81 | def greet_user_direct( 82 | greeting: str = Provide[MyContainer.greeting_provider] # (1)! 83 | ) -> str: 84 | return f"Greeting: {greeting}" 85 | ``` 86 | 87 | 1. Notice that although `greeting` is a `str`, `mypy` and you IDE will not complain. 88 | 89 | --- 90 | 91 | ## Injection Warnings 92 | 93 | If `@inject` finds **no** parameters whose default values are providers, it will issue a warning: 94 | 95 | > `Expected injection, but nothing found. Remove @inject decorator.` 96 | 97 | This is to avoid accidentally decorating a function that doesn’t actually require injection. 98 | 99 | --- 100 | 101 | ## Specifying a Scope 102 | 103 | By default, `@inject` uses the `ContextScopes.INJECT` scope. If you want to override that, do: 104 | 105 | ```python 106 | from that_depends import inject 107 | from that_depends.providers.context_resources import ContextScopes 108 | 109 | @inject(scope=ContextScopes.REQUEST) 110 | def greet_user(greeting: str = Provide[MyContainer.greeting_provider]): 111 | ... 112 | ``` 113 | 114 | When `greet_user` is called, **that-depends**: 115 | 116 | 1. Initializes the context for all `REQUEST` (or `ANY`) scoped `args` and `kwargs`. 117 | 2. Resolves all providers in the `args` and `kwargs` of the function. 118 | 3. Calls your function with the resolved dependencies. 119 | 120 | For more details regarding scopes and context management, see the [Context Resources](../providers/context-resources.md) documentation and the [Scopes](scopes.md) documentation. 121 | 122 | --- 123 | 124 | ## Overriding Providers 125 | 126 | In tests or specialized scenarios, you may want to override a provider’s value temporarily. You can do so with the container’s `override_providers()` method or the provider’s own `override_context()`: 127 | 128 | ```python 129 | def test_greet_override(): 130 | # Override the greeting_provider with a mock value 131 | with MyContainer.override_providers_sync({"greeting_provider": "TestHello"}): 132 | result = greet_user() 133 | assert result == "Greeting: TestHello" 134 | ``` 135 | 136 | This is especially helpful for unit tests where you want to substitute real dependencies (e.g., database connections) with mocks or stubs. 137 | 138 | For more details on overring providers, see the [Overriding Providers](../testing/provider-overriding.md) documentation. 139 | 140 | --- 141 | 142 | ## Frequently Asked Questions 143 | 144 | 1. **Do I need to call `@inject` every time I reference a provider?** 145 | No—only when you want **automatic** injection of providers into function parameters. If you are resolving dependencies manually (e.g., `MyContainer.greeting_provider.resolve_sync()`), then `@inject` is not needed. 146 | 147 | 2. **What if I provide a custom argument to a parameter that has a default provider?** 148 | If you explicitly pass a value, that value overrides the injected default: 149 | 150 | ~~~~python 151 | @inject 152 | def foo(x: int = Provide[MyContainer.number_factory]) -> int: 153 | return x 154 | 155 | print(foo()) # uses number_factory -> 42 156 | print(foo(99)) # explicitly uses 99 157 | ~~~~ 158 | 159 | 3. **Can I combine `@inject` with other decorators?** 160 | Yes, you can. Generally, put `@inject` **below** others, depending on the order you need. If you run into issues, experiment with the order or handle context manually. 161 | 162 | --- 163 | -------------------------------------------------------------------------------- /docs/providers/factories.md: -------------------------------------------------------------------------------- 1 | # Factories 2 | Factories are initialized on every call. 3 | 4 | ## Factory 5 | - Class or simple function is allowed. 6 | ```python 7 | import dataclasses 8 | 9 | from that_depends import BaseContainer, providers 10 | 11 | 12 | @dataclasses.dataclass(kw_only=True, slots=True) 13 | class IndependentFactory: 14 | dep1: str 15 | dep2: int 16 | 17 | 18 | class DIContainer(BaseContainer): 19 | independent_factory = providers.Factory(IndependentFactory, dep1="text", dep2=123) 20 | ``` 21 | 22 | ## AsyncFactory 23 | - Allows both sync and async creators. 24 | - Can only be resolved asynchronously, even if the creator is sync. 25 | ```python 26 | import datetime 27 | 28 | from that_depends import BaseContainer, providers 29 | 30 | 31 | async def async_factory() -> datetime.datetime: 32 | return datetime.datetime.now(tz=datetime.timezone.utc) 33 | 34 | 35 | class DIContainer(BaseContainer): 36 | async_factory = providers.AsyncFactory(async_factory) 37 | ``` 38 | 39 | > Note: If you have a class that has dependencies which need to be resolved asynchronously, you can use `AsyncFactory` to create instances of that class. The factory will handle the async resolution of dependencies. 40 | 41 | 42 | ## Retrieving provider as a Callable 43 | 44 | When you use a factory‑based provider such as `Factory` (for sync logic) or `AsyncFactory` (for async logic), the resulting provider instance has two special properties: 45 | 46 | - **`.provider`** — returns an *async callable* that, when awaited, resolves the resource. 47 | - **`.provider_sync`** — returns a *sync callable* that, when called, resolves the resource. 48 | 49 | You can think of these as no-argument functions that produce the resource you defined—similar to calling `resolve()` or `resolve_sync()` directly, but in a more convenient form when you want a standalone function handle. 50 | 51 | --- 52 | 53 | ### Basic Usage 54 | 55 | #### Defining Providers in a Container 56 | 57 | Suppose you have a `BaseContainer` subclass that defines both a sync and an async resource: 58 | 59 | ```python 60 | from that_depends import BaseContainer 61 | from that_depends.providers import Factory, AsyncFactory 62 | 63 | def build_sync_message() -> str: 64 | return "Hello from sync provider!" 65 | 66 | async def build_async_message() -> str: 67 | # Possibly perform async setup, I/O, etc. 68 | return "Hello from async provider!" 69 | 70 | class MyContainer(BaseContainer): 71 | # Synchronous Factory 72 | sync_message = Factory(build_sync_message) 73 | 74 | # Asynchronous Factory 75 | async_message = AsyncFactory(build_async_message) 76 | ``` 77 | 78 | Here, `sync_message` is a `Factory` which calls a plain function, while `async_message` is an `AsyncFactory` which calls an async function. 79 | 80 | --- 81 | 82 | #### Resolving Resources via `.provider` and `.provider_sync` 83 | 84 | The `.provider` property gives you an *async function* to await, and `.provider_sync` gives you a *synchronous* callable. They effectively wrap `.resolve()` and `.resolve_sync()`. 85 | 86 | **Synchronous Resolution** 87 | 88 | ```python 89 | # In a synchronous function or interactive session 90 | >>> msg = MyContainer.sync_message.provider_sync 91 | >>> print(msg) 92 | Hello from sync provider! 93 | ``` 94 | 95 | Here, `provider_sync` is a no-argument function that immediately returns the resolved value. 96 | 97 | **Asynchronous Resolution** 98 | 99 | ```python 100 | import asyncio 101 | 102 | async def main(): 103 | # Acquire the async resource by awaiting the provider property 104 | msg = await MyContainer.async_message.provider 105 | print(msg) 106 | 107 | asyncio.run(main()) 108 | ``` 109 | 110 | Within an async function, `MyContainer.async_message.provider` gives a no-argument async function to `await`. 111 | 112 | --- 113 | 114 | ### Passing the Provider Function Around 115 | 116 | Sometimes you may want to store or pass around the provider function itself (rather than resolving it immediately): 117 | 118 | ```python 119 | class AnotherClass: 120 | def __init__(self, sync_factory_callable: callable): 121 | self._factory_callable = sync_factory_callable 122 | 123 | def get_message(self) -> str: 124 | return self._factory_callable() 125 | 126 | 127 | # Passing MyContainer.sync_message.sync_provider to AnotherClass 128 | provider_callable = MyContainer.sync_message.provider_sync 129 | another_instance = AnotherClass(provider_callable) 130 | print(another_instance.get_message()) # "Hello from sync provider!" 131 | ``` 132 | 133 | Because `.provider_sync` is just a callable returning your dependency, it can be shared easily throughout your code. 134 | 135 | --- 136 | 137 | ### Example: Using Factories with Parameters 138 | 139 | `Factory` and `AsyncFactory` can accept dependencies (including other providers) as parameters: 140 | 141 | ```python 142 | def greet(name: str) -> str: 143 | return f"Hello, {name}!" 144 | 145 | class MyContainer(BaseContainer): 146 | name = Factory(lambda: "Alice") 147 | greeting = Factory(greet, name) 148 | ``` 149 | 150 | ```python 151 | >>> greeting_sync_fn = MyContainer.greeting.provider_sync 152 | >>> print(greeting_sync_fn()) 153 | Hello, Alice! 154 | ``` 155 | 156 | Under the hood, `greeting` calls `greet` with the result of `name.resolve_sync()`. 157 | 158 | --- 159 | 160 | ### Context Considerations 161 | 162 | If your providers use `ContextResource` or require a named scope (for instance, `REQUEST`), you need to wrap your resolves in a context manager: 163 | 164 | ```python 165 | from that_depends.providers import container_context, ContextScopes 166 | 167 | 168 | class ContextfulContainer(BaseContainer): 169 | default_scope = ContextScopes.REQUEST 170 | # ... define context-based providers ... 171 | 172 | 173 | with container_context(ContextfulContainer, scope=ContextScopes.REQUEST): 174 | result = ContextfulContainer.some_resource.provider_sync() 175 | # ... 176 | ``` 177 | 178 | You still call `.provider_sync` or `.provider`, but the container or context usage ensures resources are valid within the required scope. 179 | 180 | This pattern simplifies passing creation logic around in your code, preserving testability and clarity—whether you need sync or async behavior. 181 | -------------------------------------------------------------------------------- /tests/providers/test_selector.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import typing 4 | 5 | import pydantic 6 | import pytest 7 | 8 | from tests.container import create_async_resource, create_sync_resource 9 | from that_depends import BaseContainer, providers 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | global_state_for_selector: typing.Literal["sync_resource", "async_resource", "missing"] = "sync_resource" 14 | 15 | 16 | class SelectorState: 17 | def __init__(self) -> None: 18 | self.selector_state: typing.Literal["sync_resource", "async_resource", "missing"] = "sync_resource" 19 | 20 | def get_selector_state(self) -> typing.Literal["sync_resource", "async_resource", "missing"]: 21 | return self.selector_state 22 | 23 | 24 | selector_state = SelectorState() 25 | 26 | 27 | class DIContainer(BaseContainer): 28 | alias = "selector_container" 29 | sync_resource = providers.Resource(create_sync_resource) 30 | async_resource = providers.Resource(create_async_resource) 31 | selector: providers.Selector[datetime.datetime] = providers.Selector( 32 | selector_state.get_selector_state, 33 | sync_resource=sync_resource, 34 | async_resource=async_resource, 35 | ) 36 | 37 | 38 | async def test_selector_provider_async() -> None: 39 | selector_state.selector_state = "async_resource" 40 | selected = await DIContainer.selector() 41 | async_resource = await DIContainer.async_resource() 42 | 43 | assert selected == async_resource 44 | 45 | 46 | async def test_selector_provider_async_missing() -> None: 47 | selector_state.selector_state = "missing" 48 | with pytest.raises(KeyError, match="Key 'missing' not found in providers"): 49 | await DIContainer.selector() 50 | 51 | 52 | async def test_selector_provider_sync() -> None: 53 | selector_state.selector_state = "sync_resource" 54 | selected = DIContainer.selector.resolve_sync() 55 | sync_resource = DIContainer.sync_resource.resolve_sync() 56 | 57 | assert selected == sync_resource 58 | 59 | 60 | async def test_selector_provider_sync_missing() -> None: 61 | selector_state.selector_state = "missing" 62 | with pytest.raises(KeyError, match="Key 'missing' not found in providers"): 63 | DIContainer.selector.resolve_sync() 64 | 65 | 66 | async def test_selector_provider_overriding() -> None: 67 | now = datetime.datetime.now(tz=datetime.timezone.utc) 68 | DIContainer.selector.override_sync(now) 69 | selected_async = await DIContainer.selector() 70 | selected_sync = DIContainer.selector.resolve_sync() 71 | assert selected_async == selected_sync == now 72 | 73 | DIContainer.reset_override_sync() 74 | selector_state.selector_state = "sync_resource" 75 | selected = DIContainer.selector.resolve_sync() 76 | sync_resource = DIContainer.sync_resource.resolve_sync() 77 | assert selected == sync_resource 78 | 79 | 80 | class Settings(pydantic.BaseModel): 81 | mode: typing.Literal["one", "two"] = "one" 82 | 83 | 84 | class SettingsSelectorContainer(BaseContainer): 85 | settings = providers.Singleton(Settings) 86 | operator = providers.Selector( 87 | settings.cast.mode, 88 | one=providers.Singleton(lambda: "Provider 1"), 89 | two=providers.Singleton(lambda: "Provider 2"), 90 | ) 91 | 92 | 93 | async def test_selector_with_attrgetter_selector_sync() -> None: 94 | assert SettingsSelectorContainer.settings.resolve_sync().mode == "one" 95 | assert SettingsSelectorContainer.operator.resolve_sync() == "Provider 1" 96 | 97 | SettingsSelectorContainer.settings.override_sync(Settings(mode="two")) 98 | assert SettingsSelectorContainer.settings.resolve_sync().mode == "two" 99 | assert SettingsSelectorContainer.operator.resolve_sync() == "Provider 2" 100 | SettingsSelectorContainer.settings.reset_override_sync() 101 | 102 | 103 | async def test_selector_with_attrgetter_selector_async() -> None: 104 | assert (await SettingsSelectorContainer.settings.resolve()).mode == "one" 105 | assert (await SettingsSelectorContainer.operator.resolve()) == "Provider 1" 106 | 107 | SettingsSelectorContainer.settings.override_sync(Settings(mode="two")) 108 | assert (await SettingsSelectorContainer.settings.resolve()).mode == "two" 109 | assert (await SettingsSelectorContainer.operator.resolve()) == "Provider 2" 110 | SettingsSelectorContainer.settings.reset_override_sync() 111 | 112 | 113 | class StringSelectorContainer(BaseContainer): 114 | selector = providers.Selector( 115 | "one", 116 | one=providers.Singleton(lambda: "Provider 1"), 117 | two=providers.Singleton(lambda: "Provider 2"), 118 | ) 119 | 120 | 121 | async def test_selector_with_fixed_string() -> None: 122 | assert StringSelectorContainer.selector.resolve_sync() == "Provider 1" 123 | assert (await StringSelectorContainer.selector.resolve()) == "Provider 1" 124 | 125 | 126 | async def mode_one() -> str: 127 | return "one" 128 | 129 | 130 | class StringProviderSelectorContainer(BaseContainer): 131 | mode = providers.AsyncFactory(mode_one) 132 | selector = providers.Selector( 133 | mode.cast, 134 | one=providers.Singleton(lambda: "Provider 1"), 135 | two=providers.Singleton(lambda: "Provider 2"), 136 | ) 137 | 138 | 139 | async def test_selector_with_provider_selector_async() -> None: 140 | assert (await StringProviderSelectorContainer.selector.resolve()) == "Provider 1" 141 | 142 | 143 | class InvalidSelectorContainer(BaseContainer): 144 | selector = providers.Selector( 145 | None, # type: ignore[arg-type] 146 | one=providers.Singleton(lambda: "Provider 1"), 147 | two=providers.Singleton(lambda: "Provider 2"), 148 | ) 149 | 150 | 151 | async def test_selector_with_invalid_selector() -> None: 152 | with pytest.raises( 153 | TypeError, match="Invalid selector type: , expected str, or a provider/callable returning str" 154 | ): 155 | InvalidSelectorContainer.selector.resolve_sync() 156 | 157 | with pytest.raises( 158 | TypeError, match="Invalid selector type: , expected str, or a provider/callable returning str" 159 | ): 160 | await InvalidSelectorContainer.selector.resolve() 161 | -------------------------------------------------------------------------------- /that_depends/container.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import typing 3 | from contextlib import contextmanager 4 | from typing import overload 5 | 6 | from typing_extensions import override 7 | 8 | from that_depends.meta import BaseContainerMeta 9 | from that_depends.providers import AbstractProvider, AsyncSingleton, Resource, Singleton 10 | from that_depends.providers.context_resources import ContextScope, ContextScopes 11 | 12 | 13 | if typing.TYPE_CHECKING: 14 | import typing_extensions 15 | 16 | 17 | T = typing.TypeVar("T") 18 | P = typing.ParamSpec("P") 19 | 20 | 21 | class BaseContainer(metaclass=BaseContainerMeta): 22 | """Base container class.""" 23 | 24 | alias: str | None = None 25 | providers: dict[str, AbstractProvider[typing.Any]] 26 | containers: list[type["BaseContainer"]] 27 | default_scope: ContextScope | None = ContextScopes.ANY 28 | 29 | @classmethod 30 | @overload 31 | def context(cls, func: typing.Callable[P, T]) -> typing.Callable[P, T]: ... 32 | 33 | @classmethod 34 | @overload 35 | def context(cls, *, force: bool = False) -> typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: ... 36 | 37 | @classmethod 38 | def context( 39 | cls, func: typing.Callable[P, T] | None = None, force: bool = False 40 | ) -> typing.Callable[P, T] | typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: 41 | """Wrap a function with this resources' context. 42 | 43 | Args: 44 | func: function to be wrapped. 45 | force: force context initialization, ignoring scope. 46 | 47 | Returns: 48 | wrapped function or wrapper if func is None. 49 | 50 | """ 51 | 52 | def _wrapper(func: typing.Callable[P, T]) -> typing.Callable[P, T]: 53 | if inspect.iscoroutinefunction(func): 54 | 55 | async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 56 | async with cls.context_async(force=force): 57 | return await func(*args, **kwargs) # type: ignore[no-any-return] 58 | 59 | return typing.cast(typing.Callable[P, T], _async_wrapper) 60 | 61 | def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: 62 | with cls.context_sync(force=force): 63 | return func(*args, **kwargs) 64 | 65 | return _sync_wrapper 66 | 67 | if func: 68 | return _wrapper(func) 69 | return _wrapper 70 | 71 | @override 72 | def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": 73 | msg = f"{cls.__name__} should not be instantiated" 74 | raise RuntimeError(msg) 75 | 76 | @classmethod 77 | def connect_containers(cls, *containers: type["BaseContainer"]) -> None: 78 | """Connect containers. 79 | 80 | When `init_resources` and `tear_down` is called, 81 | same method of connected containers will also be called. 82 | """ 83 | if not hasattr(cls, "containers"): 84 | cls.containers = [] 85 | 86 | cls.containers.extend(containers) 87 | 88 | @classmethod 89 | async def init_resources(cls) -> None: 90 | """Initialize all resources.""" 91 | for provider in cls.get_providers().values(): 92 | if isinstance(provider, Resource | Singleton | AsyncSingleton): 93 | await provider.resolve() 94 | 95 | for container in cls.get_containers(): 96 | await container.init_resources() 97 | 98 | @classmethod 99 | def reset_override_sync(cls) -> None: 100 | """Reset all provider overrides.""" 101 | for v in cls.get_providers().values(): 102 | v.reset_override_sync() 103 | 104 | @classmethod 105 | def resolver(cls, item: typing.Callable[P, T]) -> typing.Callable[[], typing.Awaitable[T]]: 106 | """Decorate a function to automatically resolve dependencies on call by name. 107 | 108 | Args: 109 | item: objects for which the dependencies should be resolved. 110 | 111 | Returns: 112 | Async wrapped callable with auto-injected dependencies. 113 | 114 | """ 115 | 116 | async def _inner() -> T: 117 | return await cls.resolve(item) 118 | 119 | return _inner 120 | 121 | @classmethod 122 | async def resolve(cls, object_to_resolve: typing.Callable[..., T]) -> T: 123 | """Inject dependencies into an object automatically by name.""" 124 | signature: typing.Final = inspect.signature(object_to_resolve) 125 | kwargs = {} 126 | providers: typing.Final = cls.get_providers() 127 | for field_name, field_value in signature.parameters.items(): 128 | if field_value.default is not inspect.Parameter.empty or field_name in ("_", "__"): 129 | continue 130 | 131 | if field_name not in providers: 132 | msg = f"Provider is not found, {field_name=}" 133 | raise RuntimeError(msg) 134 | 135 | kwargs[field_name] = await providers[field_name].resolve() 136 | 137 | return object_to_resolve(**kwargs) 138 | 139 | @classmethod 140 | @contextmanager 141 | def override_providers_sync(cls, providers_for_overriding: dict[str, typing.Any]) -> typing.Iterator[None]: 142 | """Override several providers with mocks simultaneously. 143 | 144 | Args: 145 | providers_for_overriding: {provider_name: mock} dictionary. 146 | 147 | Returns: 148 | None 149 | 150 | """ 151 | current_providers: typing.Final = cls.get_providers() 152 | current_provider_names: typing.Final = set(current_providers.keys()) 153 | given_provider_names: typing.Final = set(providers_for_overriding.keys()) 154 | 155 | for given_name in given_provider_names: 156 | if given_name not in current_provider_names: 157 | msg = f"Provider with name {given_name!r} not found" 158 | raise RuntimeError(msg) 159 | 160 | for provider_name, mock in providers_for_overriding.items(): 161 | provider = current_providers[provider_name] 162 | provider.override_sync(mock) 163 | 164 | try: 165 | yield 166 | finally: 167 | for provider_name in providers_for_overriding: 168 | provider = current_providers[provider_name] 169 | provider.reset_override_sync() 170 | -------------------------------------------------------------------------------- /tests/providers/test_singleton.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import dataclasses 3 | import threading 4 | import time 5 | import typing 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | 8 | import pydantic 9 | import pytest 10 | 11 | from that_depends import BaseContainer, providers 12 | 13 | 14 | @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) 15 | class SingletonFactory: 16 | dep1: str 17 | 18 | 19 | class Settings(pydantic.BaseModel): 20 | some_setting: str = "some_value" 21 | other_setting: str = "other_value" 22 | 23 | 24 | async def create_async_obj(value: str) -> SingletonFactory: 25 | await asyncio.sleep(0.001) 26 | return SingletonFactory(dep1=f"async {value}") 27 | 28 | 29 | async def _async_creator() -> int: 30 | await asyncio.sleep(0.001) 31 | return threading.get_ident() 32 | 33 | 34 | def _sync_creator_with_dependency(dep: int) -> str: 35 | return f"Singleton {dep}" 36 | 37 | 38 | class DIContainer(BaseContainer): 39 | alias = "singleton_container" 40 | factory: providers.AsyncFactory[int] = providers.AsyncFactory(_async_creator) 41 | settings: Settings = providers.Singleton(Settings).cast 42 | singleton = providers.Singleton(SingletonFactory, dep1=settings.some_setting) 43 | singleton_async = providers.AsyncSingleton(create_async_obj, value=settings.some_setting) 44 | singleton_with_dependency = providers.Singleton(_sync_creator_with_dependency, dep=factory.cast) 45 | 46 | 47 | async def test_singleton_provider() -> None: 48 | singleton1 = await DIContainer.singleton() 49 | singleton2 = await DIContainer.singleton() 50 | singleton3 = DIContainer.singleton.resolve_sync() 51 | await DIContainer.singleton.tear_down() 52 | singleton4 = DIContainer.singleton.resolve_sync() 53 | 54 | assert singleton1 is singleton2 is singleton3 55 | assert singleton4 is not singleton1 56 | 57 | await DIContainer.tear_down() 58 | 59 | 60 | async def test_singleton_attr_getter() -> None: 61 | singleton1 = await DIContainer.singleton() 62 | 63 | assert singleton1.dep1 == Settings().some_setting 64 | 65 | await DIContainer.tear_down() 66 | 67 | 68 | async def test_singleton_with_empty_list() -> None: 69 | singleton = providers.Singleton(list[object]) 70 | 71 | singleton1 = await singleton() 72 | singleton2 = await singleton() 73 | assert singleton1 is singleton2 74 | 75 | 76 | @pytest.mark.repeat(10) 77 | async def test_singleton_asyncio_concurrency() -> None: 78 | calls: int = 0 79 | 80 | async def create_resource() -> typing.AsyncIterator[str]: 81 | nonlocal calls 82 | calls += 1 83 | await asyncio.sleep(0) 84 | yield "" 85 | 86 | @dataclasses.dataclass(kw_only=True, slots=True) 87 | class SimpleCreator: 88 | dep1: str 89 | 90 | resource = providers.Resource(create_resource) 91 | factory_with_resource = providers.Singleton(SimpleCreator, dep1=resource.cast) 92 | 93 | client1, client2 = await asyncio.gather(factory_with_resource.resolve(), factory_with_resource.resolve()) 94 | 95 | assert client1 is client2 96 | assert calls == 1 97 | 98 | 99 | @pytest.mark.repeat(10) 100 | def test_singleton_threading_concurrency() -> None: 101 | calls: int = 0 102 | lock = threading.Lock() 103 | 104 | def create_singleton() -> str: 105 | nonlocal calls 106 | with lock: 107 | calls += 1 108 | time.sleep(0.01) 109 | return "" 110 | 111 | singleton = providers.Singleton(create_singleton) 112 | 113 | with ThreadPoolExecutor(max_workers=4) as pool: 114 | tasks = [ 115 | pool.submit(singleton.resolve_sync), 116 | pool.submit(singleton.resolve_sync), 117 | pool.submit(singleton.resolve_sync), 118 | pool.submit(singleton.resolve_sync), 119 | ] 120 | results = [x.result() for x in as_completed(tasks)] 121 | 122 | assert all(x == "" for x in results) 123 | assert calls == 1 124 | 125 | 126 | async def test_async_singleton() -> None: 127 | singleton1 = await DIContainer.singleton_async() 128 | singleton2 = await DIContainer.singleton_async() 129 | singleton3 = await DIContainer.singleton_async.resolve() 130 | await DIContainer.singleton_async.tear_down() 131 | singleton4 = await DIContainer.singleton_async.resolve() 132 | 133 | assert singleton1 is singleton2 is singleton3 134 | assert singleton4 is not singleton1 135 | 136 | await DIContainer.tear_down() 137 | 138 | 139 | async def test_async_singleton_override() -> None: 140 | singleton_async = providers.AsyncSingleton(create_async_obj, "foo") 141 | singleton_async.override_sync(SingletonFactory(dep1="bar")) 142 | 143 | result = await singleton_async.resolve() 144 | assert result == SingletonFactory(dep1="bar") 145 | 146 | 147 | async def test_async_singleton_asyncio_concurrency() -> None: 148 | singleton_async = providers.AsyncSingleton(create_async_obj, "foo") 149 | 150 | results = await asyncio.gather( 151 | singleton_async(), 152 | singleton_async(), 153 | singleton_async(), 154 | singleton_async(), 155 | singleton_async(), 156 | ) 157 | 158 | assert all(val is results[0] for val in results) 159 | 160 | 161 | async def test_async_singleton_sync_resolve_failure() -> None: 162 | with pytest.raises(RuntimeError, match=r"AsyncSingleton cannot be resolved in an sync context."): 163 | DIContainer.singleton_async.resolve_sync() 164 | 165 | 166 | async def test_singleton_async_resolve_with_async_dependencies() -> None: 167 | expected = await DIContainer.singleton_with_dependency.resolve() 168 | 169 | assert expected == await DIContainer.singleton_with_dependency.resolve() 170 | 171 | results = await asyncio.gather(*[DIContainer.singleton_with_dependency.resolve() for _ in range(10)]) 172 | 173 | for val in results: 174 | assert val == expected 175 | 176 | results = await asyncio.gather( 177 | *[asyncio.to_thread(DIContainer.singleton_with_dependency.resolve_sync) for _ in range(10)], 178 | ) 179 | 180 | for val in results: 181 | assert val == expected 182 | 183 | 184 | async def test_async_singleton_teardown() -> None: 185 | singleton_async = providers.AsyncSingleton(create_async_obj, "foo") 186 | 187 | await singleton_async.resolve() 188 | singleton_async.tear_down_sync() 189 | 190 | assert singleton_async._instance is None 191 | 192 | await singleton_async.resolve() 193 | 194 | await singleton_async.tear_down() 195 | assert singleton_async._instance is None 196 | -------------------------------------------------------------------------------- /docs/providers/singleton.md: -------------------------------------------------------------------------------- 1 | # Singleton Provider 2 | 3 | A **Singleton** provider creates its instance once and caches it for all future injections or resolutions. When the instance is first requested (via `resolve_sync()` or `resolve()`), the underlying factory is called. On subsequent calls, the cached instance is returned without calling the factory again. 4 | 5 | ## How it Works 6 | 7 | ```python 8 | import random 9 | 10 | from that_depends import BaseContainer, Provide, inject, providers 11 | 12 | 13 | def some_function() -> float: 14 | """Generate number between 0.0 and 1.0""" 15 | return random.random() 16 | 17 | 18 | # define container with `Singleton` provider: 19 | class MyContainer(BaseContainer): 20 | singleton = providers.Singleton(some_function) 21 | 22 | 23 | # The provider will call `some_function` once and cache the return value 24 | 25 | # 1) Synchronous resolution 26 | MyContainer.singleton.resolve_sync() # e.g. 0.3 27 | MyContainer.singleton.resolve_sync() # 0.3 (cached) 28 | 29 | # 2) Asynchronous resolution 30 | await MyContainer.singleton.resolve() # 0.3 (same cached value) 31 | 32 | 33 | # 3) Injection example 34 | @inject 35 | async def with_singleton(number: float = Provide[MyContainer.singleton]): 36 | # number == 0.3 37 | ... 38 | ``` 39 | 40 | ### Teardown Support 41 | If you need to reset the singleton (for example, in tests or at application shutdown), you can call: 42 | ```python 43 | await MyContainer.singleton.tear_down() 44 | ``` 45 | This clears the cached instance, causing a new one to be created the next time `resolve_sync()` or `resolve()` is called. 46 | *(If you only ever use synchronous resolution, you can call `MyContainer.singleton.tear_down_sync()` instead.)* 47 | 48 | For further details refer to the [teardown documentation](../introduction/tear-down.md). 49 | 50 | --- 51 | 52 | ## Concurrency Safety 53 | 54 | `Singleton` is **thread-safe** and **async-safe**: 55 | 56 | 1. **Async Concurrency** 57 | If multiple coroutines call `resolve()` concurrently, the factory function is guaranteed to be called only once. All callers receive the same cached instance. 58 | 59 | 2. **Thread Concurrency** 60 | If multiple threads call `resolve_sync()` at the same time, the factory is only called once. All threads receive the same cached instance. 61 | 62 | ```python 63 | import threading 64 | import asyncio 65 | 66 | 67 | # In async code: 68 | async def main(): 69 | # calling resolve concurrently in different coroutines 70 | results = await asyncio.gather( 71 | MyContainer.singleton.resolve(), 72 | MyContainer.singleton.resolve(), 73 | ) 74 | # Both results point to the same instance 75 | 76 | 77 | # In threaded code: 78 | def thread_task(): 79 | instance = MyContainer.singleton.resolve_sync() 80 | ... 81 | 82 | 83 | threads = [threading.Thread(target=thread_task) for _ in range(5)] 84 | for t in threads: 85 | t.start() 86 | ``` 87 | 88 | --- 89 | 90 | ## ThreadLocalSingleton Provider 91 | 92 | If you want each *thread* to have its own, separately cached instance, use **ThreadLocalSingleton**. This provider creates a new instance per thread and reuses that instance on subsequent calls *within the same thread*. 93 | 94 | ```python 95 | import random 96 | import threading 97 | from that_depends.providers import ThreadLocalSingleton 98 | 99 | 100 | def factory() -> int: 101 | """Return a random int between 1 and 100.""" 102 | return random.randint(1, 100) 103 | 104 | 105 | # ThreadLocalSingleton caches an instance per thread 106 | singleton = ThreadLocalSingleton(factory) 107 | 108 | # In a single thread: 109 | instance1 = singleton.resolve_sync() # e.g. 56 110 | instance2 = singleton.resolve_sync() # 56 (cached in the same thread) 111 | 112 | 113 | # In multiple threads: 114 | def thread_task(): 115 | return singleton.resolve_sync() 116 | 117 | 118 | thread1 = threading.Thread(target=thread_task) 119 | thread2 = threading.Thread(target=thread_task) 120 | thread1.start() 121 | thread2.start() 122 | 123 | # thread1 and thread2 each get a different cached value 124 | ``` 125 | 126 | You can still use `.resolve()` with `ThreadLocalSingleton`, which will also maintain isolation per thread. However, note that this does *not* isolate instances per asynchronous Task – only per OS thread. 127 | 128 | --- 129 | 130 | ## Example with `pydantic-settings` 131 | 132 | Consider a scenario where your application configuration is defined via [**pydantic-settings**](https://docs.pydantic.dev/latest/concepts/pydantic_settings/). Often, you only want to parse this configuration (e.g., from environment variables) once, then reuse it throughout the application. 133 | 134 | ```python 135 | from pydantic_settings import BaseSettings 136 | from pydantic import BaseModel 137 | 138 | 139 | class DatabaseConfig(BaseModel): 140 | address: str = "127.0.0.1" 141 | port: int = 5432 142 | db_name: str = "postgres" 143 | 144 | 145 | class Settings(BaseSettings): 146 | auth_key: str = "my_auth_key" 147 | db: DatabaseConfig = DatabaseConfig() 148 | ``` 149 | 150 | ### Defining the Container 151 | 152 | Below, we define a container with a **Singleton** provider for our settings. We also define a separate async factory that connects to the database using those settings. 153 | 154 | ```python 155 | from that_depends import BaseContainer, providers 156 | 157 | async def get_db_connection(address: str, port: int, db_name: str): 158 | # e.g., create an async DB connection 159 | ... 160 | 161 | class MyContainer(BaseContainer): 162 | # We'll parse settings only once 163 | config = providers.Singleton(Settings) 164 | 165 | # We'll pass the config's DB fields into an async factory for a DB connection 166 | db_connection = providers.AsyncFactory( 167 | get_db_connection, 168 | config.db.address, 169 | config.db.port, 170 | config.db.db_name, 171 | ) 172 | ``` 173 | 174 | ### Injecting or Resolving in Code 175 | 176 | You can now inject these values directly into your functions with the `@inject` decorator: 177 | 178 | ```python 179 | from that_depends import inject, Provide 180 | 181 | @inject 182 | async def with_db_connection(conn = Provide[MyContainer.db_connection]): 183 | # conn is the created DB connection 184 | ... 185 | ``` 186 | 187 | Or you can manually resolve them when needed: 188 | 189 | ```python 190 | # Synchronously resolve the config 191 | cfg = MyContainer.config.resolve_sync() 192 | 193 | # Asynchronously resolve the DB connection 194 | connection = await MyContainer.db_connection.resolve() 195 | ``` 196 | 197 | By using `Singleton` for `Settings`, you avoid re-parsing the environment or re-initializing the configuration on each request. 198 | -------------------------------------------------------------------------------- /docs/testing/provider-overriding.md: -------------------------------------------------------------------------------- 1 | # Provider overriding 2 | 3 | DI container provides, in addition to direct dependency injection, another very important functionality: 4 | **dependencies or providers overriding**. 5 | 6 | Any provider registered with the container can be overridden. 7 | This can help you replace objects with simple stubs, or with other objects. 8 | **Override affects all providers that use the overridden provider (_see example_)**. 9 | 10 | ## Example 11 | 12 | ```python 13 | from pydantic_settings import BaseSettings 14 | from sqlalchemy import create_engine, Engine, text 15 | from testcontainers.postgres import PostgresContainer 16 | from that_depends import BaseContainer, providers, Provide, inject 17 | 18 | 19 | class SomeSQLADao: 20 | def __init__(self, *, sqla_engine: Engine): 21 | self.engine = sqla_engine 22 | self._connection = None 23 | 24 | def __enter__(self): 25 | self._connection = self.engine.connect() 26 | return self 27 | 28 | def __exit__(self, exc_type, exc_val, exc_tb): 29 | self._connection.close() 30 | 31 | def exec_query(self, query: str): 32 | return self._connection.execute(text(query)) 33 | 34 | 35 | class Settings(BaseSettings): 36 | db_url: str = 'some_production_db_url' 37 | 38 | 39 | class DIContainer(BaseContainer): 40 | settings = providers.Singleton(Settings) 41 | sqla_engine = providers.Singleton(create_engine, settings.db_url) 42 | some_sqla_dao = providers.Factory(SomeSQLADao, sqla_engine=sqla_engine) 43 | 44 | 45 | @inject 46 | def exec_query_example(some_sqla_dao=Provide[DIContainer.some_sqla_dao]): 47 | with some_sqla_dao: 48 | result = some_sqla_dao.exec_query('SELECT 234') 49 | 50 | return next(result) 51 | 52 | 53 | def main(): 54 | pg_container = PostgresContainer(image='postgres:alpine3.19') 55 | pg_container.start() 56 | db_url = pg_container.get_connection_url() 57 | 58 | """ 59 | We override only settings, but this override will also affect the 'sqla_engine' 60 | and 'some_sqla_dao' providers because the 'settings' provider is used by them! 61 | """ 62 | local_testing_settings = Settings(db_url=db_url) 63 | DIContainer.settings.override_sync(local_testing_settings) 64 | 65 | try: 66 | result = exec_query_example() 67 | assert result == (234,) 68 | finally: 69 | DIContainer.settings.reset_override_sync() 70 | pg_container.stop() 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | 76 | ``` 77 | 78 | The example above shows how overriding a nested provider ('_settings_') 79 | affects another provider ('_engine_' and '_some_sqla_dao_'). 80 | 81 | ## Override multiple providers 82 | 83 | The example above looked at overriding only one settings provider, 84 | but the container also provides the ability to override 85 | multiple providers at once with method ```override_providers_sync```. 86 | 87 | The code above could remain the same except that 88 | the single provider override could be replaced with the following code: 89 | 90 | ```python 91 | def main(): 92 | pg_container = PostgresContainer(image='postgres:alpine3.19') 93 | pg_container.start() 94 | db_url = pg_container.get_connection_url() 95 | 96 | local_testing_settings = Settings(db_url=db_url) 97 | providers_for_overriding = { 98 | 'settings': local_testing_settings, 99 | # more values... 100 | } 101 | with DIContainer.override_providers_sync(providers_for_overriding): 102 | try: 103 | result = exec_query_example() 104 | assert result == (234,) 105 | finally: 106 | pg_container.stop() 107 | ``` 108 | 109 | --- 110 | ## Using with Litestar 111 | In order to be able to inject dependencies of any type instead of existing objects, 112 | we need to **change the typing** for the injected parameter as follows: 113 | 114 | ```python3 115 | import typing 116 | from functools import partial 117 | from typing import Annotated 118 | from unittest.mock import Mock 119 | 120 | from litestar import Litestar, Router, get 121 | from litestar.di import Provide 122 | from litestar.params import Dependency 123 | from litestar.testing import TestClient 124 | 125 | from that_depends import BaseContainer, providers 126 | 127 | 128 | class ExampleService: 129 | def do_smth(self) -> str: 130 | return "something" 131 | 132 | 133 | class DIContainer(BaseContainer): 134 | example_service = providers.Factory(ExampleService) 135 | 136 | 137 | @get(path="/another-endpoint", dependencies={"example_service": Provide(DIContainer.example_service)}) 138 | async def endpoint_handler( 139 | example_service: Annotated[ExampleService, Dependency(skip_validation=True)], 140 | ) -> dict[str, typing.Any]: 141 | return {"object": example_service.do_smth()} 142 | 143 | 144 | # or if you want a little less code 145 | NoValidationDependency = partial(Dependency, skip_validation=True) 146 | 147 | 148 | @get(path="/another-endpoint", dependencies={"example_service": Provide(DIContainer.example_service)}) 149 | async def endpoint_handler( 150 | example_service: Annotated[ExampleService, NoValidationDependency()], 151 | ) -> dict[str, typing.Any]: 152 | return {"object": example_service.do_smth()} 153 | 154 | 155 | router = Router( 156 | path="/router", 157 | route_handlers=[endpoint_handler], 158 | ) 159 | 160 | app = Litestar(route_handlers=[router]) 161 | ``` 162 | 163 | Now we are ready to write tests with **overriding** and this will work with **any types**: 164 | 165 | ```python3 166 | def test_litestar_endpoint_with_overriding() -> None: 167 | some_service_mock = Mock(do_smth=lambda: "mock func") 168 | 169 | with DIContainer.example_service.override_context_sync(some_service_mock), TestClient(app=app) as client: 170 | response = client.get("/router/another-endpoint") 171 | 172 | assert response.status_code == 200 173 | assert response.json()["object"] == "mock func" 174 | ``` 175 | 176 | More about `Dependency` 177 | in the [Litestar documentation](https://docs.litestar.dev/2/usage/dependency-injection.html#the-dependency-function). 178 | 179 | --- 180 | 181 | ## Overriding and tear-down 182 | 183 | If you have a provider `A` that caches the resolved value, which depends on a provider 184 | `B` that you wish to override you might experience the following behavior: 185 | 186 | ```python 187 | class MyContainer(BaseContainer): 188 | B = providers.Singleton(lambda: 1) 189 | A = providers.Singleton(lambda x: x, B) 190 | 191 | 192 | a_old = await MyContainer.A() 193 | 194 | MyContainer.B.override_sync(32) # will not reset A's cached value 195 | 196 | a_new = await MyContainer.A() 197 | 198 | assert a_old != a_new # raises 199 | ``` 200 | 201 | This is due to the fact that `A` caches the value and doesn't get reset when you override `B`. 202 | 203 | If you wish to fix this you can tell the provider to tear-down children on override: 204 | 205 | ```python 206 | MyContainer.B.override_sync(32, tear_down_children=True) 207 | ``` 208 | -------------------------------------------------------------------------------- /tests/providers/test_resources.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import threading 4 | import time 5 | import typing 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | 8 | import pytest 9 | 10 | from tests.creators import ( 11 | AsyncContextManagerResource, 12 | ContextManagerResource, 13 | create_async_resource, 14 | create_sync_resource, 15 | ) 16 | from that_depends import BaseContainer, providers 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class DIContainer(BaseContainer): 23 | alias = "resources_container" 24 | async_resource = providers.Resource(create_async_resource) 25 | sync_resource = providers.Resource(create_sync_resource) 26 | async_resource_from_class = providers.Resource(AsyncContextManagerResource) 27 | sync_resource_from_class = providers.Resource(ContextManagerResource) 28 | 29 | 30 | @pytest.fixture(autouse=True) 31 | async def _tear_down() -> typing.AsyncIterator[None]: 32 | try: 33 | yield 34 | finally: 35 | await DIContainer.tear_down() 36 | 37 | 38 | async def test_async_resource() -> None: 39 | async_resource1 = await DIContainer.async_resource.resolve() 40 | async_resource2 = DIContainer.async_resource.resolve_sync() 41 | assert async_resource1 is async_resource2 42 | 43 | 44 | async def test_async_resource_from_class() -> None: 45 | async_resource1 = await DIContainer.async_resource_from_class.resolve() 46 | async_resource2 = DIContainer.async_resource_from_class.resolve_sync() 47 | assert async_resource1 is async_resource2 48 | 49 | 50 | async def test_sync_resource() -> None: 51 | sync_resource1 = await DIContainer.sync_resource.resolve() 52 | sync_resource2 = await DIContainer.sync_resource.resolve() 53 | assert sync_resource1 is sync_resource2 54 | 55 | 56 | async def test_sync_resource_from_class() -> None: 57 | sync_resource1 = await DIContainer.sync_resource_from_class.resolve() 58 | sync_resource2 = await DIContainer.sync_resource_from_class.resolve() 59 | assert sync_resource1 is sync_resource2 60 | 61 | 62 | async def test_async_resource_overridden() -> None: 63 | async_resource1 = await DIContainer.sync_resource.resolve() 64 | 65 | DIContainer.sync_resource.override_sync("override") 66 | 67 | async_resource2 = DIContainer.sync_resource.resolve_sync() 68 | async_resource3 = await DIContainer.sync_resource.resolve() 69 | 70 | DIContainer.sync_resource.reset_override_sync() 71 | 72 | async_resource4 = DIContainer.sync_resource.resolve_sync() 73 | 74 | assert async_resource2 is not async_resource1 75 | assert async_resource2 is async_resource3 76 | assert async_resource4 is async_resource1 77 | 78 | 79 | async def test_sync_resource_overridden() -> None: 80 | sync_resource1 = await DIContainer.sync_resource.resolve() 81 | 82 | DIContainer.sync_resource.override_sync("override") 83 | 84 | sync_resource2 = DIContainer.sync_resource.resolve_sync() 85 | sync_resource3 = await DIContainer.sync_resource.resolve() 86 | 87 | DIContainer.sync_resource.reset_override_sync() 88 | 89 | sync_resource4 = DIContainer.sync_resource.resolve_sync() 90 | 91 | assert sync_resource2 is not sync_resource1 92 | assert sync_resource2 is sync_resource3 93 | assert sync_resource4 is sync_resource1 94 | 95 | 96 | async def test_resource_with_empty_list() -> None: 97 | def create_sync_resource_with_list() -> typing.Iterator[list[typing.Any]]: 98 | logger.debug("Resource initiated") 99 | yield [] 100 | logger.debug("Resource destructed") 101 | 102 | async def create_async_resource_with_dict() -> typing.AsyncIterator[dict[str, typing.Any]]: 103 | logger.debug("Async resource initiated") 104 | yield {} 105 | logger.debug("Async resource destructed") 106 | 107 | sync_resource = providers.Resource(create_sync_resource_with_list) 108 | async_resource = providers.Resource(create_async_resource_with_dict) 109 | 110 | sync_resource1 = await sync_resource() 111 | sync_resource2 = await sync_resource() 112 | assert sync_resource1 is sync_resource2 113 | 114 | async_resource1 = await async_resource() 115 | async_resource2 = await async_resource() 116 | assert async_resource1 is async_resource2 117 | 118 | await sync_resource.tear_down() 119 | await async_resource.tear_down() 120 | 121 | 122 | async def test_resource_unsupported_creator() -> None: 123 | with pytest.raises(TypeError, match=r"Unsupported resource type"): 124 | providers.Resource(None) # type: ignore[arg-type] 125 | 126 | 127 | @pytest.mark.repeat(10) 128 | async def test_async_resource_asyncio_concurrency() -> None: 129 | calls: int = 0 130 | 131 | async def create_resource() -> typing.AsyncIterator[str]: 132 | nonlocal calls 133 | calls += 1 134 | await asyncio.sleep(0) 135 | yield "" 136 | 137 | resource = providers.Resource(create_resource) 138 | 139 | await asyncio.gather(resource.resolve(), resource.resolve()) 140 | 141 | assert calls == 1 142 | 143 | 144 | @pytest.mark.repeat(10) 145 | async def test_async_resource_async_teardown() -> None: 146 | calls: int = 0 147 | 148 | async def _create_resource() -> typing.AsyncIterator[str]: 149 | yield "" 150 | nonlocal calls 151 | await asyncio.sleep(0.01) 152 | calls += 1 153 | 154 | resource = providers.Resource(_create_resource) 155 | 156 | await resource.resolve() 157 | 158 | await asyncio.gather(*[resource.tear_down() for _ in range(10)]) 159 | 160 | assert calls == 1 161 | 162 | 163 | @pytest.mark.repeat(10) 164 | def test_resource_threading_concurrency() -> None: 165 | calls: int = 0 166 | lock = threading.Lock() 167 | 168 | def create_resource() -> typing.Iterator[str]: 169 | nonlocal calls 170 | with lock: 171 | calls += 1 172 | time.sleep(0.01) 173 | yield "" 174 | 175 | resource = providers.Resource(create_resource) 176 | 177 | with ThreadPoolExecutor(max_workers=4) as pool: 178 | tasks = [ 179 | pool.submit(resource.resolve_sync), 180 | pool.submit(resource.resolve_sync), 181 | pool.submit(resource.resolve_sync), 182 | pool.submit(resource.resolve_sync), 183 | ] 184 | results = [x.result() for x in as_completed(tasks)] 185 | 186 | assert results == ["", "", "", ""] 187 | assert calls == 1 188 | 189 | 190 | def test_sync_resource_sync_tear_down() -> None: 191 | DIContainer.sync_resource.resolve_sync() 192 | assert DIContainer.sync_resource._context.instance is not None 193 | DIContainer.sync_resource.tear_down_sync() 194 | assert DIContainer.sync_resource._context.instance is None 195 | 196 | 197 | async def test_async_resource_sync_tear_down_raises() -> None: 198 | await DIContainer.async_resource.resolve() 199 | assert DIContainer.async_resource._context.instance is not None 200 | with pytest.raises(RuntimeError): 201 | DIContainer.async_resource.tear_down_sync() 202 | -------------------------------------------------------------------------------- /tests/providers/test_attr_getter.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import random 3 | import typing 4 | from dataclasses import dataclass, field 5 | 6 | import pytest 7 | 8 | from that_depends import BaseContainer, providers 9 | from that_depends.providers.base import _get_value_from_object_by_dotted_path 10 | from that_depends.providers.context_resources import container_context 11 | 12 | 13 | @dataclass 14 | class Nested2: 15 | some_const = 144 16 | 17 | 18 | @dataclass 19 | class Nested1: 20 | nested2_attr: Nested2 = field(default_factory=Nested2) 21 | 22 | 23 | @dataclass 24 | class Settings: 25 | some_str_value: str = "some_string_value" 26 | some_int_value: int = 3453621 27 | nested1_attr: Nested1 = field(default_factory=Nested1) 28 | 29 | 30 | async def return_settings_async() -> Settings: 31 | return Settings() 32 | 33 | 34 | async def yield_settings_async() -> typing.AsyncIterator[Settings]: 35 | yield Settings() 36 | 37 | 38 | def yield_settings_sync() -> typing.Iterator[Settings]: 39 | yield Settings() 40 | 41 | 42 | @dataclass 43 | class NestingTestDTO: ... 44 | 45 | 46 | @pytest.fixture( 47 | params=[ 48 | providers.Resource(yield_settings_sync), 49 | providers.Singleton(Settings), 50 | providers.ContextResource(yield_settings_sync), 51 | providers.Object(Settings()), 52 | providers.Factory(Settings), 53 | providers.Selector(lambda: "sync", sync=providers.Factory(Settings)), 54 | ] 55 | ) 56 | def some_sync_settings_provider(request: pytest.FixtureRequest) -> providers.AbstractProvider[Settings]: 57 | return typing.cast(providers.AbstractProvider[Settings], request.param) 58 | 59 | 60 | @pytest.fixture( 61 | params=[ 62 | providers.AsyncFactory(return_settings_async), 63 | providers.Resource(yield_settings_async), 64 | providers.ContextResource(yield_settings_async), 65 | providers.Selector(lambda: "asynchronous", asynchronous=providers.AsyncFactory(return_settings_async)), 66 | ] 67 | ) 68 | def some_async_settings_provider(request: pytest.FixtureRequest) -> providers.AbstractProvider[Settings]: 69 | return typing.cast(providers.AbstractProvider[Settings], request.param) 70 | 71 | 72 | def test_attr_getter_with_zero_attribute_depth_sync( 73 | some_sync_settings_provider: providers.AbstractProvider[Settings], 74 | ) -> None: 75 | attr_getter = some_sync_settings_provider.some_str_value 76 | if isinstance(some_sync_settings_provider, providers.ContextResource): 77 | with container_context(some_sync_settings_provider): 78 | assert attr_getter.resolve_sync() == Settings().some_str_value 79 | else: 80 | assert attr_getter.resolve_sync() == Settings().some_str_value 81 | 82 | 83 | async def test_attr_getter_with_zero_attribute_depth_async( 84 | some_async_settings_provider: providers.AbstractProvider[Settings], 85 | ) -> None: 86 | attr_getter = some_async_settings_provider.some_str_value 87 | if isinstance(some_async_settings_provider, providers.ContextResource): 88 | async with container_context(some_async_settings_provider): 89 | assert await attr_getter.resolve() == Settings().some_str_value 90 | else: 91 | assert await attr_getter.resolve() == Settings().some_str_value 92 | 93 | 94 | def test_attr_getter_with_more_than_zero_attribute_depth_sync( 95 | some_sync_settings_provider: providers.AbstractProvider[Settings], 96 | ) -> None: 97 | with ( 98 | container_context(some_sync_settings_provider) 99 | if isinstance(some_sync_settings_provider, providers.ContextResource) 100 | else contextlib.nullcontext() 101 | ): 102 | attr_getter = some_sync_settings_provider.nested1_attr.nested2_attr.some_const 103 | assert attr_getter.resolve_sync() == Nested2().some_const 104 | 105 | 106 | async def test_attr_getter_with_more_than_zero_attribute_depth_async( 107 | some_async_settings_provider: providers.AbstractProvider[Settings], 108 | ) -> None: 109 | async with ( 110 | container_context(some_async_settings_provider) 111 | if isinstance(some_async_settings_provider, providers.ContextResource) 112 | else contextlib.nullcontext() 113 | ): 114 | attr_getter = some_async_settings_provider.nested1_attr.nested2_attr.some_const 115 | assert await attr_getter.resolve() == Nested2().some_const 116 | 117 | 118 | @pytest.mark.parametrize( 119 | ("field_count", "test_field_name", "test_value"), 120 | [(1, "test_field", "sdf6fF^SF(FF*4ffsf"), (5, "nested_field", -252625), (50, "50_lvl_field", 909234235)], 121 | ) 122 | def test_nesting_levels(field_count: int, test_field_name: str, test_value: str | int) -> None: 123 | obj = NestingTestDTO() 124 | fields = [f"field_{i}" for i in range(1, field_count + 1)] 125 | random.shuffle(fields) 126 | 127 | attr_path = ".".join(fields) + f".{test_field_name}" 128 | obj_copy = obj 129 | 130 | while fields: 131 | field_name = fields.pop(0) 132 | setattr(obj_copy, field_name, NestingTestDTO()) 133 | obj_copy = obj_copy.__getattribute__(field_name) 134 | 135 | setattr(obj_copy, test_field_name, test_value) 136 | 137 | attr_value = _get_value_from_object_by_dotted_path(obj, attr_path) 138 | assert attr_value == test_value 139 | 140 | 141 | def test_attr_getter_with_invalid_attribute_sync( 142 | some_sync_settings_provider: providers.AbstractProvider[Settings], 143 | ) -> None: 144 | with pytest.raises(AttributeError): 145 | some_sync_settings_provider.nested1_attr.nested2_attr.__some_private__ # noqa: B018 146 | with pytest.raises(AttributeError): 147 | some_sync_settings_provider.nested1_attr.__another_private__ # noqa: B018 148 | with pytest.raises(AttributeError): 149 | some_sync_settings_provider.nested1_attr._final_private_ # noqa: B018 150 | 151 | 152 | async def test_attr_getter_with_invalid_attribute_async( 153 | some_async_settings_provider: providers.AbstractProvider[Settings], 154 | ) -> None: 155 | with pytest.raises(AttributeError): 156 | some_async_settings_provider.nested1_attr.nested2_attr.__some_private__ # noqa: B018 157 | with pytest.raises(AttributeError): 158 | some_async_settings_provider.nested1_attr.__another_private__ # noqa: B018 159 | with pytest.raises(AttributeError): 160 | some_async_settings_provider.nested1_attr._final_private_ # noqa: B018 161 | 162 | 163 | def test_attr_getter_deregister_arguments() -> None: 164 | class _Item: 165 | def __init__(self, v: float) -> None: 166 | self.v = v 167 | 168 | class _Container(BaseContainer): 169 | parent = providers.Singleton(lambda: random.random()) 170 | child = providers.Singleton(_Item, parent.cast) 171 | 172 | attr_getter = _Container.child.v 173 | 174 | attr_getter._register_arguments() 175 | 176 | assert attr_getter.resolve_sync() == _Container.parent.resolve_sync() 177 | assert _Container.parent in attr_getter._parents 178 | 179 | attr_getter._deregister_arguments() 180 | 181 | assert _Container.parent not in attr_getter._parents 182 | -------------------------------------------------------------------------------- /that_depends/providers/factories.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import inspect 3 | import typing 4 | from typing import overload 5 | 6 | from typing_extensions import override 7 | 8 | from that_depends.providers.base import AbstractProvider 9 | from that_depends.providers.mixin import ProviderWithArguments 10 | 11 | 12 | T_co = typing.TypeVar("T_co", covariant=True) 13 | P = typing.ParamSpec("P") 14 | 15 | 16 | class AbstractFactory(ProviderWithArguments, AbstractProvider[T_co], abc.ABC): 17 | """Base class for all factories. 18 | 19 | This class defines the interface for factories that provide 20 | resources both synchronously and asynchronously. 21 | """ 22 | 23 | @property 24 | def provider(self) -> typing.Callable[[], typing.Coroutine[typing.Any, typing.Any, T_co]]: 25 | """Returns an async provider function. 26 | 27 | The async provider function can be awaited to resolve the resource. 28 | 29 | Returns: 30 | Callable[[], Coroutine[Any, Any, T_co]]: The async provider function. 31 | 32 | Example: 33 | ```python 34 | provider = my_factory.provider 35 | resource = await provider() 36 | ``` 37 | 38 | """ 39 | return self.resolve 40 | 41 | @property 42 | def provider_sync(self) -> typing.Callable[[], T_co]: 43 | """Return a sync provider function. 44 | 45 | The sync provider function can be called to resolve the resource synchronously. 46 | 47 | Returns: 48 | Callable[[], T_co]: The sync provider function. 49 | 50 | Example: 51 | ```python 52 | provider = my_factory.provider_sync 53 | resource = provider() 54 | ``` 55 | 56 | """ 57 | return self.resolve_sync 58 | 59 | 60 | class Factory(AbstractFactory[T_co]): 61 | """Provides an instance by calling a sync method. 62 | 63 | A typical usage scenario is to wrap a synchronous function 64 | that returns a resource. Each call to the provider or sync_provider 65 | produces a new instance of that resource. 66 | 67 | Example: 68 | ```python 69 | def build_resource(text: str, number: int): 70 | return f"{text}-{number}" 71 | 72 | factory = Factory(build_resource, "example", 42) 73 | resource = factory.provider_sync() # "example-42" 74 | ``` 75 | 76 | """ 77 | 78 | def _register_arguments(self) -> None: 79 | self._register(self._args) 80 | self._register(self._kwargs.values()) 81 | 82 | def _deregister_arguments(self) -> None: 83 | raise NotImplementedError 84 | 85 | __slots__ = "_args", "_factory", "_kwargs", "_override" 86 | 87 | def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: 88 | """Initialize a Factory instance. 89 | 90 | Args: 91 | factory (Callable[P, T_co]): Function that returns the resource. 92 | *args: Arguments to pass to the factory function. 93 | **kwargs: Keyword arguments to pass to the factory function. 94 | 95 | """ 96 | super().__init__() 97 | self._factory: typing.Final = factory 98 | self._args: typing.Final = args 99 | self._kwargs: typing.Final = kwargs 100 | self._register_arguments() 101 | 102 | @override 103 | async def resolve(self) -> T_co: 104 | if self._override: 105 | return typing.cast(T_co, self._override) 106 | 107 | return self._factory( 108 | *[ # type: ignore[arg-type] 109 | await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args 110 | ], 111 | **{ # type: ignore[arg-type] 112 | k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() 113 | }, 114 | ) 115 | 116 | @override 117 | def resolve_sync(self) -> T_co: 118 | if self._override: 119 | return typing.cast(T_co, self._override) 120 | 121 | return self._factory( 122 | *[ # type: ignore[arg-type] 123 | x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args 124 | ], 125 | **{ # type: ignore[arg-type] 126 | k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items() 127 | }, 128 | ) 129 | 130 | 131 | class AsyncFactory(AbstractFactory[T_co]): 132 | """Provides an instance by calling an async method. 133 | 134 | Similar to `Factory`, but requires an async function. Each call 135 | to the provider or `provider` property is awaited to produce a new instance. 136 | 137 | Example: 138 | ```python 139 | async def async_build_resource(text: str): 140 | await some_async_operation() 141 | return text.upper() 142 | 143 | async_factory = AsyncFactory(async_build_resource, "example") 144 | resource = await async_factory.provider() # "EXAMPLE" 145 | ``` 146 | 147 | """ 148 | 149 | def _register_arguments(self) -> None: 150 | self._register(self._args) 151 | self._register(self._kwargs.values()) 152 | 153 | def _deregister_arguments(self) -> None: 154 | raise NotImplementedError 155 | 156 | __slots__ = "_args", "_factory", "_kwargs", "_override" 157 | 158 | @overload 159 | def __init__( 160 | self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs 161 | ) -> None: ... 162 | 163 | @overload 164 | def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: ... 165 | 166 | def __init__( 167 | self, factory: typing.Callable[P, T_co | typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs 168 | ) -> None: 169 | """Initialize an AsyncFactory instance. 170 | 171 | Args: 172 | factory (Callable[P, T_co | Awaitable[T_co]]): Function that returns the resource (sync or async). 173 | *args: Arguments to pass to the factory function. 174 | **kwargs: Keyword arguments to pass to the factory 175 | 176 | 177 | """ 178 | super().__init__() 179 | self._factory: typing.Final = factory 180 | self._args: typing.Final = args 181 | self._kwargs: typing.Final = kwargs 182 | self._register_arguments() 183 | 184 | @override 185 | async def resolve(self) -> T_co: 186 | if self._override: 187 | return typing.cast(T_co, self._override) 188 | 189 | args = [await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args] 190 | kwargs = {k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()} 191 | 192 | result = self._factory( 193 | *args, # type:ignore[arg-type] 194 | **kwargs, # type:ignore[arg-type] 195 | ) 196 | 197 | if inspect.isawaitable(result): 198 | return await result 199 | return result 200 | 201 | @override 202 | def resolve_sync(self) -> typing.NoReturn: 203 | msg = "AsyncFactory cannot be resolved synchronously" 204 | raise RuntimeError(msg) 205 | -------------------------------------------------------------------------------- /tests/providers/test_local_singleton.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import random 3 | import threading 4 | import time 5 | import typing 6 | from concurrent.futures.thread import ThreadPoolExecutor 7 | 8 | import pytest 9 | 10 | from that_depends.providers import AsyncFactory, ThreadLocalSingleton 11 | 12 | 13 | random.seed(23) 14 | 15 | 16 | async def _async_factory() -> int: 17 | await asyncio.sleep(0.01) 18 | return threading.get_ident() 19 | 20 | 21 | def _factory() -> int: 22 | time.sleep(0.01) 23 | return random.randint(1, 100) 24 | 25 | 26 | def test_thread_local_singleton_same_thread() -> None: 27 | """Test that the same instance is returned within a single thread.""" 28 | provider = ThreadLocalSingleton(_factory) 29 | 30 | instance1 = provider.resolve_sync() 31 | instance2 = provider.resolve_sync() 32 | 33 | assert instance1 == instance2, "Singleton failed: Instances within the same thread should be identical." 34 | 35 | provider.tear_down_sync() 36 | 37 | assert provider._instance is None, "Tear down failed: Instance should be reset to None." 38 | 39 | 40 | async def test_async_thread_local_singleton_asyncio() -> None: 41 | """Test that the same instance is returned within a single thread.""" 42 | provider = ThreadLocalSingleton(_factory) 43 | 44 | instance1 = await provider.resolve() 45 | instance2 = await provider.resolve() 46 | 47 | assert instance1 == instance2, "Singleton failed: Instances within the same thread should be identical." 48 | 49 | await provider.tear_down() 50 | 51 | assert provider._instance is None, "Tear down failed: Instance should be reset to None." 52 | 53 | 54 | def test_thread_local_singleton_different_threads() -> None: 55 | """Test that different threads receive different instances.""" 56 | provider = ThreadLocalSingleton(_factory) 57 | results = [] 58 | 59 | def resolve_in_thread() -> None: 60 | results.append(provider.resolve_sync()) 61 | 62 | number_of_threads = 10 63 | 64 | threads = [threading.Thread(target=resolve_in_thread) for _ in range(number_of_threads)] 65 | 66 | for thread in threads: 67 | thread.start() 68 | for thread in threads: 69 | thread.join() 70 | 71 | assert len(results) == number_of_threads, "Test failed: Expected results from two threads." 72 | assert results[0] != results[1], "Thread-local failed: Instances across threads should differ." 73 | 74 | 75 | def test_thread_local_singleton_override() -> None: 76 | """Test overriding the ThreadLocalSingleton and resetting the override.""" 77 | provider = ThreadLocalSingleton(_factory) 78 | 79 | override_value = 101 80 | provider.override_sync(override_value) 81 | instance = provider.resolve_sync() 82 | assert instance == override_value, "Override failed: Expected overridden value." 83 | 84 | # Reset override and ensure a new instance is created 85 | provider.reset_override_sync() 86 | new_instance = provider.resolve_sync() 87 | assert new_instance != override_value, "Reset override failed: Should no longer return overridden value." 88 | 89 | 90 | def test_thread_local_singleton_override_in_threads() -> None: 91 | """Test that resetting the override in one thread does not affect another thread.""" 92 | provider = ThreadLocalSingleton(_factory) 93 | results = {} 94 | 95 | def _thread_task(thread_id: int, override_value: int | None = None) -> None: 96 | if override_value is not None: 97 | provider.override_sync(override_value) 98 | results[thread_id] = provider.resolve_sync() 99 | if override_value is not None: 100 | provider.reset_override_sync() 101 | 102 | override_value: typing.Final[int] = 101 103 | thread1 = threading.Thread(target=_thread_task, args=(1, override_value)) 104 | thread2 = threading.Thread(target=_thread_task, args=(2,)) 105 | 106 | thread1.start() 107 | thread2.start() 108 | 109 | thread1.join() 110 | thread2.join() 111 | 112 | # Validate results 113 | assert results[1] == override_value, "Thread 1: Override failed." 114 | assert results[2] != override_value, "Thread 2: Should not be affected by Thread 1's override." 115 | assert results[1] != results[2], "Instances should be unique across threads." 116 | 117 | 118 | def test_thread_local_singleton_override_temporarily() -> None: 119 | """Test temporarily overriding the ThreadLocalSingleton.""" 120 | provider = ThreadLocalSingleton(_factory) 121 | 122 | override_value: typing.Final = 101 123 | # Set a temporary override 124 | with provider.override_context_sync(override_value): 125 | instance = provider.resolve_sync() 126 | assert instance == override_value, "Override context failed: Expected overridden value." 127 | 128 | # After the context, reset to the factory 129 | new_instance = provider.resolve_sync() 130 | assert new_instance != override_value, "Override context failed: Value should reset after context." 131 | 132 | 133 | async def test_async_thread_local_singleton_asyncio_concurrency() -> None: 134 | singleton_async = ThreadLocalSingleton(_factory) 135 | 136 | expected = await singleton_async.resolve() 137 | 138 | results = await asyncio.gather( 139 | singleton_async(), 140 | singleton_async(), 141 | singleton_async(), 142 | singleton_async(), 143 | singleton_async(), 144 | ) 145 | 146 | assert all(val is results[0] for val in results) 147 | for val in results: 148 | assert val == expected, "Instances should be identical across threads." 149 | 150 | 151 | async def test_thread_local_singleton_async_resolve_with_async_dependencies() -> None: 152 | async_provider = AsyncFactory(_async_factory) 153 | 154 | def _dependent_creator(v: int) -> int: 155 | return v 156 | 157 | provider = ThreadLocalSingleton(_dependent_creator, v=async_provider.cast) 158 | 159 | expected = await provider.resolve() 160 | 161 | assert expected == await provider.resolve() 162 | 163 | results = await asyncio.gather(*[provider.resolve() for _ in range(10)]) 164 | 165 | for val in results: 166 | assert val == expected, "Instances should be identical across threads." 167 | with pytest.raises(RuntimeError): 168 | # This should raise an error because the provider is async and resolution is attempted. 169 | await asyncio.gather(asyncio.to_thread(provider.resolve_sync)) 170 | 171 | def _run_async_in_thread(coroutine: typing.Awaitable[typing.Any]) -> typing.Any: # noqa: ANN401 172 | loop = asyncio.new_event_loop() 173 | asyncio.set_event_loop(loop) 174 | result = loop.run_until_complete(coroutine) 175 | loop.close() 176 | return result 177 | 178 | with ThreadPoolExecutor() as executor: 179 | future = executor.submit(_run_async_in_thread, provider.resolve()) 180 | 181 | result = future.result() 182 | 183 | assert result != expected, ( 184 | "Since singleton should have been newly resolved, it should not have the same thread id." 185 | ) 186 | 187 | 188 | async def test_thread_local_singleton_async_resolve_override() -> None: 189 | provider = ThreadLocalSingleton(_factory) 190 | 191 | override_value = 101 192 | 193 | provider.override_sync(override_value) 194 | 195 | assert await provider.resolve() == override_value, "Override failed: Expected overridden value." 196 | -------------------------------------------------------------------------------- /tests/integrations/fastapi/test_fastapi_route_class.py: -------------------------------------------------------------------------------- 1 | import random 2 | import typing 3 | from http import HTTPStatus 4 | from inspect import signature 5 | from typing import get_type_hints 6 | 7 | from fastapi import APIRouter, Depends, FastAPI 8 | from fastapi.testclient import TestClient 9 | 10 | from that_depends import BaseContainer, Provide, providers 11 | from that_depends.integrations.fastapi import _adjust_fastapi_endpoint, create_fastapi_route_class 12 | from that_depends.providers.context_resources import ContextScopes 13 | 14 | 15 | _FACTORY_RETURN_VALUE = 42 16 | 17 | 18 | async def _async_creator() -> typing.AsyncIterator[float]: 19 | yield random.random() 20 | 21 | 22 | def _sync_creator() -> typing.Iterator[float]: 23 | yield random.random() 24 | 25 | 26 | class Container(BaseContainer): 27 | alias = "fastapi_route_class_container" 28 | provider = providers.Factory(lambda: _FACTORY_RETURN_VALUE) 29 | async_context_provider = providers.ContextResource(_async_creator).with_config(scope=ContextScopes.REQUEST) 30 | sync_context_provider = providers.ContextResource(_sync_creator).with_config(scope=ContextScopes.REQUEST) 31 | 32 | 33 | async def test_endpoint_with_context_resource_async() -> None: 34 | app = FastAPI() 35 | 36 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST)) 37 | 38 | @router.get("/hello") 39 | async def _injected(v_3: float = Provide[Container.async_context_provider]) -> float: 40 | return v_3 41 | 42 | app.include_router(router) 43 | 44 | client = TestClient(app) 45 | 46 | response = client.get("/hello") 47 | 48 | response_value = response.json() 49 | 50 | assert response.status_code == HTTPStatus.OK 51 | assert response_value <= 1 52 | assert response_value >= 0 53 | 54 | assert response_value != client.get("/hello").json() 55 | 56 | 57 | def test_endpoint_with_context_resource_sync() -> None: 58 | app = FastAPI() 59 | 60 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST)) 61 | 62 | @router.get("/hello") 63 | def _injected(v_3: float = Provide[Container.sync_context_provider]) -> float: 64 | return v_3 65 | 66 | app.include_router(router) 67 | 68 | client = TestClient(app) 69 | 70 | response = client.get("/hello") 71 | 72 | response_value = response.json() 73 | 74 | assert response.status_code == HTTPStatus.OK 75 | assert response_value <= 1 76 | assert response_value >= 0 77 | 78 | assert response_value != client.get("/hello").json() 79 | 80 | 81 | async def test_endpoint_factory_async() -> None: 82 | app = FastAPI() 83 | 84 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST)) 85 | 86 | @router.get("/hello") 87 | async def _injected(v_3: float = Provide[Container.provider]) -> float: 88 | return v_3 89 | 90 | app.include_router(router) 91 | 92 | client = TestClient(app) 93 | 94 | response = client.get("/hello") 95 | 96 | response_value = response.json() 97 | 98 | assert response.status_code == HTTPStatus.OK 99 | assert response_value == _FACTORY_RETURN_VALUE 100 | 101 | assert response_value == client.get("/hello").json() 102 | 103 | 104 | def test_endpoint_factory_sync() -> None: 105 | app = FastAPI() 106 | 107 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST)) 108 | 109 | @router.get("/hello") 110 | def _injected(v_3: float = Provide[Container.provider]) -> float: 111 | return v_3 112 | 113 | app.include_router(router) 114 | 115 | client = TestClient(app) 116 | 117 | response = client.get("/hello") 118 | 119 | response_value = response.json() 120 | 121 | assert response.status_code == HTTPStatus.OK 122 | assert response_value == _FACTORY_RETURN_VALUE 123 | 124 | assert response_value == client.get("/hello").json() 125 | 126 | 127 | async def test_endpoint_with_string_provider_async() -> None: 128 | app = FastAPI() 129 | 130 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST)) 131 | 132 | @router.get("/hello") 133 | async def _injected(v_3: float = Provide["fastapi_route_class_container.provider"]) -> float: 134 | return v_3 135 | 136 | app.include_router(router) 137 | 138 | client = TestClient(app) 139 | 140 | response = client.get("/hello") 141 | 142 | response_value = response.json() 143 | 144 | assert response.status_code == HTTPStatus.OK 145 | assert response_value == _FACTORY_RETURN_VALUE 146 | 147 | assert response_value == client.get("/hello").json() 148 | 149 | 150 | async def test_endpoint_with_string_provider_sync() -> None: 151 | app = FastAPI() 152 | 153 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST)) 154 | 155 | @router.get("/hello") 156 | def _injected(v_3: float = Provide["fastapi_route_class_container.provider"]) -> float: 157 | return v_3 158 | 159 | app.include_router(router) 160 | 161 | client = TestClient(app) 162 | 163 | response = client.get("/hello") 164 | 165 | response_value = response.json() 166 | 167 | assert response.status_code == HTTPStatus.OK 168 | assert response_value == _FACTORY_RETURN_VALUE 169 | 170 | assert response_value == client.get("/hello").json() 171 | 172 | 173 | def test_adjust_fastapi_endpoint_async() -> None: 174 | async def _injected( 175 | v_1: float, 176 | v_3: float = Provide["fastapi_route_class_container.provider"], 177 | v_2: float = Depends(Container.provider), 178 | ) -> float: 179 | return v_1 + v_3 + v_2 # pragma: no cover 180 | 181 | original_signature = signature(_injected) 182 | original_hints = get_type_hints(_injected) 183 | 184 | adjusted_endpoint = _adjust_fastapi_endpoint(_injected) 185 | 186 | adjusted_sig = signature(adjusted_endpoint) 187 | adjusted_hints = get_type_hints(adjusted_endpoint) 188 | 189 | assert len(original_hints) == len(adjusted_hints) 190 | assert len(original_signature.parameters) == len(adjusted_sig.parameters) 191 | assert original_hints.get("v_1") == adjusted_hints.get("v_1") 192 | assert original_hints.get("v_2") == adjusted_hints.get("v_2") 193 | assert original_hints.get("v_3") != adjusted_hints.get("v_3") 194 | assert original_signature.parameters.get("v_1") == adjusted_sig.parameters.get("v_1") 195 | assert original_signature.parameters.get("v_2") == adjusted_sig.parameters.get("v_2") 196 | assert original_signature.parameters.get("v_3") != adjusted_sig.parameters.get("v_3") 197 | 198 | 199 | def test_adjust_fastapi_endpoint_sync() -> None: 200 | def _injected( 201 | v_1: float, 202 | v_3: float = Provide["fastapi_route_class_container.provider"], 203 | v_2: float = Depends(Container.provider), 204 | ) -> float: 205 | return v_1 + v_3 + v_2 # pragma: no cover 206 | 207 | original_signature = signature(_injected) 208 | original_hints = get_type_hints(_injected) 209 | 210 | adjusted_endpoint = _adjust_fastapi_endpoint(_injected) 211 | 212 | adjusted_sig = signature(adjusted_endpoint) 213 | adjusted_hints = get_type_hints(adjusted_endpoint) 214 | 215 | assert len(original_hints) == len(adjusted_hints) 216 | assert len(original_signature.parameters) == len(adjusted_sig.parameters) 217 | assert original_hints.get("v_1") == adjusted_hints.get("v_1") 218 | assert original_hints.get("v_2") == adjusted_hints.get("v_2") 219 | assert original_hints.get("v_3") != adjusted_hints.get("v_3") 220 | assert original_signature.parameters.get("v_1") == adjusted_sig.parameters.get("v_1") 221 | assert original_signature.parameters.get("v_2") == adjusted_sig.parameters.get("v_2") 222 | assert original_signature.parameters.get("v_3") != adjusted_sig.parameters.get("v_3") 223 | --------------------------------------------------------------------------------