├── tests
├── __init__.py
├── providers
│ ├── __init__.py
│ ├── test_object.py
│ ├── test_factories.py
│ ├── test_collections.py
│ ├── test_inject_factories.py
│ ├── test_main_providers.py
│ ├── test_state.py
│ ├── test_selector.py
│ ├── test_singleton.py
│ ├── test_resources.py
│ ├── test_attr_getter.py
│ └── test_local_singleton.py
├── experimental
│ ├── __init__.py
│ ├── test_container_1.py
│ ├── test_lazy_provider.py
│ └── test_container_2.py
├── integrations
│ ├── __init__.py
│ ├── fastapi
│ │ ├── __init__.py
│ │ ├── test_fastapi_di_pass_request.py
│ │ ├── test_fastapi_di.py
│ │ └── test_fastapi_route_class.py
│ ├── litestar
│ │ ├── __init__.py
│ │ ├── test_litestar_di_simple.py
│ │ └── test_litestar_di.py
│ └── faststream
│ │ ├── __init__.py
│ │ ├── test_faststream_di.py
│ │ └── test_faststream_di_pass_message.py
├── conftest.py
├── test_utils.py
├── test_meta.py
├── test_dynamic_container.py
├── test_resolver.py
├── creators.py
├── test_multiple_containers.py
├── container.py
└── test_container.py
├── that_depends
├── py.typed
├── entities
│ ├── __init__.py
│ └── resource_context.py
├── integrations
│ ├── __init__.py
│ ├── fastapi.py
│ └── faststream.py
├── experimental
│ └── __init__.py
├── exceptions.py
├── utils.py
├── __init__.py
├── providers
│ ├── __init__.py
│ ├── mixin.py
│ ├── object.py
│ ├── state.py
│ ├── collection.py
│ ├── selector.py
│ ├── resources.py
│ ├── local_singleton.py
│ └── factories.py
└── container.py
├── examples
└── benchmark
│ ├── __init__.py
│ ├── RESULTS.md
│ └── injection.py
├── docs
├── requirements.txt
├── css
│ └── code.css
├── overrides
│ └── main.html
├── providers
│ ├── object.md
│ ├── collections.md
│ ├── state.md
│ ├── selector.md
│ ├── resources.md
│ ├── factories.md
│ └── singleton.md
├── dev
│ ├── main-decisions.md
│ └── contributing.md
├── testing
│ ├── fixture.md
│ └── provider-overriding.md
├── introduction
│ ├── multiple-containers.md
│ ├── ioc-container.md
│ ├── type-based-injection.md
│ ├── string-injection.md
│ ├── tear-down.md
│ ├── generator-injection.md
│ └── injection.md
├── integrations
│ ├── litestar.md
│ └── faststream.md
├── experimental
│ └── lazy.md
├── index.md
└── migration
│ └── v3.md
├── .readthedocs.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .github
└── workflows
│ ├── publish.yml
│ └── ci.yml
├── Justfile
├── LICENSE
├── README.md
├── pyproject.toml
└── mkdocs.yml
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/that_depends/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/providers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/experimental/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | """Benchmarks."""
2 |
--------------------------------------------------------------------------------
/that_depends/entities/__init__.py:
--------------------------------------------------------------------------------
1 | """Entities."""
2 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | mkdocs
2 | mkdocs-material
3 | mkdocs-llmstxt
4 |
--------------------------------------------------------------------------------
/that_depends/integrations/__init__.py:
--------------------------------------------------------------------------------
1 | """Integrations with other frameworks."""
2 |
--------------------------------------------------------------------------------
/tests/integrations/fastapi/__init__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | pytest.importorskip("fastapi")
5 |
--------------------------------------------------------------------------------
/tests/integrations/litestar/__init__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | pytest.importorskip("litestar")
5 |
--------------------------------------------------------------------------------
/tests/integrations/faststream/__init__.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | pytest.importorskip("faststream")
5 | pytest.importorskip("nats")
6 |
--------------------------------------------------------------------------------
/that_depends/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | """Experimental features."""
2 |
3 | from that_depends.experimental.providers import LazyProvider
4 |
5 |
6 | __all__ = [
7 | "LazyProvider",
8 | ]
9 |
--------------------------------------------------------------------------------
/docs/css/code.css:
--------------------------------------------------------------------------------
1 | /* Round the corners of code blocks */
2 | .md-typeset .highlight code {
3 | border-radius: 10px !important;
4 | }
5 |
6 | @media (max-width: 600px) {
7 | .md-typeset code {
8 | border-radius: 4px;
9 | }
10 | }
11 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: "ubuntu-22.04"
5 | tools:
6 | python: "3.10"
7 |
8 | python:
9 | install:
10 | - requirements: docs/requirements.txt
11 |
12 |
13 | mkdocs:
14 | configuration: mkdocs.yml
15 |
--------------------------------------------------------------------------------
/docs/overrides/main.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block announce %}
4 | that-depends 3.0 just released! If you are upgrading, check out the migration guide.
5 | {% endblock %}
6 |
--------------------------------------------------------------------------------
/that_depends/exceptions.py:
--------------------------------------------------------------------------------
1 | class TypeNotBoundError(Exception):
2 | """Exception raised when a type is not bound to a provider."""
3 |
4 |
5 | class StateNotInitializedError(Exception):
6 | """Exception raised when the state is not initialized."""
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Generic things
2 | *.pyc
3 | *~
4 | __pycache__/*
5 | *.swp
6 | *.sqlite3
7 | *.map
8 | .vscode
9 | .idea
10 | .DS_Store
11 | .env
12 | .mypy_cache
13 | .pytest_cache
14 | .ruff_cache
15 | .coverage*
16 | htmlcov/
17 | coverage.xml
18 | pytest.xml
19 | dist/
20 | .python-version
21 | .venv
22 | uv.lock
23 |
--------------------------------------------------------------------------------
/docs/providers/object.md:
--------------------------------------------------------------------------------
1 | # Object
2 |
3 | Object provider returns an object “as is”.
4 |
5 | ```python
6 | from that_depends import BaseContainer, providers
7 |
8 |
9 | class DIContainer(BaseContainer):
10 | object_provider = providers.Object(1)
11 |
12 |
13 | assert DIContainer.object_provider() == 1
14 | ```
15 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import pytest
4 |
5 | from tests import container
6 |
7 |
8 | @pytest.fixture(autouse=True)
9 | async def _clear_di_container() -> typing.AsyncIterator[None]:
10 | try:
11 | yield
12 | finally:
13 | container.DIContainer.reset_override_sync()
14 | await container.DIContainer.tear_down()
15 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: lint
5 | name: lint
6 | entry: just
7 | args: [lint]
8 | language: system
9 | types: [python]
10 | pass_filenames: false
11 | - repo: https://github.com/pre-commit/pre-commit-hooks
12 | rev: v5.0.0
13 | hooks:
14 | - id: end-of-file-fixer
15 | - id: mixed-line-ending
16 |
--------------------------------------------------------------------------------
/examples/benchmark/RESULTS.md:
--------------------------------------------------------------------------------
1 | # Injection
2 | Based on this [benchmark](injection.py):
3 |
4 | | version / iterations | 10^4 | 10^5 | 10^6 |
5 | |----------------------|--------|--------|---------|
6 | | 3.2.0 | 0.1039 | 1.0563 | 10.3576 |
7 | | 3.0.0.a2 | 0.0870 | 0.9430 | 8.8815 |
8 | | 3.0.0.a1 | 0.1399 | 1.4136 | 14.1829 |
9 | | 2.2.0 | 0.1465 | 1.4927 | 14.9044 |
10 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from that_depends.utils import UNSET, Unset, is_set
2 |
3 |
4 | def test_is_set_with_unset() -> None:
5 | assert not is_set(UNSET)
6 |
7 |
8 | def test_is_set_with_value() -> None:
9 | assert is_set(42)
10 | assert is_set("hello")
11 | assert is_set([1, 2, 3])
12 | assert is_set(None)
13 |
14 |
15 | def test_unset_is_singleton() -> None:
16 | assert isinstance(UNSET, Unset)
17 | assert UNSET is Unset()
18 |
--------------------------------------------------------------------------------
/tests/providers/test_object.py:
--------------------------------------------------------------------------------
1 | from that_depends import BaseContainer, providers
2 |
3 |
4 | instance = "some string"
5 |
6 |
7 | class DIContainer(BaseContainer):
8 | alias = "object_container"
9 | instance = providers.Object(instance)
10 |
11 |
12 | async def test_object_provider() -> None:
13 | instance1 = await DIContainer.instance()
14 | instance2 = DIContainer.instance.resolve_sync()
15 |
16 | assert instance1 is instance2 is instance
17 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish Package
2 |
3 | on:
4 | release:
5 | types:
6 | - published
7 |
8 | jobs:
9 | publish:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - uses: extractions/setup-just@v2
14 | - uses: astral-sh/setup-uv@v5
15 | with:
16 | cache-dependency-glob: "**/pyproject.toml"
17 | - run: just publish
18 | env:
19 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
20 |
--------------------------------------------------------------------------------
/tests/test_meta.py:
--------------------------------------------------------------------------------
1 | from that_depends import ContextScopes
2 | from that_depends.meta import BaseContainerMeta
3 |
4 |
5 | def test_base_container_meta_has_correct_default_scope() -> None:
6 | class _Test(metaclass=BaseContainerMeta):
7 | pass
8 |
9 | assert _Test.get_scope() == ContextScopes.ANY
10 |
11 |
12 | def test_base_container_meta_has_correct_default_name() -> None:
13 | class _Test(metaclass=BaseContainerMeta):
14 | pass
15 |
16 | assert _Test.name() == "_Test"
17 |
--------------------------------------------------------------------------------
/docs/dev/main-decisions.md:
--------------------------------------------------------------------------------
1 | # Main decisions
2 | 1. Dependency resolving is async by default:
3 | - framework was developed mostly for usage with async python applications;
4 | - sync resolving is also possible, but it will fail in runtime in case of unresolved async dependencies;
5 | 2. Container is global:
6 | - it's needed for injections without wiring to work;
7 | - this way most of the logic stays in providers;
8 | 3. Focus on maximum compatibility with mypy:
9 | - no need for `# type: ignore`
10 | - no need for `typing.cast`
11 |
--------------------------------------------------------------------------------
/tests/experimental/test_container_1.py:
--------------------------------------------------------------------------------
1 | from tests.experimental.test_container_2 import Container2
2 | from that_depends import BaseContainer, providers
3 | from that_depends.experimental import LazyProvider
4 |
5 |
6 | class Container1(BaseContainer):
7 | """Test Container 1."""
8 |
9 | alias = "container_1"
10 | obj_1 = providers.Object(1)
11 | obj_2 = LazyProvider(module_string="tests.experimental.test_container_2", provider_string="Container2.obj_2")
12 |
13 |
14 | def test_lazy_provider_resolution_sync() -> None:
15 | assert Container2.obj_2.resolve_sync() == 2 # noqa: PLR2004
16 |
--------------------------------------------------------------------------------
/tests/test_dynamic_container.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | import pytest
4 |
5 | from tests import container
6 | from that_depends import BaseContainer, providers
7 |
8 |
9 | class DIContainer(BaseContainer):
10 | alias = "dynamic_container"
11 | sync_resource: providers.Resource[datetime.datetime]
12 |
13 |
14 | async def test_dynamic_container_not_supported() -> None:
15 | new_provider = providers.Resource(container.create_sync_resource)
16 | with pytest.raises(AttributeError):
17 | DIContainer.sync_resource = new_provider
18 |
19 | with pytest.raises(AttributeError):
20 | DIContainer.something_new = new_provider
21 |
--------------------------------------------------------------------------------
/that_depends/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, ClassVar, TypeGuard, TypeVar
2 |
3 |
4 | T = TypeVar("T")
5 |
6 |
7 | class _Singleton(type):
8 | _instances: ClassVar[dict[type, Any]] = {}
9 |
10 | def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401
11 | if cls not in cls._instances:
12 | cls._instances[cls] = super().__call__(*args, **kwargs)
13 | return cls._instances[cls]
14 |
15 |
16 | class Unset(metaclass=_Singleton):
17 | """Represents an unset value."""
18 |
19 |
20 | UNSET = Unset()
21 |
22 |
23 | def is_set(value: T | Unset) -> TypeGuard[T]:
24 | """Check if a value is set (not UNSET)."""
25 | return value is not UNSET
26 |
--------------------------------------------------------------------------------
/docs/dev/contributing.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 | `that-depends` is an opensource project, and we are opened to new contributors.
3 |
4 | ## Getting started
5 | 1. Make sure that you have [uv](https://docs.astral.sh/uv/) and [just](https://just.systems/) installed.
6 | 2. Clone project:
7 | ```
8 | git@github.com:modern-python/that-depends.git
9 | cd that-depends
10 | ```
11 | 3. Install dependencies running `just install`
12 |
13 | ## Running linters
14 | `Ruff` and `mypy` are used for static analysis.
15 |
16 | Run all checks by command `just lint`
17 |
18 | ## Running tests
19 | Run all tests by command `just tests`
20 |
21 | ## Building and running the documentation
22 | Host the documentation locally by running `just docs`
23 |
--------------------------------------------------------------------------------
/docs/testing/fixture.md:
--------------------------------------------------------------------------------
1 | # Fixture
2 | ## Dependencies teardown
3 | When using dependency injection in tests, it's important to properly tear down resources after tests complete.
4 |
5 | Without proper teardown, the Python event loop might close before resources have a chance to shut down properly, leading to errors like `RuntimeError: Event loop is closed`.
6 |
7 | You can set up automatic teardown using a pytest fixture:
8 |
9 | ```python
10 | import pytest_asyncio
11 | from typing import AsyncGenerator
12 |
13 | from my_project import DIContainer
14 |
15 | @pytest_asyncio.fixture(autouse=True)
16 | async def di_container_teardown() -> AsyncGenerator[None]:
17 | try:
18 | yield
19 | finally:
20 | await DIContainer.tear_down()
21 | ```
22 |
--------------------------------------------------------------------------------
/Justfile:
--------------------------------------------------------------------------------
1 | default: install lint test
2 |
3 | install:
4 | uv lock --upgrade
5 | uv sync --all-extras --frozen
6 | @just hook
7 |
8 | lint:
9 | uv run ruff format
10 | uv run ruff check --fix
11 | uv run mypy .
12 |
13 | lint-ci:
14 | uv run ruff format --check
15 | uv run ruff check --no-fix
16 | uv run mypy .
17 |
18 | test *args:
19 | uv run --no-sync pytest {{ args }}
20 |
21 | publish:
22 | rm -rf dist
23 | uv version $GITHUB_REF_NAME
24 | uv build
25 | uv publish --token $PYPI_TOKEN
26 |
27 | hook:
28 | uv run pre-commit install --install-hooks --overwrite
29 |
30 | unhook:
31 | uv run pre-commit uninstall
32 |
33 | docs:
34 | uv pip install -r docs/requirements.txt
35 | uv run mkdocs serve
36 |
--------------------------------------------------------------------------------
/that_depends/__init__.py:
--------------------------------------------------------------------------------
1 | """Dependency injection framework for Python."""
2 |
3 | from that_depends import providers
4 | from that_depends.container import BaseContainer
5 | from that_depends.injection import Provide, inject
6 | from that_depends.providers import container_context
7 | from that_depends.providers.context_resources import (
8 | ContextScope,
9 | ContextScopes,
10 | fetch_context_item,
11 | fetch_context_item_by_type,
12 | get_current_scope,
13 | )
14 |
15 |
16 | __all__ = [
17 | "BaseContainer",
18 | "ContextScope",
19 | "ContextScopes",
20 | "Provide",
21 | "container_context",
22 | "fetch_context_item",
23 | "fetch_context_item_by_type",
24 | "get_current_scope",
25 | "inject",
26 | "providers",
27 | ]
28 |
--------------------------------------------------------------------------------
/tests/test_resolver.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import pytest
4 |
5 | from tests import container
6 | from tests.container import DIContainer
7 |
8 |
9 | @dataclasses.dataclass(kw_only=True, slots=True)
10 | class WrongFactory:
11 | some_dep: int = 1
12 | not_existing_name: container.DependentFactory
13 | sync_resource: str
14 |
15 |
16 | async def test_dependency_resolver() -> None:
17 | resolver = DIContainer.resolver(container.FreeFactory)
18 | dep1 = await resolver()
19 | dep2 = await DIContainer.resolve(container.FreeFactory)
20 |
21 | assert dep1
22 | assert dep2
23 | assert dep1 is not dep2
24 |
25 |
26 | async def test_dependency_resolver_failed() -> None:
27 | resolver = DIContainer.resolver(WrongFactory)
28 | with pytest.raises(RuntimeError, match=r"Provider is not found, field_name='not_existing_name'"):
29 | await resolver()
30 |
--------------------------------------------------------------------------------
/tests/integrations/litestar/test_litestar_di_simple.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | from litestar import Litestar, get
4 | from litestar.di import Provide
5 | from litestar.status_codes import HTTP_200_OK
6 | from litestar.testing import TestClient
7 |
8 | from tests import container
9 |
10 |
11 | @get("/")
12 | async def index(injected: datetime.datetime) -> datetime.datetime:
13 | return injected
14 |
15 |
16 | app = Litestar([index], dependencies={"injected": Provide(container.DIContainer.async_resource)})
17 |
18 |
19 | async def test_litestar_di() -> None:
20 | with TestClient(app=app) as client:
21 | response = client.get("/")
22 | assert response.status_code == HTTP_200_OK, response.text
23 | assert (
24 | datetime.datetime.fromisoformat(response.json().replace("Z", "+00:00"))
25 | == await container.DIContainer.async_resource()
26 | )
27 |
--------------------------------------------------------------------------------
/docs/introduction/multiple-containers.md:
--------------------------------------------------------------------------------
1 | # Usage with multiple containers
2 |
3 | You can use providers from other containers as following:
4 | ```python
5 | from tests import container
6 | from that_depends import BaseContainer, providers
7 |
8 |
9 | class InnerContainer(BaseContainer):
10 | sync_resource = providers.Resource(container.create_sync_resource)
11 | async_resource = providers.Resource(container.create_async_resource)
12 |
13 |
14 | class OuterContainer(BaseContainer):
15 | sequence = providers.List(InnerContainer.sync_resource, InnerContainer.async_resource)
16 | ```
17 |
18 | But this way you have to manage `InnerContainer` lifecycle:
19 |
20 | ```python
21 | await InnerContainer.tear_down()
22 | ```
23 |
24 | Or you can connect sub-containers to the main container:
25 |
26 | ```python
27 | OuterContainer.connect_containers(InnerContainer)
28 |
29 |
30 | # this will init resources for `InnerContainer` also
31 | await OuterContainer.init_resources()
32 |
33 | # and this will tear down resources for `InnerContainer` also
34 | await OuterContainer.tear_down()
35 | ```
36 |
--------------------------------------------------------------------------------
/docs/providers/collections.md:
--------------------------------------------------------------------------------
1 | # Collections
2 | There are several collection providers: `List` and `Dict`
3 |
4 | ## List
5 | - List provider contains other providers.
6 | - Resolves into list of dependencies.
7 |
8 | ```python
9 | import random
10 | from that_depends import BaseContainer, providers
11 |
12 |
13 | class DIContainer(BaseContainer):
14 | random_number = providers.Factory(random.random)
15 | numbers_sequence = providers.List(random_number, random_number)
16 |
17 |
18 | DIContainer.numbers_sequence.resolve_sync()
19 | # [0.3035656170071561, 0.8280498192037787]
20 | ```
21 |
22 | ## Dict
23 | - Dict provider is a collection of named providers.
24 | - Resolves into dict of dependencies.
25 |
26 | ```python
27 | import random
28 | from that_depends import BaseContainer, providers
29 |
30 |
31 | class DIContainer(BaseContainer):
32 | random_number = providers.Factory(random.random)
33 | numbers_map = providers.Dict(key1=random_number, key2=random_number)
34 |
35 |
36 | DIContainer.numbers_map.resolve_sync()
37 | # {'key1': 0.6851384528299208, 'key2': 0.41044920948045294}
38 | ```
39 |
--------------------------------------------------------------------------------
/docs/integrations/litestar.md:
--------------------------------------------------------------------------------
1 | # Usage with `Litestar`
2 |
3 | ```python
4 | import typing
5 | import fastapi
6 | import contextlib
7 | from litestar import Litestar, get
8 | from litestar.di import Provide
9 | from litestar.status_codes import HTTP_200_OK
10 | from litestar.testing import TestClient
11 |
12 | from tests import container
13 |
14 |
15 | @get("/")
16 | async def index(injected: str) -> str:
17 | return injected
18 |
19 |
20 | @contextlib.asynccontextmanager
21 | async def lifespan_manager(_: fastapi.FastAPI) -> typing.AsyncIterator[None]:
22 | try:
23 | yield
24 | finally:
25 | await container.DIContainer.tear_down()
26 |
27 |
28 | app = Litestar(
29 | route_handlers=[index],
30 | dependencies={"injected": Provide(container.DIContainer.async_resource)},
31 | lifespan=[lifespan_manager],
32 | )
33 |
34 |
35 | def test_litestar_di() -> None:
36 | with (TestClient(app=app) as client):
37 | response = client.get("/")
38 | assert response.status_code == HTTP_200_OK, response.text
39 | assert response.text == "async resource"
40 | ```
41 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 modern-python
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/that_depends/providers/__init__.py:
--------------------------------------------------------------------------------
1 | """Providers."""
2 |
3 | from that_depends.providers.base import AbstractProvider, AttrGetter
4 | from that_depends.providers.collection import Dict, List
5 | from that_depends.providers.context_resources import (
6 | ContextResource,
7 | DIContextMiddleware,
8 | container_context,
9 | )
10 | from that_depends.providers.factories import AsyncFactory, Factory
11 | from that_depends.providers.local_singleton import ThreadLocalSingleton
12 | from that_depends.providers.object import Object
13 | from that_depends.providers.resources import Resource
14 | from that_depends.providers.selector import Selector
15 | from that_depends.providers.singleton import AsyncSingleton, Singleton
16 | from that_depends.providers.state import State
17 |
18 |
19 | __all__ = [
20 | "AbstractProvider",
21 | "AsyncFactory",
22 | "AsyncSingleton",
23 | "AttrGetter",
24 | "ContextResource",
25 | "DIContextMiddleware",
26 | "Dict",
27 | "Factory",
28 | "List",
29 | "Object",
30 | "Resource",
31 | "Selector",
32 | "Singleton",
33 | "State",
34 | "ThreadLocalSingleton",
35 | "container_context",
36 | ]
37 |
--------------------------------------------------------------------------------
/tests/experimental/test_lazy_provider.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from that_depends.experimental import LazyProvider
4 |
5 |
6 | def test_lazy_provider_incorrect_initialization() -> None:
7 | with pytest.raises(
8 | ValueError,
9 | match=r"You must provide either import_string "
10 | "OR both module_string AND provider_string, but not both or neither.",
11 | ):
12 | LazyProvider(module_string="3213") # type: ignore[call-overload]
13 |
14 | with pytest.raises(ValueError, match=r"Invalid import_string ''"):
15 | LazyProvider("")
16 |
17 | with pytest.raises(ValueError, match=r"Invalid provider_string ''"):
18 | LazyProvider(module_string="some.module", provider_string="")
19 |
20 | with pytest.raises(ValueError, match=r"Invalid module_string '.'"):
21 | LazyProvider(module_string=".", provider_string="SomeProvider")
22 |
23 | with pytest.raises(ValueError, match=r"Invalid import_string 'import.'"):
24 | LazyProvider("import.")
25 |
26 |
27 | def test_lazy_provider_incorrect_import_string() -> None:
28 | p = LazyProvider("some.random.path")
29 | with pytest.raises(ImportError):
30 | p.resolve_sync()
31 |
--------------------------------------------------------------------------------
/that_depends/providers/mixin.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class CannotTearDownSyncError(RuntimeError):
5 | """Raised when attempting to tear down an async resource in sync mode."""
6 |
7 |
8 | class SupportsTeardown(abc.ABC):
9 | """Interface for objects that support teardown."""
10 |
11 | @abc.abstractmethod
12 | async def tear_down(self, propagate: bool = True) -> None:
13 | """Perform any necessary cleanup operations.
14 |
15 | This method is called when the object is no longer needed.
16 | """
17 |
18 | @abc.abstractmethod
19 | def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None:
20 | """Perform any necessary cleanup operations.
21 |
22 | This method is called when the object is no longer needed.
23 | """
24 |
25 |
26 | class ProviderWithArguments(abc.ABC):
27 | """Interface for providers that require arguments."""
28 |
29 | @abc.abstractmethod
30 | def _register_arguments(self) -> None:
31 | """Register arguments for the provider."""
32 |
33 | @abc.abstractmethod
34 | def _deregister_arguments(self) -> None:
35 | """Deregister arguments for the provider."""
36 |
--------------------------------------------------------------------------------
/that_depends/providers/object.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | from typing_extensions import override
4 |
5 | from that_depends.providers.base import AbstractProvider
6 |
7 |
8 | T_co = typing.TypeVar("T_co", covariant=True)
9 | P = typing.ParamSpec("P")
10 |
11 |
12 | class Object(AbstractProvider[T_co]):
13 | """Provides an object "as is" without any modification.
14 |
15 | This provider always returns the same object that was given during
16 | initialization.
17 |
18 | Example:
19 | ```python
20 | provider = Object(1)
21 | result = provider.resolve_sync()
22 | print(result) # 1
23 | ```
24 |
25 | """
26 |
27 | __slots__ = ("_obj",)
28 |
29 | def __init__(self, obj: T_co) -> None:
30 | """Initialize the provider with the given object.
31 |
32 | Args:
33 | obj (T_co): The object to be provided.
34 |
35 | """
36 | super().__init__()
37 | self._obj: typing.Final = obj
38 |
39 | @override
40 | async def resolve(self) -> T_co:
41 | return self.resolve_sync()
42 |
43 | @override
44 | def resolve_sync(self) -> T_co:
45 | if self._override is not None:
46 | return typing.cast(T_co, self._override)
47 | return self._obj
48 |
--------------------------------------------------------------------------------
/examples/benchmark/injection.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | import time
4 | import typing
5 |
6 | from that_depends import BaseContainer, ContextScopes, Provide, inject, providers
7 |
8 |
9 | logging.basicConfig(level=logging.INFO)
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def _iterator(x: float) -> typing.Iterator[float]:
14 | yield x
15 |
16 |
17 | class _Container(BaseContainer):
18 | grandparent = providers.Singleton(lambda: random.random())
19 | parent = providers.Factory(lambda x: x, grandparent.cast)
20 | item = providers.ContextResource(_iterator, parent.cast).with_config(scope=ContextScopes.REQUEST)
21 |
22 |
23 | @inject(scope=ContextScopes.REQUEST)
24 | def _injected(
25 | x_1: float = Provide[_Container.grandparent],
26 | x_2: float = Provide[_Container.parent],
27 | x_3: float = Provide[_Container.item],
28 | ) -> float:
29 | return x_1 + x_2 + x_3
30 |
31 |
32 | def _bench(n_iterations: int) -> float:
33 | start = time.time()
34 | for _ in range(n_iterations):
35 | _injected()
36 | end = time.time()
37 | _Container.tear_down_sync()
38 | return end - start
39 |
40 |
41 | if __name__ == "__main__":
42 | for n in [10000, 100000, 1000000]:
43 | duration = _bench(n)
44 | logger.info(f"Injected {n} times in {duration:.4f} seconds") # noqa: G004
45 |
--------------------------------------------------------------------------------
/tests/integrations/fastapi/test_fastapi_di_pass_request.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import fastapi
4 | from starlette import status
5 | from starlette.testclient import TestClient
6 |
7 | from tests import container
8 | from that_depends import BaseContainer, container_context, fetch_context_item, providers
9 |
10 |
11 | async def init_di_context(request: fastapi.Request) -> typing.AsyncIterator[None]:
12 | async with container_context(global_context={"request": request}):
13 | yield
14 |
15 |
16 | app = fastapi.FastAPI(dependencies=[fastapi.Depends(init_di_context)])
17 |
18 |
19 | class DIContainer(BaseContainer):
20 | alias = "fastapi_container"
21 | context_request = providers.Factory(
22 | lambda: fetch_context_item("request"),
23 | )
24 |
25 |
26 | @app.get("/")
27 | async def read_root(
28 | context_request: typing.Annotated[
29 | container.DependentFactory, # incorrect type but available for backwards compatibility
30 | fastapi.Depends(DIContainer.context_request),
31 | ],
32 | request: fastapi.Request,
33 | ) -> None:
34 | assert isinstance(request, fastapi.Request)
35 | assert request == context_request
36 |
37 |
38 | client = TestClient(app)
39 |
40 |
41 | async def test_read_main() -> None:
42 | response = client.get("/")
43 | assert response.status_code == status.HTTP_200_OK
44 |
--------------------------------------------------------------------------------
/tests/providers/test_factories.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from that_depends import BaseContainer, Provide, inject, providers
4 |
5 |
6 | async def test_async_factory_with_sync_creator() -> None:
7 | return_value = random.random()
8 | f = providers.AsyncFactory(lambda: return_value)
9 |
10 | assert await f.resolve() == return_value
11 |
12 |
13 | async def test_async_factory_with_sync_creator_multiple_parents() -> None:
14 | """Dependencies of async factory get resolved correctly with a sync creator."""
15 | _return_value_1 = 32
16 | _return_value_2 = 12
17 |
18 | async def _async_creator() -> int:
19 | return 32
20 |
21 | def _sync_creator() -> int:
22 | return _return_value_2
23 |
24 | class _Adder:
25 | def __init__(self, x: int, y: int) -> None:
26 | self.x = x
27 | self.y = y
28 |
29 | def result(self) -> int:
30 | return self.x + self.y
31 |
32 | class _Container(BaseContainer):
33 | p1 = providers.AsyncFactory(_async_creator)
34 | p2 = providers.Factory(_sync_creator)
35 | p3 = providers.AsyncFactory(_Adder, x=p1.cast, y=p2.cast)
36 |
37 | @inject
38 | async def _injected(adder: _Adder = Provide[_Container.p3]) -> int:
39 | return adder.result()
40 |
41 | assert await _injected() == _return_value_1 + _return_value_2
42 |
--------------------------------------------------------------------------------
/docs/providers/state.md:
--------------------------------------------------------------------------------
1 | # State
2 |
3 | The `State` provider stores a value as part of a context.
4 |
5 | It is useful when you want to pass a value into your Container that other providers depend on.
6 |
7 |
8 | ## Creating a state provider
9 |
10 | The `State` provider does not accept any arguments when it created.
11 | ```python
12 | from that_depends import BaseContainer, providers
13 | class Container(BaseContainer):
14 | my_state: providers.State[int] = providers.State()
15 | ```
16 |
17 | ## Initializing state
18 |
19 | === "async"
20 | ```python
21 |
22 | with Container.my_state.init(42):
23 | print(await Container.my_state.resolve()) # 42
24 |
25 | ```
26 |
27 | === "sync"
28 | ```python
29 | with Container.my_state.init(42):
30 | print(Container.my_state.resolve_sync()) # 42
31 |
32 | ```
33 |
34 | > Note: If you try to resolve a `State` provider without initializing it first it will raise an `StateNotInitializedError`.
35 |
36 |
37 | ## Nested state
38 |
39 | The `State` provider will always resolve the last initialize value.
40 |
41 | ```python
42 | with Container.my_state.init(1):
43 | print(Container.my_state.resolve_sync()) # 1
44 |
45 | with Container.my_state.init(2):
46 | print(Container.my_state.resolve_sync()) # 2
47 |
48 | print(Container.my_state.resolve_sync()) # 1
49 | ```
50 |
--------------------------------------------------------------------------------
/docs/introduction/ioc-container.md:
--------------------------------------------------------------------------------
1 | # The Dependency Injection Container
2 |
3 | Containers serve as a central place to store and manage providers. You also define
4 | your dependency graph in the containers.
5 |
6 | While providers can be defined outside of containers with that depends, this is not recommended
7 | if you want to use any [context features](../providers/context-resources.md)
8 |
9 |
10 | ## Quickstart
11 |
12 | Define a container by subclassing `BaseContainer` and define your providers as class attributes.
13 | ```python
14 | from that_depends import BaseContainer
15 |
16 | class Container(BaseContainer):
17 | # define your providers here
18 | ```
19 |
20 | Then you can build your dependency graph within the container:
21 |
22 | ```python
23 | from that_depends import providers
24 |
25 | class Container(BaseContainer):
26 | config = providers.Singleton(Config)
27 | session = providers.Factory(create_db_session, config=config.db) # (1)!
28 |
29 | user_repository = providers.Factory(
30 | UserRepository,
31 | session=session.cast, # (3)!
32 | config.users
33 | ) # (2)!
34 |
35 |
36 | ```
37 |
38 | 1. The configuration will be resolved and then the `.db` attribute will be passed to the `create_db_session` creator
39 | as a keyword argument when resolving the `session` provider.
40 | 2. Depends on both the session and configuration providers.
41 | 3. Providers have the `cast` property that will change their type to the return type of their creator, use it to prevent type errors.
42 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: main
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request: {}
8 |
9 | concurrency:
10 | group: ${{ github.head_ref || github.run_id }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | lint:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v5
18 | - uses: extractions/setup-just@v3
19 | - uses: astral-sh/setup-uv@v7
20 | with:
21 | cache-dependency-glob: "**/pyproject.toml"
22 | - run: uv python install 3.10
23 | - run: just install lint-ci
24 |
25 | pytest:
26 | runs-on: ubuntu-latest
27 | strategy:
28 | fail-fast: false
29 | matrix:
30 | python-version:
31 | - "3.10"
32 | - "3.11"
33 | - "3.12"
34 | - "3.13"
35 | - "3.14"
36 | faststream-version:
37 | - "<0.6.0"
38 | - ">=0.6.0"
39 | steps:
40 | - uses: actions/checkout@v5
41 | - uses: extractions/setup-just@v3
42 | - uses: astral-sh/setup-uv@v7
43 | with:
44 | cache-dependency-glob: "**/pyproject.toml"
45 | - run: uv python install ${{ matrix.python-version }}
46 | - run: just install
47 | - run: uv run --with "faststream${{ matrix.faststream-version }}" pytest . --cov=. --cov-report xml
48 | - uses: codecov/codecov-action@v5.4.3
49 | env:
50 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
51 | with:
52 | files: ./coverage.xml
53 | flags: unittests
54 | name: codecov-${{ matrix.python-version }}
55 |
--------------------------------------------------------------------------------
/docs/experimental/lazy.md:
--------------------------------------------------------------------------------
1 | # Lazy Provider
2 |
3 | The `LazyProvider` enables you to reference other providers without explicitly
4 | importing them into your module.
5 |
6 | This can be helpful if you have a circular dependency between providers in
7 | multiple containers.
8 |
9 |
10 | ## Creating a Lazy Provider
11 |
12 | === "Single import string"
13 | ```python
14 | from that_depends.experimental import LazyProvider
15 |
16 | lazy_p = LazyProvider("my.module.attribute.path")
17 | ```
18 | === "Separate module and provider"
19 | ```python
20 | from that_depends.experimental import LazyProvider
21 |
22 | lazy_p = LazyProvider(module_string="my.module", provider_string="attribute.path")
23 | ```
24 |
25 |
26 | ## Usage
27 |
28 | You can use the lazy provider in exactly the same way as you would use the referenced provider.
29 |
30 | ```python
31 | # first_container.py
32 | from that_depends import BaseContainer, providers, ContextScopes
33 |
34 | def my_creator():
35 | yield 42
36 |
37 | class FirstContainer(BaseContainer):
38 | value_provider = providers.ContextResource(my_creator).with_config(scope=ContextScopes.APP)
39 | ```
40 |
41 | You can lazily import this provider:
42 | ```python
43 | # second_container.py
44 | from that_depends.experimental import LazyProvider
45 | from that_depends import BaseContainer, providers
46 | class SecondContainer(BaseContainer):
47 | lazy_value = LazyProvider("first_container.FirstContainer.value_provider")
48 |
49 |
50 | with SecondContainer.lazy_value.context_sync(force=True):
51 | SecondContainer.lazy_value.resolve_sync() # 42
52 | ```
53 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | "That Depends"
2 | ==
3 | [](https://codecov.io/gh/modern-python/that-depends)
4 | [](https://mypy.readthedocs.io/en/stable/getting_started.html#strict-mode-and-configuration)
5 | [](https://pypi.python.org/pypi/that-depends)
6 | [](https://pepy.tech/projects/that-depends)
7 | [](https://github.com/modern-python/that-depends/stargazers)
8 | [](https://libs.tech/project/773446541/that-depends)
9 | [](https://that-depends.readthedocs.io/llms.txt)
10 |
11 | Dependency injection framework for Python.
12 |
13 | It is production-ready and gives you the following:
14 | - Simple async-first DI framework with IOC-container.
15 | - Python 3.10+ support.
16 | - Full coverage by types annotations (mypy in strict mode).
17 | - Inbuilt FastAPI, FastStream and LiteStar compatibility.
18 | - Dependency context management with scopes.
19 | - Overriding dependencies for tests.
20 | - Injecting dependencies in functions and coroutines without wiring.
21 | - Package with zero dependencies.
22 |
23 |
24 | ### Installation
25 | ```bash
26 | pip install that-depends
27 | ```
28 |
29 | ## 📚 [Documentation](https://that-depends.readthedocs.io)
30 |
31 | ## 📦 [PyPi](https://pypi.org/project/that-depends)
32 |
33 | ## 📝 [License](LICENSE)
34 |
--------------------------------------------------------------------------------
/tests/creators.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import types
4 | import typing
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | async def create_async_resource() -> typing.AsyncIterator[datetime.datetime]:
11 | logger.debug("Async resource initiated")
12 | try:
13 | yield datetime.datetime.now(tz=datetime.timezone.utc)
14 | finally:
15 | logger.debug("Async resource destructed")
16 |
17 |
18 | def create_sync_resource() -> typing.Iterator[datetime.datetime]:
19 | logger.debug("Resource initiated")
20 | try:
21 | yield datetime.datetime.now(tz=datetime.timezone.utc)
22 | finally:
23 | logger.debug("Resource destructed")
24 |
25 |
26 | class ContextManagerResource(typing.ContextManager[datetime.datetime]):
27 | def __enter__(self) -> datetime.datetime:
28 | logger.debug("Resource initiated")
29 | return datetime.datetime.now(tz=datetime.timezone.utc)
30 |
31 | def __exit__(
32 | self,
33 | exc_type: type[BaseException] | None,
34 | exc_value: BaseException | None,
35 | traceback: types.TracebackType | None,
36 | ) -> None:
37 | logger.debug("Resource destructed")
38 |
39 |
40 | class AsyncContextManagerResource(typing.AsyncContextManager[datetime.datetime]):
41 | async def __aenter__(self) -> datetime.datetime:
42 | logger.debug("Async resource initiated")
43 | return datetime.datetime.now(tz=datetime.timezone.utc)
44 |
45 | async def __aexit__(
46 | self,
47 | exc_type: type[BaseException] | None,
48 | exc_value: BaseException | None,
49 | traceback: types.TracebackType | None,
50 | ) -> None:
51 | logger.debug("Async resource destructed")
52 |
--------------------------------------------------------------------------------
/that_depends/providers/state.py:
--------------------------------------------------------------------------------
1 | import typing
2 | import uuid
3 | from contextlib import contextmanager
4 | from contextvars import ContextVar
5 | from typing import TypeVar
6 |
7 | from typing_extensions import override
8 |
9 | from that_depends.exceptions import StateNotInitializedError
10 | from that_depends.providers import AbstractProvider
11 | from that_depends.utils import UNSET, Unset, is_set
12 |
13 |
14 | T = TypeVar("T", covariant=False)
15 |
16 | _STATE_UNSET_ERROR_MESSAGE: typing.Final[str] = (
17 | "State has not been initialized.\n"
18 | "Please use `with provider.init(state):` to initialize the state before resolving it."
19 | )
20 |
21 |
22 | class State(AbstractProvider[T]):
23 | """Provides a value that can be passed into the provider at runtime."""
24 |
25 | def __init__(self) -> None:
26 | """Create a state provider."""
27 | super().__init__()
28 |
29 | self._state: ContextVar[T | Unset] = ContextVar(f"STATE_{uuid.uuid4()}", default=UNSET)
30 |
31 | @contextmanager
32 | def init(self, state: T) -> typing.Iterator[T]:
33 | """Set the state provider's value.
34 |
35 | Args:
36 | state: value to store.
37 |
38 | Returns:
39 | Iterator[T_co]: A context manager that yields the state value.
40 |
41 | """
42 | token = self._state.set(state)
43 | yield state
44 | self._state.reset(token)
45 |
46 | @override
47 | async def resolve(self) -> T:
48 | return self.resolve_sync()
49 |
50 | @override
51 | def resolve_sync(self) -> T:
52 | value = self._state.get()
53 | if is_set(value):
54 | return value
55 | raise StateNotInitializedError(_STATE_UNSET_ERROR_MESSAGE)
56 |
--------------------------------------------------------------------------------
/docs/introduction/type-based-injection.md:
--------------------------------------------------------------------------------
1 | # Type based injection
2 |
3 | `that-depends` also supports dependency injection without explicitly referencing
4 | the provider of the dependency.
5 |
6 | ## Quick Start
7 |
8 | In order to make use of this, you need to bind providers to the type they will provide:
9 | ```python
10 | class Container(BaseContainer):
11 | my_provider = providers.Factory(lambda: random.random()).bind(float)
12 | ```
13 |
14 | Then provide inject into your functions or generators:
15 |
16 | === "Option 1"
17 | ```python
18 | @Container.inject
19 | async def foo(v: float = Provide()):
20 | return v
21 | ```
22 | === "Option 2"
23 | ```python
24 | @inject(container=Container)
25 | async def foo(v: float = Provide()):
26 | return v
27 | ```
28 |
29 |
30 | ## Default bind
31 |
32 | Per default, providers will **not** be bound to any type, even if your creator
33 | function has type hints. So make sure to always bind your providers.
34 |
35 | You can also bind multiple types to the same provider:
36 | ```python
37 | class Container(BaseContainer):
38 | my_provider = providers.Factory(lambda: random.random()).bind(float, complex)
39 | ```
40 |
41 | ## Contravariant binding
42 |
43 | Per default injection will be invariant to the bound types.
44 |
45 | If you wish to enable contravariance for your bound types you can do so by setting
46 | `#!python contravariant=True` in the `bind` method:
47 |
48 | ```python hl_lines="6 9"
49 | class A: ...
50 |
51 | class B(A): ...
52 |
53 | class Container(BaseContainer):
54 | my_provider = providers.Factory(lambda: B()).bind(B, contravariant=True)
55 |
56 | @Container.inject
57 | async def foo(v: A = Provide()) -> A: # (1)!
58 | return v
59 | ```
60 |
61 | 1. `v` will receive an instance of `B` since `A` is a supertype of `B`.
62 |
--------------------------------------------------------------------------------
/docs/providers/selector.md:
--------------------------------------------------------------------------------
1 | # Selector
2 |
3 | The Selector provider chooses between provider based on a key. This resolves into a single dependency.
4 |
5 | The selector can be a callable that returns a string, an instance of `AbstractProvider` or a string.
6 |
7 | ## Callable selectors
8 |
9 | ```python
10 | import os
11 | from typing import Protocol
12 | from that_depends import BaseContainer, providers
13 |
14 | class StorageService(Protocol):
15 | ...
16 |
17 | class StorageServiceLocal(StorageService):
18 | ...
19 |
20 | class StorageServiceRemote(StorageService):
21 | ...
22 |
23 | class DIContainer(BaseContainer):
24 | storage_service = providers.Selector(
25 | lambda: os.getenv("STORAGE_BACKEND", "local"),
26 | local=providers.Factory(StorageServiceLocal),
27 | remote=providers.Factory(StorageServiceRemote),
28 | )
29 | ```
30 |
31 | ## Provider selectors
32 |
33 | In this example, we have a Pydantic-Settings class that contains the key to select the provider.
34 |
35 | ```python
36 | from typing import Literal
37 | from pydantic_settings import BaseSettings
38 |
39 | class Settings(BaseSettings):
40 | storage_backend: Literal["local", "remote"] = "remote"
41 |
42 | class DIContainer(BaseContainer):
43 | settings = providers.Singleton(Settings)
44 | selector = providers.Selector(
45 | settings.cast.storage_backend,
46 | local=providers.Factory(StorageServiceLocal),
47 | remote=providers.Factory(StorageServiceRemote),
48 | )
49 | ```
50 |
51 | ## Fixed string selectors
52 |
53 | This can be useful for quickly testing.
54 |
55 | ```python
56 | class DIContainer(BaseContainer):
57 | selector = providers.Selector(
58 | "local",
59 | local=providers.Factory(StorageServiceLocal),
60 | remote=providers.Factory(StorageServiceRemote),
61 | )
62 | ```
63 |
--------------------------------------------------------------------------------
/tests/integrations/faststream/test_faststream_di.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import typing
3 |
4 | from faststream import Depends
5 | from faststream.nats import NatsBroker, TestNatsBroker
6 |
7 | from tests import container
8 | from that_depends.integrations.faststream import DIContextMiddleware
9 |
10 |
11 | broker = NatsBroker(middlewares=[DIContextMiddleware(container.DIContainer.context_resource)])
12 |
13 | TEST_SUBJECT = "test"
14 |
15 |
16 | @broker.subscriber(TEST_SUBJECT)
17 | async def index_subscriber(
18 | dependency: typing.Annotated[
19 | container.DependentFactory,
20 | Depends(container.DIContainer.dependent_factory),
21 | ],
22 | free_dependency: typing.Annotated[
23 | container.FreeFactory,
24 | Depends(container.DIContainer.resolver(container.FreeFactory)),
25 | ],
26 | singleton: typing.Annotated[
27 | container.SingletonFactory,
28 | Depends(container.DIContainer.singleton),
29 | ],
30 | singleton_attribute: typing.Annotated[bool, Depends(container.DIContainer.singleton.dep1)],
31 | context_resource: typing.Annotated[datetime.datetime, Depends(container.DIContainer.context_resource)],
32 | ) -> datetime.datetime:
33 | assert dependency.sync_resource == free_dependency.dependent_factory.sync_resource
34 | assert dependency.async_resource == free_dependency.dependent_factory.async_resource
35 | assert singleton.dep1 is True
36 | assert singleton_attribute is True
37 | assert isinstance(context_resource, datetime.datetime)
38 | return dependency.async_resource
39 |
40 |
41 | async def test_read_main() -> None:
42 | async with TestNatsBroker(broker) as br:
43 | result = await br.request(None, TEST_SUBJECT)
44 |
45 | result_str = typing.cast(str, await result.decode())
46 | assert (
47 | datetime.datetime.fromisoformat(result_str.replace("Z", "+00:00"))
48 | == await container.DIContainer.async_resource()
49 | )
50 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # That Depends
2 |
3 | Welcome to the `that-depends` documentation!
4 |
5 | `that-depends` is a python dependency injection framework which, among other things,
6 | supports the following:
7 |
8 | - Async and sync dependency resolution
9 | - Scopes and granular context management
10 | - Dependency injection anywhere
11 | - Fully typed and tested
12 | - Compatibility with popular frameworks like `FastAPI` and `LiteStar`
13 | - Python 3.10+ support
14 |
15 | ---
16 |
17 | ## Installation
18 |
19 | === "pip"
20 | ```bash
21 | pip install that-depends
22 | ```
23 | === "uv"
24 | ```bash
25 | uv add that-depends
26 | ```
27 |
28 | ---
29 |
30 | ## Quickstart
31 |
32 | ### Define a creator
33 | ```python
34 | async def create_async_resource():
35 | logger.debug("Async resource initiated")
36 | try:
37 | yield "async resource"
38 | finally:
39 | logger.debug("Async resource destructed")
40 | ```
41 |
42 | ### Setup Dependency Injection Container with Providers
43 | ```python
44 | from that_depends import BaseContainer, providers
45 |
46 | class Container(BaseContainer):
47 | provider = providers.Resource(create_async_resource)
48 | ```
49 |
50 | See the [containers documentation](introduction/ioc-container.md) for more information on defining the container.
51 |
52 | For a list of providers and their usage, see the [providers section](providers/collections.md).
53 | ### Resolve dependencies in your code
54 | ```python
55 | await Container.provider()
56 | ```
57 |
58 | ### Inject providers in function arguments
59 | ```python
60 | from that_depends import inject, Provide
61 |
62 | @inject
63 | async def some_foo(value: str = Provide[Container.provider]):
64 | return value
65 |
66 | await some_foo() # "async resource"
67 | ```
68 |
69 | See the [injection documentation](introduction/injection.md) for more information.
70 |
71 | ## llms.txt
72 |
73 | `that-depends` provides a [llms.txt](https://that-depends.readthedocs.io/llms.txt) file.
74 |
75 | You can find documentation on how to use it [here](https://llmstxt.org/).
76 |
--------------------------------------------------------------------------------
/tests/integrations/faststream/test_faststream_di_pass_message.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | from faststream import BaseMiddleware, Context, Depends
4 | from faststream.nats import NatsBroker, TestNatsBroker
5 | from faststream.nats.message import NatsMessage
6 | from packaging.version import Version
7 |
8 | from that_depends import BaseContainer, container_context, fetch_context_item, providers
9 | from that_depends.integrations.faststream import _FASTSTREAM_VERSION
10 |
11 |
12 | if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover
13 | from faststream.message import StreamMessage
14 | else: # pragma: no cover
15 | from faststream.broker.message import StreamMessage # type: ignore[import-not-found, no-redef]
16 |
17 |
18 | class ContextMiddleware(BaseMiddleware):
19 | async def consume_scope(
20 | self,
21 | call_next: typing.Callable[..., typing.Awaitable[typing.Any]],
22 | msg: StreamMessage[typing.Any],
23 | ) -> typing.Any: # noqa: ANN401
24 | async with container_context(global_context={"request": msg}):
25 | return await call_next(msg)
26 |
27 |
28 | if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover
29 | broker = NatsBroker(middlewares=(ContextMiddleware,))
30 |
31 | else: # pragma: no cover
32 | broker = NatsBroker(middlewares=(ContextMiddleware,), validate=False) # type: ignore[call-arg]
33 |
34 | TEST_SUBJECT = "test"
35 |
36 |
37 | class DIContainer(BaseContainer):
38 | context_request = providers.Factory(
39 | lambda: fetch_context_item("request"),
40 | )
41 |
42 |
43 | @broker.subscriber(TEST_SUBJECT)
44 | async def index_subscriber(
45 | context_request: typing.Annotated[
46 | NatsMessage,
47 | Depends(DIContainer.context_request, cast=False),
48 | ],
49 | message: typing.Annotated[
50 | NatsMessage,
51 | Context(),
52 | ],
53 | ) -> bool:
54 | return message is context_request
55 |
56 |
57 | async def test_read_main() -> None:
58 | async with TestNatsBroker(broker) as br:
59 | result = await br.request(None, TEST_SUBJECT)
60 |
61 | assert (await result.decode()) is True
62 |
--------------------------------------------------------------------------------
/tests/test_multiple_containers.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | import pytest
4 |
5 | from tests import container
6 | from that_depends import BaseContainer, providers
7 |
8 |
9 | class InnerContainer(BaseContainer):
10 | sync_resource = providers.Resource(container.create_sync_resource)
11 | async_resource = providers.Resource(container.create_async_resource)
12 |
13 |
14 | class OuterContainer(BaseContainer):
15 | sequence = providers.List(InnerContainer.sync_resource, InnerContainer.async_resource)
16 |
17 |
18 | OuterContainer.connect_containers(InnerContainer)
19 |
20 |
21 | async def test_included_container() -> None:
22 | sequence = await OuterContainer.sequence()
23 | assert all(isinstance(x, datetime.datetime) for x in sequence)
24 |
25 | await OuterContainer.tear_down()
26 | assert InnerContainer.sync_resource._context.instance is None
27 | assert InnerContainer.async_resource._context.instance is None
28 |
29 | await OuterContainer.init_resources()
30 | sync_resource_context = InnerContainer.sync_resource._context
31 | assert sync_resource_context
32 | assert sync_resource_context.instance is not None
33 | async_resource_context = InnerContainer.async_resource._context
34 | assert async_resource_context
35 | assert async_resource_context.instance is not None
36 | await OuterContainer.tear_down()
37 |
38 |
39 | async def test_overwriting_container_warns(recwarn: None) -> None: # noqa:ARG001
40 | class _A(BaseContainer):
41 | pass
42 |
43 | with pytest.warns(UserWarning, match="Overwriting container '_A'"):
44 |
45 | class _A(BaseContainer): # type: ignore[no-redef]
46 | pass
47 |
48 | class _A(BaseContainer): # type: ignore[no-redef]
49 | pass
50 |
51 |
52 | async def test_overwriting_container_with_alias_warns(recwarn: None) -> None: # noqa:ARG001
53 | class _A(BaseContainer):
54 | alias = "a"
55 |
56 | with pytest.warns(UserWarning, match="Overwriting container 'a'"):
57 |
58 | class _A(BaseContainer): # type: ignore[no-redef]
59 | alias = "a"
60 |
61 | class _A(BaseContainer): # type: ignore[no-redef]
62 | alias = "a"
63 |
--------------------------------------------------------------------------------
/tests/providers/test_collections.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import pytest
4 |
5 | from tests.container import create_async_resource, create_sync_resource
6 | from that_depends import BaseContainer, providers
7 | from that_depends.providers import AbstractProvider
8 |
9 |
10 | class DIContainer(BaseContainer):
11 | alias = "collection_container"
12 | sync_resource = providers.Resource(create_sync_resource)
13 | async_resource = providers.Resource(create_async_resource)
14 | sequence = providers.List(sync_resource, async_resource)
15 | mapping = providers.Dict(sync_resource=sync_resource, async_resource=async_resource)
16 |
17 |
18 | @pytest.fixture(autouse=True)
19 | async def _clear_di_container() -> typing.AsyncIterator[None]:
20 | try:
21 | yield
22 | finally:
23 | await DIContainer.tear_down()
24 |
25 |
26 | async def test_list_provider() -> None:
27 | sequence = await DIContainer.sequence()
28 | sync_resource = await DIContainer.sync_resource()
29 | async_resource = await DIContainer.async_resource()
30 |
31 | assert sequence == [sync_resource, async_resource]
32 |
33 |
34 | def test_list_failed_sync_resolve() -> None:
35 | with pytest.raises(RuntimeError, match=r"AsyncResource cannot be resolved synchronously"):
36 | DIContainer.sequence.resolve_sync()
37 |
38 |
39 | async def test_list_sync_resolve_after_init() -> None:
40 | await DIContainer.init_resources()
41 | DIContainer.sequence.resolve_sync()
42 |
43 |
44 | async def test_dict_provider() -> None:
45 | mapping = await DIContainer.mapping()
46 | sync_resource = await DIContainer.sync_resource()
47 | async_resource = await DIContainer.async_resource()
48 |
49 | assert mapping == {"sync_resource": sync_resource, "async_resource": async_resource}
50 | assert mapping == DIContainer.mapping.resolve_sync()
51 |
52 |
53 | @pytest.mark.parametrize("provider", [DIContainer.sequence, DIContainer.mapping])
54 | async def test_attr_getter_in_collections_providers(provider: AbstractProvider[typing.Any]) -> None:
55 | with pytest.raises(AttributeError, match=f"'{type(provider)}' object has no attribute 'some_attribute'"):
56 | await provider.some_attribute
57 |
--------------------------------------------------------------------------------
/tests/container.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import datetime
3 | import logging
4 | import typing
5 |
6 | from that_depends import BaseContainer, providers
7 |
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def create_sync_resource() -> typing.Iterator[datetime.datetime]:
13 | logger.debug("Resource initiated")
14 | try:
15 | yield datetime.datetime.now(tz=datetime.timezone.utc)
16 | finally:
17 | logger.debug("Resource destructed")
18 |
19 |
20 | async def create_async_resource() -> typing.AsyncIterator[datetime.datetime]:
21 | logger.debug("Async resource initiated")
22 | try:
23 | yield datetime.datetime.now(tz=datetime.timezone.utc)
24 | finally:
25 | logger.debug("Async resource destructed")
26 |
27 |
28 | @dataclasses.dataclass(kw_only=True, slots=True)
29 | class SimpleFactory:
30 | dep1: str
31 | dep2: int
32 |
33 |
34 | async def async_factory(now: datetime.datetime) -> datetime.datetime:
35 | return now + datetime.timedelta(hours=1)
36 |
37 |
38 | @dataclasses.dataclass(kw_only=True, slots=True)
39 | class DependentFactory:
40 | simple_factory: SimpleFactory
41 | sync_resource: datetime.datetime
42 | async_resource: datetime.datetime
43 |
44 |
45 | @dataclasses.dataclass(kw_only=True, slots=True)
46 | class FreeFactory:
47 | dependent_factory: DependentFactory
48 | sync_resource: str
49 |
50 |
51 | @dataclasses.dataclass(kw_only=True, slots=True)
52 | class SingletonFactory:
53 | dep1: bool
54 |
55 |
56 | class DIContainer(BaseContainer):
57 | default_scope = None
58 | alias = "test_container"
59 | sync_resource = providers.Resource(create_sync_resource)
60 | async_resource = providers.Resource(create_async_resource)
61 |
62 | simple_factory = providers.Factory(SimpleFactory, dep1="text", dep2=123)
63 | async_factory = providers.AsyncFactory(async_factory, async_resource.cast)
64 | dependent_factory = providers.Factory(
65 | DependentFactory,
66 | simple_factory=simple_factory.cast,
67 | sync_resource=sync_resource.cast,
68 | async_resource=async_resource.cast,
69 | )
70 | singleton = providers.Singleton(SingletonFactory, dep1=True)
71 | object = providers.Object(object())
72 | context_resource = providers.ContextResource(create_async_resource)
73 |
--------------------------------------------------------------------------------
/tests/providers/test_inject_factories.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import typing
3 |
4 | import pytest
5 |
6 | from tests import container
7 | from that_depends import BaseContainer, providers
8 |
9 |
10 | @dataclasses.dataclass(kw_only=True, slots=True)
11 | class InjectedFactories:
12 | sync_factory: typing.Callable[..., container.DependentFactory]
13 | async_factory: typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, container.DependentFactory]]
14 |
15 |
16 | class DIContainer(BaseContainer):
17 | alias = "inject_factories_container"
18 | sync_resource = providers.Resource(container.create_sync_resource)
19 | async_resource = providers.Resource(container.create_async_resource)
20 |
21 | simple_factory = providers.Factory(container.SimpleFactory, dep1="text", dep2=123)
22 | dependent_factory = providers.Factory(
23 | container.DependentFactory,
24 | simple_factory=simple_factory.cast,
25 | sync_resource=sync_resource.cast,
26 | async_resource=async_resource.cast,
27 | )
28 | injected_factories = providers.Factory(
29 | InjectedFactories,
30 | sync_factory=dependent_factory.provider_sync,
31 | async_factory=dependent_factory.provider,
32 | )
33 |
34 |
35 | async def test_async_provider() -> None:
36 | injected_factories = await DIContainer.injected_factories()
37 | instance1 = await injected_factories.async_factory()
38 | instance2 = await injected_factories.async_factory()
39 |
40 | assert isinstance(instance1, container.DependentFactory)
41 | assert isinstance(instance2, container.DependentFactory)
42 | assert instance1 is not instance2
43 |
44 | await DIContainer.tear_down()
45 |
46 |
47 | async def test_sync_provider() -> None:
48 | injected_factories = await DIContainer.injected_factories()
49 | with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
50 | injected_factories.sync_factory()
51 |
52 | await DIContainer.init_resources()
53 | instance1 = injected_factories.sync_factory()
54 | instance2 = injected_factories.sync_factory()
55 |
56 | assert isinstance(instance1, container.DependentFactory)
57 | assert isinstance(instance2, container.DependentFactory)
58 | assert instance1 is not instance2
59 |
60 | await DIContainer.tear_down()
61 |
--------------------------------------------------------------------------------
/docs/introduction/string-injection.md:
--------------------------------------------------------------------------------
1 | ## Provider resolution by name
2 |
3 | The `@inject` decorator can be used to inject a provider by its name. This is useful when you want to inject a provider that is not directly imported in the current module.
4 |
5 |
6 | This serves two primary purposes:
7 |
8 | - A higher level of decoupling between container and your code.
9 | - Avoiding circular imports.
10 |
11 | ---
12 | ## Usage
13 |
14 | To inject a provider by name, use the `Provide` marker with a string argument that has the following format:
15 |
16 | ```
17 | Container.Provider[.attribute.attribute...]
18 | ```
19 |
20 | The string will be validated when it is passed to `Provide[]`, thus will raise an exception
21 | immediately.
22 |
23 | **For example**:
24 |
25 | ```python
26 | from that_depends import BaseContainer, inject, Provide
27 |
28 | class Config(BaseSettings):
29 | name: str = "Damian"
30 |
31 | class A(BaseContainer):
32 | b = providers.Factory(Config)
33 |
34 | @inject
35 | def read(val = Provide["A.b.name"]):
36 | return val
37 |
38 | assert read() == "Damian"
39 | ```
40 |
41 | ### Container alias
42 |
43 | Containers support aliases:
44 |
45 | ```python
46 |
47 | class A(BaseContainer):
48 | alias = "C" # replaces the container name.
49 | b = providers.Factory(Config)
50 |
51 | @inject
52 | def read(val = Provide["C.b.name"]): # `A` can no longer be used.
53 | return val
54 |
55 | assert read() == "Damian"
56 | ```
57 |
58 | ---
59 | ## Considerations
60 |
61 | This feature is primarily intended as a fallback when other options are not optimal or
62 | simply not available, thus is recommended to be used sparingly.
63 |
64 | If you do decide to use injection by name, consider the following:
65 |
66 | - In order for this type of injection to work, your container must be in scope when the injected function is called:
67 | ```python
68 | from that_depends import BaseContainer, inject, Provide
69 |
70 | @inject
71 | def injected(f = Provide["MyContainer.my_provider"]): ...
72 |
73 | injected() # will raise an Exception
74 |
75 | class MyContainer(BaseContainer):
76 | my_provider = providers.Factory(some_creator)
77 |
78 | injected() # will resolve
79 | ```
80 | - Validation of whether you have provided a correct container name and provider name will only happen when the function is called.
81 |
--------------------------------------------------------------------------------
/tests/providers/test_main_providers.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | import pytest
4 |
5 | from tests import container
6 | from tests.container import DIContainer
7 | from that_depends import providers
8 |
9 |
10 | async def test_factory_providers() -> None:
11 | simple_factory = await DIContainer.simple_factory()
12 | dependent_factory = await DIContainer.dependent_factory()
13 | async_factory = await DIContainer.async_factory()
14 | sync_resource = await DIContainer.sync_resource()
15 | async_resource = await DIContainer.async_resource()
16 |
17 | assert dependent_factory.simple_factory is not simple_factory
18 | assert DIContainer.simple_factory.resolve_sync() is not simple_factory
19 | assert dependent_factory.sync_resource == sync_resource
20 | assert dependent_factory.async_resource == async_resource
21 | assert isinstance(async_factory, datetime.datetime)
22 |
23 |
24 | async def test_factories_cannot_deregister_arguments() -> None:
25 | with pytest.raises(NotImplementedError):
26 | DIContainer.simple_factory._deregister_arguments()
27 | with pytest.raises(NotImplementedError):
28 | DIContainer.async_factory._deregister_arguments()
29 |
30 |
31 | async def test_async_resource_provider() -> None:
32 | async_resource = await DIContainer.async_resource()
33 |
34 | assert DIContainer.async_resource.resolve_sync() is async_resource
35 |
36 |
37 | def test_failed_sync_resolve() -> None:
38 | with pytest.raises(RuntimeError, match="AsyncFactory cannot be resolved synchronously"):
39 | DIContainer.async_factory.resolve_sync()
40 |
41 | with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
42 | DIContainer.async_resource.resolve_sync()
43 |
44 |
45 | def test_wrong_providers_init() -> None:
46 | with pytest.raises(TypeError, match="Unsupported resource type"):
47 | providers.Resource(lambda: None) # type: ignore[arg-type,return-value]
48 |
49 |
50 | def test_container_init_error() -> None:
51 | with pytest.raises(RuntimeError, match="DIContainer should not be instantiated"):
52 | DIContainer()
53 |
54 |
55 | async def test_free_dependency() -> None:
56 | resolver = DIContainer.resolver(container.FreeFactory)
57 | dep1 = await resolver()
58 | dep2 = await DIContainer.resolve(container.FreeFactory)
59 | assert dep1
60 | assert dep2
61 |
--------------------------------------------------------------------------------
/tests/integrations/fastapi/test_fastapi_di.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import typing
3 |
4 | import fastapi
5 | import pytest
6 | from starlette import status
7 | from starlette.testclient import TestClient
8 |
9 | from tests import container
10 | from that_depends import fetch_context_item
11 | from that_depends.providers import DIContextMiddleware
12 |
13 |
14 | _GLOBAL_CONTEXT: typing.Final[dict[str, str]] = {"test2": "value2", "test1": "value1"}
15 |
16 |
17 | @pytest.fixture
18 | def fastapi_app() -> fastapi.FastAPI:
19 | app = fastapi.FastAPI()
20 | app.add_middleware(DIContextMiddleware, container.DIContainer, global_context=_GLOBAL_CONTEXT)
21 |
22 | @app.get("/")
23 | async def read_root(
24 | dependency: typing.Annotated[
25 | container.DependentFactory,
26 | fastapi.Depends(container.DIContainer.dependent_factory),
27 | ],
28 | free_dependency: typing.Annotated[
29 | container.FreeFactory,
30 | fastapi.Depends(container.DIContainer.resolver(container.FreeFactory)),
31 | ],
32 | singleton: typing.Annotated[
33 | container.SingletonFactory,
34 | fastapi.Depends(container.DIContainer.singleton),
35 | ],
36 | singleton_attribute: typing.Annotated[bool, fastapi.Depends(container.DIContainer.singleton.dep1)],
37 | context_resource: typing.Annotated[datetime.datetime, fastapi.Depends(container.DIContainer.context_resource)],
38 | ) -> datetime.datetime:
39 | assert dependency.sync_resource == free_dependency.dependent_factory.sync_resource
40 | assert dependency.async_resource == free_dependency.dependent_factory.async_resource
41 | assert singleton.dep1 is True
42 | assert singleton_attribute is True
43 | assert context_resource == await container.DIContainer.context_resource.resolve()
44 | for key, value in _GLOBAL_CONTEXT.items():
45 | assert fetch_context_item(key) == value
46 | return dependency.async_resource
47 |
48 | return app
49 |
50 |
51 | @pytest.fixture
52 | def fastapi_client(fastapi_app: fastapi.FastAPI) -> TestClient:
53 | return TestClient(fastapi_app)
54 |
55 |
56 | async def test_read_main(fastapi_client: TestClient) -> None:
57 | response = fastapi_client.get("/")
58 | assert response.status_code == status.HTTP_200_OK
59 | assert (
60 | datetime.datetime.fromisoformat(response.json().replace("Z", "+00:00"))
61 | == await container.DIContainer.async_resource()
62 | )
63 |
--------------------------------------------------------------------------------
/docs/integrations/faststream.md:
--------------------------------------------------------------------------------
1 | # Usage with `FastStream`
2 |
3 |
4 | `that-depends` is out of the box compatible with `faststream.Depends()`:
5 |
6 | ```python hl_lines="14"
7 | from typing import Annotated
8 | from faststream import Depends
9 | from faststream.asgi import AsgiFastStream
10 | from faststream.rabbit import RabbitBroker
11 |
12 |
13 | broker = RabbitBroker()
14 | app = AsgiFastStream(broker)
15 |
16 | @broker.subscriber(queue="queue")
17 | async def process(
18 | text: str,
19 | suffix: Annotated[
20 | str, Depends(Container.suffix_factory) # (1)!
21 | ],
22 | ) -> None:
23 | return text + suffix
24 | ```
25 |
26 | 1. This would be the same as `Provide[Container.suffix_factory]`
27 |
28 |
29 | ## Context Middleware
30 |
31 | If you are using [ContextResource](../providers/context-resources.md) provider, you likely will want to
32 | initialize a context before processing message with `faststream.`
33 |
34 | `that-depends` provides integration for these use cases:
35 |
36 | ```shell
37 | pip install that-depends[faststream]
38 | ```
39 |
40 | Then you can use the `DIContextMiddleware` with your broker:
41 |
42 | ```python
43 | from that_depends.integrations.faststream import DIContextMiddleware
44 | from that_depends import ContextScopes
45 | from faststream.rabbit import RabbitBroker
46 |
47 | broker = RabbitBroker(middlewares=[DIContextMiddleware(Container, scope=ContextScopes.REQUEST)])
48 | ```
49 |
50 | ## Example
51 |
52 | Here is an example that includes life-cycle events:
53 |
54 | ```python
55 | import datetime
56 | import contextlib
57 | import typing
58 |
59 | from faststream import FastStream, Depends, Logger
60 | from faststream.rabbit import RabbitBroker
61 |
62 | from tests import container
63 |
64 |
65 | @contextlib.asynccontextmanager
66 | async def lifespan_manager() -> typing.AsyncIterator[None]:
67 | try:
68 | yield
69 | finally:
70 | await container.DIContainer.tear_down()
71 |
72 |
73 | broker = RabbitBroker()
74 | app = FastStream(broker, lifespan=lifespan_manager)
75 |
76 |
77 | @broker.subscriber("in")
78 | async def read_root(
79 | logger: Logger,
80 | some_dependency: typing.Annotated[
81 | container.DependentFactory,
82 | Depends(container.DIContainer.dependent_factory)
83 | ],
84 | ) -> datetime.datetime:
85 | startup_time = some_dependency.async_resource
86 | logger.info(startup_time)
87 | return startup_time
88 |
89 |
90 | @app.after_startup
91 | async def t() -> None:
92 | await broker.publish(None, "in")
93 | ```
94 |
--------------------------------------------------------------------------------
/tests/providers/test_state.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import random
3 |
4 | import pytest
5 |
6 | from that_depends import BaseContainer
7 | from that_depends.exceptions import StateNotInitializedError
8 | from that_depends.providers import AsyncFactory, Factory, State
9 |
10 |
11 | async def test_state_raises_if_not_set_async() -> None:
12 | class _Container(BaseContainer):
13 | state: State[str] = State()
14 |
15 | with pytest.raises(StateNotInitializedError):
16 | await _Container.state.resolve()
17 |
18 |
19 | def test_state_raises_if_not_set_sync() -> None:
20 | class _Container(BaseContainer):
21 | state: State[str] = State()
22 |
23 | with pytest.raises(StateNotInitializedError):
24 | _Container.state.resolve_sync()
25 |
26 |
27 | async def test_state_set_resolves_correctly_async() -> None:
28 | class _Container(BaseContainer):
29 | state: State[str] = State()
30 |
31 | state_value = "test_value"
32 |
33 | with _Container.state.init(state_value) as value:
34 | assert value == state_value
35 | assert await _Container.state.resolve() == state_value
36 |
37 |
38 | def test_state_set_resolves_correctly_sync() -> None:
39 | class _Container(BaseContainer):
40 | state: State[str] = State()
41 |
42 | state_value = "test_value"
43 |
44 | with _Container.state.init(state_value) as value:
45 | assert value == state_value
46 | assert _Container.state.resolve_sync() == state_value
47 |
48 |
49 | async def test_state_correctly_manages_its_context() -> None:
50 | async def _async_creator(x: int) -> int:
51 | await asyncio.sleep(random.random())
52 | return x
53 |
54 | class _Container(BaseContainer):
55 | state: State[int] = State()
56 | dependent = Factory(lambda x: x, state.cast)
57 | async_dependent = AsyncFactory(_async_creator, state.cast)
58 |
59 | async def _main(x: int) -> None:
60 | with _Container.state.init(x):
61 | dependent_value = await _Container.dependent.resolve()
62 | assert dependent_value == x
63 | async_dependent_value = await _Container.async_dependent.resolve()
64 | assert async_dependent_value == x
65 | new_x = x + random.randint(1, 1000)
66 | with _Container.state.init(new_x):
67 | dependent_value = await _Container.dependent.resolve()
68 | assert dependent_value == new_x
69 | async_dependent_value = await _Container.async_dependent.resolve()
70 | assert async_dependent_value == new_x
71 |
72 | dependent_value = await _Container.dependent.resolve()
73 | assert dependent_value == x
74 | async_dependent_value = await _Container.async_dependent.resolve()
75 | assert async_dependent_value == x
76 |
77 | tasks = [_main(x) for x in range(10)]
78 |
79 | await asyncio.gather(*tasks)
80 |
--------------------------------------------------------------------------------
/that_depends/integrations/fastapi.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from collections.abc import Callable
3 | from inspect import signature
4 |
5 | from fastapi import Depends, Request, Response
6 | from fastapi.routing import APIRoute
7 |
8 | from that_depends.injection import StringProviderDefinition
9 | from that_depends.providers import AbstractProvider
10 | from that_depends.providers.context_resources import ContextScope, SupportsContext, container_context
11 |
12 |
13 | P = typing.ParamSpec("P")
14 | T = typing.TypeVar("T")
15 |
16 |
17 | def _adjust_fastapi_endpoint(endpoint: Callable[P, T]) -> Callable[P, T]:
18 | hints = typing.get_type_hints(endpoint, include_extras=True)
19 | sig = signature(endpoint)
20 | new_params = []
21 | for name, param in sig.parameters.items():
22 | if isinstance(param.default, AbstractProvider):
23 | hints[name] = Depends(param.default)
24 | new_params.append(param.replace(default=Depends(param.default)))
25 | elif isinstance(param.default, StringProviderDefinition):
26 | provider = param.default.provider
27 | hints[name] = Depends(provider)
28 | new_params.append(param.replace(default=Depends(provider)))
29 | else:
30 | new_params.append(param)
31 | endpoint.__annotations__ = hints
32 | endpoint.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined]
33 | return endpoint
34 |
35 |
36 | def create_fastapi_route_class(
37 | *context_items: SupportsContext[typing.Any],
38 | global_context: dict[str, typing.Any] | None = None,
39 | scope: ContextScope | None = None,
40 | ) -> type[APIRoute]:
41 | """Create a `that-depends` fastapi route class.
42 |
43 | Args:
44 | *context_items: Items to initialize context for.
45 | global_context: Global context to use.
46 | scope: scope to enter before on request.
47 |
48 | Returns:
49 | type[APIRoute]: A custom fastapi route class.
50 |
51 | """
52 |
53 | class _Route(APIRoute):
54 | """Custom that-depends router for FastAPI."""
55 |
56 | def __init__(self, path: str, endpoint: Callable[..., typing.Any], **kwargs: typing.Any) -> None: # noqa: ANN401
57 | endpoint = _adjust_fastapi_endpoint(endpoint)
58 | super().__init__(path=path, endpoint=endpoint, **kwargs)
59 |
60 | def get_route_handler(self) -> Callable[[Request], typing.Coroutine[typing.Any, typing.Any, Response]]:
61 | original_route_handler = super().get_route_handler()
62 |
63 | async def _custom_route_handler(request: Request) -> Response:
64 | async with container_context(*context_items, global_context=global_context, scope=scope):
65 | response: Response = await original_route_handler(request)
66 | return response
67 |
68 | return _custom_route_handler if context_items or global_context or scope else original_route_handler
69 |
70 | return _Route
71 |
--------------------------------------------------------------------------------
/that_depends/entities/resource_context.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import contextlib
3 | import threading
4 | import typing
5 | import warnings
6 |
7 | from typing_extensions import override
8 |
9 | from that_depends.providers.mixin import CannotTearDownSyncError, SupportsTeardown
10 |
11 |
12 | T_co = typing.TypeVar("T_co", covariant=True)
13 |
14 |
15 | class ResourceContext(SupportsTeardown, typing.Generic[T_co]):
16 | """Class to manage a resources' context."""
17 |
18 | __slots__ = "asyncio_lock", "context_stack", "instance", "is_async", "threading_lock"
19 |
20 | def __init__(self, is_async: bool) -> None:
21 | """Create a new ResourceContext instance.
22 |
23 | Args:
24 | is_async (bool): Whether the ResourceContext was created in
25 | an async context.
26 | For example within a ``async with container_context(Container): ...`` statement.
27 |
28 | """
29 | self.instance: T_co | None = None
30 | self.asyncio_lock: typing.Final = asyncio.Lock()
31 | self.threading_lock: typing.Final = threading.Lock()
32 | self.context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None
33 | self.is_async = is_async
34 |
35 | @staticmethod
36 | def is_context_stack_async(
37 | context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None,
38 | ) -> typing.TypeGuard[contextlib.AsyncExitStack]:
39 | """Check if the context stack is an async context stack."""
40 | return isinstance(context_stack, contextlib.AsyncExitStack)
41 |
42 | @staticmethod
43 | def is_context_stack_sync(
44 | context_stack: contextlib.AsyncExitStack | contextlib.ExitStack,
45 | ) -> typing.TypeGuard[contextlib.ExitStack]:
46 | """Check if the context stack is a sync context stack."""
47 | return isinstance(context_stack, contextlib.ExitStack)
48 |
49 | @override
50 | async def tear_down(self, propagate: bool = True) -> None:
51 | """Tear down the async context stack."""
52 | if self.context_stack is None:
53 | return
54 |
55 | if self.is_context_stack_async(self.context_stack):
56 | await self.context_stack.aclose()
57 | elif self.is_context_stack_sync(self.context_stack):
58 | self.context_stack.close()
59 | self.context_stack = None
60 | self.instance = None
61 |
62 | @override
63 | def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None:
64 | """Tear down the sync context stack."""
65 | if self.context_stack is None:
66 | return
67 |
68 | if self.is_context_stack_sync(self.context_stack):
69 | self.context_stack.close()
70 | self.context_stack = None
71 | self.instance = None
72 | elif self.is_context_stack_async(self.context_stack):
73 | msg = "Cannot tear down async context in sync mode"
74 | if raise_on_async:
75 | raise CannotTearDownSyncError(msg)
76 | warnings.warn(msg, RuntimeWarning, stacklevel=2)
77 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "that-depends"
3 | description = "Simple Dependency Injection framework"
4 | authors = [
5 | { name = "Artur Shiriev", email = "me@shiriev.ru" },
6 | ]
7 | readme = "README.md"
8 | requires-python = ">=3.10,<4"
9 | license = "MIT"
10 | keywords = ["di", "dependency injector", "ioc-container", "mocks", "python"]
11 | classifiers = [
12 | "Programming Language :: Python :: 3.10",
13 | "Programming Language :: Python :: 3.11",
14 | "Programming Language :: Python :: 3.12",
15 | "Programming Language :: Python :: 3.13",
16 | "Programming Language :: Python :: 3.14",
17 | "Typing :: Typed",
18 | "Topic :: Software Development :: Libraries",
19 | ]
20 | version = "0"
21 |
22 | [project.optional-dependencies]
23 | fastapi = [
24 | "fastapi",
25 | ]
26 | faststream = [
27 | "faststream"
28 | ]
29 |
30 | [project.urls]
31 | repository = "https://github.com/modern-python/that-depends"
32 | docs = "https://that-depends.readthedocs.io"
33 |
34 | [dependency-groups]
35 | dev = [
36 | "httpx",
37 | "pytest",
38 | "pytest-cov",
39 | "pytest-asyncio",
40 | "pytest-repeat",
41 | "ruff",
42 | "mypy",
43 | "typing-extensions",
44 | "pre-commit",
45 | "litestar",
46 | "faststream[nats]",
47 | "mkdocs>=1.6.1",
48 | "pytest-randomly",
49 | "mkdocs-llmstxt>=0.4.0",
50 | "pydantic<2.12.4" # causes tests to fail because of: TypeError: _eval_type() got an unexpected keyword argument 'prefer_fwd_module'
51 | ]
52 |
53 | [build-system]
54 | requires = ["uv_build"]
55 | build-backend = "uv_build"
56 |
57 | [tool.uv.build-backend]
58 | module-name = "that_depends"
59 | module-root = ""
60 |
61 | [tool.mypy]
62 | python_version = "3.10"
63 | strict = true
64 |
65 | [tool.ruff]
66 | fix = true
67 | unsafe-fixes = true
68 | line-length = 120
69 | target-version = "py310"
70 | extend-exclude = [
71 | "docs",
72 | ]
73 |
74 | [tool.ruff.lint]
75 | select = ["ALL"]
76 | ignore = [
77 | "B008", # Do not perform function call `Provide` in argument defaults
78 | "D100", # ignore missing module docstrings.
79 | "D105", # ignore missing docstrings in magic methods.
80 | "S101", # allow asserts
81 | "TCH", # ignore flake8-type-checking
82 | "FBT", # allow boolean args
83 | "D203", # "one-blank-line-before-class" conflicting with D211
84 | "D213", # "multi-line-summary-second-line" conflicting with D212
85 | "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
86 | "COM812", # flake8-commas "Trailing comma missing"
87 | "ISC001", # flake8-implicit-str-concat
88 | "PT028", # Test function has default argument
89 | ]
90 | isort.lines-after-imports = 2
91 | isort.no-lines-before = ["standard-library", "local-folder"]
92 | per-file-ignores = { "tests/*"= ["D1", "SLF001"]}
93 |
94 | [tool.pytest.ini_options]
95 | addopts = "--cov=. --cov-report term-missing"
96 | asyncio_mode = "auto"
97 | asyncio_default_fixture_loop_scope = "function"
98 | filterwarnings = [
99 | "ignore::UserWarning",
100 | ]
101 |
102 | [tool.coverage.report]
103 | exclude_also = ["if typing.TYPE_CHECKING:"]
104 |
--------------------------------------------------------------------------------
/tests/test_container.py:
--------------------------------------------------------------------------------
1 | import random
2 | import typing
3 |
4 | from tests.container import DIContainer
5 | from that_depends import BaseContainer, providers
6 |
7 |
8 | def _sync_resource() -> typing.Iterator[float]:
9 | yield random.random()
10 |
11 |
12 | async def test_container_sync_teardown() -> None:
13 | await DIContainer.init_resources()
14 | DIContainer.tear_down_sync()
15 | for provider in DIContainer.providers.values():
16 | if isinstance(provider, providers.Resource):
17 | if provider._is_async:
18 | assert provider._context.instance is not None
19 | else:
20 | assert provider._context.instance is None
21 | if isinstance(provider, providers.Singleton):
22 | assert provider._instance is None
23 |
24 |
25 | async def test_container_tear_down() -> None:
26 | await DIContainer.init_resources()
27 | await DIContainer.tear_down()
28 | for provider in DIContainer.providers.values():
29 | if isinstance(provider, providers.Resource):
30 | assert provider._context.instance is None
31 | if isinstance(provider, providers.Singleton):
32 | assert provider._instance is None
33 |
34 |
35 | async def test_container_sync_tear_down_propagation() -> None:
36 | class _DependentContainer(BaseContainer):
37 | singleton = providers.Singleton(lambda: random.random())
38 | resource = providers.Resource(_sync_resource)
39 |
40 | DIContainer.connect_containers(_DependentContainer)
41 |
42 | await DIContainer.init_resources()
43 |
44 | assert isinstance(_DependentContainer.singleton._instance, float)
45 | assert isinstance(_DependentContainer.resource._context.instance, float)
46 |
47 | DIContainer.tear_down_sync()
48 |
49 | assert _DependentContainer.singleton._instance is None
50 | assert _DependentContainer.resource._context.instance is None
51 |
52 |
53 | async def test_container_tear_down_propagation() -> None:
54 | async def _async_singleton() -> float:
55 | return random.random()
56 |
57 | async def _async_resource() -> typing.AsyncIterator[float]:
58 | yield random.random()
59 |
60 | class _DependentContainer(BaseContainer):
61 | async_singleton = providers.AsyncSingleton(_async_singleton)
62 | async_resource = providers.Resource(_async_resource)
63 | sync_singleton = providers.Singleton(lambda: random.random())
64 | sync_resource = providers.Resource(_sync_resource)
65 |
66 | DIContainer.connect_containers(_DependentContainer)
67 |
68 | await DIContainer.init_resources()
69 |
70 | assert isinstance(_DependentContainer.async_singleton._instance, float)
71 | assert isinstance(_DependentContainer.async_resource._context.instance, float)
72 | assert isinstance(_DependentContainer.sync_singleton._instance, float)
73 | assert isinstance(_DependentContainer.sync_resource._context.instance, float)
74 |
75 | await DIContainer.tear_down()
76 |
77 | assert _DependentContainer.async_singleton._instance is None
78 | assert _DependentContainer.async_resource._context.instance is None
79 | assert _DependentContainer.sync_singleton._instance is None
80 | assert _DependentContainer.sync_resource._context.instance is None
81 |
--------------------------------------------------------------------------------
/docs/introduction/tear-down.md:
--------------------------------------------------------------------------------
1 | # Tear-down
2 |
3 | Certain providers in `that-depends` require explicit finalization. These providers
4 | implement the `SupportsTeardown` API. Containers also support finalizing their
5 | resources.
6 |
7 | ## Quick-start
8 |
9 | If you are using a provider that supports finalization such as a [Singleton](../providers/singleton.md):
10 |
11 | First, define the container & provider.
12 |
13 | ```python
14 | from that_depends import BaseContainer, providers
15 | import random
16 |
17 | class MyContainer(BaseContainer):
18 | config = providers.Singleton(lambda: random.random())
19 | ```
20 |
21 | Resolve the provider somewhere in your code
22 |
23 | ```python
24 | await MyContainer.config()
25 | ```
26 |
27 | When you want to reset the cached value:
28 |
29 | ```python
30 | await MyContainer.config.tear_down()
31 | ```
32 |
33 | For [Resources](../providers/resources.md) this will also call any finalization logic in your
34 | context manager.
35 |
36 | ---
37 |
38 | ## Propagation
39 |
40 | Per default `that-depends` will propagate tear-down to dependent providers.
41 |
42 | This means that if you have defined a provider `A` that is dependent on provider `B`,
43 | when calling `await B.tear_down()`, this will also execute `await A.tear_down()`.
44 |
45 | **For example:**
46 |
47 | ```python
48 | class MyContainer(BaseContainer):
49 | B = providers.Singleton(lambda: random.random())
50 | A = providers.Singleton(lambda x: x, B)
51 |
52 | b = await MyContainer.B()
53 | a = await MyContainer.A()
54 |
55 | assert a == b
56 |
57 | await MyContainer.B.tear_down()
58 | a_new = await MyContainer.A()
59 | assert a_new != a
60 | ```
61 |
62 | If you do not wish to propagate tear-down simply call `tear_down(propagate=False)` or `tear_down_sync(propagate=False)`.
63 |
64 | ---
65 |
66 | ## Sync tear-down
67 |
68 | If you need to call tear-down from a sync context you can use the `tear_down_sync()` method. However,
69 | keep in mind that because dependent resources might be async, this will fail to correctly finalize these async
70 | resources.
71 |
72 | Per default this will raise a `CannotTearDownSyncError`:
73 |
74 | ```python
75 | async def async_creator(val: float) -> typing.AsyncIterator[float]:
76 | yield val
77 | print("Finalization!")
78 |
79 |
80 | class MyContainer(BaseContainer):
81 | B = providers.Singleton(lambda: random.random())
82 | A = providers.Resource(async_creator, B.cast)
83 |
84 |
85 | b = await MyContainer.B()
86 | a = await MyContainer.A()
87 |
88 | MyContainer.B.tear_down_sync() # raises
89 | ```
90 |
91 | If you do not want to see these errors you can reduce this to a `RuntimeWarning`:
92 |
93 | ```python
94 | MyContainer.B.tear_down_sync(raise_on_async=False)
95 | ```
96 |
97 | ---
98 |
99 | ## Containers and tear-down
100 |
101 | Containers also support the `tear_down` and `tear_down_sync` methods. When calling
102 | `await Container.tear_down()` all providers in the container will be torn down.
103 |
104 |
105 | `Container.tear_down_sync()` is also implemented but not recommended unless you are sure
106 | all providers in your container are sync.
107 |
108 | > Container methods do not support the `propagate` & `raise_on_async` arguments, so if you
109 | > need more granular control try to tear down the providers explicitly.
110 |
--------------------------------------------------------------------------------
/that_depends/providers/collection.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | from typing_extensions import override
4 |
5 | from that_depends.providers.base import AbstractProvider
6 |
7 |
8 | T_co = typing.TypeVar("T_co", covariant=True)
9 |
10 |
11 | class List(AbstractProvider[list[T_co]]):
12 | """Provides multiple resources as a list.
13 |
14 | The `List` provider resolves multiple dependencies into a list.
15 |
16 | Example:
17 | ```python
18 | from that_depends import providers
19 |
20 | provider1 = providers.Factory(lambda: 1)
21 | provider2 = providers.Factory(lambda: 2)
22 |
23 | list_provider = List(provider1, provider2)
24 |
25 | # Synchronous resolution
26 | resolved_list = list_provider.resolve_sync()
27 | print(resolved_list) # Output: [1, 2]
28 |
29 | # Asynchronous resolution
30 | import asyncio
31 | resolved_list_async = asyncio.run(list_provider.resolve())
32 | print(resolved_list_async) # Output: [1, 2]
33 | ```
34 |
35 | """
36 |
37 | __slots__ = ("_providers",)
38 |
39 | def __init__(self, *providers: AbstractProvider[T_co]) -> None:
40 | """Create a new List provider instance.
41 |
42 | Args:
43 | *providers: List of providers to resolve.
44 |
45 | """
46 | super().__init__()
47 | self._providers: typing.Final = providers
48 | self._register(self._providers)
49 |
50 | @override
51 | def __getattr__(self, attr_name: str) -> typing.Any:
52 | msg = f"'{type(self)}' object has no attribute '{attr_name}'"
53 | raise AttributeError(msg)
54 |
55 | @override
56 | async def resolve(self) -> list[T_co]:
57 | return [await x.resolve() for x in self._providers]
58 |
59 | @override
60 | def resolve_sync(self) -> list[T_co]:
61 | return [x.resolve_sync() for x in self._providers]
62 |
63 | @override
64 | async def __call__(self) -> list[T_co]:
65 | return await self.resolve()
66 |
67 |
68 | class Dict(AbstractProvider[dict[str, T_co]]):
69 | """Provides multiple resources as a dictionary.
70 |
71 | The `Dict` provider resolves multiple named dependencies into a dictionary.
72 |
73 | Example:
74 | ```python
75 | from that_depends import providers
76 |
77 | provider1 = providers.Factory(lambda: 1)
78 | provider2 = providers.Factory(lambda: 2)
79 |
80 | dict_provider = Dict(key1=provider1, key2=provider2)
81 |
82 | # Synchronous resolution
83 | resolved_dict = dict_provider.resolve_sync()
84 | print(resolved_dict) # Output: {"key1": 1, "key2": 2}
85 |
86 | # Asynchronous resolution
87 | import asyncio
88 | resolved_dict_async = asyncio.run(dict_provider.resolve())
89 | print(resolved_dict_async) # Output: {"key1": 1, "key2": 2}
90 | ```
91 |
92 | """
93 |
94 | __slots__ = ("_providers",)
95 |
96 | def __init__(self, **providers: AbstractProvider[T_co]) -> None:
97 | """Create a new Dict provider instance.
98 |
99 | Args:
100 | **providers: Dictionary of providers to resolve.
101 |
102 | """
103 | super().__init__()
104 | self._providers: typing.Final = providers
105 | self._register(self._providers.values())
106 |
107 | @override
108 | def __getattr__(self, attr_name: str) -> typing.Any:
109 | msg = f"'{type(self)}' object has no attribute '{attr_name}'"
110 | raise AttributeError(msg)
111 |
112 | @override
113 | async def resolve(self) -> dict[str, T_co]:
114 | return {key: await provider.resolve() for key, provider in self._providers.items()}
115 |
116 | @override
117 | def resolve_sync(self) -> dict[str, T_co]:
118 | return {key: provider.resolve_sync() for key, provider in self._providers.items()}
119 |
--------------------------------------------------------------------------------
/that_depends/providers/selector.py:
--------------------------------------------------------------------------------
1 | """Selection based providers."""
2 |
3 | import typing
4 |
5 | from typing_extensions import override
6 |
7 | from that_depends.providers.base import AbstractProvider
8 |
9 |
10 | T_co = typing.TypeVar("T_co", covariant=True)
11 |
12 |
13 | class Selector(AbstractProvider[T_co]):
14 | """Chooses a provider based on a key returned by a selector function.
15 |
16 | This class allows you to dynamically select and resolve one of several
17 | named providers at runtime. The provider key is determined by a
18 | user-supplied selector function.
19 |
20 | Examples:
21 | ```python
22 | def environment_selector():
23 | return "local"
24 |
25 | selector_instance = Selector(
26 | environment_selector,
27 | local=LocalStorageProvider(),
28 | remote=RemoteStorageProvider(),
29 | )
30 |
31 | # Synchronously resolve the selected provider
32 | service = selector_instance.resolve_sync()
33 | ```
34 |
35 | """
36 |
37 | __slots__ = "_override", "_providers", "_selector"
38 |
39 | def __init__(
40 | self, selector: typing.Callable[[], str] | AbstractProvider[str] | str, **providers: AbstractProvider[T_co]
41 | ) -> None:
42 | """Initialize a new Selector instance.
43 |
44 | Args:
45 | selector (Callable[[], str]): A function that returns the key
46 | of the provider to use.
47 | **providers (AbstractProvider[T_co]): The named providers from
48 | which one will be selected based on the `selector`.
49 |
50 | Examples:
51 | ```python
52 | def my_selector():
53 | return "remote"
54 |
55 | my_selector_instance = Selector(
56 | my_selector,
57 | local=LocalStorageProvider(),
58 | remote=RemoteStorageProvider(),
59 | )
60 |
61 | # The "remote" provider will be selected
62 | selected_service = my_selector_instance.resolve_sync()
63 | ```
64 |
65 | """
66 | super().__init__()
67 | self._selector: typing.Final[typing.Callable[[], str] | AbstractProvider[str] | str] = selector
68 | self._providers: typing.Final = providers
69 |
70 | @override
71 | async def resolve(self) -> T_co:
72 | if self._override:
73 | return typing.cast(T_co, self._override)
74 |
75 | if isinstance(self._selector, AbstractProvider):
76 | selected_key = await self._selector.resolve()
77 | else:
78 | selected_key = self._get_selected_key()
79 |
80 | self._validate_key(selected_key)
81 |
82 | return await self._providers[selected_key].resolve()
83 |
84 | @override
85 | def resolve_sync(self) -> T_co:
86 | if self._override:
87 | return typing.cast(T_co, self._override)
88 |
89 | if isinstance(self._selector, AbstractProvider):
90 | selected_key = self._selector.resolve_sync()
91 | else:
92 | selected_key = self._get_selected_key()
93 |
94 | self._validate_key(selected_key)
95 |
96 | return self._providers[selected_key].resolve_sync()
97 |
98 | def _get_selected_key(self) -> str:
99 | if callable(self._selector):
100 | selected_key = self._selector()
101 | elif isinstance(self._selector, str):
102 | selected_key = self._selector
103 | else:
104 | msg = f"Invalid selector type: {type(self._selector)}, expected str, or a provider/callable returning str"
105 | raise TypeError(msg)
106 | return typing.cast(str, selected_key)
107 |
108 | def _validate_key(self, key: str) -> None:
109 | if key not in self._providers:
110 | msg = f"Key '{key}' not found in providers"
111 | raise KeyError(msg)
112 |
--------------------------------------------------------------------------------
/docs/providers/resources.md:
--------------------------------------------------------------------------------
1 | # Resource Provider
2 |
3 | A **Resource** is a special provider that:
4 |
5 | - **Resolves** its dependency only **once** and **caches** the resolved instance for future injections.
6 | - **Includes** teardown (finalization) logic, unlike a plain `Singleton`.
7 | - **Supports** generator or async generator functions for creation (allowing a `yield` plus teardown in `finally`).
8 | - **Also** allows usage of classes that implement standard Python context managers (`typing.ContextManager` or `typing.AsyncContextManager`), but *does not* automatically integrate with `container_context`.
9 |
10 | This makes `Resource` ideal for dependencies that need:
11 |
12 | 1. A **single creation** step,
13 | 2. A **single finalization** step,
14 | 3. **Thread/async safety**—all consumers receive the same resource object, and concurrency is handled.
15 |
16 | ---
17 |
18 | ## How It Works
19 |
20 | ### Defining a Sync or Async Resource
21 |
22 | You can define your creation logic as either a **generator** or a **context manager** class (sync or async).
23 |
24 | **Synchronous generator** example:
25 |
26 | ```python
27 | import typing
28 |
29 | def create_sync_resource() -> typing.Iterator[str]:
30 | print("Creating sync resource")
31 | try:
32 | yield "sync resource"
33 | finally:
34 | print("Tearing down sync resource")
35 | ```
36 |
37 | **Asynchronous generator** example:
38 |
39 | ```python
40 | import typing
41 |
42 | async def create_async_resource() -> typing.AsyncIterator[str]:
43 | print("Creating async resource")
44 | try:
45 | yield "async resource"
46 | finally:
47 | print("Tearing down async resource")
48 | ```
49 |
50 | You then attach them to a container:
51 |
52 | ```python
53 | from that_depends import BaseContainer
54 | from that_depends.providers import Resource
55 |
56 | class MyContainer(BaseContainer):
57 | sync_resource = Resource(create_sync_resource)
58 | async_resource = Resource(create_async_resource)
59 | ```
60 |
61 | ---
62 |
63 | ## Resolving and Teardown
64 |
65 | Once defined, you can explicitly **resolve** the resource and **tear it down**:
66 |
67 | ```python
68 | # Synchronous resource usage
69 | value_sync = MyContainer.sync_resource.resolve_sync()
70 | print(value_sync) # "sync resource"
71 | MyContainer.sync_resource.tear_down_sync()
72 |
73 | # Asynchronous resource usage
74 | import asyncio
75 |
76 |
77 | async def main():
78 | value_async = await MyContainer.async_resource.resolve()
79 | print(value_async) # "async resource"
80 | await MyContainer.async_resource.tear_down()
81 |
82 |
83 | asyncio.run(main())
84 | ```
85 |
86 | - **`resolve_sync()`** or **`resolve()`**: Creates (if needed) and returns the resource instance.
87 | - **`tear_down_sync()`** or **`tear_down()`**: Closes/cleans up the resource (triggering your `finally` block or exiting the context manager) and resets the cached instance to `None`. A subsequent resolve call will then recreate it.
88 |
89 | ---
90 |
91 | ## Concurrency Safety
92 |
93 | `Resource` is **safe** to use under **threading** and **asyncio** concurrency. Internally, a lock ensures only one resource instance is created per container:
94 |
95 | - Multiple threads calling `resolve_sync()` simultaneously will produce a **single** instance for that container.
96 | - Multiple coroutines calling `resolve()` simultaneously will likewise produce **only one** instance for that container in an async environment.
97 |
98 | ```python
99 | # Even if multiple coroutines call resolve in parallel,
100 | # only one instance is created at a time:
101 | await MyContainer.async_resource.resolve()
102 |
103 | # Similarly, multiple threads calling resolve_sync concurrently
104 | # still yield just one instance until teardown:
105 | MyContainer.sync_resource.resolve_sync()
106 | ```
107 |
108 | ---
109 |
110 | ## Using Context Managers Directly
111 |
112 | If your resource is a standard **context manager** or **async context manager** class, `Resource` will handle entering and exiting it under the hood. For example:
113 |
114 | ```python
115 | import typing
116 | from that_depends.providers import Resource
117 |
118 |
119 | class SyncFileManager:
120 | def __enter__(self) -> str:
121 | print("Opening file")
122 | return "/path/to/file"
123 |
124 | def __exit__(self, exc_type, exc_val, exc_tb):
125 | print("Closing file")
126 |
127 |
128 | sync_file_resource = Resource(SyncFileManager)
129 |
130 | # usage
131 | file_path = sync_file_resource.resolve_sync()
132 | print(file_path)
133 | sync_file_resource.tear_down_sync()
134 | ```
135 |
--------------------------------------------------------------------------------
/tests/integrations/litestar/test_litestar_di.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from functools import partial
3 | from typing import Annotated
4 | from unittest.mock import Mock
5 |
6 | from litestar import Controller, Litestar, Router, get
7 | from litestar.di import Provide
8 | from litestar.params import Dependency
9 | from litestar.status_codes import HTTP_200_OK
10 | from litestar.testing import TestClient
11 |
12 | from that_depends import BaseContainer, providers
13 |
14 |
15 | def bool_fn(value: bool) -> bool:
16 | return value
17 |
18 |
19 | def str_fn() -> str:
20 | return ""
21 |
22 |
23 | def list_fn() -> list[str]:
24 | return ["some"]
25 |
26 |
27 | def int_fn() -> int:
28 | return 1
29 |
30 |
31 | class SomeService:
32 | pass
33 |
34 |
35 | class DIContainer(BaseContainer):
36 | alias = "litestar_container"
37 | bool_fn = providers.Factory(bool_fn, value=False)
38 | str_fn = providers.Factory(str_fn)
39 | list_fn = providers.Factory(list_fn)
40 | int_fn = providers.Factory(int_fn)
41 | some_service = providers.Factory(SomeService)
42 |
43 |
44 | _NoValidationDependency = partial(Dependency, skip_validation=True)
45 |
46 |
47 | class MyController(Controller):
48 | path = "/controller"
49 | dependencies = {"controller_dependency": Provide(DIContainer.list_fn)} # noqa: RUF012
50 |
51 | @get(
52 | path="/handler",
53 | dependencies={
54 | "local_dependency": Provide(DIContainer.int_fn),
55 | },
56 | )
57 | async def my_route_handler(
58 | self,
59 | app_dependency: bool,
60 | router_dependency: str,
61 | controller_dependency: list[str],
62 | local_dependency: int,
63 | ) -> dict[str, typing.Any]:
64 | return {
65 | "app_dependency": app_dependency,
66 | "router_dependency": router_dependency,
67 | "controller_dependency": controller_dependency,
68 | "local_dependency": local_dependency,
69 | }
70 |
71 | @get(path="/mock_overriding", dependencies={"some_service": Provide(DIContainer.some_service)})
72 | async def mock_overriding_endpoint_handler(
73 | self, some_service: Annotated[SomeService, _NoValidationDependency()]
74 | ) -> None:
75 | assert isinstance(some_service, Mock)
76 |
77 |
78 | my_router = Router(
79 | path="/router",
80 | dependencies={"router_dependency": Provide(DIContainer.str_fn)},
81 | route_handlers=[MyController],
82 | )
83 |
84 | # on the app
85 | app = Litestar(route_handlers=[my_router], dependencies={"app_dependency": Provide(DIContainer.bool_fn)}, debug=True)
86 |
87 |
88 | def test_litestar_endpoint_with_mock_overriding() -> None:
89 | some_service_mock = Mock()
90 |
91 | with DIContainer.some_service.override_context_sync(some_service_mock), TestClient(app=app) as client:
92 | response = client.get("/router/controller/mock_overriding")
93 |
94 | assert response.status_code == HTTP_200_OK
95 |
96 |
97 | def test_litestar_di() -> None:
98 | with TestClient(app=app) as client:
99 | response = client.get("/router/controller/handler")
100 | assert response.status_code == HTTP_200_OK, response.text
101 | assert response.json() == {
102 | "app_dependency": False,
103 | "controller_dependency": ["some"],
104 | "local_dependency": 1,
105 | "router_dependency": "",
106 | }
107 |
108 |
109 | def test_litestar_di_override_fail_on_provider_override() -> None:
110 | mock = 12345364758999
111 | with TestClient(app=app) as client, DIContainer.int_fn.override_context_sync(mock):
112 | response = client.get("/router/controller/handler")
113 |
114 | assert response.status_code == HTTP_200_OK, response.text
115 | assert response.json() == {
116 | "app_dependency": False,
117 | "controller_dependency": ["some"],
118 | "local_dependency": mock,
119 | "router_dependency": "",
120 | }
121 |
122 |
123 | def test_litestar_di_override_fail_on_override_providers() -> None:
124 | mock = 12345364758999
125 | overrides = {
126 | "int_fn": mock,
127 | }
128 | with TestClient(app=app) as client, DIContainer.override_providers_sync(overrides):
129 | response = client.get("/router/controller/handler")
130 |
131 | assert response.status_code == HTTP_200_OK, response.text
132 | assert response.json() == {
133 | "app_dependency": False,
134 | "controller_dependency": ["some"],
135 | "local_dependency": mock,
136 | "router_dependency": "",
137 | }
138 |
--------------------------------------------------------------------------------
/that_depends/providers/resources.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | from typing_extensions import override
4 |
5 | from that_depends.entities.resource_context import ResourceContext
6 | from that_depends.providers.base import AbstractResource, ResourceCreatorType
7 | from that_depends.providers.mixin import SupportsTeardown
8 |
9 |
10 | T_co = typing.TypeVar("T_co", covariant=True)
11 | P = typing.ParamSpec("P")
12 |
13 |
14 | class Resource(SupportsTeardown, AbstractResource[T_co]):
15 | """Provides a resource that is resolved once and cached for future usage.
16 |
17 | Unlike a singleton, this provider includes finalization logic and can be
18 | used with a generator or async generator to manage resource lifecycle.
19 | It also supports usage with `typing.ContextManager` or `typing.AsyncContextManager`.
20 | Threading and asyncio concurrency are supported, ensuring only one instance
21 | is created regardless of concurrent resolves.
22 |
23 | Example:
24 | ```python
25 | async def create_async_resource():
26 | try:
27 | yield "async resource"
28 | finally:
29 | # Finalize resource
30 | pass
31 |
32 | class MyContainer:
33 | async_resource = Resource(create_async_resource)
34 |
35 | async def main():
36 | async_resource_instance = await MyContainer.async_resource.resolve()
37 | await MyContainer.async_resource.tear_down()
38 | ```
39 |
40 | """
41 |
42 | __slots__ = (
43 | "_args",
44 | "_context",
45 | "_creator",
46 | "_creator",
47 | "_is_async",
48 | "_kwargs",
49 | "_override",
50 | )
51 |
52 | def __init__(
53 | self,
54 | creator: ResourceCreatorType[P, T_co],
55 | *args: P.args,
56 | **kwargs: P.kwargs,
57 | ) -> None:
58 | """Initialize the Resource provider with a callable for resource creation.
59 |
60 | The callable can be a generator or async generator that yields the resource
61 | (with optional teardown logic), or a context manager. Only one instance will be
62 | created and cached until explicitly torn down.
63 |
64 | Args:
65 | creator: The callable, generator, or context manager that creates the resource.
66 | *args: Positional arguments passed to the creator.
67 | **kwargs: Keyword arguments passed to the creator.
68 |
69 | Example:
70 | ```python
71 | def custom_creator(name: str):
72 | try:
73 | yield f"Resource created for {name}"
74 | finally:
75 | pass # Teardown
76 |
77 | resource_provider = Resource(custom_creator, "example")
78 | instance = resource_provider.sync_resolve()
79 | resource_provider.tear_down()
80 | ```
81 |
82 | """
83 | super().__init__(creator, *args, **kwargs)
84 | self._context: typing.Final[ResourceContext[T_co]] = ResourceContext(is_async=self.is_async)
85 |
86 | def _fetch_context(self) -> ResourceContext[T_co]:
87 | return self._context
88 |
89 | @override
90 | async def tear_down(self, propagate: bool = True) -> None:
91 | """Tear down the resource if it has been created.
92 |
93 | If the resource was never resolved, or was already torn down,
94 | calling this method has no effect.
95 |
96 | Example:
97 | ```python
98 | # Assuming my_provider was previously resolved
99 | await my_provider.tear_down()
100 | ```
101 |
102 | """
103 | await self._fetch_context().tear_down()
104 | self._deregister_arguments()
105 | if propagate:
106 | await self._tear_down_children()
107 |
108 | @override
109 | def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None:
110 | """Sync tear down the resource if it has been created.
111 |
112 | If the resource was never resolved, or was already torn down,
113 | calling this method has no effect.
114 |
115 | If you try to sync tear down an async resource, this will raise an exception.
116 |
117 | Example:
118 | ```python
119 | # Assuming my_provider was previously resolved
120 | my_provider.sync_tear_down()
121 | ```
122 |
123 | """
124 | self._fetch_context().tear_down_sync(propagate=propagate, raise_on_async=raise_on_async)
125 | self._deregister_arguments()
126 | if propagate:
127 | self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async)
128 |
--------------------------------------------------------------------------------
/docs/migration/v3.md:
--------------------------------------------------------------------------------
1 | # Migrating from 2.\* to 3.\*
2 |
3 | ## How to Read This Guide
4 |
5 | This guide is intended to help you migrate existing functionality from `that-depends` version `2.*` to `3.*`.
6 | The goal is to enable you to migrate as quickly as possible while making only the minimal necessary changes to your codebase.
7 |
8 | If you want to learn more about the new features introduced in `3.*`, please refer to the [documentation](https://that-depends.readthedocs.io/) and the [release notes](https://github.com/modern-python/that-depends/releases).
9 |
10 | ---
11 |
12 | ## Deprecated or Removed Features
13 |
14 | ### **`container_context()`** can no longer be initialized without arguments.
15 |
16 | Previously the following code would reset the context for all all providers in all containers:
17 | ```python
18 | async with container_context():
19 | # all ContextResources have been reset
20 | ...
21 | ```
22 | This was done to enable easy migration and compatibility with `1.*`.
23 |
24 | Now `container_context` must be called with at least 1 argument or keyword-argument.
25 | Thus if you want to reset the context for all containers you need to provide them explicity:
26 | ```python
27 | async with container_context(MyContainer_1, MyContainer_2, ...):
28 | ...
29 | ```
30 |
31 | ---
32 |
33 | ### **`container_context()`** no longer accepts `reset_all_containers` keyword argument.
34 |
35 | You can no longer reset the context for all containers by using the `container_context` context manager.
36 | Previously you could have done something like this:
37 | ```python
38 | async with container_context(reset_all_containers=True):
39 | # all ContextResources have been reset
40 | ...
41 | ```
42 | Now you will need to explicitly pass the containers to the container context:
43 | ```python
44 | async with container_context(MyContainer_1, MyContainer_2, ...):
45 | ...
46 | ```
47 |
48 | ---
49 |
50 | ### **`@inject(scope=...)`** no longer enters the scope.
51 |
52 | The `@inject` decorator no longer enters the scope specified in the `scope` argument.
53 |
54 | In `2.*` the provided scope would be entered before the function was called:
55 | ```python
56 | @inject(scope=ContextScopes.REQUEST)
57 | def injected(...): ...
58 | assert get_current_scope() == ContextScopes.REQUEST
59 | ```
60 |
61 | In `3.*` the scope is only used to resolve relevant dependencies that match the provided scope.
62 | Thus, to achieve the same behaviour as in `2.*` you need to set `enter_scope=True`:
63 | ```python
64 | @inject(scope=ContextScopes.REQUEST, enter_scope=True)
65 | def injected(...): ...
66 | assert get_current_scope() == ContextScopes.REQUEST
67 | ```
68 | For further details, please refer to the [scopes documentation](../introduction/scopes.md#named-scopes-with-the-inject-wrapper)
69 |
70 | ---
71 |
72 | ## Changes in the API
73 |
74 | ### Changes to naming of methods.
75 |
76 | You can expect the default implementation of provider and container methods to be async.
77 | This means that methods **not** explicitly ending with `_sync` are normally async.
78 |
79 | This has also introduced new interfaces for operations where part of the implementation was missing (only async or sync implementation was provided in `2.*`).
80 |
81 |
82 | For example `tear_down()` is now async per default and `tear_down_sync()` has been introduced.
83 |
84 | Other examples of similar changes include:
85 |
86 | - `.async_resolve()` -> `.resolve()`
87 | - `.sync_resolve()` -> `.resolve_sync()`
88 |
89 | ---
90 |
91 | ### Tear down propagation enabled per default.
92 |
93 | Tear down is now propagated to all dependencies by default.
94 |
95 | Previously if you called tear down for a provider this only reset the given provider.
96 | ```python
97 | await provider.tear_down()
98 | ```
99 | In order to maintain this behaviour in `3.*`:
100 | ```python
101 | await provider.tear_down(propagate=False)
102 | ```
103 | For more details regarding tear-down propagation see the [documentation](../introduction/tear-down.md).
104 |
105 | ---
106 |
107 | ### Overriding is now async per default.
108 |
109 | As mentioned [above](#changes-to-naming-of-methods), `.override()` methods are now async per default.
110 |
111 | Thus any instances of the following in `2.*`:
112 | ```python
113 | provider.override(value)
114 | ```
115 | Should be changed to the following in `3.*`:
116 | ```python
117 | provider.override_sync(value)
118 | ```
119 |
120 | This is also the case for the following methods:
121 |
122 | - `.override_context()` -> `.override_context_sync()`
123 | - `.reset_override()` -> `.reset_override_sync()`
124 |
125 | > **Note:** Overrides now support tear-down, read more in the [documentation](../testing/provider-overriding.md)
126 |
127 | ---
128 |
129 | ## Further Help
130 |
131 | If you continue to experience issues during migration, consider creating a [discussion](https://github.com/modern-python/that-depends/discussions) or opening an [issue](https://github.com/modern-python/that-depends/issues).
132 |
--------------------------------------------------------------------------------
/that_depends/providers/local_singleton.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import threading
3 | import typing
4 |
5 | from typing_extensions import override
6 |
7 | from that_depends.providers import AbstractProvider
8 | from that_depends.providers.mixin import ProviderWithArguments, SupportsTeardown
9 |
10 |
11 | T_co = typing.TypeVar("T_co", covariant=True)
12 | P = typing.ParamSpec("P")
13 |
14 |
15 | class ThreadLocalSingleton(ProviderWithArguments, SupportsTeardown, AbstractProvider[T_co]):
16 | """Creates a new instance for each thread using a thread-local store.
17 |
18 | This provider ensures that each thread gets its own instance, which is
19 | created via the specified factory function. Once created, the instance is
20 | cached for future injections within the same thread.
21 |
22 | Example:
23 | ```python
24 | def factory():
25 | return random.randint(1, 100)
26 |
27 | singleton = ThreadLocalSingleton(factory)
28 |
29 | # Same thread, same instance
30 | instance1 = singleton.resolve_sync()
31 | instance2 = singleton.resolve_sync()
32 |
33 | def thread_task():
34 | return singleton.resolve_sync()
35 |
36 | threads = [threading.Thread(target=thread_task) for i in range(10)]
37 | for thread in threads:
38 | thread.start() # Each thread will get a different instance
39 | ```
40 |
41 | """
42 |
43 | def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None:
44 | """Initialize the ThreadLocalSingleton provider.
45 |
46 | Args:
47 | factory: A callable that returns a new instance of the dependency.
48 | *args: Positional arguments to pass to the factory.
49 | **kwargs: Keyword arguments to pass to the factory
50 |
51 | """
52 | super().__init__()
53 | self._factory: typing.Final = factory
54 | self._thread_local = threading.local()
55 | self._asyncio_lock = asyncio.Lock()
56 | self._args: typing.Final = args
57 | self._kwargs: typing.Final = kwargs
58 |
59 | def _register_arguments(self) -> None:
60 | self._register(self._args)
61 | self._register(self._kwargs.values())
62 |
63 | def _deregister_arguments(self) -> None:
64 | self._deregister(self._args)
65 | self._deregister(self._kwargs.values())
66 |
67 | @property
68 | def _instance(self) -> T_co | None:
69 | return getattr(self._thread_local, "instance", None)
70 |
71 | @_instance.setter
72 | def _instance(self, value: T_co | None) -> None:
73 | self._thread_local.instance = value
74 |
75 | @override
76 | async def resolve(self) -> T_co:
77 | if self._override is not None:
78 | return typing.cast(T_co, self._override)
79 |
80 | async with self._asyncio_lock:
81 | if self._instance is not None:
82 | return self._instance
83 |
84 | self._register_arguments()
85 |
86 | self._instance = self._factory(
87 | *[await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type]
88 | **{ # type: ignore[arg-type]
89 | k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()
90 | },
91 | )
92 | return self._instance
93 |
94 | @override
95 | def resolve_sync(self) -> T_co:
96 | if self._override is not None:
97 | return typing.cast(T_co, self._override)
98 |
99 | if self._instance is not None:
100 | return self._instance
101 |
102 | self._register_arguments()
103 |
104 | self._instance = self._factory(
105 | *[x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args], # type: ignore[arg-type]
106 | **{k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}, # type: ignore[arg-type]
107 | )
108 | return self._instance
109 |
110 | @override
111 | def tear_down_sync(self, propagate: bool = True, raise_on_async: bool = True) -> None:
112 | if self._instance is not None:
113 | self._instance = None
114 | self._deregister_arguments()
115 | if propagate:
116 | self._tear_down_children_sync(propagate=propagate, raise_on_async=raise_on_async)
117 |
118 | @override
119 | async def tear_down(self, propagate: bool = True) -> None:
120 | """Reset the thread-local instance.
121 |
122 | After calling this method, subsequent calls to `resolve_sync()` on the
123 | same thread will produce a new instance.
124 | """
125 | if self._instance is not None:
126 | self._instance = None
127 | self._deregister_arguments()
128 | if propagate:
129 | await self._tear_down_children()
130 |
--------------------------------------------------------------------------------
/docs/introduction/generator-injection.md:
--------------------------------------------------------------------------------
1 | # Injection into Generator Functions
2 |
3 |
4 | `that-depends` supports dependency injections into generator functions. However, this comes
5 | with some minor limitations compared to regular functions.
6 |
7 |
8 | ## Quickstart
9 |
10 | You can use the `@inject` decorator to inject dependencies into generator functions:
11 |
12 | === "async generator"
13 |
14 | ```python
15 | @inject
16 | async def my_generator(value: str = Provide[Container.factory]) -> typing.AsyncGenerator[str, None]:
17 | yield value
18 | ```
19 |
20 | === "sync generator"
21 |
22 | ```python
23 | @inject
24 | def my_generator(value: str = Provide[Container.factory]) -> typing.Generator[str, None, None]:
25 | yield value
26 | ```
27 |
28 | === "async context manager"
29 |
30 | ```python
31 | @contextlib.asynccontextmanager
32 | @inject
33 | async def my_generator(value: str = Provide[Container.factory]) -> typing.AsyncIterator[str]:
34 | yield value
35 | ```
36 |
37 | === "sync context manager"
38 |
39 | ```python
40 | @contextlib.contextmanager
41 | @inject
42 | def my_generator(value: str = Provide[Container.factory]) -> typing.Iterator[str]:
43 | yield value
44 | ```
45 |
46 | ## Supported Generators
47 |
48 | ### Synchronous Generators
49 |
50 | `that-depends` supports injection into sync generator functions with the following signature:
51 |
52 | ```python
53 | Callable[P, Generator[, , ]]
54 | ```
55 |
56 | This means that wrapping a sync generator with `@inject` will always preserve all the behaviour of the wrapped generator:
57 |
58 | - It will yield as expected
59 | - It will accept sending values via `send()`
60 | - It will raise `StopIteration` when the generator is exhausted or otherwise returns.
61 |
62 |
63 | ### Asynchronous Generators
64 |
65 | `that-depends` supports injection into async generator functions with the following signature:
66 |
67 | ```python
68 | Callable[P, AsyncGenerator[, None]]
69 | ```
70 |
71 | This means that wrapping an async generator with `@inject` will have the following effects:
72 |
73 | - The generator will yield as expected
74 | - The generator will **not** accept values via `asend()`
75 |
76 | If you need to send values to an async generator, you can simply resolve dependencies in the generator body:
77 |
78 | ```python
79 |
80 | async def my_generator() -> typing.AsyncGenerator[float, float]:
81 | value = await Container.factory.resolve()
82 | receive = yield value # (1)!
83 | yield receive + value
84 |
85 | ```
86 |
87 | 1. This receive will always be `None` if you would wrap this generator with @inject.
88 |
89 |
90 |
91 | ## ContextResources
92 |
93 | `that-depends` will **not** allow context initialization for [ContextResource](../providers/context-resources.md) providers
94 | as part of dependency injection into a generator.
95 |
96 | This is the case for both async and sync injection.
97 |
98 | **For example:**
99 | ```python
100 | def sync_resource() -> typing.Iterator[float]:
101 | yield random.random()
102 |
103 | class Container(BaseContainer):
104 | sync_provider = providers.ContextResource(sync_resource).with_config(scope=ContextScopes.INJECT)
105 | dependent_provider = providers.Factory(lambda x: x, sync_provider.cast)
106 |
107 | @inject(scope=ContextScopes.INJECT) # (1)!
108 | def injected(val: float = Provide[Container.dependent_provider]) -> typing.Generator[float, None, None]:
109 | yield val
110 |
111 | # This will raise a `ContextProviderError`!
112 | next(_injected())
113 | ```
114 |
115 | 1. Matches context scope of `sync_provider` provider, which is a dependency of the `dependent_provider` provider.
116 |
117 |
118 | When calling `next(injected())`, `that-depends` will try to initialize a new context for the `sync_provider`,
119 | however, this is not permitted for generators, thus it will raise a `ContextProviderError`.
120 |
121 |
122 | Keep in mind that if context does not need to be initialized, the generator injection will work as expected:
123 |
124 | ```python
125 | def sync_resource() -> typing.Iterator[float]:
126 | yield random.random()
127 |
128 | class Container(BaseContainer):
129 | sync_provider = providers.ContextResource(sync_resource).with_config(scope=ContextScopes.REQUEST)
130 | dependent_provider = providers.Factory(lambda x: x, sync_provider.cast)
131 |
132 | @inject(scope=ContextScopes.INJECT) # (1)!
133 | def injected(val: float = Provide[Container.dependent_provider]) -> typing.Generator[float, None, None]:
134 | yield val
135 |
136 |
137 | with container_context(scope=ContextScopes.REQUEST):
138 | # This will resolve as expected
139 | next(_injected())
140 | ```
141 |
142 | Since no context initialization was needed, the generator will work as expected.
143 |
144 | 1. Scope provided to `@inject` no longer matches scope of the `sync_provider`
145 |
146 |
147 | ### Container Context
148 |
149 | Similarly to above, the `@container_context` also does **not** support generators:
150 |
151 | ```python
152 | @container_context(Container)
153 | async def my_generator() -> typing.AsyncIterator[None]:
154 | yield
155 | ```
156 | The above code will raise a `UserWarning`.
157 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: that-depends
2 | repo_url: https://github.com/modern-python/that-depends
3 | site_url: https://that-depends.readthedocs.io/
4 | docs_dir: docs
5 | edit_uri: edit/main/docs/
6 | nav:
7 | - Quick-Start: index.md
8 | - Introduction:
9 | - Containers: introduction/ioc-container.md
10 | - Dependency Injection: introduction/injection.md
11 | - Generator Injection: introduction/generator-injection.md
12 | - Type-based Injection: introduction/type-based-injection.md
13 | - String Injection: introduction/string-injection.md
14 | - Scopes: introduction/scopes.md
15 | - Tear-down: introduction/tear-down.md
16 | - Multiple Containers: introduction/multiple-containers.md
17 |
18 | - Providers:
19 | - Collections: providers/collections.md
20 | - Context-Resources: providers/context-resources.md
21 | - Factories: providers/factories.md
22 | - Object: providers/object.md
23 | - Resources: providers/resources.md
24 | - Selector: providers/selector.md
25 | - Singletons: providers/singleton.md
26 | - State: providers/state.md
27 | - Experimental Features:
28 | - Lazy Provider: experimental/lazy.md
29 |
30 | - Integrations:
31 | - FastAPI: integrations/fastapi.md
32 | - FastStream: integrations/faststream.md
33 | - Litestar: integrations/litestar.md
34 | - Testing:
35 | - Fixtures: testing/fixture.md
36 | - Overriding: testing/provider-overriding.md
37 | - Migration:
38 | - 1.* to 2.*: migration/v2.md
39 | - 2.* to 3.*: migration/v3.md
40 | - Development:
41 | - Contributing: dev/contributing.md
42 | - Decisions: dev/main-decisions.md
43 |
44 | theme:
45 | name: material
46 | custom_dir: docs/overrides
47 | features:
48 | - content.code.copy
49 | - content.code.annotate
50 | - content.action.edit
51 | - content.action.view
52 | - navigation.footer
53 | - navigation.sections
54 | - navigation.expand
55 | - navigation.top
56 | - navigation.instant
57 | - header.autohide
58 | - announce.dismiss
59 | icon:
60 | edit: material/pencil
61 | view: material/eye
62 | repo: fontawesome/brands/git-alt
63 | palette:
64 | - media: "(prefers-color-scheme: light)"
65 | scheme: default
66 | primary: black
67 | accent: pink
68 | toggle:
69 | icon: material/brightness-7
70 | name: Switch to dark mode
71 | - media: "(prefers-color-scheme: dark)"
72 | scheme: slate
73 | primary: black
74 | accent: pink
75 | toggle:
76 | icon: material/brightness-4
77 | name: Switch to system preference
78 |
79 | markdown_extensions:
80 | - toc:
81 | permalink: true
82 | - pymdownx.highlight:
83 | anchor_linenums: true
84 | line_spans: __span
85 | pygments_lang_class: true
86 | - pymdownx.inlinehilite
87 | - pymdownx.snippets
88 | - pymdownx.superfences
89 | - def_list
90 | - codehilite:
91 | use_pygments: true
92 | - attr_list
93 | - md_in_html
94 | - pymdownx.superfences
95 | - pymdownx.tabbed:
96 | alternate_style: true
97 |
98 | extra_css:
99 | - css/code.css
100 | plugins:
101 | - llmstxt:
102 | full_output: llms-full.txt
103 | markdown_description: "A simple dependency injection framework for Python."
104 | sections:
105 | Quickstart:
106 | - index.md: How to install and basic usage example
107 | Containers:
108 | - introduction/ioc-container.md: Defining the DI Container
109 | - introduction/multiple-containers.md: Using multiple DI Containers
110 | Injecting Dependencies:
111 | - introduction/injection.md: Default way to inject dependencies
112 | - introduction/generator-injection.md: Injecting dependencies into generators
113 | - introduction/type-based-injection.md: Injecting dependencies by bound type
114 | - introduction/string-injection.md: Injecting dependencies by providing a reference to a Provider by string
115 | Simple Providers:
116 | - providers/singleton.md: The Singleton and AsyncSingleton provider
117 | - providers/factories.md: The Factory and AsyncFactory provider
118 | - providers/object.md: The Object provider
119 | - providers/resources.md: The Resource and AsyncResource provider
120 | - introduction/tear-down.md: Handling tear down for Singletons and Resources
121 | Context Providers:
122 | - providers/context-resources.md: The ContextResource provider
123 | - introduction/scopes.md: Handling scopes with ContextResources
124 | - providers/state.md: The State provider
125 | Providers that interact with other providers:
126 | - providers/collections.md: The Dict and List provider
127 | - providers/selector.md: The Selector provider
128 | Integrations with other Frameworks:
129 | - integrations/fastapi.md: Integration with FastApi
130 | - integrations/faststream.md: Integration with FastStream
131 | - integrations/litestar.md: Integration with LiteStar
132 | Testing:
133 | - testing/fixture.md: Usage with fixtures
134 | - testing/provider-overriding.md: Overriding providers
135 | Experimental features:
136 | - experimental/lazy.md: The Lazy provider
137 |
138 | extra:
139 | social:
140 | - icon: fontawesome/brands/github
141 | link: https://github.com/modern-python/that-depends
142 | name: GitHub
143 |
--------------------------------------------------------------------------------
/that_depends/integrations/faststream.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from importlib.metadata import version
3 | from types import TracebackType
4 | from typing import Any, Final, Optional
5 |
6 | from packaging.version import Version
7 | from typing_extensions import deprecated, override
8 |
9 | from that_depends import container_context
10 | from that_depends.providers.context_resources import ContextScope, SupportsContext
11 | from that_depends.utils import UNSET, Unset, is_set
12 |
13 |
14 | _FASTSTREAM_MODULE_NAME: Final[str] = "faststream"
15 | _FASTSTREAM_VERSION: Final[str] = version(_FASTSTREAM_MODULE_NAME)
16 | if Version(_FASTSTREAM_VERSION) >= Version("0.6.0"): # pragma: no cover
17 | from faststream import BaseMiddleware, ContextRepo
18 | from faststream._internal.types import AnyMsg
19 |
20 | class DIContextMiddleware(BaseMiddleware):
21 | """Initializes the container context for faststream brokers."""
22 |
23 | def __init__(
24 | self,
25 | *context_items: SupportsContext[Any],
26 | msg: AnyMsg | None = None,
27 | context: Optional["ContextRepo"] = None,
28 | global_context: dict[str, Any] | Unset = UNSET,
29 | scope: ContextScope | Unset = UNSET,
30 | ) -> None:
31 | """Initialize the container context middleware.
32 |
33 | Args:
34 | *context_items (SupportsContext[Any]): Context items to initialize.
35 | msg (Any): Message object.
36 | context (ContextRepo): Context repository.
37 | global_context (dict[str, Any] | Unset): Global context to initialize the container.
38 | scope (ContextScope | Unset): Context scope to initialize the container.
39 |
40 | """
41 | super().__init__(msg, context=context) # type: ignore[arg-type]
42 | self._context: container_context | None = None
43 | self._context_items = set(context_items)
44 | self._global_context = global_context
45 | self._scope = scope
46 |
47 | @override
48 | async def on_receive(self) -> None:
49 | self._context = container_context(
50 | *self._context_items,
51 | scope=self._scope if is_set(self._scope) else None,
52 | global_context=self._global_context if is_set(self._global_context) else None,
53 | )
54 | await self._context.__aenter__()
55 |
56 | @override
57 | async def after_processed(
58 | self,
59 | exc_type: type[BaseException] | None = None,
60 | exc_val: BaseException | None = None,
61 | exc_tb: Optional["TracebackType"] = None,
62 | ) -> bool | None:
63 | if self._context is not None:
64 | await self._context.__aexit__(exc_type, exc_val, exc_tb)
65 | return None
66 |
67 | def __call__(self, msg: Any = None, **kwargs: Any) -> "DIContextMiddleware": # noqa: ANN401
68 | """Create an instance of DIContextMiddleware.
69 |
70 | Args:
71 | msg (Any): Message object.
72 | **kwargs: Additional keyword arguments.
73 |
74 | Returns:
75 | DIContextMiddleware: A new instance of DIContextMiddleware.
76 |
77 | """
78 | context = kwargs.get("context")
79 |
80 | return DIContextMiddleware(
81 | *self._context_items,
82 | msg=msg,
83 | context=context,
84 | scope=self._scope,
85 | global_context=self._global_context,
86 | )
87 | else: # pragma: no cover
88 | from faststream import BaseMiddleware
89 |
90 | @deprecated("Will be removed with faststream v1")
91 | class DIContextMiddleware(BaseMiddleware): # type: ignore[no-redef]
92 | """Initializes the container context for faststream brokers."""
93 |
94 | def __init__(
95 | self,
96 | *context_items: SupportsContext[Any],
97 | global_context: dict[str, Any] | Unset = UNSET,
98 | scope: ContextScope | Unset = UNSET,
99 | ) -> None:
100 | """Initialize the container context middleware.
101 |
102 | Args:
103 | *context_items (SupportsContext[Any]): Context items to initialize.
104 | global_context (dict[str, Any] | Unset): Global context to initialize the container.
105 | scope (ContextScope | Unset): Context scope to initialize the container.
106 |
107 | """
108 | super().__init__() # type: ignore[call-arg]
109 | self._context: container_context | None = None
110 | self._context_items = set(context_items)
111 | self._global_context = global_context
112 | self._scope = scope
113 |
114 | @override
115 | async def on_receive(self) -> None:
116 | self._context = container_context(
117 | *self._context_items,
118 | scope=self._scope if is_set(self._scope) else None,
119 | global_context=self._global_context if is_set(self._global_context) else None,
120 | )
121 | await self._context.__aenter__()
122 |
123 | @override
124 | async def after_processed(
125 | self,
126 | exc_type: type[BaseException] | None = None,
127 | exc_val: BaseException | None = None,
128 | exc_tb: Optional["TracebackType"] = None,
129 | ) -> bool | None:
130 | if self._context is not None:
131 | await self._context.__aexit__(exc_type, exc_val, exc_tb)
132 | return None
133 |
134 | def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> "DIContextMiddleware": # noqa: ARG002, ANN401
135 | """Create an instance of DIContextMiddleware."""
136 | return DIContextMiddleware(*self._context_items, scope=self._scope, global_context=self._global_context)
137 |
--------------------------------------------------------------------------------
/tests/experimental/test_container_2.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections.abc import AsyncIterator, Iterator
3 |
4 | import pytest
5 | import typing_extensions
6 |
7 | from that_depends import BaseContainer, ContextScopes, container_context, providers
8 | from that_depends.experimental import LazyProvider
9 |
10 |
11 | class _RandomWrapper:
12 | def __init__(self) -> None:
13 | self.value = random.random()
14 |
15 | @typing_extensions.override
16 | def __eq__(self, other: object) -> bool:
17 | if isinstance(other, _RandomWrapper):
18 | return self.value == other.value
19 | return False # pragma: nocover
20 |
21 | def __hash__(self) -> int:
22 | return 0 # pragma: nocover
23 |
24 |
25 | async def _async_creator() -> AsyncIterator[float]:
26 | yield random.random()
27 |
28 |
29 | def _sync_creator() -> Iterator[_RandomWrapper]:
30 | yield _RandomWrapper()
31 |
32 |
33 | class Container2(BaseContainer):
34 | """Test Container 2."""
35 |
36 | alias = "container_2"
37 | default_scope = ContextScopes.APP
38 | obj_1 = LazyProvider("tests.experimental.test_container_1.Container1.obj_1")
39 | obj_2 = providers.Object(2)
40 | async_context_provider = providers.ContextResource(_async_creator)
41 | sync_context_provider = providers.ContextResource(_sync_creator)
42 | singleton_provider = providers.Singleton(lambda: random.random())
43 |
44 |
45 | async def test_lazy_provider_resolution_async() -> None:
46 | assert await Container2.obj_1.resolve() == 1
47 |
48 |
49 | def test_lazy_provider_override_sync() -> None:
50 | override_value = 42
51 | Container2.obj_1.override_sync(override_value)
52 | assert Container2.obj_1.resolve_sync() == override_value
53 | Container2.obj_1.reset_override_sync()
54 | assert Container2.obj_1.resolve_sync() == 1
55 |
56 |
57 | async def test_lazy_provider_override_async() -> None:
58 | override_value = 42
59 | await Container2.obj_1.override(override_value)
60 | assert await Container2.obj_1.resolve() == override_value
61 | await Container2.obj_1.reset_override()
62 | assert await Container2.obj_1.resolve() == 1
63 |
64 |
65 | def test_lazy_provider_invalid_state() -> None:
66 | lazy_provider = LazyProvider(
67 | module_string="tests.experimental.test_container_2", provider_string="Container2.sync_context_provider"
68 | )
69 | lazy_provider._module_string = None
70 | with pytest.raises(RuntimeError):
71 | lazy_provider.resolve_sync()
72 |
73 |
74 | async def test_lazy_provider_context_resource_async() -> None:
75 | lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.async_context_provider")
76 | async with lazy_provider.context_async(force=True):
77 | assert await lazy_provider.resolve() == await Container2.async_context_provider.resolve()
78 | async with Container2.async_context_provider.context_async(force=True):
79 | assert await lazy_provider.resolve() == await Container2.async_context_provider.resolve()
80 |
81 | with pytest.raises(RuntimeError):
82 | await lazy_provider.resolve()
83 |
84 | async with container_context(Container2, scope=ContextScopes.APP):
85 | assert await lazy_provider.resolve() == await Container2.async_context_provider.resolve()
86 |
87 | assert lazy_provider.get_scope() == ContextScopes.APP
88 |
89 | assert lazy_provider.supports_context_sync() is False
90 |
91 |
92 | def test_lazy_provider_context_resource_sync() -> None:
93 | lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.sync_context_provider")
94 | with lazy_provider.context_sync(force=True):
95 | assert lazy_provider.resolve_sync() == Container2.sync_context_provider.resolve_sync()
96 | with Container2.sync_context_provider.context_sync(force=True):
97 | assert lazy_provider.resolve_sync() == Container2.sync_context_provider.resolve_sync()
98 |
99 | with pytest.raises(RuntimeError):
100 | lazy_provider.resolve_sync()
101 |
102 | with container_context(Container2, scope=ContextScopes.APP):
103 | assert lazy_provider.resolve_sync() == Container2.sync_context_provider.resolve_sync()
104 |
105 | assert lazy_provider.get_scope() == ContextScopes.APP
106 |
107 | assert lazy_provider.supports_context_sync() is True
108 |
109 |
110 | async def test_lazy_provider_tear_down_async() -> None:
111 | lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.singleton_provider")
112 | assert lazy_provider.resolve_sync() == Container2.singleton_provider.resolve_sync()
113 |
114 | await lazy_provider.tear_down()
115 |
116 | assert await lazy_provider.resolve() == Container2.singleton_provider.resolve_sync()
117 |
118 |
119 | def test_lazy_provider_tear_down_sync() -> None:
120 | lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.singleton_provider")
121 | assert lazy_provider.resolve_sync() == Container2.singleton_provider.resolve_sync()
122 |
123 | lazy_provider.tear_down_sync()
124 |
125 | assert lazy_provider.resolve_sync() == Container2.singleton_provider.resolve_sync()
126 |
127 |
128 | async def test_lazy_provider_not_implemented() -> None:
129 | lazy_provider = Container2.obj_1
130 | with pytest.raises(NotImplementedError):
131 | lazy_provider.get_scope()
132 | with pytest.raises(NotImplementedError):
133 | lazy_provider.context_sync()
134 | with pytest.raises(NotImplementedError):
135 | lazy_provider.context_async()
136 | with pytest.raises(NotImplementedError):
137 | lazy_provider.tear_down_sync()
138 | with pytest.raises(NotImplementedError):
139 | await lazy_provider.tear_down()
140 |
141 |
142 | def test_lazy_provider_attr_getter() -> None:
143 | lazy_provider = LazyProvider("tests.experimental.test_container_2.Container2.sync_context_provider")
144 | with lazy_provider.context_sync(force=True):
145 | assert isinstance(lazy_provider.value.resolve_sync(), float)
146 |
--------------------------------------------------------------------------------
/docs/introduction/injection.md:
--------------------------------------------------------------------------------
1 | # Injecting Providers in **that-depends**
2 |
3 | `that-depends` uses a decorator-based approach for both synchronous and asynchronous functions. By decorating a function with `@inject` and marking certain parameters as `Provide[...]`, **that-depends** will automatically resolve the specified providers at call time.
4 |
5 | ---
6 |
7 | ## Overview
8 |
9 | In **that-depends**, you define your dependencies as `AbstractProvider` instances—e.g., `Singleton`, `Factory`, `Resource`, or others. These providers typically live inside a subclass of `BaseContainer`, making them globally accessible.
10 |
11 | When you want to use a provider in a function, you can mark a parameter’s **default value** as:
12 |
13 | ```python
14 | my_param = Provide[MyContainer.some_provider]
15 | ```
16 |
17 | You then decorate the function with `@inject`. This tells `that-depends` to automatically resolve providers when you call your function.
18 |
19 | ---
20 |
21 | ## Quick Start
22 |
23 | Below is a simple example demonstrating how to define a container, declare a provider, and inject that provider into a function.
24 |
25 | ### 1. Define a Container and a Provider
26 |
27 | ```python
28 | from that_depends import BaseContainer
29 | from that_depends.providers import Singleton
30 |
31 | class MyContainer(BaseContainer):
32 | greeting_provider = Singleton(lambda: "Hello from MyContainer")
33 | ```
34 | For more details on Containers, refer to the [Containers](ioc-container.md) documentation.
35 |
36 | ### 2. Inject the Provider into a Function
37 |
38 | ```python
39 | from that_depends import inject, Provide
40 |
41 | @inject
42 | def greet_user(greeting: str = Provide[MyContainer.greeting_provider]) -> str:
43 | return f"Greeting: {greeting}"
44 | ```
45 |
46 | Here:
47 |
48 | 1. We used `@inject` above `greet_user`.
49 | 2. We declared a parameter `greeting`, whose default value is `Provide[MyContainer.greeting_provider]`.
50 |
51 | ### 3. Call the Function
52 |
53 | ```python
54 | print(greet_user()) # "Greeting: Hello from MyContainer"
55 | ```
56 |
57 | ---
58 |
59 | ## The `@inject` Decorator in Detail
60 |
61 |
62 | ### Synchronous vs Asynchronous Functions
63 |
64 | `@inject` works on both sync and async functions. Just note that injecting async providers into sync functions is not supported.
65 |
66 | ```python
67 | @inject
68 | async def async_greet_user(greeting: str = Provide[MyContainer.greeting_provider]) -> str:
69 | # asynchronous operations...
70 | return f"Greeting: {greeting}"
71 | ```
72 |
73 | ---
74 |
75 | ## Using `Provide[...]` as a Default
76 |
77 | It is recommended to wrap your provider in `Provide[...]` when using it as a default in an injected function since it provides correct type resolution:
78 |
79 | ```python
80 | @inject
81 | def greet_user_direct(
82 | greeting: str = Provide[MyContainer.greeting_provider] # (1)!
83 | ) -> str:
84 | return f"Greeting: {greeting}"
85 | ```
86 |
87 | 1. Notice that although `greeting` is a `str`, `mypy` and you IDE will not complain.
88 |
89 | ---
90 |
91 | ## Injection Warnings
92 |
93 | If `@inject` finds **no** parameters whose default values are providers, it will issue a warning:
94 |
95 | > `Expected injection, but nothing found. Remove @inject decorator.`
96 |
97 | This is to avoid accidentally decorating a function that doesn’t actually require injection.
98 |
99 | ---
100 |
101 | ## Specifying a Scope
102 |
103 | By default, `@inject` uses the `ContextScopes.INJECT` scope. If you want to override that, do:
104 |
105 | ```python
106 | from that_depends import inject
107 | from that_depends.providers.context_resources import ContextScopes
108 |
109 | @inject(scope=ContextScopes.REQUEST)
110 | def greet_user(greeting: str = Provide[MyContainer.greeting_provider]):
111 | ...
112 | ```
113 |
114 | When `greet_user` is called, **that-depends**:
115 |
116 | 1. Initializes the context for all `REQUEST` (or `ANY`) scoped `args` and `kwargs`.
117 | 2. Resolves all providers in the `args` and `kwargs` of the function.
118 | 3. Calls your function with the resolved dependencies.
119 |
120 | For more details regarding scopes and context management, see the [Context Resources](../providers/context-resources.md) documentation and the [Scopes](scopes.md) documentation.
121 |
122 | ---
123 |
124 | ## Overriding Providers
125 |
126 | In tests or specialized scenarios, you may want to override a provider’s value temporarily. You can do so with the container’s `override_providers()` method or the provider’s own `override_context()`:
127 |
128 | ```python
129 | def test_greet_override():
130 | # Override the greeting_provider with a mock value
131 | with MyContainer.override_providers_sync({"greeting_provider": "TestHello"}):
132 | result = greet_user()
133 | assert result == "Greeting: TestHello"
134 | ```
135 |
136 | This is especially helpful for unit tests where you want to substitute real dependencies (e.g., database connections) with mocks or stubs.
137 |
138 | For more details on overring providers, see the [Overriding Providers](../testing/provider-overriding.md) documentation.
139 |
140 | ---
141 |
142 | ## Frequently Asked Questions
143 |
144 | 1. **Do I need to call `@inject` every time I reference a provider?**
145 | No—only when you want **automatic** injection of providers into function parameters. If you are resolving dependencies manually (e.g., `MyContainer.greeting_provider.resolve_sync()`), then `@inject` is not needed.
146 |
147 | 2. **What if I provide a custom argument to a parameter that has a default provider?**
148 | If you explicitly pass a value, that value overrides the injected default:
149 |
150 | ~~~~python
151 | @inject
152 | def foo(x: int = Provide[MyContainer.number_factory]) -> int:
153 | return x
154 |
155 | print(foo()) # uses number_factory -> 42
156 | print(foo(99)) # explicitly uses 99
157 | ~~~~
158 |
159 | 3. **Can I combine `@inject` with other decorators?**
160 | Yes, you can. Generally, put `@inject` **below** others, depending on the order you need. If you run into issues, experiment with the order or handle context manually.
161 |
162 | ---
163 |
--------------------------------------------------------------------------------
/docs/providers/factories.md:
--------------------------------------------------------------------------------
1 | # Factories
2 | Factories are initialized on every call.
3 |
4 | ## Factory
5 | - Class or simple function is allowed.
6 | ```python
7 | import dataclasses
8 |
9 | from that_depends import BaseContainer, providers
10 |
11 |
12 | @dataclasses.dataclass(kw_only=True, slots=True)
13 | class IndependentFactory:
14 | dep1: str
15 | dep2: int
16 |
17 |
18 | class DIContainer(BaseContainer):
19 | independent_factory = providers.Factory(IndependentFactory, dep1="text", dep2=123)
20 | ```
21 |
22 | ## AsyncFactory
23 | - Allows both sync and async creators.
24 | - Can only be resolved asynchronously, even if the creator is sync.
25 | ```python
26 | import datetime
27 |
28 | from that_depends import BaseContainer, providers
29 |
30 |
31 | async def async_factory() -> datetime.datetime:
32 | return datetime.datetime.now(tz=datetime.timezone.utc)
33 |
34 |
35 | class DIContainer(BaseContainer):
36 | async_factory = providers.AsyncFactory(async_factory)
37 | ```
38 |
39 | > Note: If you have a class that has dependencies which need to be resolved asynchronously, you can use `AsyncFactory` to create instances of that class. The factory will handle the async resolution of dependencies.
40 |
41 |
42 | ## Retrieving provider as a Callable
43 |
44 | When you use a factory‑based provider such as `Factory` (for sync logic) or `AsyncFactory` (for async logic), the resulting provider instance has two special properties:
45 |
46 | - **`.provider`** — returns an *async callable* that, when awaited, resolves the resource.
47 | - **`.provider_sync`** — returns a *sync callable* that, when called, resolves the resource.
48 |
49 | You can think of these as no-argument functions that produce the resource you defined—similar to calling `resolve()` or `resolve_sync()` directly, but in a more convenient form when you want a standalone function handle.
50 |
51 | ---
52 |
53 | ### Basic Usage
54 |
55 | #### Defining Providers in a Container
56 |
57 | Suppose you have a `BaseContainer` subclass that defines both a sync and an async resource:
58 |
59 | ```python
60 | from that_depends import BaseContainer
61 | from that_depends.providers import Factory, AsyncFactory
62 |
63 | def build_sync_message() -> str:
64 | return "Hello from sync provider!"
65 |
66 | async def build_async_message() -> str:
67 | # Possibly perform async setup, I/O, etc.
68 | return "Hello from async provider!"
69 |
70 | class MyContainer(BaseContainer):
71 | # Synchronous Factory
72 | sync_message = Factory(build_sync_message)
73 |
74 | # Asynchronous Factory
75 | async_message = AsyncFactory(build_async_message)
76 | ```
77 |
78 | Here, `sync_message` is a `Factory` which calls a plain function, while `async_message` is an `AsyncFactory` which calls an async function.
79 |
80 | ---
81 |
82 | #### Resolving Resources via `.provider` and `.provider_sync`
83 |
84 | The `.provider` property gives you an *async function* to await, and `.provider_sync` gives you a *synchronous* callable. They effectively wrap `.resolve()` and `.resolve_sync()`.
85 |
86 | **Synchronous Resolution**
87 |
88 | ```python
89 | # In a synchronous function or interactive session
90 | >>> msg = MyContainer.sync_message.provider_sync
91 | >>> print(msg)
92 | Hello from sync provider!
93 | ```
94 |
95 | Here, `provider_sync` is a no-argument function that immediately returns the resolved value.
96 |
97 | **Asynchronous Resolution**
98 |
99 | ```python
100 | import asyncio
101 |
102 | async def main():
103 | # Acquire the async resource by awaiting the provider property
104 | msg = await MyContainer.async_message.provider
105 | print(msg)
106 |
107 | asyncio.run(main())
108 | ```
109 |
110 | Within an async function, `MyContainer.async_message.provider` gives a no-argument async function to `await`.
111 |
112 | ---
113 |
114 | ### Passing the Provider Function Around
115 |
116 | Sometimes you may want to store or pass around the provider function itself (rather than resolving it immediately):
117 |
118 | ```python
119 | class AnotherClass:
120 | def __init__(self, sync_factory_callable: callable):
121 | self._factory_callable = sync_factory_callable
122 |
123 | def get_message(self) -> str:
124 | return self._factory_callable()
125 |
126 |
127 | # Passing MyContainer.sync_message.sync_provider to AnotherClass
128 | provider_callable = MyContainer.sync_message.provider_sync
129 | another_instance = AnotherClass(provider_callable)
130 | print(another_instance.get_message()) # "Hello from sync provider!"
131 | ```
132 |
133 | Because `.provider_sync` is just a callable returning your dependency, it can be shared easily throughout your code.
134 |
135 | ---
136 |
137 | ### Example: Using Factories with Parameters
138 |
139 | `Factory` and `AsyncFactory` can accept dependencies (including other providers) as parameters:
140 |
141 | ```python
142 | def greet(name: str) -> str:
143 | return f"Hello, {name}!"
144 |
145 | class MyContainer(BaseContainer):
146 | name = Factory(lambda: "Alice")
147 | greeting = Factory(greet, name)
148 | ```
149 |
150 | ```python
151 | >>> greeting_sync_fn = MyContainer.greeting.provider_sync
152 | >>> print(greeting_sync_fn())
153 | Hello, Alice!
154 | ```
155 |
156 | Under the hood, `greeting` calls `greet` with the result of `name.resolve_sync()`.
157 |
158 | ---
159 |
160 | ### Context Considerations
161 |
162 | If your providers use `ContextResource` or require a named scope (for instance, `REQUEST`), you need to wrap your resolves in a context manager:
163 |
164 | ```python
165 | from that_depends.providers import container_context, ContextScopes
166 |
167 |
168 | class ContextfulContainer(BaseContainer):
169 | default_scope = ContextScopes.REQUEST
170 | # ... define context-based providers ...
171 |
172 |
173 | with container_context(ContextfulContainer, scope=ContextScopes.REQUEST):
174 | result = ContextfulContainer.some_resource.provider_sync()
175 | # ...
176 | ```
177 |
178 | You still call `.provider_sync` or `.provider`, but the container or context usage ensures resources are valid within the required scope.
179 |
180 | This pattern simplifies passing creation logic around in your code, preserving testability and clarity—whether you need sync or async behavior.
181 |
--------------------------------------------------------------------------------
/tests/providers/test_selector.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import typing
4 |
5 | import pydantic
6 | import pytest
7 |
8 | from tests.container import create_async_resource, create_sync_resource
9 | from that_depends import BaseContainer, providers
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 | global_state_for_selector: typing.Literal["sync_resource", "async_resource", "missing"] = "sync_resource"
14 |
15 |
16 | class SelectorState:
17 | def __init__(self) -> None:
18 | self.selector_state: typing.Literal["sync_resource", "async_resource", "missing"] = "sync_resource"
19 |
20 | def get_selector_state(self) -> typing.Literal["sync_resource", "async_resource", "missing"]:
21 | return self.selector_state
22 |
23 |
24 | selector_state = SelectorState()
25 |
26 |
27 | class DIContainer(BaseContainer):
28 | alias = "selector_container"
29 | sync_resource = providers.Resource(create_sync_resource)
30 | async_resource = providers.Resource(create_async_resource)
31 | selector: providers.Selector[datetime.datetime] = providers.Selector(
32 | selector_state.get_selector_state,
33 | sync_resource=sync_resource,
34 | async_resource=async_resource,
35 | )
36 |
37 |
38 | async def test_selector_provider_async() -> None:
39 | selector_state.selector_state = "async_resource"
40 | selected = await DIContainer.selector()
41 | async_resource = await DIContainer.async_resource()
42 |
43 | assert selected == async_resource
44 |
45 |
46 | async def test_selector_provider_async_missing() -> None:
47 | selector_state.selector_state = "missing"
48 | with pytest.raises(KeyError, match="Key 'missing' not found in providers"):
49 | await DIContainer.selector()
50 |
51 |
52 | async def test_selector_provider_sync() -> None:
53 | selector_state.selector_state = "sync_resource"
54 | selected = DIContainer.selector.resolve_sync()
55 | sync_resource = DIContainer.sync_resource.resolve_sync()
56 |
57 | assert selected == sync_resource
58 |
59 |
60 | async def test_selector_provider_sync_missing() -> None:
61 | selector_state.selector_state = "missing"
62 | with pytest.raises(KeyError, match="Key 'missing' not found in providers"):
63 | DIContainer.selector.resolve_sync()
64 |
65 |
66 | async def test_selector_provider_overriding() -> None:
67 | now = datetime.datetime.now(tz=datetime.timezone.utc)
68 | DIContainer.selector.override_sync(now)
69 | selected_async = await DIContainer.selector()
70 | selected_sync = DIContainer.selector.resolve_sync()
71 | assert selected_async == selected_sync == now
72 |
73 | DIContainer.reset_override_sync()
74 | selector_state.selector_state = "sync_resource"
75 | selected = DIContainer.selector.resolve_sync()
76 | sync_resource = DIContainer.sync_resource.resolve_sync()
77 | assert selected == sync_resource
78 |
79 |
80 | class Settings(pydantic.BaseModel):
81 | mode: typing.Literal["one", "two"] = "one"
82 |
83 |
84 | class SettingsSelectorContainer(BaseContainer):
85 | settings = providers.Singleton(Settings)
86 | operator = providers.Selector(
87 | settings.cast.mode,
88 | one=providers.Singleton(lambda: "Provider 1"),
89 | two=providers.Singleton(lambda: "Provider 2"),
90 | )
91 |
92 |
93 | async def test_selector_with_attrgetter_selector_sync() -> None:
94 | assert SettingsSelectorContainer.settings.resolve_sync().mode == "one"
95 | assert SettingsSelectorContainer.operator.resolve_sync() == "Provider 1"
96 |
97 | SettingsSelectorContainer.settings.override_sync(Settings(mode="two"))
98 | assert SettingsSelectorContainer.settings.resolve_sync().mode == "two"
99 | assert SettingsSelectorContainer.operator.resolve_sync() == "Provider 2"
100 | SettingsSelectorContainer.settings.reset_override_sync()
101 |
102 |
103 | async def test_selector_with_attrgetter_selector_async() -> None:
104 | assert (await SettingsSelectorContainer.settings.resolve()).mode == "one"
105 | assert (await SettingsSelectorContainer.operator.resolve()) == "Provider 1"
106 |
107 | SettingsSelectorContainer.settings.override_sync(Settings(mode="two"))
108 | assert (await SettingsSelectorContainer.settings.resolve()).mode == "two"
109 | assert (await SettingsSelectorContainer.operator.resolve()) == "Provider 2"
110 | SettingsSelectorContainer.settings.reset_override_sync()
111 |
112 |
113 | class StringSelectorContainer(BaseContainer):
114 | selector = providers.Selector(
115 | "one",
116 | one=providers.Singleton(lambda: "Provider 1"),
117 | two=providers.Singleton(lambda: "Provider 2"),
118 | )
119 |
120 |
121 | async def test_selector_with_fixed_string() -> None:
122 | assert StringSelectorContainer.selector.resolve_sync() == "Provider 1"
123 | assert (await StringSelectorContainer.selector.resolve()) == "Provider 1"
124 |
125 |
126 | async def mode_one() -> str:
127 | return "one"
128 |
129 |
130 | class StringProviderSelectorContainer(BaseContainer):
131 | mode = providers.AsyncFactory(mode_one)
132 | selector = providers.Selector(
133 | mode.cast,
134 | one=providers.Singleton(lambda: "Provider 1"),
135 | two=providers.Singleton(lambda: "Provider 2"),
136 | )
137 |
138 |
139 | async def test_selector_with_provider_selector_async() -> None:
140 | assert (await StringProviderSelectorContainer.selector.resolve()) == "Provider 1"
141 |
142 |
143 | class InvalidSelectorContainer(BaseContainer):
144 | selector = providers.Selector(
145 | None, # type: ignore[arg-type]
146 | one=providers.Singleton(lambda: "Provider 1"),
147 | two=providers.Singleton(lambda: "Provider 2"),
148 | )
149 |
150 |
151 | async def test_selector_with_invalid_selector() -> None:
152 | with pytest.raises(
153 | TypeError, match="Invalid selector type: , expected str, or a provider/callable returning str"
154 | ):
155 | InvalidSelectorContainer.selector.resolve_sync()
156 |
157 | with pytest.raises(
158 | TypeError, match="Invalid selector type: , expected str, or a provider/callable returning str"
159 | ):
160 | await InvalidSelectorContainer.selector.resolve()
161 |
--------------------------------------------------------------------------------
/that_depends/container.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import typing
3 | from contextlib import contextmanager
4 | from typing import overload
5 |
6 | from typing_extensions import override
7 |
8 | from that_depends.meta import BaseContainerMeta
9 | from that_depends.providers import AbstractProvider, AsyncSingleton, Resource, Singleton
10 | from that_depends.providers.context_resources import ContextScope, ContextScopes
11 |
12 |
13 | if typing.TYPE_CHECKING:
14 | import typing_extensions
15 |
16 |
17 | T = typing.TypeVar("T")
18 | P = typing.ParamSpec("P")
19 |
20 |
21 | class BaseContainer(metaclass=BaseContainerMeta):
22 | """Base container class."""
23 |
24 | alias: str | None = None
25 | providers: dict[str, AbstractProvider[typing.Any]]
26 | containers: list[type["BaseContainer"]]
27 | default_scope: ContextScope | None = ContextScopes.ANY
28 |
29 | @classmethod
30 | @overload
31 | def context(cls, func: typing.Callable[P, T]) -> typing.Callable[P, T]: ...
32 |
33 | @classmethod
34 | @overload
35 | def context(cls, *, force: bool = False) -> typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: ...
36 |
37 | @classmethod
38 | def context(
39 | cls, func: typing.Callable[P, T] | None = None, force: bool = False
40 | ) -> typing.Callable[P, T] | typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]:
41 | """Wrap a function with this resources' context.
42 |
43 | Args:
44 | func: function to be wrapped.
45 | force: force context initialization, ignoring scope.
46 |
47 | Returns:
48 | wrapped function or wrapper if func is None.
49 |
50 | """
51 |
52 | def _wrapper(func: typing.Callable[P, T]) -> typing.Callable[P, T]:
53 | if inspect.iscoroutinefunction(func):
54 |
55 | async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
56 | async with cls.context_async(force=force):
57 | return await func(*args, **kwargs) # type: ignore[no-any-return]
58 |
59 | return typing.cast(typing.Callable[P, T], _async_wrapper)
60 |
61 | def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
62 | with cls.context_sync(force=force):
63 | return func(*args, **kwargs)
64 |
65 | return _sync_wrapper
66 |
67 | if func:
68 | return _wrapper(func)
69 | return _wrapper
70 |
71 | @override
72 | def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self":
73 | msg = f"{cls.__name__} should not be instantiated"
74 | raise RuntimeError(msg)
75 |
76 | @classmethod
77 | def connect_containers(cls, *containers: type["BaseContainer"]) -> None:
78 | """Connect containers.
79 |
80 | When `init_resources` and `tear_down` is called,
81 | same method of connected containers will also be called.
82 | """
83 | if not hasattr(cls, "containers"):
84 | cls.containers = []
85 |
86 | cls.containers.extend(containers)
87 |
88 | @classmethod
89 | async def init_resources(cls) -> None:
90 | """Initialize all resources."""
91 | for provider in cls.get_providers().values():
92 | if isinstance(provider, Resource | Singleton | AsyncSingleton):
93 | await provider.resolve()
94 |
95 | for container in cls.get_containers():
96 | await container.init_resources()
97 |
98 | @classmethod
99 | def reset_override_sync(cls) -> None:
100 | """Reset all provider overrides."""
101 | for v in cls.get_providers().values():
102 | v.reset_override_sync()
103 |
104 | @classmethod
105 | def resolver(cls, item: typing.Callable[P, T]) -> typing.Callable[[], typing.Awaitable[T]]:
106 | """Decorate a function to automatically resolve dependencies on call by name.
107 |
108 | Args:
109 | item: objects for which the dependencies should be resolved.
110 |
111 | Returns:
112 | Async wrapped callable with auto-injected dependencies.
113 |
114 | """
115 |
116 | async def _inner() -> T:
117 | return await cls.resolve(item)
118 |
119 | return _inner
120 |
121 | @classmethod
122 | async def resolve(cls, object_to_resolve: typing.Callable[..., T]) -> T:
123 | """Inject dependencies into an object automatically by name."""
124 | signature: typing.Final = inspect.signature(object_to_resolve)
125 | kwargs = {}
126 | providers: typing.Final = cls.get_providers()
127 | for field_name, field_value in signature.parameters.items():
128 | if field_value.default is not inspect.Parameter.empty or field_name in ("_", "__"):
129 | continue
130 |
131 | if field_name not in providers:
132 | msg = f"Provider is not found, {field_name=}"
133 | raise RuntimeError(msg)
134 |
135 | kwargs[field_name] = await providers[field_name].resolve()
136 |
137 | return object_to_resolve(**kwargs)
138 |
139 | @classmethod
140 | @contextmanager
141 | def override_providers_sync(cls, providers_for_overriding: dict[str, typing.Any]) -> typing.Iterator[None]:
142 | """Override several providers with mocks simultaneously.
143 |
144 | Args:
145 | providers_for_overriding: {provider_name: mock} dictionary.
146 |
147 | Returns:
148 | None
149 |
150 | """
151 | current_providers: typing.Final = cls.get_providers()
152 | current_provider_names: typing.Final = set(current_providers.keys())
153 | given_provider_names: typing.Final = set(providers_for_overriding.keys())
154 |
155 | for given_name in given_provider_names:
156 | if given_name not in current_provider_names:
157 | msg = f"Provider with name {given_name!r} not found"
158 | raise RuntimeError(msg)
159 |
160 | for provider_name, mock in providers_for_overriding.items():
161 | provider = current_providers[provider_name]
162 | provider.override_sync(mock)
163 |
164 | try:
165 | yield
166 | finally:
167 | for provider_name in providers_for_overriding:
168 | provider = current_providers[provider_name]
169 | provider.reset_override_sync()
170 |
--------------------------------------------------------------------------------
/tests/providers/test_singleton.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import dataclasses
3 | import threading
4 | import time
5 | import typing
6 | from concurrent.futures import ThreadPoolExecutor, as_completed
7 |
8 | import pydantic
9 | import pytest
10 |
11 | from that_depends import BaseContainer, providers
12 |
13 |
14 | @dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
15 | class SingletonFactory:
16 | dep1: str
17 |
18 |
19 | class Settings(pydantic.BaseModel):
20 | some_setting: str = "some_value"
21 | other_setting: str = "other_value"
22 |
23 |
24 | async def create_async_obj(value: str) -> SingletonFactory:
25 | await asyncio.sleep(0.001)
26 | return SingletonFactory(dep1=f"async {value}")
27 |
28 |
29 | async def _async_creator() -> int:
30 | await asyncio.sleep(0.001)
31 | return threading.get_ident()
32 |
33 |
34 | def _sync_creator_with_dependency(dep: int) -> str:
35 | return f"Singleton {dep}"
36 |
37 |
38 | class DIContainer(BaseContainer):
39 | alias = "singleton_container"
40 | factory: providers.AsyncFactory[int] = providers.AsyncFactory(_async_creator)
41 | settings: Settings = providers.Singleton(Settings).cast
42 | singleton = providers.Singleton(SingletonFactory, dep1=settings.some_setting)
43 | singleton_async = providers.AsyncSingleton(create_async_obj, value=settings.some_setting)
44 | singleton_with_dependency = providers.Singleton(_sync_creator_with_dependency, dep=factory.cast)
45 |
46 |
47 | async def test_singleton_provider() -> None:
48 | singleton1 = await DIContainer.singleton()
49 | singleton2 = await DIContainer.singleton()
50 | singleton3 = DIContainer.singleton.resolve_sync()
51 | await DIContainer.singleton.tear_down()
52 | singleton4 = DIContainer.singleton.resolve_sync()
53 |
54 | assert singleton1 is singleton2 is singleton3
55 | assert singleton4 is not singleton1
56 |
57 | await DIContainer.tear_down()
58 |
59 |
60 | async def test_singleton_attr_getter() -> None:
61 | singleton1 = await DIContainer.singleton()
62 |
63 | assert singleton1.dep1 == Settings().some_setting
64 |
65 | await DIContainer.tear_down()
66 |
67 |
68 | async def test_singleton_with_empty_list() -> None:
69 | singleton = providers.Singleton(list[object])
70 |
71 | singleton1 = await singleton()
72 | singleton2 = await singleton()
73 | assert singleton1 is singleton2
74 |
75 |
76 | @pytest.mark.repeat(10)
77 | async def test_singleton_asyncio_concurrency() -> None:
78 | calls: int = 0
79 |
80 | async def create_resource() -> typing.AsyncIterator[str]:
81 | nonlocal calls
82 | calls += 1
83 | await asyncio.sleep(0)
84 | yield ""
85 |
86 | @dataclasses.dataclass(kw_only=True, slots=True)
87 | class SimpleCreator:
88 | dep1: str
89 |
90 | resource = providers.Resource(create_resource)
91 | factory_with_resource = providers.Singleton(SimpleCreator, dep1=resource.cast)
92 |
93 | client1, client2 = await asyncio.gather(factory_with_resource.resolve(), factory_with_resource.resolve())
94 |
95 | assert client1 is client2
96 | assert calls == 1
97 |
98 |
99 | @pytest.mark.repeat(10)
100 | def test_singleton_threading_concurrency() -> None:
101 | calls: int = 0
102 | lock = threading.Lock()
103 |
104 | def create_singleton() -> str:
105 | nonlocal calls
106 | with lock:
107 | calls += 1
108 | time.sleep(0.01)
109 | return ""
110 |
111 | singleton = providers.Singleton(create_singleton)
112 |
113 | with ThreadPoolExecutor(max_workers=4) as pool:
114 | tasks = [
115 | pool.submit(singleton.resolve_sync),
116 | pool.submit(singleton.resolve_sync),
117 | pool.submit(singleton.resolve_sync),
118 | pool.submit(singleton.resolve_sync),
119 | ]
120 | results = [x.result() for x in as_completed(tasks)]
121 |
122 | assert all(x == "" for x in results)
123 | assert calls == 1
124 |
125 |
126 | async def test_async_singleton() -> None:
127 | singleton1 = await DIContainer.singleton_async()
128 | singleton2 = await DIContainer.singleton_async()
129 | singleton3 = await DIContainer.singleton_async.resolve()
130 | await DIContainer.singleton_async.tear_down()
131 | singleton4 = await DIContainer.singleton_async.resolve()
132 |
133 | assert singleton1 is singleton2 is singleton3
134 | assert singleton4 is not singleton1
135 |
136 | await DIContainer.tear_down()
137 |
138 |
139 | async def test_async_singleton_override() -> None:
140 | singleton_async = providers.AsyncSingleton(create_async_obj, "foo")
141 | singleton_async.override_sync(SingletonFactory(dep1="bar"))
142 |
143 | result = await singleton_async.resolve()
144 | assert result == SingletonFactory(dep1="bar")
145 |
146 |
147 | async def test_async_singleton_asyncio_concurrency() -> None:
148 | singleton_async = providers.AsyncSingleton(create_async_obj, "foo")
149 |
150 | results = await asyncio.gather(
151 | singleton_async(),
152 | singleton_async(),
153 | singleton_async(),
154 | singleton_async(),
155 | singleton_async(),
156 | )
157 |
158 | assert all(val is results[0] for val in results)
159 |
160 |
161 | async def test_async_singleton_sync_resolve_failure() -> None:
162 | with pytest.raises(RuntimeError, match=r"AsyncSingleton cannot be resolved in an sync context."):
163 | DIContainer.singleton_async.resolve_sync()
164 |
165 |
166 | async def test_singleton_async_resolve_with_async_dependencies() -> None:
167 | expected = await DIContainer.singleton_with_dependency.resolve()
168 |
169 | assert expected == await DIContainer.singleton_with_dependency.resolve()
170 |
171 | results = await asyncio.gather(*[DIContainer.singleton_with_dependency.resolve() for _ in range(10)])
172 |
173 | for val in results:
174 | assert val == expected
175 |
176 | results = await asyncio.gather(
177 | *[asyncio.to_thread(DIContainer.singleton_with_dependency.resolve_sync) for _ in range(10)],
178 | )
179 |
180 | for val in results:
181 | assert val == expected
182 |
183 |
184 | async def test_async_singleton_teardown() -> None:
185 | singleton_async = providers.AsyncSingleton(create_async_obj, "foo")
186 |
187 | await singleton_async.resolve()
188 | singleton_async.tear_down_sync()
189 |
190 | assert singleton_async._instance is None
191 |
192 | await singleton_async.resolve()
193 |
194 | await singleton_async.tear_down()
195 | assert singleton_async._instance is None
196 |
--------------------------------------------------------------------------------
/docs/providers/singleton.md:
--------------------------------------------------------------------------------
1 | # Singleton Provider
2 |
3 | A **Singleton** provider creates its instance once and caches it for all future injections or resolutions. When the instance is first requested (via `resolve_sync()` or `resolve()`), the underlying factory is called. On subsequent calls, the cached instance is returned without calling the factory again.
4 |
5 | ## How it Works
6 |
7 | ```python
8 | import random
9 |
10 | from that_depends import BaseContainer, Provide, inject, providers
11 |
12 |
13 | def some_function() -> float:
14 | """Generate number between 0.0 and 1.0"""
15 | return random.random()
16 |
17 |
18 | # define container with `Singleton` provider:
19 | class MyContainer(BaseContainer):
20 | singleton = providers.Singleton(some_function)
21 |
22 |
23 | # The provider will call `some_function` once and cache the return value
24 |
25 | # 1) Synchronous resolution
26 | MyContainer.singleton.resolve_sync() # e.g. 0.3
27 | MyContainer.singleton.resolve_sync() # 0.3 (cached)
28 |
29 | # 2) Asynchronous resolution
30 | await MyContainer.singleton.resolve() # 0.3 (same cached value)
31 |
32 |
33 | # 3) Injection example
34 | @inject
35 | async def with_singleton(number: float = Provide[MyContainer.singleton]):
36 | # number == 0.3
37 | ...
38 | ```
39 |
40 | ### Teardown Support
41 | If you need to reset the singleton (for example, in tests or at application shutdown), you can call:
42 | ```python
43 | await MyContainer.singleton.tear_down()
44 | ```
45 | This clears the cached instance, causing a new one to be created the next time `resolve_sync()` or `resolve()` is called.
46 | *(If you only ever use synchronous resolution, you can call `MyContainer.singleton.tear_down_sync()` instead.)*
47 |
48 | For further details refer to the [teardown documentation](../introduction/tear-down.md).
49 |
50 | ---
51 |
52 | ## Concurrency Safety
53 |
54 | `Singleton` is **thread-safe** and **async-safe**:
55 |
56 | 1. **Async Concurrency**
57 | If multiple coroutines call `resolve()` concurrently, the factory function is guaranteed to be called only once. All callers receive the same cached instance.
58 |
59 | 2. **Thread Concurrency**
60 | If multiple threads call `resolve_sync()` at the same time, the factory is only called once. All threads receive the same cached instance.
61 |
62 | ```python
63 | import threading
64 | import asyncio
65 |
66 |
67 | # In async code:
68 | async def main():
69 | # calling resolve concurrently in different coroutines
70 | results = await asyncio.gather(
71 | MyContainer.singleton.resolve(),
72 | MyContainer.singleton.resolve(),
73 | )
74 | # Both results point to the same instance
75 |
76 |
77 | # In threaded code:
78 | def thread_task():
79 | instance = MyContainer.singleton.resolve_sync()
80 | ...
81 |
82 |
83 | threads = [threading.Thread(target=thread_task) for _ in range(5)]
84 | for t in threads:
85 | t.start()
86 | ```
87 |
88 | ---
89 |
90 | ## ThreadLocalSingleton Provider
91 |
92 | If you want each *thread* to have its own, separately cached instance, use **ThreadLocalSingleton**. This provider creates a new instance per thread and reuses that instance on subsequent calls *within the same thread*.
93 |
94 | ```python
95 | import random
96 | import threading
97 | from that_depends.providers import ThreadLocalSingleton
98 |
99 |
100 | def factory() -> int:
101 | """Return a random int between 1 and 100."""
102 | return random.randint(1, 100)
103 |
104 |
105 | # ThreadLocalSingleton caches an instance per thread
106 | singleton = ThreadLocalSingleton(factory)
107 |
108 | # In a single thread:
109 | instance1 = singleton.resolve_sync() # e.g. 56
110 | instance2 = singleton.resolve_sync() # 56 (cached in the same thread)
111 |
112 |
113 | # In multiple threads:
114 | def thread_task():
115 | return singleton.resolve_sync()
116 |
117 |
118 | thread1 = threading.Thread(target=thread_task)
119 | thread2 = threading.Thread(target=thread_task)
120 | thread1.start()
121 | thread2.start()
122 |
123 | # thread1 and thread2 each get a different cached value
124 | ```
125 |
126 | You can still use `.resolve()` with `ThreadLocalSingleton`, which will also maintain isolation per thread. However, note that this does *not* isolate instances per asynchronous Task – only per OS thread.
127 |
128 | ---
129 |
130 | ## Example with `pydantic-settings`
131 |
132 | Consider a scenario where your application configuration is defined via [**pydantic-settings**](https://docs.pydantic.dev/latest/concepts/pydantic_settings/). Often, you only want to parse this configuration (e.g., from environment variables) once, then reuse it throughout the application.
133 |
134 | ```python
135 | from pydantic_settings import BaseSettings
136 | from pydantic import BaseModel
137 |
138 |
139 | class DatabaseConfig(BaseModel):
140 | address: str = "127.0.0.1"
141 | port: int = 5432
142 | db_name: str = "postgres"
143 |
144 |
145 | class Settings(BaseSettings):
146 | auth_key: str = "my_auth_key"
147 | db: DatabaseConfig = DatabaseConfig()
148 | ```
149 |
150 | ### Defining the Container
151 |
152 | Below, we define a container with a **Singleton** provider for our settings. We also define a separate async factory that connects to the database using those settings.
153 |
154 | ```python
155 | from that_depends import BaseContainer, providers
156 |
157 | async def get_db_connection(address: str, port: int, db_name: str):
158 | # e.g., create an async DB connection
159 | ...
160 |
161 | class MyContainer(BaseContainer):
162 | # We'll parse settings only once
163 | config = providers.Singleton(Settings)
164 |
165 | # We'll pass the config's DB fields into an async factory for a DB connection
166 | db_connection = providers.AsyncFactory(
167 | get_db_connection,
168 | config.db.address,
169 | config.db.port,
170 | config.db.db_name,
171 | )
172 | ```
173 |
174 | ### Injecting or Resolving in Code
175 |
176 | You can now inject these values directly into your functions with the `@inject` decorator:
177 |
178 | ```python
179 | from that_depends import inject, Provide
180 |
181 | @inject
182 | async def with_db_connection(conn = Provide[MyContainer.db_connection]):
183 | # conn is the created DB connection
184 | ...
185 | ```
186 |
187 | Or you can manually resolve them when needed:
188 |
189 | ```python
190 | # Synchronously resolve the config
191 | cfg = MyContainer.config.resolve_sync()
192 |
193 | # Asynchronously resolve the DB connection
194 | connection = await MyContainer.db_connection.resolve()
195 | ```
196 |
197 | By using `Singleton` for `Settings`, you avoid re-parsing the environment or re-initializing the configuration on each request.
198 |
--------------------------------------------------------------------------------
/docs/testing/provider-overriding.md:
--------------------------------------------------------------------------------
1 | # Provider overriding
2 |
3 | DI container provides, in addition to direct dependency injection, another very important functionality:
4 | **dependencies or providers overriding**.
5 |
6 | Any provider registered with the container can be overridden.
7 | This can help you replace objects with simple stubs, or with other objects.
8 | **Override affects all providers that use the overridden provider (_see example_)**.
9 |
10 | ## Example
11 |
12 | ```python
13 | from pydantic_settings import BaseSettings
14 | from sqlalchemy import create_engine, Engine, text
15 | from testcontainers.postgres import PostgresContainer
16 | from that_depends import BaseContainer, providers, Provide, inject
17 |
18 |
19 | class SomeSQLADao:
20 | def __init__(self, *, sqla_engine: Engine):
21 | self.engine = sqla_engine
22 | self._connection = None
23 |
24 | def __enter__(self):
25 | self._connection = self.engine.connect()
26 | return self
27 |
28 | def __exit__(self, exc_type, exc_val, exc_tb):
29 | self._connection.close()
30 |
31 | def exec_query(self, query: str):
32 | return self._connection.execute(text(query))
33 |
34 |
35 | class Settings(BaseSettings):
36 | db_url: str = 'some_production_db_url'
37 |
38 |
39 | class DIContainer(BaseContainer):
40 | settings = providers.Singleton(Settings)
41 | sqla_engine = providers.Singleton(create_engine, settings.db_url)
42 | some_sqla_dao = providers.Factory(SomeSQLADao, sqla_engine=sqla_engine)
43 |
44 |
45 | @inject
46 | def exec_query_example(some_sqla_dao=Provide[DIContainer.some_sqla_dao]):
47 | with some_sqla_dao:
48 | result = some_sqla_dao.exec_query('SELECT 234')
49 |
50 | return next(result)
51 |
52 |
53 | def main():
54 | pg_container = PostgresContainer(image='postgres:alpine3.19')
55 | pg_container.start()
56 | db_url = pg_container.get_connection_url()
57 |
58 | """
59 | We override only settings, but this override will also affect the 'sqla_engine'
60 | and 'some_sqla_dao' providers because the 'settings' provider is used by them!
61 | """
62 | local_testing_settings = Settings(db_url=db_url)
63 | DIContainer.settings.override_sync(local_testing_settings)
64 |
65 | try:
66 | result = exec_query_example()
67 | assert result == (234,)
68 | finally:
69 | DIContainer.settings.reset_override_sync()
70 | pg_container.stop()
71 |
72 |
73 | if __name__ == '__main__':
74 | main()
75 |
76 | ```
77 |
78 | The example above shows how overriding a nested provider ('_settings_')
79 | affects another provider ('_engine_' and '_some_sqla_dao_').
80 |
81 | ## Override multiple providers
82 |
83 | The example above looked at overriding only one settings provider,
84 | but the container also provides the ability to override
85 | multiple providers at once with method ```override_providers_sync```.
86 |
87 | The code above could remain the same except that
88 | the single provider override could be replaced with the following code:
89 |
90 | ```python
91 | def main():
92 | pg_container = PostgresContainer(image='postgres:alpine3.19')
93 | pg_container.start()
94 | db_url = pg_container.get_connection_url()
95 |
96 | local_testing_settings = Settings(db_url=db_url)
97 | providers_for_overriding = {
98 | 'settings': local_testing_settings,
99 | # more values...
100 | }
101 | with DIContainer.override_providers_sync(providers_for_overriding):
102 | try:
103 | result = exec_query_example()
104 | assert result == (234,)
105 | finally:
106 | pg_container.stop()
107 | ```
108 |
109 | ---
110 | ## Using with Litestar
111 | In order to be able to inject dependencies of any type instead of existing objects,
112 | we need to **change the typing** for the injected parameter as follows:
113 |
114 | ```python3
115 | import typing
116 | from functools import partial
117 | from typing import Annotated
118 | from unittest.mock import Mock
119 |
120 | from litestar import Litestar, Router, get
121 | from litestar.di import Provide
122 | from litestar.params import Dependency
123 | from litestar.testing import TestClient
124 |
125 | from that_depends import BaseContainer, providers
126 |
127 |
128 | class ExampleService:
129 | def do_smth(self) -> str:
130 | return "something"
131 |
132 |
133 | class DIContainer(BaseContainer):
134 | example_service = providers.Factory(ExampleService)
135 |
136 |
137 | @get(path="/another-endpoint", dependencies={"example_service": Provide(DIContainer.example_service)})
138 | async def endpoint_handler(
139 | example_service: Annotated[ExampleService, Dependency(skip_validation=True)],
140 | ) -> dict[str, typing.Any]:
141 | return {"object": example_service.do_smth()}
142 |
143 |
144 | # or if you want a little less code
145 | NoValidationDependency = partial(Dependency, skip_validation=True)
146 |
147 |
148 | @get(path="/another-endpoint", dependencies={"example_service": Provide(DIContainer.example_service)})
149 | async def endpoint_handler(
150 | example_service: Annotated[ExampleService, NoValidationDependency()],
151 | ) -> dict[str, typing.Any]:
152 | return {"object": example_service.do_smth()}
153 |
154 |
155 | router = Router(
156 | path="/router",
157 | route_handlers=[endpoint_handler],
158 | )
159 |
160 | app = Litestar(route_handlers=[router])
161 | ```
162 |
163 | Now we are ready to write tests with **overriding** and this will work with **any types**:
164 |
165 | ```python3
166 | def test_litestar_endpoint_with_overriding() -> None:
167 | some_service_mock = Mock(do_smth=lambda: "mock func")
168 |
169 | with DIContainer.example_service.override_context_sync(some_service_mock), TestClient(app=app) as client:
170 | response = client.get("/router/another-endpoint")
171 |
172 | assert response.status_code == 200
173 | assert response.json()["object"] == "mock func"
174 | ```
175 |
176 | More about `Dependency`
177 | in the [Litestar documentation](https://docs.litestar.dev/2/usage/dependency-injection.html#the-dependency-function).
178 |
179 | ---
180 |
181 | ## Overriding and tear-down
182 |
183 | If you have a provider `A` that caches the resolved value, which depends on a provider
184 | `B` that you wish to override you might experience the following behavior:
185 |
186 | ```python
187 | class MyContainer(BaseContainer):
188 | B = providers.Singleton(lambda: 1)
189 | A = providers.Singleton(lambda x: x, B)
190 |
191 |
192 | a_old = await MyContainer.A()
193 |
194 | MyContainer.B.override_sync(32) # will not reset A's cached value
195 |
196 | a_new = await MyContainer.A()
197 |
198 | assert a_old != a_new # raises
199 | ```
200 |
201 | This is due to the fact that `A` caches the value and doesn't get reset when you override `B`.
202 |
203 | If you wish to fix this you can tell the provider to tear-down children on override:
204 |
205 | ```python
206 | MyContainer.B.override_sync(32, tear_down_children=True)
207 | ```
208 |
--------------------------------------------------------------------------------
/tests/providers/test_resources.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import threading
4 | import time
5 | import typing
6 | from concurrent.futures import ThreadPoolExecutor, as_completed
7 |
8 | import pytest
9 |
10 | from tests.creators import (
11 | AsyncContextManagerResource,
12 | ContextManagerResource,
13 | create_async_resource,
14 | create_sync_resource,
15 | )
16 | from that_depends import BaseContainer, providers
17 |
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class DIContainer(BaseContainer):
23 | alias = "resources_container"
24 | async_resource = providers.Resource(create_async_resource)
25 | sync_resource = providers.Resource(create_sync_resource)
26 | async_resource_from_class = providers.Resource(AsyncContextManagerResource)
27 | sync_resource_from_class = providers.Resource(ContextManagerResource)
28 |
29 |
30 | @pytest.fixture(autouse=True)
31 | async def _tear_down() -> typing.AsyncIterator[None]:
32 | try:
33 | yield
34 | finally:
35 | await DIContainer.tear_down()
36 |
37 |
38 | async def test_async_resource() -> None:
39 | async_resource1 = await DIContainer.async_resource.resolve()
40 | async_resource2 = DIContainer.async_resource.resolve_sync()
41 | assert async_resource1 is async_resource2
42 |
43 |
44 | async def test_async_resource_from_class() -> None:
45 | async_resource1 = await DIContainer.async_resource_from_class.resolve()
46 | async_resource2 = DIContainer.async_resource_from_class.resolve_sync()
47 | assert async_resource1 is async_resource2
48 |
49 |
50 | async def test_sync_resource() -> None:
51 | sync_resource1 = await DIContainer.sync_resource.resolve()
52 | sync_resource2 = await DIContainer.sync_resource.resolve()
53 | assert sync_resource1 is sync_resource2
54 |
55 |
56 | async def test_sync_resource_from_class() -> None:
57 | sync_resource1 = await DIContainer.sync_resource_from_class.resolve()
58 | sync_resource2 = await DIContainer.sync_resource_from_class.resolve()
59 | assert sync_resource1 is sync_resource2
60 |
61 |
62 | async def test_async_resource_overridden() -> None:
63 | async_resource1 = await DIContainer.sync_resource.resolve()
64 |
65 | DIContainer.sync_resource.override_sync("override")
66 |
67 | async_resource2 = DIContainer.sync_resource.resolve_sync()
68 | async_resource3 = await DIContainer.sync_resource.resolve()
69 |
70 | DIContainer.sync_resource.reset_override_sync()
71 |
72 | async_resource4 = DIContainer.sync_resource.resolve_sync()
73 |
74 | assert async_resource2 is not async_resource1
75 | assert async_resource2 is async_resource3
76 | assert async_resource4 is async_resource1
77 |
78 |
79 | async def test_sync_resource_overridden() -> None:
80 | sync_resource1 = await DIContainer.sync_resource.resolve()
81 |
82 | DIContainer.sync_resource.override_sync("override")
83 |
84 | sync_resource2 = DIContainer.sync_resource.resolve_sync()
85 | sync_resource3 = await DIContainer.sync_resource.resolve()
86 |
87 | DIContainer.sync_resource.reset_override_sync()
88 |
89 | sync_resource4 = DIContainer.sync_resource.resolve_sync()
90 |
91 | assert sync_resource2 is not sync_resource1
92 | assert sync_resource2 is sync_resource3
93 | assert sync_resource4 is sync_resource1
94 |
95 |
96 | async def test_resource_with_empty_list() -> None:
97 | def create_sync_resource_with_list() -> typing.Iterator[list[typing.Any]]:
98 | logger.debug("Resource initiated")
99 | yield []
100 | logger.debug("Resource destructed")
101 |
102 | async def create_async_resource_with_dict() -> typing.AsyncIterator[dict[str, typing.Any]]:
103 | logger.debug("Async resource initiated")
104 | yield {}
105 | logger.debug("Async resource destructed")
106 |
107 | sync_resource = providers.Resource(create_sync_resource_with_list)
108 | async_resource = providers.Resource(create_async_resource_with_dict)
109 |
110 | sync_resource1 = await sync_resource()
111 | sync_resource2 = await sync_resource()
112 | assert sync_resource1 is sync_resource2
113 |
114 | async_resource1 = await async_resource()
115 | async_resource2 = await async_resource()
116 | assert async_resource1 is async_resource2
117 |
118 | await sync_resource.tear_down()
119 | await async_resource.tear_down()
120 |
121 |
122 | async def test_resource_unsupported_creator() -> None:
123 | with pytest.raises(TypeError, match=r"Unsupported resource type"):
124 | providers.Resource(None) # type: ignore[arg-type]
125 |
126 |
127 | @pytest.mark.repeat(10)
128 | async def test_async_resource_asyncio_concurrency() -> None:
129 | calls: int = 0
130 |
131 | async def create_resource() -> typing.AsyncIterator[str]:
132 | nonlocal calls
133 | calls += 1
134 | await asyncio.sleep(0)
135 | yield ""
136 |
137 | resource = providers.Resource(create_resource)
138 |
139 | await asyncio.gather(resource.resolve(), resource.resolve())
140 |
141 | assert calls == 1
142 |
143 |
144 | @pytest.mark.repeat(10)
145 | async def test_async_resource_async_teardown() -> None:
146 | calls: int = 0
147 |
148 | async def _create_resource() -> typing.AsyncIterator[str]:
149 | yield ""
150 | nonlocal calls
151 | await asyncio.sleep(0.01)
152 | calls += 1
153 |
154 | resource = providers.Resource(_create_resource)
155 |
156 | await resource.resolve()
157 |
158 | await asyncio.gather(*[resource.tear_down() for _ in range(10)])
159 |
160 | assert calls == 1
161 |
162 |
163 | @pytest.mark.repeat(10)
164 | def test_resource_threading_concurrency() -> None:
165 | calls: int = 0
166 | lock = threading.Lock()
167 |
168 | def create_resource() -> typing.Iterator[str]:
169 | nonlocal calls
170 | with lock:
171 | calls += 1
172 | time.sleep(0.01)
173 | yield ""
174 |
175 | resource = providers.Resource(create_resource)
176 |
177 | with ThreadPoolExecutor(max_workers=4) as pool:
178 | tasks = [
179 | pool.submit(resource.resolve_sync),
180 | pool.submit(resource.resolve_sync),
181 | pool.submit(resource.resolve_sync),
182 | pool.submit(resource.resolve_sync),
183 | ]
184 | results = [x.result() for x in as_completed(tasks)]
185 |
186 | assert results == ["", "", "", ""]
187 | assert calls == 1
188 |
189 |
190 | def test_sync_resource_sync_tear_down() -> None:
191 | DIContainer.sync_resource.resolve_sync()
192 | assert DIContainer.sync_resource._context.instance is not None
193 | DIContainer.sync_resource.tear_down_sync()
194 | assert DIContainer.sync_resource._context.instance is None
195 |
196 |
197 | async def test_async_resource_sync_tear_down_raises() -> None:
198 | await DIContainer.async_resource.resolve()
199 | assert DIContainer.async_resource._context.instance is not None
200 | with pytest.raises(RuntimeError):
201 | DIContainer.async_resource.tear_down_sync()
202 |
--------------------------------------------------------------------------------
/tests/providers/test_attr_getter.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import random
3 | import typing
4 | from dataclasses import dataclass, field
5 |
6 | import pytest
7 |
8 | from that_depends import BaseContainer, providers
9 | from that_depends.providers.base import _get_value_from_object_by_dotted_path
10 | from that_depends.providers.context_resources import container_context
11 |
12 |
13 | @dataclass
14 | class Nested2:
15 | some_const = 144
16 |
17 |
18 | @dataclass
19 | class Nested1:
20 | nested2_attr: Nested2 = field(default_factory=Nested2)
21 |
22 |
23 | @dataclass
24 | class Settings:
25 | some_str_value: str = "some_string_value"
26 | some_int_value: int = 3453621
27 | nested1_attr: Nested1 = field(default_factory=Nested1)
28 |
29 |
30 | async def return_settings_async() -> Settings:
31 | return Settings()
32 |
33 |
34 | async def yield_settings_async() -> typing.AsyncIterator[Settings]:
35 | yield Settings()
36 |
37 |
38 | def yield_settings_sync() -> typing.Iterator[Settings]:
39 | yield Settings()
40 |
41 |
42 | @dataclass
43 | class NestingTestDTO: ...
44 |
45 |
46 | @pytest.fixture(
47 | params=[
48 | providers.Resource(yield_settings_sync),
49 | providers.Singleton(Settings),
50 | providers.ContextResource(yield_settings_sync),
51 | providers.Object(Settings()),
52 | providers.Factory(Settings),
53 | providers.Selector(lambda: "sync", sync=providers.Factory(Settings)),
54 | ]
55 | )
56 | def some_sync_settings_provider(request: pytest.FixtureRequest) -> providers.AbstractProvider[Settings]:
57 | return typing.cast(providers.AbstractProvider[Settings], request.param)
58 |
59 |
60 | @pytest.fixture(
61 | params=[
62 | providers.AsyncFactory(return_settings_async),
63 | providers.Resource(yield_settings_async),
64 | providers.ContextResource(yield_settings_async),
65 | providers.Selector(lambda: "asynchronous", asynchronous=providers.AsyncFactory(return_settings_async)),
66 | ]
67 | )
68 | def some_async_settings_provider(request: pytest.FixtureRequest) -> providers.AbstractProvider[Settings]:
69 | return typing.cast(providers.AbstractProvider[Settings], request.param)
70 |
71 |
72 | def test_attr_getter_with_zero_attribute_depth_sync(
73 | some_sync_settings_provider: providers.AbstractProvider[Settings],
74 | ) -> None:
75 | attr_getter = some_sync_settings_provider.some_str_value
76 | if isinstance(some_sync_settings_provider, providers.ContextResource):
77 | with container_context(some_sync_settings_provider):
78 | assert attr_getter.resolve_sync() == Settings().some_str_value
79 | else:
80 | assert attr_getter.resolve_sync() == Settings().some_str_value
81 |
82 |
83 | async def test_attr_getter_with_zero_attribute_depth_async(
84 | some_async_settings_provider: providers.AbstractProvider[Settings],
85 | ) -> None:
86 | attr_getter = some_async_settings_provider.some_str_value
87 | if isinstance(some_async_settings_provider, providers.ContextResource):
88 | async with container_context(some_async_settings_provider):
89 | assert await attr_getter.resolve() == Settings().some_str_value
90 | else:
91 | assert await attr_getter.resolve() == Settings().some_str_value
92 |
93 |
94 | def test_attr_getter_with_more_than_zero_attribute_depth_sync(
95 | some_sync_settings_provider: providers.AbstractProvider[Settings],
96 | ) -> None:
97 | with (
98 | container_context(some_sync_settings_provider)
99 | if isinstance(some_sync_settings_provider, providers.ContextResource)
100 | else contextlib.nullcontext()
101 | ):
102 | attr_getter = some_sync_settings_provider.nested1_attr.nested2_attr.some_const
103 | assert attr_getter.resolve_sync() == Nested2().some_const
104 |
105 |
106 | async def test_attr_getter_with_more_than_zero_attribute_depth_async(
107 | some_async_settings_provider: providers.AbstractProvider[Settings],
108 | ) -> None:
109 | async with (
110 | container_context(some_async_settings_provider)
111 | if isinstance(some_async_settings_provider, providers.ContextResource)
112 | else contextlib.nullcontext()
113 | ):
114 | attr_getter = some_async_settings_provider.nested1_attr.nested2_attr.some_const
115 | assert await attr_getter.resolve() == Nested2().some_const
116 |
117 |
118 | @pytest.mark.parametrize(
119 | ("field_count", "test_field_name", "test_value"),
120 | [(1, "test_field", "sdf6fF^SF(FF*4ffsf"), (5, "nested_field", -252625), (50, "50_lvl_field", 909234235)],
121 | )
122 | def test_nesting_levels(field_count: int, test_field_name: str, test_value: str | int) -> None:
123 | obj = NestingTestDTO()
124 | fields = [f"field_{i}" for i in range(1, field_count + 1)]
125 | random.shuffle(fields)
126 |
127 | attr_path = ".".join(fields) + f".{test_field_name}"
128 | obj_copy = obj
129 |
130 | while fields:
131 | field_name = fields.pop(0)
132 | setattr(obj_copy, field_name, NestingTestDTO())
133 | obj_copy = obj_copy.__getattribute__(field_name)
134 |
135 | setattr(obj_copy, test_field_name, test_value)
136 |
137 | attr_value = _get_value_from_object_by_dotted_path(obj, attr_path)
138 | assert attr_value == test_value
139 |
140 |
141 | def test_attr_getter_with_invalid_attribute_sync(
142 | some_sync_settings_provider: providers.AbstractProvider[Settings],
143 | ) -> None:
144 | with pytest.raises(AttributeError):
145 | some_sync_settings_provider.nested1_attr.nested2_attr.__some_private__ # noqa: B018
146 | with pytest.raises(AttributeError):
147 | some_sync_settings_provider.nested1_attr.__another_private__ # noqa: B018
148 | with pytest.raises(AttributeError):
149 | some_sync_settings_provider.nested1_attr._final_private_ # noqa: B018
150 |
151 |
152 | async def test_attr_getter_with_invalid_attribute_async(
153 | some_async_settings_provider: providers.AbstractProvider[Settings],
154 | ) -> None:
155 | with pytest.raises(AttributeError):
156 | some_async_settings_provider.nested1_attr.nested2_attr.__some_private__ # noqa: B018
157 | with pytest.raises(AttributeError):
158 | some_async_settings_provider.nested1_attr.__another_private__ # noqa: B018
159 | with pytest.raises(AttributeError):
160 | some_async_settings_provider.nested1_attr._final_private_ # noqa: B018
161 |
162 |
163 | def test_attr_getter_deregister_arguments() -> None:
164 | class _Item:
165 | def __init__(self, v: float) -> None:
166 | self.v = v
167 |
168 | class _Container(BaseContainer):
169 | parent = providers.Singleton(lambda: random.random())
170 | child = providers.Singleton(_Item, parent.cast)
171 |
172 | attr_getter = _Container.child.v
173 |
174 | attr_getter._register_arguments()
175 |
176 | assert attr_getter.resolve_sync() == _Container.parent.resolve_sync()
177 | assert _Container.parent in attr_getter._parents
178 |
179 | attr_getter._deregister_arguments()
180 |
181 | assert _Container.parent not in attr_getter._parents
182 |
--------------------------------------------------------------------------------
/that_depends/providers/factories.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import inspect
3 | import typing
4 | from typing import overload
5 |
6 | from typing_extensions import override
7 |
8 | from that_depends.providers.base import AbstractProvider
9 | from that_depends.providers.mixin import ProviderWithArguments
10 |
11 |
12 | T_co = typing.TypeVar("T_co", covariant=True)
13 | P = typing.ParamSpec("P")
14 |
15 |
16 | class AbstractFactory(ProviderWithArguments, AbstractProvider[T_co], abc.ABC):
17 | """Base class for all factories.
18 |
19 | This class defines the interface for factories that provide
20 | resources both synchronously and asynchronously.
21 | """
22 |
23 | @property
24 | def provider(self) -> typing.Callable[[], typing.Coroutine[typing.Any, typing.Any, T_co]]:
25 | """Returns an async provider function.
26 |
27 | The async provider function can be awaited to resolve the resource.
28 |
29 | Returns:
30 | Callable[[], Coroutine[Any, Any, T_co]]: The async provider function.
31 |
32 | Example:
33 | ```python
34 | provider = my_factory.provider
35 | resource = await provider()
36 | ```
37 |
38 | """
39 | return self.resolve
40 |
41 | @property
42 | def provider_sync(self) -> typing.Callable[[], T_co]:
43 | """Return a sync provider function.
44 |
45 | The sync provider function can be called to resolve the resource synchronously.
46 |
47 | Returns:
48 | Callable[[], T_co]: The sync provider function.
49 |
50 | Example:
51 | ```python
52 | provider = my_factory.provider_sync
53 | resource = provider()
54 | ```
55 |
56 | """
57 | return self.resolve_sync
58 |
59 |
60 | class Factory(AbstractFactory[T_co]):
61 | """Provides an instance by calling a sync method.
62 |
63 | A typical usage scenario is to wrap a synchronous function
64 | that returns a resource. Each call to the provider or sync_provider
65 | produces a new instance of that resource.
66 |
67 | Example:
68 | ```python
69 | def build_resource(text: str, number: int):
70 | return f"{text}-{number}"
71 |
72 | factory = Factory(build_resource, "example", 42)
73 | resource = factory.provider_sync() # "example-42"
74 | ```
75 |
76 | """
77 |
78 | def _register_arguments(self) -> None:
79 | self._register(self._args)
80 | self._register(self._kwargs.values())
81 |
82 | def _deregister_arguments(self) -> None:
83 | raise NotImplementedError
84 |
85 | __slots__ = "_args", "_factory", "_kwargs", "_override"
86 |
87 | def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None:
88 | """Initialize a Factory instance.
89 |
90 | Args:
91 | factory (Callable[P, T_co]): Function that returns the resource.
92 | *args: Arguments to pass to the factory function.
93 | **kwargs: Keyword arguments to pass to the factory function.
94 |
95 | """
96 | super().__init__()
97 | self._factory: typing.Final = factory
98 | self._args: typing.Final = args
99 | self._kwargs: typing.Final = kwargs
100 | self._register_arguments()
101 |
102 | @override
103 | async def resolve(self) -> T_co:
104 | if self._override:
105 | return typing.cast(T_co, self._override)
106 |
107 | return self._factory(
108 | *[ # type: ignore[arg-type]
109 | await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args
110 | ],
111 | **{ # type: ignore[arg-type]
112 | k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()
113 | },
114 | )
115 |
116 | @override
117 | def resolve_sync(self) -> T_co:
118 | if self._override:
119 | return typing.cast(T_co, self._override)
120 |
121 | return self._factory(
122 | *[ # type: ignore[arg-type]
123 | x.resolve_sync() if isinstance(x, AbstractProvider) else x for x in self._args
124 | ],
125 | **{ # type: ignore[arg-type]
126 | k: v.resolve_sync() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()
127 | },
128 | )
129 |
130 |
131 | class AsyncFactory(AbstractFactory[T_co]):
132 | """Provides an instance by calling an async method.
133 |
134 | Similar to `Factory`, but requires an async function. Each call
135 | to the provider or `provider` property is awaited to produce a new instance.
136 |
137 | Example:
138 | ```python
139 | async def async_build_resource(text: str):
140 | await some_async_operation()
141 | return text.upper()
142 |
143 | async_factory = AsyncFactory(async_build_resource, "example")
144 | resource = await async_factory.provider() # "EXAMPLE"
145 | ```
146 |
147 | """
148 |
149 | def _register_arguments(self) -> None:
150 | self._register(self._args)
151 | self._register(self._kwargs.values())
152 |
153 | def _deregister_arguments(self) -> None:
154 | raise NotImplementedError
155 |
156 | __slots__ = "_args", "_factory", "_kwargs", "_override"
157 |
158 | @overload
159 | def __init__(
160 | self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs
161 | ) -> None: ...
162 |
163 | @overload
164 | def __init__(self, factory: typing.Callable[P, T_co], *args: P.args, **kwargs: P.kwargs) -> None: ...
165 |
166 | def __init__(
167 | self, factory: typing.Callable[P, T_co | typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs
168 | ) -> None:
169 | """Initialize an AsyncFactory instance.
170 |
171 | Args:
172 | factory (Callable[P, T_co | Awaitable[T_co]]): Function that returns the resource (sync or async).
173 | *args: Arguments to pass to the factory function.
174 | **kwargs: Keyword arguments to pass to the factory
175 |
176 |
177 | """
178 | super().__init__()
179 | self._factory: typing.Final = factory
180 | self._args: typing.Final = args
181 | self._kwargs: typing.Final = kwargs
182 | self._register_arguments()
183 |
184 | @override
185 | async def resolve(self) -> T_co:
186 | if self._override:
187 | return typing.cast(T_co, self._override)
188 |
189 | args = [await x.resolve() if isinstance(x, AbstractProvider) else x for x in self._args]
190 | kwargs = {k: await v.resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()}
191 |
192 | result = self._factory(
193 | *args, # type:ignore[arg-type]
194 | **kwargs, # type:ignore[arg-type]
195 | )
196 |
197 | if inspect.isawaitable(result):
198 | return await result
199 | return result
200 |
201 | @override
202 | def resolve_sync(self) -> typing.NoReturn:
203 | msg = "AsyncFactory cannot be resolved synchronously"
204 | raise RuntimeError(msg)
205 |
--------------------------------------------------------------------------------
/tests/providers/test_local_singleton.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import random
3 | import threading
4 | import time
5 | import typing
6 | from concurrent.futures.thread import ThreadPoolExecutor
7 |
8 | import pytest
9 |
10 | from that_depends.providers import AsyncFactory, ThreadLocalSingleton
11 |
12 |
13 | random.seed(23)
14 |
15 |
16 | async def _async_factory() -> int:
17 | await asyncio.sleep(0.01)
18 | return threading.get_ident()
19 |
20 |
21 | def _factory() -> int:
22 | time.sleep(0.01)
23 | return random.randint(1, 100)
24 |
25 |
26 | def test_thread_local_singleton_same_thread() -> None:
27 | """Test that the same instance is returned within a single thread."""
28 | provider = ThreadLocalSingleton(_factory)
29 |
30 | instance1 = provider.resolve_sync()
31 | instance2 = provider.resolve_sync()
32 |
33 | assert instance1 == instance2, "Singleton failed: Instances within the same thread should be identical."
34 |
35 | provider.tear_down_sync()
36 |
37 | assert provider._instance is None, "Tear down failed: Instance should be reset to None."
38 |
39 |
40 | async def test_async_thread_local_singleton_asyncio() -> None:
41 | """Test that the same instance is returned within a single thread."""
42 | provider = ThreadLocalSingleton(_factory)
43 |
44 | instance1 = await provider.resolve()
45 | instance2 = await provider.resolve()
46 |
47 | assert instance1 == instance2, "Singleton failed: Instances within the same thread should be identical."
48 |
49 | await provider.tear_down()
50 |
51 | assert provider._instance is None, "Tear down failed: Instance should be reset to None."
52 |
53 |
54 | def test_thread_local_singleton_different_threads() -> None:
55 | """Test that different threads receive different instances."""
56 | provider = ThreadLocalSingleton(_factory)
57 | results = []
58 |
59 | def resolve_in_thread() -> None:
60 | results.append(provider.resolve_sync())
61 |
62 | number_of_threads = 10
63 |
64 | threads = [threading.Thread(target=resolve_in_thread) for _ in range(number_of_threads)]
65 |
66 | for thread in threads:
67 | thread.start()
68 | for thread in threads:
69 | thread.join()
70 |
71 | assert len(results) == number_of_threads, "Test failed: Expected results from two threads."
72 | assert results[0] != results[1], "Thread-local failed: Instances across threads should differ."
73 |
74 |
75 | def test_thread_local_singleton_override() -> None:
76 | """Test overriding the ThreadLocalSingleton and resetting the override."""
77 | provider = ThreadLocalSingleton(_factory)
78 |
79 | override_value = 101
80 | provider.override_sync(override_value)
81 | instance = provider.resolve_sync()
82 | assert instance == override_value, "Override failed: Expected overridden value."
83 |
84 | # Reset override and ensure a new instance is created
85 | provider.reset_override_sync()
86 | new_instance = provider.resolve_sync()
87 | assert new_instance != override_value, "Reset override failed: Should no longer return overridden value."
88 |
89 |
90 | def test_thread_local_singleton_override_in_threads() -> None:
91 | """Test that resetting the override in one thread does not affect another thread."""
92 | provider = ThreadLocalSingleton(_factory)
93 | results = {}
94 |
95 | def _thread_task(thread_id: int, override_value: int | None = None) -> None:
96 | if override_value is not None:
97 | provider.override_sync(override_value)
98 | results[thread_id] = provider.resolve_sync()
99 | if override_value is not None:
100 | provider.reset_override_sync()
101 |
102 | override_value: typing.Final[int] = 101
103 | thread1 = threading.Thread(target=_thread_task, args=(1, override_value))
104 | thread2 = threading.Thread(target=_thread_task, args=(2,))
105 |
106 | thread1.start()
107 | thread2.start()
108 |
109 | thread1.join()
110 | thread2.join()
111 |
112 | # Validate results
113 | assert results[1] == override_value, "Thread 1: Override failed."
114 | assert results[2] != override_value, "Thread 2: Should not be affected by Thread 1's override."
115 | assert results[1] != results[2], "Instances should be unique across threads."
116 |
117 |
118 | def test_thread_local_singleton_override_temporarily() -> None:
119 | """Test temporarily overriding the ThreadLocalSingleton."""
120 | provider = ThreadLocalSingleton(_factory)
121 |
122 | override_value: typing.Final = 101
123 | # Set a temporary override
124 | with provider.override_context_sync(override_value):
125 | instance = provider.resolve_sync()
126 | assert instance == override_value, "Override context failed: Expected overridden value."
127 |
128 | # After the context, reset to the factory
129 | new_instance = provider.resolve_sync()
130 | assert new_instance != override_value, "Override context failed: Value should reset after context."
131 |
132 |
133 | async def test_async_thread_local_singleton_asyncio_concurrency() -> None:
134 | singleton_async = ThreadLocalSingleton(_factory)
135 |
136 | expected = await singleton_async.resolve()
137 |
138 | results = await asyncio.gather(
139 | singleton_async(),
140 | singleton_async(),
141 | singleton_async(),
142 | singleton_async(),
143 | singleton_async(),
144 | )
145 |
146 | assert all(val is results[0] for val in results)
147 | for val in results:
148 | assert val == expected, "Instances should be identical across threads."
149 |
150 |
151 | async def test_thread_local_singleton_async_resolve_with_async_dependencies() -> None:
152 | async_provider = AsyncFactory(_async_factory)
153 |
154 | def _dependent_creator(v: int) -> int:
155 | return v
156 |
157 | provider = ThreadLocalSingleton(_dependent_creator, v=async_provider.cast)
158 |
159 | expected = await provider.resolve()
160 |
161 | assert expected == await provider.resolve()
162 |
163 | results = await asyncio.gather(*[provider.resolve() for _ in range(10)])
164 |
165 | for val in results:
166 | assert val == expected, "Instances should be identical across threads."
167 | with pytest.raises(RuntimeError):
168 | # This should raise an error because the provider is async and resolution is attempted.
169 | await asyncio.gather(asyncio.to_thread(provider.resolve_sync))
170 |
171 | def _run_async_in_thread(coroutine: typing.Awaitable[typing.Any]) -> typing.Any: # noqa: ANN401
172 | loop = asyncio.new_event_loop()
173 | asyncio.set_event_loop(loop)
174 | result = loop.run_until_complete(coroutine)
175 | loop.close()
176 | return result
177 |
178 | with ThreadPoolExecutor() as executor:
179 | future = executor.submit(_run_async_in_thread, provider.resolve())
180 |
181 | result = future.result()
182 |
183 | assert result != expected, (
184 | "Since singleton should have been newly resolved, it should not have the same thread id."
185 | )
186 |
187 |
188 | async def test_thread_local_singleton_async_resolve_override() -> None:
189 | provider = ThreadLocalSingleton(_factory)
190 |
191 | override_value = 101
192 |
193 | provider.override_sync(override_value)
194 |
195 | assert await provider.resolve() == override_value, "Override failed: Expected overridden value."
196 |
--------------------------------------------------------------------------------
/tests/integrations/fastapi/test_fastapi_route_class.py:
--------------------------------------------------------------------------------
1 | import random
2 | import typing
3 | from http import HTTPStatus
4 | from inspect import signature
5 | from typing import get_type_hints
6 |
7 | from fastapi import APIRouter, Depends, FastAPI
8 | from fastapi.testclient import TestClient
9 |
10 | from that_depends import BaseContainer, Provide, providers
11 | from that_depends.integrations.fastapi import _adjust_fastapi_endpoint, create_fastapi_route_class
12 | from that_depends.providers.context_resources import ContextScopes
13 |
14 |
15 | _FACTORY_RETURN_VALUE = 42
16 |
17 |
18 | async def _async_creator() -> typing.AsyncIterator[float]:
19 | yield random.random()
20 |
21 |
22 | def _sync_creator() -> typing.Iterator[float]:
23 | yield random.random()
24 |
25 |
26 | class Container(BaseContainer):
27 | alias = "fastapi_route_class_container"
28 | provider = providers.Factory(lambda: _FACTORY_RETURN_VALUE)
29 | async_context_provider = providers.ContextResource(_async_creator).with_config(scope=ContextScopes.REQUEST)
30 | sync_context_provider = providers.ContextResource(_sync_creator).with_config(scope=ContextScopes.REQUEST)
31 |
32 |
33 | async def test_endpoint_with_context_resource_async() -> None:
34 | app = FastAPI()
35 |
36 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST))
37 |
38 | @router.get("/hello")
39 | async def _injected(v_3: float = Provide[Container.async_context_provider]) -> float:
40 | return v_3
41 |
42 | app.include_router(router)
43 |
44 | client = TestClient(app)
45 |
46 | response = client.get("/hello")
47 |
48 | response_value = response.json()
49 |
50 | assert response.status_code == HTTPStatus.OK
51 | assert response_value <= 1
52 | assert response_value >= 0
53 |
54 | assert response_value != client.get("/hello").json()
55 |
56 |
57 | def test_endpoint_with_context_resource_sync() -> None:
58 | app = FastAPI()
59 |
60 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST))
61 |
62 | @router.get("/hello")
63 | def _injected(v_3: float = Provide[Container.sync_context_provider]) -> float:
64 | return v_3
65 |
66 | app.include_router(router)
67 |
68 | client = TestClient(app)
69 |
70 | response = client.get("/hello")
71 |
72 | response_value = response.json()
73 |
74 | assert response.status_code == HTTPStatus.OK
75 | assert response_value <= 1
76 | assert response_value >= 0
77 |
78 | assert response_value != client.get("/hello").json()
79 |
80 |
81 | async def test_endpoint_factory_async() -> None:
82 | app = FastAPI()
83 |
84 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST))
85 |
86 | @router.get("/hello")
87 | async def _injected(v_3: float = Provide[Container.provider]) -> float:
88 | return v_3
89 |
90 | app.include_router(router)
91 |
92 | client = TestClient(app)
93 |
94 | response = client.get("/hello")
95 |
96 | response_value = response.json()
97 |
98 | assert response.status_code == HTTPStatus.OK
99 | assert response_value == _FACTORY_RETURN_VALUE
100 |
101 | assert response_value == client.get("/hello").json()
102 |
103 |
104 | def test_endpoint_factory_sync() -> None:
105 | app = FastAPI()
106 |
107 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST))
108 |
109 | @router.get("/hello")
110 | def _injected(v_3: float = Provide[Container.provider]) -> float:
111 | return v_3
112 |
113 | app.include_router(router)
114 |
115 | client = TestClient(app)
116 |
117 | response = client.get("/hello")
118 |
119 | response_value = response.json()
120 |
121 | assert response.status_code == HTTPStatus.OK
122 | assert response_value == _FACTORY_RETURN_VALUE
123 |
124 | assert response_value == client.get("/hello").json()
125 |
126 |
127 | async def test_endpoint_with_string_provider_async() -> None:
128 | app = FastAPI()
129 |
130 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST))
131 |
132 | @router.get("/hello")
133 | async def _injected(v_3: float = Provide["fastapi_route_class_container.provider"]) -> float:
134 | return v_3
135 |
136 | app.include_router(router)
137 |
138 | client = TestClient(app)
139 |
140 | response = client.get("/hello")
141 |
142 | response_value = response.json()
143 |
144 | assert response.status_code == HTTPStatus.OK
145 | assert response_value == _FACTORY_RETURN_VALUE
146 |
147 | assert response_value == client.get("/hello").json()
148 |
149 |
150 | async def test_endpoint_with_string_provider_sync() -> None:
151 | app = FastAPI()
152 |
153 | router = APIRouter(route_class=create_fastapi_route_class(Container, scope=ContextScopes.REQUEST))
154 |
155 | @router.get("/hello")
156 | def _injected(v_3: float = Provide["fastapi_route_class_container.provider"]) -> float:
157 | return v_3
158 |
159 | app.include_router(router)
160 |
161 | client = TestClient(app)
162 |
163 | response = client.get("/hello")
164 |
165 | response_value = response.json()
166 |
167 | assert response.status_code == HTTPStatus.OK
168 | assert response_value == _FACTORY_RETURN_VALUE
169 |
170 | assert response_value == client.get("/hello").json()
171 |
172 |
173 | def test_adjust_fastapi_endpoint_async() -> None:
174 | async def _injected(
175 | v_1: float,
176 | v_3: float = Provide["fastapi_route_class_container.provider"],
177 | v_2: float = Depends(Container.provider),
178 | ) -> float:
179 | return v_1 + v_3 + v_2 # pragma: no cover
180 |
181 | original_signature = signature(_injected)
182 | original_hints = get_type_hints(_injected)
183 |
184 | adjusted_endpoint = _adjust_fastapi_endpoint(_injected)
185 |
186 | adjusted_sig = signature(adjusted_endpoint)
187 | adjusted_hints = get_type_hints(adjusted_endpoint)
188 |
189 | assert len(original_hints) == len(adjusted_hints)
190 | assert len(original_signature.parameters) == len(adjusted_sig.parameters)
191 | assert original_hints.get("v_1") == adjusted_hints.get("v_1")
192 | assert original_hints.get("v_2") == adjusted_hints.get("v_2")
193 | assert original_hints.get("v_3") != adjusted_hints.get("v_3")
194 | assert original_signature.parameters.get("v_1") == adjusted_sig.parameters.get("v_1")
195 | assert original_signature.parameters.get("v_2") == adjusted_sig.parameters.get("v_2")
196 | assert original_signature.parameters.get("v_3") != adjusted_sig.parameters.get("v_3")
197 |
198 |
199 | def test_adjust_fastapi_endpoint_sync() -> None:
200 | def _injected(
201 | v_1: float,
202 | v_3: float = Provide["fastapi_route_class_container.provider"],
203 | v_2: float = Depends(Container.provider),
204 | ) -> float:
205 | return v_1 + v_3 + v_2 # pragma: no cover
206 |
207 | original_signature = signature(_injected)
208 | original_hints = get_type_hints(_injected)
209 |
210 | adjusted_endpoint = _adjust_fastapi_endpoint(_injected)
211 |
212 | adjusted_sig = signature(adjusted_endpoint)
213 | adjusted_hints = get_type_hints(adjusted_endpoint)
214 |
215 | assert len(original_hints) == len(adjusted_hints)
216 | assert len(original_signature.parameters) == len(adjusted_sig.parameters)
217 | assert original_hints.get("v_1") == adjusted_hints.get("v_1")
218 | assert original_hints.get("v_2") == adjusted_hints.get("v_2")
219 | assert original_hints.get("v_3") != adjusted_hints.get("v_3")
220 | assert original_signature.parameters.get("v_1") == adjusted_sig.parameters.get("v_1")
221 | assert original_signature.parameters.get("v_2") == adjusted_sig.parameters.get("v_2")
222 | assert original_signature.parameters.get("v_3") != adjusted_sig.parameters.get("v_3")
223 |
--------------------------------------------------------------------------------