├── .flake8 ├── .github └── workflows │ └── build.yml ├── .gitignore ├── .isort.cfg ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── examples-summary.py ├── examples ├── README.md ├── example-01.py ├── example-02.py └── example-03.py ├── mypy.ini ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── rodi.code-workspace ├── rodi ├── __about__.py ├── __init__.py └── py.typed └── tests ├── __init__.py ├── examples.py ├── test_examples.py ├── test_fn_exec.py └── test_services.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = __pycache__,built,build,venv 3 | ignore = E203, E266, W503, E701, E704 4 | max-line-length = 88 5 | max-complexity = 18 6 | select = B,C,E,F,W,T4,B9 7 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | publish_artifacts: 7 | description: 'Publish artifacts (Y|N)' 8 | required: true 9 | default: 'N' 10 | release: 11 | types: [published] 12 | push: 13 | branches: 14 | - main 15 | - ci 16 | paths-ignore: 17 | - README.md 18 | - CHANGELOG.md 19 | pull_request: 20 | branches: 21 | - "*" 22 | 23 | env: 24 | PROJECT_NAME: rodi 25 | 26 | jobs: 27 | build: 28 | runs-on: ubuntu-latest 29 | strategy: 30 | fail-fast: false 31 | matrix: 32 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 33 | 34 | steps: 35 | - uses: actions/checkout@v1 36 | with: 37 | fetch-depth: 9 38 | submodules: false 39 | 40 | - name: Use Python ${{ matrix.python-version }} 41 | uses: actions/setup-python@v4 42 | with: 43 | python-version: ${{ matrix.python-version }} 44 | 45 | - name: Download dependencies 46 | run: | 47 | pip install -r requirements.txt 48 | 49 | - name: Run tests 50 | run: | 51 | pip install -e . 52 | pytest --doctest-modules --junitxml=junit/pytest-results-${{ matrix.python-version }}.xml --cov=$PROJECT_NAME --cov-report=xml tests/ 53 | 54 | - name: Run linters 55 | run: | 56 | echo "Running linters - if build fails here, please be patient!" 57 | 58 | flake8 $PROJECT_NAME 59 | flake8 tests 60 | isort --check-only $PROJECT_NAME 2>&1 61 | isort --check-only tests 2>&1 62 | black --check $PROJECT_NAME 2>&1 63 | black --check tests 2>&1 64 | 65 | - name: Upload pytest test results 66 | uses: actions/upload-artifact@master 67 | with: 68 | name: pytest-results-${{ matrix.python-version }} 69 | path: junit/pytest-results-${{ matrix.python-version }}.xml 70 | if: always() 71 | 72 | - name: Codecov 73 | run: | 74 | bash <(sed -i 's/filename=\"/filename=\"rodi\//g' coverage.xml) 75 | bash <(curl -s https://codecov.io/bash) 76 | 77 | - name: Install distribution dependencies 78 | run: pip install --upgrade build 79 | if: matrix.python-version == 3.12 80 | 81 | - name: Create distribution package 82 | run: python -m build 83 | if: matrix.python-version == 3.12 84 | 85 | - name: Upload distribution package 86 | uses: actions/upload-artifact@v4 87 | with: 88 | name: dist 89 | path: dist 90 | if: matrix.python-version == 3.12 91 | 92 | publish: 93 | runs-on: ubuntu-latest 94 | needs: build 95 | if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish_artifacts == 'Y') 96 | steps: 97 | - name: Download a distribution artifact 98 | uses: actions/download-artifact@v4 99 | with: 100 | name: dist 101 | path: dist 102 | 103 | - name: Use Python 3.12 104 | uses: actions/setup-python@v1 105 | with: 106 | python-version: '3.12' 107 | 108 | - name: Install dependencies 109 | run: | 110 | pip install twine 111 | 112 | - name: Publish distribution 📦 to Test PyPI 113 | run: | 114 | twine upload -r testpypi dist/* 115 | env: 116 | TWINE_USERNAME: __token__ 117 | TWINE_PASSWORD: ${{ secrets.test_pypi_password }} 118 | 119 | - name: Publish distribution 📦 to PyPI 120 | run: | 121 | twine upload -r pypi dist/* 122 | env: 123 | TWINE_USERNAME: __token__ 124 | TWINE_PASSWORD: ${{ secrets.pypi_password }} 125 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | venv* 3 | htmlcov 4 | .coverage 5 | __pycache__ 6 | *.egg-info 7 | *.tar.gz 8 | .mypy_cache 9 | build 10 | dist 11 | deps 12 | *.py,cover 13 | junit 14 | coverage.xml 15 | .venv 16 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | multi_line_output = 3 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [2.0.8] - 2025-04-12 9 | 10 | - Add the link to the [documentation](https://www.neoteroi.dev/rodi/). 11 | - Remove the `UnsupportedUnionTypeException` as `Rodi` supports union types, 12 | they only require proper handling. 13 | 14 | ## [2.0.7] - 2025-03-28 15 | 16 | - Add the possibility to specify the `ActivationScope` class when instantiating 17 | the `Container` or the `Services` object. This class will be used when 18 | creating scopes. For the issue #55. 19 | - Add an **experimental** class, `TrackingActivationScope` to support nested 20 | scopes transparently, using `contextvars.ContextVar`. For more context, see 21 | the tests `test_nested_scope_1`, `test_nested_scope_2`, 22 | `test_nested_scope_async_1`. For the issue #55. 23 | - Raise a `TypeError` if trying to obtain a service from a disposed scope. 24 | - Remove Python 3.8 from the build matrix, add Python 3.13. 25 | - Handle setuptools warning: _SetuptoolsDeprecationWarning: License classifiers are deprecated_. 26 | 27 | ## [2.0.6] - 2023-12-09 :hammer: 28 | - Fixes import for Protocols support regardless of Python version (partially 29 | broken for Python 3.9), by @fennel-akunesh 30 | 31 | ## [2.0.5] - 2023-11-25 :lab_coat: 32 | - Adds support for resolving `Protocol` classes even when they don't define an 33 | `__init__` method, by @lucas-labs 34 | - Fixes bug in service provider build logic causing singletons to be instantiated 35 | n times when they are registered after its dependant, by @lucas-labs 36 | - Changes the "ignore attributes" logic so that if a class variable has already 37 | been initialized externally, rodi doesn't attempt to reinitialize it (and to 38 | also prevent overriding it if the initialized class variable is also a 39 | registered object), by @lucas-labs 40 | 41 | ## [2.0.4] - 2023-10-28 :dragon: 42 | - Fixes bug in Singleton implementation: stop singleton provider from recreating 43 | objects implementing `__len__`, by [Klavionik](https://github.com/Klavionik). 44 | - Add Python 3.12 and remove Python 3.7 from the build matrix. 45 | 46 | ## [2.0.3] - 2023-08-14 :sun_with_face: 47 | - Checks `scoped_services` before resolving from map when in a scope, by [StummeJ](https://github.com/StummeJ). 48 | - Allow getting from scope context without needing to provide scope, by [StummeJ](https://github.com/StummeJ). 49 | 50 | ## [2.0.2] - 2023-03-31 :flamingo: 51 | - Ignores `ClassVar` properties when resolving dependencies by class notations. 52 | - Marks `rodi` 2 as stable. 53 | 54 | ## [2.0.1] - 2023-03-14 :croissant: 55 | - Removes the strict requirement for resolved classes to have `__init__` 56 | methods defined, to add support for `Protocol`s that do not define an 57 | `__init__` method (thus using `*args`, `**kwargs`), 58 | [akhundMurad](https://github.com/akhundMurad)'s contribution. 59 | - Corrects a code smell, replacing an `i` counter with `enumerate`, 60 | [GLEF1X](https://github.com/GLEF1X)'s contribution. 61 | 62 | ## [2.0.0] - 2023-01-07 :star: 63 | - Introduces a `ContainerProtocol` to improve interoperability between 64 | libraries and alternative implementations of DI containers. The protocol is 65 | inspired by [punq](https://github.com/bobthemighty/punq), since its code API 66 | is the most user-friendly and intelligible of those that were reviewed. 67 | The `ContainerProtocol` can be used through [composition](https://en.wikipedia.org/wiki/Composition_over_inheritance) 68 | to replace `rodi` with alternative implementations of dependency injection in 69 | those libraries that use `DI`. 70 | - Simplifies the code API of the library to support using the `Container` class 71 | to `register` and `resolve` services. The class `Services` is still used and 72 | available, but it's no more necessary to use it directly. 73 | - Replaces `setup.py` with `pyproject.toml`. 74 | - Renames context classes: "GetServiceContext" to "ActivationScope", 75 | "ResolveContext" to "ResolutionContext". 76 | - The "add_exact*" methods have been made private, to simplify the public API. 77 | - Improves type annotations; [MaximZayats](https://github.com/MaximZayats)' contribution. 78 | - Adds typehints to GetServiceContext init params; [guscardvs](https://github.com/guscardvs)' contribution. 79 | 80 | ## [1.1.3] - 2022-03-27 :droplet: 81 | - Corrects a bug that would cause false positives when raising exceptions 82 | for circular dependencies. The code now let recursion errors happen if they 83 | need to happen, re-raising a circular dependency error. 84 | 85 | ## [1.1.2] - 2022-03-14 :rabbit: 86 | - Adds `py.typed` file. 87 | - Applies `isort` and enforces `isort` and `black` checks in CI pipeline. 88 | - Corrects the type annotation for `FactoryCallableType`. 89 | 90 | ## [1.1.1] - 2021-02-23 :cactus: 91 | - Adds support for Generics and GenericAlias `Mapping[X, Y]`, `Iterable[T]`, 92 | `List[T]`, `Set[T]`, `Tuple[T, ...]`, Python 3.9 `list[T]`, etc. ([fixes 93 | #9](https://github.com/Neoteroi/rodi/issues/9)). 94 | - Improves typing-friendliness making the `ServiceProvider.get` method 95 | returning a value of the input type. 96 | - Adds support for Python 3.10.0a5 ✨. However, when classes and functions 97 | require locals, they need to be decorated. See [PEP 98 | 563](https://www.python.org/dev/peps/pep-0563/). 99 | 100 | ## [1.1.0] - 2021-01-31 :grapes: 101 | - Adds support to resolve class attributes annotations. 102 | - Changes how classes without an `__init__` method are handled. 103 | - Updates links to the GitHub organization, [Neoteroi](https://github.com/Neoteroi). 104 | 105 | ## [1.0.9] - 2020-11-08 :octocat: 106 | - Completely migrates to GitHub Workflows. 107 | - Improves build to test Python 3.6 and 3.9. 108 | - Adds a changelog. 109 | - Improves badges. 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Roberto Prevato 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: release test 2 | 3 | 4 | artifacts: test 5 | python -m build 6 | 7 | 8 | clean: 9 | rm -rf dist/ 10 | 11 | 12 | prepforbuild: 13 | pip install build 14 | 15 | 16 | build: 17 | python -m build 18 | 19 | 20 | test-release: 21 | twine upload --repository testpypi dist/* 22 | 23 | 24 | release: 25 | twine upload --repository pypi dist/* 26 | 27 | 28 | test: 29 | pytest 30 | 31 | 32 | test-cov: 33 | pytest --cov-report html --cov=rodi tests/ 34 | 35 | 36 | format: 37 | isort rodi 38 | isort tests 39 | black rodi 40 | black tests 41 | 42 | 43 | lint-types: 44 | mypy rodi --explicit-package-bases 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Build](https://github.com/Neoteroi/rodi/workflows/Build/badge.svg) 2 | [![pypi](https://img.shields.io/pypi/v/rodi.svg)](https://pypi.python.org/pypi/rodi) 3 | [![versions](https://img.shields.io/pypi/pyversions/rodi.svg)](https://github.com/Neoteroi/rodi) 4 | [![codecov](https://codecov.io/gh/Neoteroi/rodi/branch/main/graph/badge.svg?token=VzAnusWIZt)](https://codecov.io/gh/Neoteroi/rodi) 5 | [![license](https://img.shields.io/github/license/Neoteroi/rodi.svg)](https://github.com/Neoteroi/rodi/blob/main/LICENSE) 6 | [![documentation](https://img.shields.io/badge/📖-docs-purple)](https://www.neoteroi.dev/rodi/) 7 | 8 | # Implementation of dependency injection for Python 3 9 | 10 | **Features:** 11 | 12 | * types resolution by signature types annotations (_type hints_) 13 | * types resolution by class annotations (_type hints_) 14 | * types resolution by names and aliases (_convention over configuration_) 15 | * unintrusive: builds objects graph **without** the need to change the 16 | source code of classes 17 | * minimum overhead to obtain services, once the objects graph is built 18 | * support for singletons, transient, and scoped services 19 | 20 | This library is freely inspired by .NET Standard 21 | `Microsoft.Extensions.DependencyInjection` implementation (_ref. [MSDN, 22 | Dependency injection in ASP.NET 23 | Core](https://docs.microsoft.com/en-us/aspnet/core/fundamentals/dependency-injection?view=aspnetcore-2.1), 24 | [Using dependency injection in a .Net Core console 25 | application](https://andrewlock.net/using-dependency-injection-in-a-net-core-console-application/)_). 26 | The `ContainerProtocol` for v2 is inspired by [punq](https://github.com/bobthemighty/punq). 27 | 28 | ## Documentation 29 | 30 | Rodi is documented here: [https://www.neoteroi.dev/rodi/](https://www.neoteroi.dev/rodi/). 31 | 32 | ## Installation 33 | 34 | ```bash 35 | pip install rodi 36 | ``` 37 | 38 | ## Efficient 39 | 40 | `rodi` works by inspecting code **once** at runtime, to generate 41 | functions that return instances of desired types - as long as the object graph 42 | is not altered. Inspections are done either on constructors 43 | (____init____) or class annotations. Validation steps, for 44 | example to detect circular dependencies or missing services, are done when 45 | building these functions, so additional validation is not needed when 46 | activating services. 47 | 48 | ## Flexible 49 | 50 | `rodi` offers two code APIs: 51 | 52 | - one is kept as generic as possible, using a `ContainerProtocol` for scenarios 53 | in which it is desirable being able to replace `rodi` with alternative 54 | implementations of dependency injection for Python. The protocol only expects 55 | a class being able to `register` and `resolve` types, and to tell if a type 56 | is configured in it (`__contains__`). Even if other implementations of DI 57 | don´t implement these three methods, it should be easy to use 58 | [composition](https://en.wikipedia.org/wiki/Composition_over_inheritance) to 59 | wrap other libraries with a compatible class. 60 | - one is a more concrete implementation, for scenarios where it's not desirable 61 | to consider alternative implementations of dependency injection. 62 | 63 | For this reason, the examples report two ways to achieve certain things. 64 | 65 | ### Examples 66 | 67 | For examples, refer to the [examples folder](./examples). 68 | 69 | ### Recommended practices 70 | 71 | All services should be configured once, when an application starts, and the 72 | object graph should *not* be altered during normal program execution. 73 | Example: if you build a web application, configure the object graph when 74 | bootstrapping the application, avoid altering the `Container` configuration 75 | while handling web requests. 76 | 77 | Aim at keeping the `Container` and service graphs abstracted from the front-end 78 | layer of your application, and avoid mixing runtime values with container 79 | configuration. Example: if you build a web application, avoid if possible 80 | relying on the HTTP Request object being a service registered in your container. 81 | 82 | ## Service life style: 83 | 84 | * singleton - instantiated only once per service provider 85 | * transient - services are instantiated every time they are required 86 | * scoped - instantiated only once per root service resolution call 87 | (e.g. once per web request) 88 | 89 | ## Usage in BlackSheep 90 | 91 | `rodi` is used in the [BlackSheep](https://www.neoteroi.dev/blacksheep/) 92 | web framework to implement [dependency injection](https://www.neoteroi.dev/blacksheep/dependency-injection/) for 93 | request handlers. 94 | -------------------------------------------------------------------------------- /examples-summary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generates a README.md file for the examples folder. 3 | """ 4 | 5 | import glob 6 | import importlib 7 | import sys 8 | 9 | examples = [file for file in glob.glob("./examples/*.py")] 10 | examples.sort() 11 | sys.path.append("./examples") 12 | 13 | with open("./examples/README.md", mode="wt", encoding="utf8 ") as examples_readme: 14 | examples_readme.write( 15 | "\n\n" 16 | ) 17 | examples_readme.write("""# Examples""") 18 | 19 | for file_path in examples: 20 | if "__init__" in file_path: 21 | continue 22 | 23 | module_name = file_path.replace("./examples/", "").replace(".py", "") 24 | 25 | module = importlib.import_module(module_name) 26 | 27 | if not module.__doc__: 28 | continue 29 | 30 | examples_readme.write(f"\n\n## {module_name}.py\n") 31 | examples_readme.write(str(module.__doc__)) 32 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Examples 4 | 5 | ## example-01.py 6 | 7 | This example illustrates a basic usage of the Container class to register 8 | two types, and automatic resolution achieved through types inspection. 9 | 10 | Two services are registered as "transient" services, meaning that a new instance is 11 | created whenever needed. 12 | 13 | 14 | ## example-02.py 15 | 16 | This example illustrates a basic usage of the Container class to register 17 | a concrete type by base type, and its activation by base type. 18 | 19 | This pattern helps writing code that is decoupled (e.g. business layer logic separated 20 | from exact implementations of data access logic). 21 | 22 | 23 | ## example-03.py 24 | 25 | This example illustrates how to configure a singleton object. 26 | -------------------------------------------------------------------------------- /examples/example-01.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example illustrates a basic usage of the Container class to register 3 | two types, and automatic resolution achieved through types inspection. 4 | 5 | Two services are registered as "transient" services, meaning that a new instance is 6 | created whenever needed. 7 | """ 8 | 9 | from rodi import Container 10 | 11 | 12 | class A: 13 | pass 14 | 15 | 16 | class B: 17 | friend: A 18 | 19 | 20 | container = Container() 21 | 22 | container.register(A) 23 | container.register(B) 24 | 25 | example_1 = container.resolve(B) 26 | 27 | assert isinstance(example_1, B) 28 | assert isinstance(example_1.friend, A) 29 | 30 | 31 | example_2 = container.resolve(B) 32 | 33 | assert isinstance(example_2, B) 34 | assert isinstance(example_2.friend, A) 35 | 36 | assert example_1 is not example_2 37 | -------------------------------------------------------------------------------- /examples/example-02.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example illustrates a basic usage of the Container class to register 3 | a concrete type by base type, and its activation by base type. 4 | 5 | This pattern helps writing code that is decoupled (e.g. business layer logic separated 6 | from exact implementations of data access logic). 7 | """ 8 | from abc import ABC, abstractmethod 9 | from dataclasses import dataclass 10 | 11 | from rodi import Container 12 | 13 | 14 | @dataclass 15 | class Cat: 16 | id: str 17 | name: str 18 | 19 | 20 | class CatsRepository(ABC): 21 | @abstractmethod 22 | def get_cat(self, cat_id: str) -> Cat: 23 | """Gets information of a cat by ID.""" 24 | 25 | 26 | class SQLiteCatsRepository(CatsRepository): 27 | def get_cat(self, cat_id: str) -> Cat: 28 | """Gets information of a cat by ID, from a source SQLite DB.""" 29 | raise NotImplementedError() 30 | 31 | 32 | container = Container() 33 | 34 | container.register(CatsRepository, SQLiteCatsRepository) 35 | 36 | example_1 = container.resolve(CatsRepository) 37 | 38 | assert isinstance(example_1, SQLiteCatsRepository) 39 | -------------------------------------------------------------------------------- /examples/example-03.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example illustrates how to configure a singleton object. 3 | """ 4 | from dataclasses import dataclass 5 | 6 | from rodi import Container 7 | 8 | 9 | @dataclass 10 | class Cat: 11 | id: str 12 | name: str 13 | 14 | 15 | # Using the ContainerProtocol (recommended if it is desirable to possibly replace the 16 | # library with an alternative implementation of dependency injection) 17 | container = Container() 18 | 19 | container.register(Cat, instance=Cat("1", "Celine")) 20 | 21 | example = container.resolve(Cat) 22 | 23 | assert isinstance(example, Cat) 24 | assert example.id == "1" and example.name == "Celine" 25 | 26 | assert example is container.resolve(Cat) 27 | 28 | 29 | # Using the original code API 30 | class Foo: 31 | ... 32 | 33 | 34 | container.add_instance(Foo()) 35 | 36 | assert container.resolve(Foo) is container.resolve(Foo) 37 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.11 3 | follow_imports = skip 4 | ignore_missing_imports = True 5 | ignore_errors = False 6 | warn_redundant_casts = True 7 | warn_unused_configs = True 8 | show_column_numbers = True 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "rodi" 7 | dynamic = ["version"] 8 | authors = [{ name = "Roberto Prevato", email = "roberto.prevato@gmail.com" }] 9 | description = "Implementation of dependency injection for Python 3" 10 | license = { file = "LICENSE" } 11 | readme = "README.md" 12 | requires-python = ">=3.7" 13 | classifiers = [ 14 | "Development Status :: 5 - Production/Stable", 15 | "Programming Language :: Python :: 3", 16 | "Programming Language :: Python :: 3.7", 17 | "Programming Language :: Python :: 3.8", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Programming Language :: Python :: 3.13", 23 | "Operating System :: OS Independent", 24 | ] 25 | keywords = ["dependency", "injection", "type", "hints", "typing"] 26 | dependencies = ["typing_extensions; python_version < '3.8'"] 27 | 28 | [tool.hatch.build.targets.sdist] 29 | exclude = [ 30 | "/.github", 31 | "/docs", 32 | "/examples", 33 | "/deps", 34 | "/htmlcov", 35 | "/tests", 36 | "mkdocs-plugins.code-workspace", 37 | "Makefile", 38 | "CODE_OF_CONDUCT.md", 39 | ".isort.cfg", 40 | ".gitignore", 41 | ".flake8", 42 | "junit", 43 | "rodi.code-workspace", 44 | "requirements.txt", 45 | "mypy.ini", 46 | "pytest.ini", 47 | "examples-summary.py", 48 | ] 49 | 50 | [tool.hatch.version] 51 | path = "rodi/__about__.py" 52 | 53 | [project.urls] 54 | "Homepage" = "https://github.com/Neoteroi/rodi" 55 | "Bug Tracker" = "https://github.com/Neoteroi/rodi/issues" 56 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | cqa: Code Quality Assurance 4 | junit_family=xunit1 5 | asyncio_mode=strict 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-cov 3 | pytest-asyncio 4 | flake8 5 | black 6 | isort 7 | mypy 8 | dataclasses==0.8; python_version < '3.7' 9 | -------------------------------------------------------------------------------- /rodi.code-workspace: -------------------------------------------------------------------------------- 1 | { 2 | "folders": [ 3 | { 4 | "path": "." 5 | } 6 | ], 7 | "settings": { 8 | "python.testing.pytestArgs": [ 9 | "." 10 | ], 11 | "files.trimTrailingWhitespace": true, 12 | "files.trimFinalNewlines": true, 13 | "python.testing.unittestEnabled": false, 14 | "python.testing.nosetestsEnabled": false, 15 | "python.testing.pytestEnabled": true, 16 | "python.linting.pylintEnabled": false, 17 | "python.linting.flake8Enabled": true, 18 | "python.linting.mypyEnabled": false, 19 | "python.linting.enabled": true, 20 | "[python]": { 21 | "editor.tabSize": 4, 22 | "editor.rulers": [ 23 | 88, 24 | 100 25 | ] 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /rodi/__about__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.0.8" 2 | -------------------------------------------------------------------------------- /rodi/__init__.py: -------------------------------------------------------------------------------- 1 | import contextvars 2 | import inspect 3 | import re 4 | import sys 5 | from collections import defaultdict 6 | from enum import Enum 7 | from inspect import Signature, _empty, isabstract, isclass, iscoroutinefunction 8 | from typing import ( 9 | Any, 10 | Callable, 11 | ClassVar, 12 | DefaultDict, 13 | Dict, 14 | Mapping, 15 | Optional, 16 | Set, 17 | Type, 18 | TypeVar, 19 | Union, 20 | cast, 21 | get_type_hints, 22 | ) 23 | 24 | if sys.version_info >= (3, 8): # pragma: no cover 25 | try: 26 | from typing import _no_init_or_replace_init as _no_init 27 | except ImportError: # pragma: no cover 28 | from typing import _no_init 29 | 30 | try: 31 | from typing import Protocol 32 | except ImportError: # pragma: no cover 33 | from typing_extensions import Protocol 34 | 35 | 36 | T = TypeVar("T") 37 | 38 | 39 | class ContainerProtocol(Protocol): 40 | """ 41 | Generic interface of DI Container that can register and resolve services, 42 | and tell if a type is configured. 43 | """ 44 | 45 | def register(self, obj_type: Union[Type, str], *args, **kwargs): 46 | """Registers a type in the container, with optional arguments.""" 47 | 48 | def resolve(self, obj_type: Union[Type[T], str], *args, **kwargs) -> T: 49 | """Activates an instance of the given type, with optional arguments.""" 50 | 51 | def __contains__(self, item) -> bool: 52 | """ 53 | Returns a value indicating whether a given type is configured in this container. 54 | """ 55 | 56 | 57 | AliasesTypeHint = Dict[str, Type] 58 | 59 | 60 | def inject(globalsns=None, localns=None) -> Callable[..., Any]: 61 | """ 62 | Marks a class or a function as injected. This method is only necessary if the class 63 | uses locals and the user uses Python >= 3.10, to bind the function's locals to the 64 | factory. 65 | """ 66 | if localns is None or globalsns is None: 67 | frame = inspect.currentframe() 68 | try: 69 | if localns is None: 70 | localns = frame.f_back.f_locals # type: ignore 71 | if globalsns is None: 72 | globalsns = frame.f_back.f_globals # type: ignore 73 | finally: 74 | del frame 75 | 76 | def decorator(f): 77 | f._locals = localns 78 | f._globals = globalsns 79 | return f 80 | 81 | return decorator 82 | 83 | 84 | def _get_obj_locals(obj) -> Optional[Dict[str, Any]]: 85 | return getattr(obj, "_locals", None) 86 | 87 | 88 | def class_name(input_type): 89 | if input_type in {list, set} and str( # noqa: E721 90 | type(input_type) == "" 91 | ): 92 | # for Python 3.9 list[T], set[T] 93 | return str(input_type) 94 | try: 95 | return input_type.__name__ 96 | except AttributeError: 97 | # for example, this is the case for List[str], Tuple[str, ...], etc. 98 | return str(input_type) 99 | 100 | 101 | class DIException(Exception): 102 | """Base exception class for DI exceptions.""" 103 | 104 | 105 | class FactoryMissingContextException(DIException): 106 | def __init__(self, function) -> None: 107 | super().__init__( 108 | f"The factory '{function.__name__}' lacks locals and globals data. " 109 | "Decorate the function with the `@inject()` decorator defined in " 110 | "`rodi`. This is necessary since PEP 563." 111 | ) 112 | 113 | 114 | class CannotResolveTypeException(DIException): 115 | """ 116 | Exception risen when it is not possible to resolve a Type.""" 117 | 118 | def __init__(self, desired_type): 119 | super().__init__(f"Unable to resolve the type '{desired_type}'.") 120 | 121 | 122 | class CannotResolveParameterException(DIException): 123 | """ 124 | Exception risen when it is not possible to resolve a parameter, 125 | necessary to instantiate a type.""" 126 | 127 | def __init__(self, param_name, desired_type): 128 | super().__init__( 129 | f"Unable to resolve parameter '{param_name}' " 130 | f"when resolving '{class_name(desired_type)}'" 131 | ) 132 | 133 | 134 | class OverridingServiceException(DIException): 135 | """ 136 | Exception risen when registering a service 137 | would override an existing one.""" 138 | 139 | def __init__(self, key, value): 140 | key_name = key if isinstance(key, str) else class_name(key) 141 | super().__init__( 142 | f"A service with key '{key_name}' is already " 143 | f"registered and would be overridden by value {value}." 144 | ) 145 | 146 | 147 | class CircularDependencyException(DIException): 148 | """Exception risen when a circular dependency between a type and 149 | one of its parameters is detected.""" 150 | 151 | def __init__(self, expected_type, desired_type): 152 | super().__init__( 153 | "A circular dependency was detected for the service " 154 | f"of type '{class_name(expected_type)}' " 155 | f"for '{class_name(desired_type)}'" 156 | ) 157 | 158 | 159 | class InvalidOperationInStrictMode(DIException): 160 | def __init__(self): 161 | super().__init__( 162 | "The services are configured in strict mode, the operation is invalid." 163 | ) 164 | 165 | 166 | class AliasAlreadyDefined(DIException): 167 | """Exception risen when trying to add an alias that already exists.""" 168 | 169 | def __init__(self, name): 170 | super().__init__( 171 | f"Cannot define alias '{name}'. " 172 | f"An alias with given name is already defined." 173 | ) 174 | 175 | 176 | class AliasConfigurationError(DIException): 177 | def __init__(self, name, _type): 178 | super().__init__( 179 | f"An alias '{name}' for type '{class_name(_type)}' was defined, " 180 | f"but the type was not configured in the Container." 181 | ) 182 | 183 | 184 | class MissingTypeException(DIException): 185 | """Exception risen when a type must be specified to use a factory""" 186 | 187 | def __init__(self): 188 | super().__init__( 189 | "Please specify the factory return type or " 190 | "annotate its return type; func() -> Foo:" 191 | ) 192 | 193 | 194 | class InvalidFactory(DIException): 195 | """Exception risen when a factory is not valid""" 196 | 197 | def __init__(self, _type): 198 | super().__init__( 199 | f"The factory specified for type {class_name(_type)} is not " 200 | f"valid, it must be a function with either these signatures: " 201 | f"def example_factory(context, type): " 202 | f"or," 203 | f"def example_factory(context): " 204 | f"or," 205 | f"def example_factory(): " 206 | ) 207 | 208 | 209 | class ServiceLifeStyle(Enum): 210 | TRANSIENT = 1 211 | SCOPED = 2 212 | SINGLETON = 3 213 | 214 | 215 | def _get_factory_annotations_or_throw(factory): 216 | factory_locals = getattr(factory, "_locals", None) 217 | factory_globals = getattr(factory, "_globals", None) 218 | 219 | if factory_locals is None: 220 | raise FactoryMissingContextException(factory) 221 | 222 | return get_type_hints(factory, globalns=factory_globals, localns=factory_locals) 223 | 224 | 225 | class ActivationScope: 226 | __slots__ = ("scoped_services", "provider") 227 | 228 | def __init__( 229 | self, 230 | provider: Optional["Services"] = None, 231 | scoped_services: Optional[Dict[Union[Type[T], str], T]] = None, 232 | ): 233 | self.provider = provider or Services() 234 | self.scoped_services = scoped_services or {} 235 | 236 | def __enter__(self): 237 | if self.scoped_services is None: 238 | self.scoped_services = {} 239 | return self 240 | 241 | def __exit__(self, exc_type, exc_val, exc_tb): 242 | self.dispose() 243 | 244 | def get( 245 | self, 246 | desired_type: Union[Type[T], str], 247 | scope: Optional["ActivationScope"] = None, 248 | *, 249 | default: Optional[Any] = ..., 250 | ) -> T: 251 | if self.provider is None: 252 | raise TypeError("This scope is disposed.") 253 | return self.provider.get(desired_type, scope or self, default=default) 254 | 255 | def dispose(self): 256 | if self.provider: 257 | self.provider = None 258 | 259 | if self.scoped_services: 260 | self.scoped_services.clear() 261 | self.scoped_services = None 262 | 263 | 264 | class TrackingActivationScope(ActivationScope): 265 | """ 266 | This is an experimental class to support nested scopes transparently. 267 | To use it, create a container including the `scope_cls` parameter: 268 | `Container(scope_cls=TrackingActivationScope)`. 269 | """ 270 | 271 | _active_scopes = contextvars.ContextVar("active_scopes", default=[]) 272 | 273 | __slots__ = ("scoped_services", "provider", "parent_scope") 274 | 275 | def __init__(self, provider=None, scoped_services=None): 276 | # Get the current stack of active scopes 277 | stack = self._active_scopes.get() 278 | 279 | # Detect the parent scope if it exists 280 | self.parent_scope = stack[-1] if stack else None 281 | 282 | # Initialize scoped services 283 | scoped_services = scoped_services or {} 284 | if self.parent_scope: 285 | scoped_services.update(self.parent_scope.scoped_services) 286 | 287 | super().__init__(provider, scoped_services) 288 | 289 | def __enter__(self): 290 | # Push this scope onto the stack 291 | stack = self._active_scopes.get() 292 | self._active_scopes.set(stack + [self]) 293 | return self 294 | 295 | def __exit__(self, exc_type, exc_val, exc_tb): 296 | # Pop this scope from the stack 297 | stack = self._active_scopes.get() 298 | self._active_scopes.set(stack[:-1]) 299 | self.dispose() 300 | 301 | def dispose(self): 302 | if self.provider: 303 | self.provider = None 304 | 305 | 306 | class ResolutionContext: 307 | __slots__ = ("resolved", "dynamic_chain") 308 | __deletable__ = ("resolved",) 309 | 310 | def __init__(self): 311 | self.resolved = {} 312 | self.dynamic_chain = [] 313 | 314 | def __enter__(self): 315 | return self 316 | 317 | def __exit__(self, exc_type, exc_val, exc_tb): 318 | self.dispose() 319 | 320 | def dispose(self): 321 | del self.resolved 322 | self.dynamic_chain.clear() 323 | 324 | 325 | class InstanceProvider: 326 | __slots__ = ("instance",) 327 | 328 | def __init__(self, instance): 329 | self.instance = instance 330 | 331 | def __call__(self, context, parent_type): 332 | return self.instance 333 | 334 | 335 | class TypeProvider: 336 | __slots__ = ("_type",) 337 | 338 | def __init__(self, _type): 339 | self._type = _type 340 | 341 | def __call__(self, context, parent_type): 342 | return self._type() 343 | 344 | 345 | class ScopedTypeProvider: 346 | __slots__ = ("_type",) 347 | 348 | def __init__(self, _type): 349 | self._type = _type 350 | 351 | def __call__(self, context: ActivationScope, parent_type): 352 | if self._type in context.scoped_services: 353 | return context.scoped_services[self._type] 354 | 355 | service = self._type() 356 | context.scoped_services[self._type] = service 357 | return service 358 | 359 | 360 | class ArgsTypeProvider: 361 | __slots__ = ("_type", "_args_callbacks") 362 | 363 | def __init__(self, _type, args_callbacks): 364 | self._type = _type 365 | self._args_callbacks = args_callbacks 366 | 367 | def __call__(self, context, parent_type): 368 | return self._type(*[fn(context, self._type) for fn in self._args_callbacks]) 369 | 370 | 371 | class FactoryTypeProvider: 372 | __slots__ = ("_type", "factory") 373 | 374 | def __init__(self, _type, factory): 375 | self._type = _type 376 | self.factory = factory 377 | 378 | def __call__(self, context: ActivationScope, parent_type): 379 | assert isinstance(context, ActivationScope) 380 | return self.factory(context, parent_type) 381 | 382 | 383 | class SingletonFactoryTypeProvider: 384 | __slots__ = ("_type", "factory", "instance") 385 | 386 | def __init__(self, _type, factory): 387 | self._type = _type 388 | self.factory = factory 389 | self.instance = None 390 | 391 | def __call__(self, context: ActivationScope, parent_type): 392 | if self.instance is None: 393 | self.instance = self.factory(context, parent_type) 394 | return self.instance 395 | 396 | 397 | class ScopedFactoryTypeProvider: 398 | __slots__ = ("_type", "factory") 399 | 400 | def __init__(self, _type, factory): 401 | self._type = _type 402 | self.factory = factory 403 | 404 | def __call__(self, context: ActivationScope, parent_type): 405 | if self._type in context.scoped_services: 406 | return context.scoped_services[self._type] 407 | 408 | instance = self.factory(context, parent_type) 409 | context.scoped_services[self._type] = instance 410 | return instance 411 | 412 | 413 | class ScopedArgsTypeProvider: 414 | __slots__ = ("_type", "_args_callbacks") 415 | 416 | def __init__(self, _type, args_callbacks): 417 | self._type = _type 418 | self._args_callbacks = args_callbacks 419 | 420 | def __call__(self, context: ActivationScope, parent_type): 421 | if self._type in context.scoped_services: 422 | return context.scoped_services[self._type] 423 | 424 | service = self._type(*[fn(context, self._type) for fn in self._args_callbacks]) 425 | context.scoped_services[self._type] = service 426 | return service 427 | 428 | 429 | class SingletonTypeProvider: 430 | __slots__ = ("_type", "_instance", "_args_callbacks") 431 | 432 | def __init__(self, _type, _args_callbacks): 433 | self._type = _type 434 | self._args_callbacks = _args_callbacks 435 | self._instance = None 436 | 437 | def __call__(self, context, parent_type): 438 | if self._instance is None: 439 | self._instance = ( 440 | self._type(*[fn(context, self._type) for fn in self._args_callbacks]) 441 | if self._args_callbacks 442 | else self._type() 443 | ) 444 | 445 | return self._instance 446 | 447 | 448 | def get_annotations_type_provider( 449 | concrete_type: Type, 450 | resolvers: Mapping[str, Callable], 451 | life_style: ServiceLifeStyle, 452 | resolver_context: ResolutionContext, 453 | ): 454 | def factory(context, parent_type): 455 | instance = concrete_type() 456 | for name, resolver in resolvers.items(): 457 | setattr(instance, name, resolver(context, parent_type)) 458 | return instance 459 | 460 | return FactoryResolver(concrete_type, factory, life_style)(resolver_context) 461 | 462 | 463 | def _get_plain_class_factory(concrete_type: Type): 464 | def factory(*args): 465 | return concrete_type() 466 | 467 | return factory 468 | 469 | 470 | class InstanceResolver: 471 | __slots__ = ("instance",) 472 | 473 | def __init__(self, instance): 474 | self.instance = instance 475 | 476 | def __repr__(self): 477 | return f"" 478 | 479 | def __call__(self, context: ResolutionContext): 480 | return InstanceProvider(self.instance) 481 | 482 | 483 | class Dependency: 484 | __slots__ = ("name", "annotation") 485 | 486 | def __init__(self, name, annotation): 487 | self.name = name 488 | self.annotation = annotation 489 | 490 | 491 | class DynamicResolver: 492 | __slots__ = ("_concrete_type", "services", "life_style") 493 | 494 | def __init__(self, concrete_type, services, life_style): 495 | assert isclass(concrete_type) 496 | assert not isabstract(concrete_type) 497 | 498 | self._concrete_type = concrete_type 499 | self.services = services 500 | self.life_style = life_style 501 | 502 | @property 503 | def concrete_type(self) -> Type: 504 | return self._concrete_type 505 | 506 | def _get_resolver(self, desired_type, context: ResolutionContext): 507 | # NB: the following two lines are important to ensure that singletons 508 | # are instantiated only once per service provider 509 | # to not repeat operations more than once 510 | if desired_type in context.resolved: 511 | return context.resolved[desired_type] 512 | 513 | reg = self.services._map.get(desired_type) 514 | assert ( 515 | reg is not None 516 | ), f"A resolver for type {class_name(desired_type)} is not configured" 517 | resolver = reg(context) 518 | 519 | # add the resolver to the context, so we can find it 520 | # next time we need it 521 | context.resolved[desired_type] = resolver 522 | return resolver 523 | 524 | def _get_resolvers_for_parameters( 525 | self, 526 | concrete_type, 527 | context: ResolutionContext, 528 | params: Mapping[str, Dependency], 529 | ): 530 | fns = [] 531 | services = self.services 532 | 533 | for param_name, param in params.items(): 534 | if param_name in ("self", "args", "kwargs"): 535 | continue 536 | 537 | param_type = param.annotation 538 | 539 | if param_type is _empty: 540 | if services.strict: 541 | raise CannotResolveParameterException(param_name, concrete_type) 542 | 543 | # support for exact, user defined aliases, without ambiguity 544 | exact_alias = services._exact_aliases.get(param_name) 545 | 546 | if exact_alias: 547 | param_type = exact_alias 548 | else: 549 | aliases = services._aliases[param_name] 550 | 551 | if aliases: 552 | assert ( 553 | len(aliases) == 1 554 | ), "Configured aliases cannot be ambiguous" 555 | for param_type in aliases: 556 | break 557 | 558 | if param_type not in services._map: 559 | raise CannotResolveParameterException(param_name, concrete_type) 560 | 561 | param_resolver = self._get_resolver(param_type, context) 562 | fns.append(param_resolver) 563 | return fns 564 | 565 | def _resolve_by_init_method(self, context: ResolutionContext): 566 | sig = Signature.from_callable(self.concrete_type.__init__) 567 | params = { 568 | key: Dependency(key, value.annotation) 569 | for key, value in sig.parameters.items() 570 | } 571 | 572 | if sys.version_info >= (3, 10): # pragma: no cover 573 | # Python 3.10 574 | annotations = get_type_hints( 575 | self.concrete_type.__init__, 576 | vars(sys.modules[self.concrete_type.__module__]), 577 | _get_obj_locals(self.concrete_type), 578 | ) 579 | for key, value in params.items(): 580 | if key in annotations: 581 | value.annotation = annotations[key] 582 | 583 | concrete_type = self.concrete_type 584 | 585 | if len(params) == 1 and next(iter(params.keys())) == "self": 586 | if self.life_style == ServiceLifeStyle.SINGLETON: 587 | return SingletonTypeProvider(concrete_type, None) 588 | 589 | if self.life_style == ServiceLifeStyle.SCOPED: 590 | return ScopedTypeProvider(concrete_type) 591 | 592 | return TypeProvider(concrete_type) 593 | 594 | fns = self._get_resolvers_for_parameters(concrete_type, context, params) 595 | 596 | if self.life_style == ServiceLifeStyle.SINGLETON: 597 | return SingletonTypeProvider(concrete_type, fns) 598 | 599 | if self.life_style == ServiceLifeStyle.SCOPED: 600 | return ScopedArgsTypeProvider(concrete_type, fns) 601 | 602 | return ArgsTypeProvider(concrete_type, fns) 603 | 604 | def _ignore_class_attribute(self, key: str, value) -> bool: 605 | """ 606 | Returns a value indicating whether a class attribute should be ignored for 607 | dependency resolution, by name and value. 608 | It's ignored if it's a ClassVar or if it's already initialized explicitly. 609 | """ 610 | is_classvar = getattr(value, "__origin__", None) is ClassVar 611 | is_initialized = getattr(self.concrete_type, key, None) is not None 612 | 613 | return is_classvar or is_initialized 614 | 615 | def _has_default_init(self): 616 | init = getattr(self.concrete_type, "__init__", None) 617 | 618 | if init is object.__init__: 619 | return True 620 | 621 | if sys.version_info >= (3, 8): # pragma: no cover 622 | if init is _no_init: 623 | return True 624 | return False 625 | 626 | def _resolve_by_annotations( 627 | self, context: ResolutionContext, annotations: Dict[str, Type] 628 | ): 629 | params = { 630 | key: Dependency(key, value) 631 | for key, value in annotations.items() 632 | if not self._ignore_class_attribute(key, value) 633 | } 634 | concrete_type = self.concrete_type 635 | 636 | fns = self._get_resolvers_for_parameters(concrete_type, context, params) 637 | resolvers = {} 638 | 639 | for i, name in enumerate(params.keys()): 640 | resolvers[name] = fns[i] 641 | 642 | return get_annotations_type_provider( 643 | self.concrete_type, resolvers, self.life_style, context 644 | ) 645 | 646 | def __call__(self, context: ResolutionContext): 647 | concrete_type = self.concrete_type 648 | 649 | chain = context.dynamic_chain 650 | chain.append(concrete_type) 651 | 652 | if self._has_default_init(): 653 | annotations = get_type_hints( 654 | concrete_type, 655 | vars(sys.modules[concrete_type.__module__]), 656 | _get_obj_locals(concrete_type), 657 | ) 658 | 659 | if annotations: 660 | try: 661 | return self._resolve_by_annotations(context, annotations) 662 | except RecursionError: 663 | raise CircularDependencyException(chain[0], concrete_type) 664 | 665 | return FactoryResolver( 666 | concrete_type, _get_plain_class_factory(concrete_type), self.life_style 667 | )(context) 668 | 669 | try: 670 | return self._resolve_by_init_method(context) 671 | except RecursionError: 672 | raise CircularDependencyException(chain[0], concrete_type) 673 | 674 | 675 | class FactoryResolver: 676 | __slots__ = ("concrete_type", "factory", "params", "life_style") 677 | 678 | def __init__(self, concrete_type, factory, life_style): 679 | self.factory = factory 680 | self.concrete_type = concrete_type 681 | self.life_style = life_style 682 | 683 | def __call__(self, context: ResolutionContext): 684 | if self.life_style == ServiceLifeStyle.SINGLETON: 685 | return SingletonFactoryTypeProvider(self.concrete_type, self.factory) 686 | 687 | if self.life_style == ServiceLifeStyle.SCOPED: 688 | return ScopedFactoryTypeProvider(self.concrete_type, self.factory) 689 | 690 | return FactoryTypeProvider(self.concrete_type, self.factory) 691 | 692 | 693 | first_cap_re = re.compile("(.)([A-Z][a-z]+)") 694 | all_cap_re = re.compile("([a-z0-9])([A-Z])") 695 | 696 | 697 | def to_standard_param_name(name): 698 | value = all_cap_re.sub(r"\1_\2", first_cap_re.sub(r"\1_\2", name)).lower() 699 | if value.startswith("i_"): 700 | return "i" + value[2:] 701 | return value 702 | 703 | 704 | class Services: 705 | """ 706 | Provides methods to activate instances of classes, by cached activator functions. 707 | """ 708 | 709 | __slots__ = ("_map", "_executors", "_scope_cls") 710 | 711 | def __init__( 712 | self, 713 | services_map=None, 714 | scope_cls: Optional[Type[ActivationScope]] = None, 715 | ): 716 | if services_map is None: 717 | services_map = {} 718 | self._map = services_map 719 | self._executors = {} 720 | self._scope_cls = scope_cls or ActivationScope 721 | 722 | def __contains__(self, item): 723 | return item in self._map 724 | 725 | def __getitem__(self, item): 726 | return self.get(item) 727 | 728 | def __setitem__(self, key, value): 729 | self.set(key, value) 730 | 731 | def create_scope( 732 | self, scoped: Optional[Dict[Union[Type, str], Any]] = None 733 | ) -> ActivationScope: 734 | return self._scope_cls(self, scoped) 735 | 736 | def set(self, new_type: Union[Type, str], value: Any): 737 | """ 738 | Sets a new service of desired type, as singleton. 739 | This method exists to increase interoperability of Services class (with dict). 740 | 741 | :param new_type: 742 | :param value: 743 | :return: 744 | """ 745 | type_name = class_name(new_type) 746 | if new_type in self._map or ( 747 | not isinstance(new_type, str) and type_name in self._map 748 | ): 749 | raise OverridingServiceException(self._map[new_type], new_type) 750 | 751 | def resolver(context, desired_type): 752 | return value 753 | 754 | self._map[new_type] = resolver 755 | if not isinstance(new_type, str): 756 | self._map[type_name] = resolver 757 | 758 | def get( 759 | self, 760 | desired_type: Union[Type[T], str], 761 | scope: Optional[ActivationScope] = None, 762 | *, 763 | default: Optional[Any] = ..., 764 | ) -> T: 765 | """ 766 | Gets a service of the desired type, returning an activated instance. 767 | 768 | :param desired_type: desired service type. 769 | :param context: optional context, used to handle scoped services. 770 | :return: an instance of the desired type 771 | """ 772 | if scope is None: 773 | scope = self.create_scope() 774 | 775 | resolver = self._map.get(desired_type) 776 | scoped_service = scope.scoped_services.get(desired_type) if scope else None 777 | 778 | if not resolver and not scoped_service: 779 | if default is not ...: 780 | return cast(T, default) 781 | raise CannotResolveTypeException(desired_type) 782 | 783 | return cast(T, scoped_service or resolver(scope, desired_type)) 784 | 785 | def _get_getter(self, key, param): 786 | if param.annotation is _empty: 787 | 788 | def getter(context): 789 | return self.get(key, context) 790 | 791 | else: 792 | 793 | def getter(context): 794 | return self.get(param.annotation, context) 795 | 796 | getter.__name__ = f"" 797 | return getter 798 | 799 | def get_executor(self, method: Callable) -> Callable: 800 | sig = Signature.from_callable(method) 801 | params = { 802 | key: Dependency(key, value.annotation) 803 | for key, value in sig.parameters.items() 804 | } 805 | 806 | if sys.version_info >= (3, 10): # pragma: no cover 807 | # Python 3.10 808 | annotations = _get_factory_annotations_or_throw(method) 809 | for key, value in params.items(): 810 | if key in annotations: 811 | value.annotation = annotations[key] 812 | 813 | fns = [] 814 | 815 | for key, value in params.items(): 816 | fns.append(self._get_getter(key, value)) 817 | 818 | if iscoroutinefunction(method): 819 | 820 | async def async_executor( 821 | scoped: Optional[Dict[Union[Type, str], Any]] = None, 822 | ): 823 | with self.create_scope(scoped) as context: 824 | return await method(*[fn(context) for fn in fns]) 825 | 826 | return async_executor 827 | 828 | def executor(scoped: Optional[Dict[Union[Type, str], Any]] = None): 829 | with self.create_scope(scoped) as context: 830 | return method(*[fn(context) for fn in fns]) 831 | 832 | return executor 833 | 834 | def exec( 835 | self, 836 | method: Callable, 837 | scoped: Optional[Dict[Type, Any]] = None, 838 | ) -> Any: 839 | try: 840 | executor = self._executors[method] 841 | except KeyError: 842 | executor = self.get_executor(method) 843 | self._executors[method] = executor 844 | return executor(scoped) 845 | 846 | 847 | FactoryCallableNoArguments = Callable[[], Any] 848 | FactoryCallableSingleArgument = Callable[[ActivationScope], Any] 849 | FactoryCallableTwoArguments = Callable[[ActivationScope, Type], Any] 850 | FactoryCallableType = Union[ 851 | FactoryCallableNoArguments, 852 | FactoryCallableSingleArgument, 853 | FactoryCallableTwoArguments, 854 | ] 855 | 856 | 857 | class FactoryWrapperNoArgs: 858 | __slots__ = ("factory",) 859 | 860 | def __init__(self, factory): 861 | self.factory = factory 862 | 863 | def __call__(self, context, activating_type): 864 | return self.factory() 865 | 866 | 867 | class FactoryWrapperContextArg: 868 | __slots__ = ("factory",) 869 | 870 | def __init__(self, factory): 871 | self.factory = factory 872 | 873 | def __call__(self, context, activating_type): 874 | return self.factory(context) 875 | 876 | 877 | class Container(ContainerProtocol): 878 | """ 879 | Configuration class for a collection of services. 880 | """ 881 | 882 | __slots__ = ("_map", "_aliases", "_exact_aliases", "_scope_cls", "strict") 883 | 884 | def __init__( 885 | self, 886 | *, 887 | strict: bool = False, 888 | scope_cls: Optional[Type[ActivationScope]] = None, 889 | ): 890 | self._map: Dict[Type, Callable] = {} 891 | self._aliases: DefaultDict[str, Set[Type]] = defaultdict(set) 892 | self._exact_aliases: Dict[str, Type] = {} 893 | self._provider: Optional[Services] = None 894 | self._scope_cls = scope_cls 895 | self.strict = strict 896 | 897 | @property 898 | def provider(self) -> Services: 899 | if self._provider is None: 900 | self._provider = self.build_provider() 901 | return self._provider 902 | 903 | def __iter__(self): 904 | yield from self._map.items() 905 | 906 | def __contains__(self, key): 907 | return key in self._map 908 | 909 | def bind_types( 910 | self, 911 | obj_type: Any, 912 | concrete_type: Any = None, 913 | life_style: ServiceLifeStyle = ServiceLifeStyle.TRANSIENT, 914 | ): 915 | try: 916 | assert issubclass(concrete_type, obj_type), ( 917 | f"Cannot register {class_name(obj_type)} for abstract class " 918 | f"{class_name(concrete_type)}" 919 | ) 920 | except TypeError: 921 | # ignore, this happens with generic types 922 | pass 923 | self._bind(obj_type, DynamicResolver(concrete_type, self, life_style)) 924 | return self 925 | 926 | def register( 927 | self, 928 | obj_type: Any, 929 | sub_type: Any = None, 930 | instance: Any = None, 931 | *args, 932 | **kwargs, 933 | ) -> "Container": 934 | """ 935 | Registers a type in this container. 936 | """ 937 | if instance is not None: 938 | self.add_instance(instance, declared_class=obj_type) 939 | return self 940 | 941 | if sub_type is None: 942 | self._add_exact_transient(obj_type) 943 | else: 944 | self.add_transient(obj_type, sub_type) 945 | return self 946 | 947 | def resolve( 948 | self, 949 | obj_type: Union[Type[T], str], 950 | scope: Any = None, 951 | *args, 952 | **kwargs, 953 | ) -> T: 954 | """ 955 | Resolves a service by type, obtaining an instance of that type. 956 | """ 957 | return self.provider.get(obj_type, scope=scope) 958 | 959 | def add_alias(self, name: str, desired_type: Type): 960 | """ 961 | Adds an alias to the set of inferred aliases. 962 | 963 | :param name: parameter name 964 | :param desired_type: desired type by parameter name 965 | :return: self 966 | """ 967 | if self.strict: 968 | raise InvalidOperationInStrictMode() 969 | if name in self._aliases or name in self._exact_aliases: 970 | raise AliasAlreadyDefined(name) 971 | self._aliases[name].add(desired_type) 972 | return self 973 | 974 | def add_aliases(self, values: AliasesTypeHint): 975 | """ 976 | Adds aliases to the set of inferred aliases. 977 | 978 | :param values: mapping object (parameter name: class) 979 | :return: self 980 | """ 981 | for key, value in values.items(): 982 | self.add_alias(key, value) 983 | return self 984 | 985 | def set_alias(self, name: str, desired_type: Type, override: bool = False): 986 | """ 987 | Sets an exact alias for a desired type. 988 | 989 | :param name: parameter name 990 | :param desired_type: desired type by parameter name 991 | :param override: whether to override existing values, or throw exception 992 | :return: self 993 | """ 994 | if self.strict: 995 | raise InvalidOperationInStrictMode() 996 | if not override and name in self._exact_aliases: 997 | raise AliasAlreadyDefined(name) 998 | self._exact_aliases[name] = desired_type 999 | return self 1000 | 1001 | def set_aliases(self, values: AliasesTypeHint, override: bool = False): 1002 | """Sets many exact aliases for desired types. 1003 | 1004 | :param values: mapping object (parameter name: class) 1005 | :param override: whether to override existing values, or throw exception 1006 | :return: self 1007 | """ 1008 | for key, value in values.items(): 1009 | self.set_alias(key, value, override) 1010 | return self 1011 | 1012 | def _bind(self, key: Type, value: Any) -> None: 1013 | if key in self._map: 1014 | raise OverridingServiceException(key, value) 1015 | self._map[key] = value 1016 | 1017 | if self._provider is not None: 1018 | self._provider = None 1019 | 1020 | key_name = class_name(key) 1021 | 1022 | if self.strict or "." in key_name: 1023 | return 1024 | 1025 | self._aliases[key_name].add(key) 1026 | self._aliases[key_name.lower()].add(key) 1027 | self._aliases[to_standard_param_name(key_name)].add(key) 1028 | 1029 | def add_instance( 1030 | self, instance: Any, declared_class: Optional[Type] = None 1031 | ) -> "Container": 1032 | """ 1033 | Registers an exact instance, optionally by declared class. 1034 | 1035 | :param instance: singleton to be registered 1036 | :param declared_class: optionally, lets define the class used as reference of 1037 | the singleton 1038 | :return: the service collection itself 1039 | """ 1040 | self._bind( 1041 | instance.__class__ if not declared_class else declared_class, 1042 | InstanceResolver(instance), 1043 | ) 1044 | return self 1045 | 1046 | def add_singleton( 1047 | self, base_type: Type, concrete_type: Optional[Type] = None 1048 | ) -> "Container": 1049 | """ 1050 | Registers a type by base type, to be instantiated with singleton lifetime. 1051 | If a single type is given, the method `add_exact_singleton` is used. 1052 | 1053 | :param base_type: registered type. If a concrete type is provided, it must 1054 | inherit the base type. 1055 | :param concrete_type: concrete class 1056 | :return: the service collection itself 1057 | """ 1058 | if concrete_type is None: 1059 | return self._add_exact_singleton(base_type) 1060 | 1061 | return self.bind_types(base_type, concrete_type, ServiceLifeStyle.SINGLETON) 1062 | 1063 | def add_scoped( 1064 | self, base_type: Type, concrete_type: Optional[Type] = None 1065 | ) -> "Container": 1066 | """ 1067 | Registers a type by base type, to be instantiated with scoped lifetime. 1068 | If a single type is given, the method `add_exact_scoped` is used. 1069 | 1070 | :param base_type: registered type. If a concrete type is provided, it must 1071 | inherit the base type. 1072 | :param concrete_type: concrete class 1073 | :return: the service collection itself 1074 | """ 1075 | if concrete_type is None: 1076 | return self._add_exact_scoped(base_type) 1077 | 1078 | return self.bind_types(base_type, concrete_type, ServiceLifeStyle.SCOPED) 1079 | 1080 | def add_transient( 1081 | self, base_type: Type, concrete_type: Optional[Type] = None 1082 | ) -> "Container": 1083 | """ 1084 | Registers a type by base type, to be instantiated with transient lifetime. 1085 | If a single type is given, the method `add_exact_transient` is used. 1086 | 1087 | :param base_type: registered type. If a concrete type is provided, it must 1088 | inherit the base type. 1089 | :param concrete_type: concrete class 1090 | :return: the service collection itself 1091 | """ 1092 | if concrete_type is None: 1093 | return self._add_exact_transient(base_type) 1094 | 1095 | return self.bind_types(base_type, concrete_type, ServiceLifeStyle.TRANSIENT) 1096 | 1097 | def _add_exact_singleton(self, concrete_type: Type) -> "Container": 1098 | """ 1099 | Registers an exact type, to be instantiated with singleton lifetime. 1100 | 1101 | :param concrete_type: concrete class 1102 | :return: the service collection itself 1103 | """ 1104 | assert not isabstract(concrete_type) 1105 | self._bind( 1106 | concrete_type, 1107 | DynamicResolver(concrete_type, self, ServiceLifeStyle.SINGLETON), 1108 | ) 1109 | return self 1110 | 1111 | def _add_exact_scoped(self, concrete_type: Type) -> "Container": 1112 | """ 1113 | Registers an exact type, to be instantiated with scoped lifetime. 1114 | 1115 | :param concrete_type: concrete class 1116 | :return: the service collection itself 1117 | """ 1118 | assert not isabstract(concrete_type) 1119 | self._bind( 1120 | concrete_type, DynamicResolver(concrete_type, self, ServiceLifeStyle.SCOPED) 1121 | ) 1122 | return self 1123 | 1124 | def _add_exact_transient(self, concrete_type: Type) -> "Container": 1125 | """ 1126 | Registers an exact type, to be instantiated with transient lifetime. 1127 | 1128 | :param concrete_type: concrete class 1129 | :return: the service collection itself 1130 | """ 1131 | assert not isabstract(concrete_type) 1132 | self._bind( 1133 | concrete_type, 1134 | DynamicResolver(concrete_type, self, ServiceLifeStyle.TRANSIENT), 1135 | ) 1136 | return self 1137 | 1138 | def add_singleton_by_factory( 1139 | self, factory: FactoryCallableType, return_type: Optional[Type] = None 1140 | ) -> "Container": 1141 | self.register_factory(factory, return_type, ServiceLifeStyle.SINGLETON) 1142 | return self 1143 | 1144 | def add_transient_by_factory( 1145 | self, factory: FactoryCallableType, return_type: Optional[Type] = None 1146 | ) -> "Container": 1147 | self.register_factory(factory, return_type, ServiceLifeStyle.TRANSIENT) 1148 | return self 1149 | 1150 | def add_scoped_by_factory( 1151 | self, factory: FactoryCallableType, return_type: Optional[Type] = None 1152 | ) -> "Container": 1153 | self.register_factory(factory, return_type, ServiceLifeStyle.SCOPED) 1154 | return self 1155 | 1156 | @staticmethod 1157 | def _check_factory(factory, signature, handled_type) -> Callable: 1158 | assert callable(factory), "The factory must be callable" 1159 | 1160 | params_len = len(signature.parameters) 1161 | 1162 | if params_len == 0: 1163 | return FactoryWrapperNoArgs(factory) 1164 | 1165 | if params_len == 1: 1166 | return FactoryWrapperContextArg(factory) 1167 | 1168 | if params_len == 2: 1169 | return factory 1170 | 1171 | raise InvalidFactory(handled_type) 1172 | 1173 | def register_factory( 1174 | self, 1175 | factory: Callable, 1176 | return_type: Optional[Type], 1177 | life_style: ServiceLifeStyle, 1178 | ) -> None: 1179 | if not callable(factory): 1180 | raise InvalidFactory(return_type) 1181 | 1182 | sign = Signature.from_callable(factory) 1183 | if return_type is None: 1184 | if sign.return_annotation is _empty: 1185 | raise MissingTypeException() 1186 | return_type = sign.return_annotation 1187 | 1188 | if isinstance(return_type, str): # pragma: no cover 1189 | # Python 3.10 1190 | annotations = _get_factory_annotations_or_throw(factory) 1191 | return_type = annotations["return"] 1192 | 1193 | self._bind( 1194 | return_type, # type: ignore 1195 | FactoryResolver( 1196 | return_type, self._check_factory(factory, sign, return_type), life_style 1197 | ), 1198 | ) 1199 | 1200 | def build_provider(self) -> Services: 1201 | """ 1202 | Builds and returns a service provider that can be used to activate and obtain 1203 | services. 1204 | 1205 | The configuration of services is validated at this point, if any service cannot 1206 | be instantiated due to missing dependencies, an exception is thrown inside this 1207 | operation. 1208 | 1209 | :return: Service provider that can be used to activate and obtain services. 1210 | """ 1211 | with ResolutionContext() as context: 1212 | _map: Dict[Union[str, Type], Type] = {} 1213 | 1214 | for _type, resolver in self._map.items(): 1215 | if isinstance(resolver, DynamicResolver): 1216 | context.dynamic_chain.clear() 1217 | 1218 | if _type in context.resolved: 1219 | # assert _type not in context.resolved, "_map keys must be unique" 1220 | # check if its in the map 1221 | if _type in _map: 1222 | # NB: do not call resolver if one was already prepared for the 1223 | # type 1224 | raise OverridingServiceException(_type, resolver) 1225 | else: 1226 | resolved = context.resolved[_type] 1227 | else: 1228 | # add to context so that we don't repeat operations 1229 | resolved = resolver(context) 1230 | context.resolved[_type] = resolved 1231 | 1232 | _map[_type] = resolved 1233 | 1234 | type_name = class_name(_type) 1235 | if "." not in type_name: 1236 | _map[type_name] = _map[_type] 1237 | 1238 | if not self.strict: 1239 | assert self._aliases is not None 1240 | assert self._exact_aliases is not None 1241 | 1242 | # include aliases in the map; 1243 | for name, _types in self._aliases.items(): 1244 | for _type in _types: 1245 | break 1246 | _map[name] = self._get_alias_target_type(name, _map, _type) 1247 | 1248 | for name, _type in self._exact_aliases.items(): 1249 | _map[name] = self._get_alias_target_type(name, _map, _type) 1250 | 1251 | return Services(_map, scope_cls=self._scope_cls) 1252 | 1253 | @staticmethod 1254 | def _get_alias_target_type(name, _map, _type): 1255 | try: 1256 | return _map[_type] 1257 | except KeyError: 1258 | raise AliasConfigurationError(name, _type) 1259 | -------------------------------------------------------------------------------- /rodi/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Neoteroi/rodi/1b9367b743597fc9ec5eab68d5c8e515d7d066bd/rodi/py.typed -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Neoteroi/rodi/1b9367b743597fc9ec5eab68d5c8e515d7d066bd/tests/__init__.py -------------------------------------------------------------------------------- /tests/examples.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from abc import ABC, abstractmethod 3 | from typing import Optional 4 | 5 | 6 | class Cat: 7 | def __init__(self, name: str): 8 | self.name = name 9 | 10 | 11 | # abstract interface 12 | class ICatsRepository(ABC): 13 | @abstractmethod 14 | def get_by_id(self, _id) -> Cat: 15 | pass 16 | 17 | 18 | # one of the possible implementations of ICatsRepository 19 | class InMemoryCatsRepository(ICatsRepository): 20 | def __init__(self): 21 | self._cats = {} 22 | 23 | def get_by_id(self, _id) -> Cat: 24 | return self._cats.get(_id) 25 | 26 | 27 | # NB: example of business layer class, using interface of repository 28 | class GetCatRequestHandler: 29 | def __init__(self, cats_repository: ICatsRepository): 30 | self.repo = cats_repository 31 | 32 | def get_cat(self, _id): 33 | cat = self.repo.get_by_id(_id) 34 | return cat 35 | 36 | 37 | # NB: example of controller class; 38 | class CatsController: 39 | def __init__(self, get_cat_request_handler: GetCatRequestHandler): 40 | self.cat_request_handler = get_cat_request_handler 41 | 42 | 43 | class IRequestContext(ABC): 44 | @property 45 | @abstractmethod 46 | def id(self): 47 | pass 48 | 49 | @property 50 | @abstractmethod 51 | def user(self): 52 | pass 53 | 54 | 55 | class RequestContext(IRequestContext): 56 | def __init__(self): 57 | pass 58 | 59 | @property 60 | def id(self): 61 | return "Example" 62 | 63 | @property 64 | def user(self): 65 | return "Example" 66 | 67 | 68 | class ServiceSettings: 69 | def __init__(self, foo_db_connection_string): 70 | self.foo_db_connection_string = foo_db_connection_string 71 | 72 | 73 | class FooDBContext: 74 | def __init__(self, service_settings: ServiceSettings): 75 | self.settings = service_settings 76 | self.connection_string = service_settings.foo_db_connection_string 77 | 78 | 79 | class FooDBCatsRepository(ICatsRepository): 80 | def __init__(self, context: FooDBContext): 81 | self.context = context 82 | 83 | def get_by_id(self, _id) -> Cat: 84 | pass 85 | 86 | 87 | class IValueProvider: 88 | @property 89 | @abstractmethod 90 | def value(self): 91 | pass 92 | 93 | 94 | class ValueProvider(IValueProvider): 95 | __slots__ = "_value" 96 | 97 | def __init__(self, value): 98 | self._value = value 99 | 100 | @property 101 | def value(self): 102 | return self._value 103 | 104 | 105 | class IdGetter: 106 | def __init__(self): 107 | self.value = uuid.uuid4() 108 | 109 | def __repr__(self): 110 | return f"" 111 | 112 | def __str__(self): 113 | return f"" 114 | 115 | 116 | class A: 117 | def __init__(self, id_getter: IdGetter): 118 | self.id_getter = id_getter 119 | 120 | 121 | class B: 122 | def __init__(self, a: A, id_getter: IdGetter): 123 | self.a = a 124 | self.id_getter = id_getter 125 | 126 | 127 | class C: 128 | def __init__(self, a: A, b: B, id_getter: IdGetter): 129 | self.a = a 130 | self.b = b 131 | self.id_getter = id_getter 132 | 133 | 134 | class ICircle(ABC): 135 | pass 136 | 137 | 138 | class Circle(ICircle): 139 | def __init__(self, circular: ICircle): 140 | # NB: this is not supported by DI 141 | self.circular = circular 142 | 143 | 144 | class Circle2(ICircle): 145 | circular: ICircle 146 | 147 | 148 | class Shape: 149 | def __init__(self, circle: Circle): 150 | self.circle = circle 151 | 152 | 153 | class Foo: 154 | def __init__(self): 155 | pass 156 | 157 | 158 | class UfoOne: 159 | def __init__(self): 160 | pass 161 | 162 | 163 | class UfoTwo: 164 | def __init__(self, one: UfoOne): 165 | self.one = one 166 | 167 | 168 | class UfoThree(UfoTwo): 169 | def __init__(self, one: UfoOne, foo: Foo): 170 | super().__init__(one) 171 | self.foo = foo 172 | 173 | 174 | class UfoFour(UfoThree): 175 | def __init__(self, one: UfoOne, foo: Foo): 176 | super().__init__(one, foo) 177 | 178 | 179 | class TypeWithOptional: 180 | def __init__(self, foo: Optional[Foo]): 181 | self.foo = foo 182 | 183 | 184 | class SelfReferencingCircle: 185 | def __init__(self, circle: "SelfReferencingCircle"): 186 | self.circular = circle 187 | 188 | 189 | class TrickyCircle: 190 | def __init__(self, circle: ICircle): 191 | self.circular = circle 192 | 193 | 194 | class ResolveThisByParameterName: 195 | def __init__(self, icats_repository): 196 | self.cats_repository = icats_repository 197 | 198 | 199 | class IByParamName: 200 | pass 201 | 202 | 203 | class FooByParamName(IByParamName): 204 | def __init__(self, foo): 205 | self.foo = foo 206 | 207 | 208 | class Jing: 209 | def __init__(self, jang): 210 | self.jang = jang 211 | 212 | 213 | class Jang: 214 | def __init__(self, jing): 215 | self.jing = jing 216 | 217 | 218 | class Q: 219 | def __init__(self): 220 | pass 221 | 222 | 223 | class R: 224 | def __init__(self, p): 225 | self.p = p 226 | 227 | 228 | class P: 229 | def __init__(self): 230 | pass 231 | 232 | 233 | class W: 234 | def __init__(self, x): 235 | self.x = x 236 | 237 | 238 | class X: 239 | def __init__(self, y): 240 | self.y = y 241 | 242 | 243 | class Y: 244 | def __init__(self, z): 245 | self.z = z 246 | 247 | 248 | class Z: 249 | def __init__(self, w): 250 | self.w = w 251 | 252 | 253 | class Ko: 254 | def __init__(self): 255 | pass 256 | 257 | 258 | class Ok: 259 | def __init__(self): 260 | pass 261 | 262 | 263 | class PrecedenceOfTypeHintsOverNames: 264 | def __init__(self, foo: Q, ko: P): 265 | self.q = foo 266 | self.p = ko 267 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import importlib 3 | import sys 4 | 5 | import pytest 6 | 7 | examples = [file for file in glob.glob("./examples/*.py")] 8 | 9 | 10 | sys.path.append("./examples") 11 | 12 | 13 | @pytest.mark.parametrize("file_path", examples) 14 | def test_example(file_path: str): 15 | module_name = ( 16 | # Windows 17 | file_path.replace("./examples\\", "") 18 | # Unix 19 | .replace("./examples/", "").replace(".py", "") 20 | ) 21 | # assertions are in imported modules 22 | importlib.import_module(module_name) 23 | -------------------------------------------------------------------------------- /tests/test_fn_exec.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions exec tests. 3 | exec functions are designed to enable executing any function injecting parameters. 4 | """ 5 | 6 | import pytest 7 | 8 | from rodi import Container, inject 9 | 10 | 11 | class Example: 12 | def __init__(self, repository): 13 | self.repository = repository 14 | 15 | 16 | class Context: 17 | def __init__(self): 18 | self.trace_id = "1111" 19 | 20 | 21 | class Repository: 22 | def __init__(self, context: Context): 23 | self.context = context 24 | 25 | 26 | def test_execute_function(): 27 | class Example: 28 | def __init__(self, repository): 29 | self.repository = repository 30 | 31 | class Context: 32 | def __init__(self): 33 | self.trace_id = "1111" 34 | 35 | @inject() 36 | class Repository: 37 | def __init__(self, context: Context): 38 | self.context = context 39 | 40 | called = False 41 | 42 | @inject() 43 | def fn(example, context: Context): 44 | nonlocal called 45 | called = True 46 | assert isinstance(example, Example) 47 | assert isinstance(example.repository, Repository) 48 | assert isinstance(context, Context) 49 | # scoped parameter: 50 | assert context is example.repository.context 51 | return context.trace_id 52 | 53 | container = Container() 54 | 55 | container.add_transient(Example) 56 | container.add_transient(Repository) 57 | container.add_scoped(Context) 58 | 59 | provider = container.build_provider() 60 | 61 | result = provider.exec(fn) 62 | 63 | assert called 64 | assert result == Context().trace_id 65 | 66 | 67 | def test_executor(): 68 | called = False 69 | 70 | @inject() 71 | def fn(example, context: Context): 72 | nonlocal called 73 | called = True 74 | assert isinstance(example, Example) 75 | assert isinstance(example.repository, Repository) 76 | assert isinstance(context, Context) 77 | # scoped parameter: 78 | assert context is example.repository.context 79 | return context.trace_id 80 | 81 | container = Container() 82 | 83 | container.add_transient(Example) 84 | container.add_transient(Repository) 85 | container.add_scoped(Context) 86 | 87 | provider = container.build_provider() 88 | 89 | executor = provider.get_executor(fn) 90 | 91 | result = executor() 92 | 93 | assert called 94 | assert result == Context().trace_id 95 | 96 | 97 | def test_executor_with_given_scoped_services(): 98 | called = False 99 | 100 | @inject() 101 | def fn(example, context: Context): 102 | nonlocal called 103 | called = True 104 | assert isinstance(example, Example) 105 | assert isinstance(example.repository, Repository) 106 | assert isinstance(context, Context) 107 | # scoped parameter: 108 | assert context is example.repository.context 109 | return context 110 | 111 | container = Container() 112 | 113 | container.add_transient(Example) 114 | container.add_transient(Repository) 115 | container.add_scoped(Context) 116 | 117 | provider = container.build_provider() 118 | 119 | executor = provider.get_executor(fn) 120 | 121 | given_context = Context() 122 | result = executor({Context: given_context}) 123 | 124 | assert called 125 | assert result is given_context 126 | 127 | 128 | @pytest.mark.asyncio 129 | async def test_async_executor(): 130 | called = False 131 | 132 | @inject() 133 | async def fn(example, context: Context): 134 | nonlocal called 135 | called = True 136 | assert isinstance(example, Example) 137 | assert isinstance(example.repository, Repository) 138 | assert isinstance(context, Context) 139 | # scoped parameter: 140 | assert context is example.repository.context 141 | return context.trace_id 142 | 143 | container = Container() 144 | 145 | container.add_transient(Example) 146 | container.add_transient(Repository) 147 | container.add_scoped(Context) 148 | 149 | provider = container.build_provider() 150 | 151 | executor = provider.get_executor(fn) 152 | 153 | result = await executor() 154 | 155 | assert called 156 | assert result == Context().trace_id 157 | -------------------------------------------------------------------------------- /tests/test_services.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | from abc import ABC 4 | from dataclasses import dataclass 5 | from typing import ( 6 | Any, 7 | ClassVar, 8 | Dict, 9 | Generic, 10 | Iterable, 11 | List, 12 | Mapping, 13 | Optional, 14 | Sequence, 15 | Tuple, 16 | Type, 17 | TypeVar, 18 | ) 19 | 20 | import pytest 21 | from pytest import raises 22 | 23 | from rodi import ( 24 | ActivationScope, 25 | AliasAlreadyDefined, 26 | AliasConfigurationError, 27 | CannotResolveParameterException, 28 | CannotResolveTypeException, 29 | CircularDependencyException, 30 | Container, 31 | ContainerProtocol, 32 | DynamicResolver, 33 | FactoryMissingContextException, 34 | InstanceResolver, 35 | InvalidFactory, 36 | InvalidOperationInStrictMode, 37 | MissingTypeException, 38 | OverridingServiceException, 39 | ServiceLifeStyle, 40 | Services, 41 | TrackingActivationScope, 42 | _get_factory_annotations_or_throw, 43 | inject, 44 | to_standard_param_name, 45 | ) 46 | from tests.examples import ( 47 | A, 48 | B, 49 | C, 50 | Cat, 51 | CatsController, 52 | Circle, 53 | Circle2, 54 | Foo, 55 | FooByParamName, 56 | FooDBCatsRepository, 57 | FooDBContext, 58 | GetCatRequestHandler, 59 | IByParamName, 60 | ICatsRepository, 61 | ICircle, 62 | IdGetter, 63 | InMemoryCatsRepository, 64 | IRequestContext, 65 | Jang, 66 | Jing, 67 | Ko, 68 | Ok, 69 | P, 70 | PrecedenceOfTypeHintsOverNames, 71 | Q, 72 | R, 73 | RequestContext, 74 | ResolveThisByParameterName, 75 | ServiceSettings, 76 | Shape, 77 | TrickyCircle, 78 | TypeWithOptional, 79 | UfoFour, 80 | UfoOne, 81 | UfoThree, 82 | UfoTwo, 83 | W, 84 | X, 85 | Y, 86 | Z, 87 | ) 88 | 89 | T_1 = TypeVar("T_1") 90 | 91 | 92 | try: 93 | from typing import Protocol 94 | except ImportError: # pragma: no cover 95 | # support for Python 3.7 96 | from typing_extensions import Protocol 97 | 98 | 99 | class LoggedVar(Generic[T_1]): 100 | def __init__(self, value: T_1, name: str) -> None: 101 | self.name = name 102 | self.value = value 103 | 104 | def set(self, new: T_1) -> None: 105 | self.log("Set " + repr(self.value)) 106 | self.value = new 107 | 108 | def get(self) -> T_1: 109 | self.log("Get " + repr(self.value)) 110 | return self.value 111 | 112 | def log(self, message: str) -> None: 113 | print(self.name, message) 114 | 115 | 116 | def arrange_cats_example(): 117 | container = Container() 118 | container.add_transient(ICatsRepository, FooDBCatsRepository) 119 | container.add_scoped(IRequestContext, RequestContext) 120 | container._add_exact_transient(GetCatRequestHandler) 121 | container._add_exact_transient(CatsController) 122 | container.add_instance(ServiceSettings("foodb:example;something;")) 123 | container._add_exact_transient(FooDBContext) 124 | return container 125 | 126 | 127 | @pytest.mark.parametrize( 128 | "value,expected_result", 129 | ( 130 | ("CamelCase", "camel_case"), 131 | ("HTTPResponse", "http_response"), 132 | ("ICatsRepository", "icats_repository"), 133 | ("Cat", "cat"), 134 | ("UFO", "ufo"), 135 | ), 136 | ) 137 | def test_standard_param_name(value, expected_result): 138 | snaked = to_standard_param_name(value) 139 | assert snaked == expected_result 140 | 141 | 142 | def test_singleton_by_instance(): 143 | container = Container() 144 | container.add_instance(Cat("Celine")) 145 | provider = container.build_provider() 146 | 147 | cat = provider.get(Cat) 148 | 149 | assert cat is not None 150 | assert cat.name == "Celine" 151 | 152 | 153 | def test_transient_by_type_without_parameters(): 154 | container = Container() 155 | container.add_transient(ICatsRepository, InMemoryCatsRepository) 156 | provider = container.build_provider() 157 | cats_repo = provider.get(ICatsRepository) 158 | 159 | assert isinstance(cats_repo, InMemoryCatsRepository) 160 | other_cats_repo = provider.get(ICatsRepository) 161 | 162 | assert isinstance(other_cats_repo, InMemoryCatsRepository) 163 | assert cats_repo is not other_cats_repo 164 | 165 | 166 | def test_transient_by_type_with_parameters(): 167 | container = Container() 168 | container.add_transient(ICatsRepository, FooDBCatsRepository) 169 | 170 | # NB: 171 | container.add_instance(ServiceSettings("foodb:example;something;")) 172 | container._add_exact_transient(FooDBContext) 173 | provider = container.build_provider() 174 | 175 | cats_repo = provider.get(ICatsRepository) 176 | 177 | assert isinstance(cats_repo, FooDBCatsRepository) 178 | assert isinstance(cats_repo.context, FooDBContext) 179 | assert isinstance(cats_repo.context.settings, ServiceSettings) 180 | assert cats_repo.context.connection_string == "foodb:example;something;" 181 | 182 | 183 | def test_add_transient_shortcut(): 184 | container = Container() 185 | container.add_transient(ICatsRepository, FooDBCatsRepository) 186 | 187 | # NB: 188 | container.add_instance(ServiceSettings("foodb:example;something;")) 189 | container.add_transient(FooDBContext) 190 | provider = container.build_provider() 191 | 192 | cats_repo = provider.get(ICatsRepository) 193 | 194 | assert isinstance(cats_repo, FooDBCatsRepository) 195 | assert isinstance(cats_repo.context, FooDBContext) 196 | assert isinstance(cats_repo.context.settings, ServiceSettings) 197 | assert cats_repo.context.connection_string == "foodb:example;something;" 198 | 199 | 200 | def test_raises_for_overriding_service(): 201 | container = Container() 202 | container.add_transient(ICircle, Circle) 203 | 204 | with pytest.raises(OverridingServiceException) as context: 205 | container.add_singleton(ICircle, Circle) 206 | 207 | assert "ICircle" in str(context.value) 208 | 209 | with pytest.raises(OverridingServiceException) as context: 210 | container.add_transient(ICircle, Circle) 211 | 212 | assert "ICircle" in str(context.value) 213 | 214 | with pytest.raises(OverridingServiceException) as context: 215 | container.add_scoped(ICircle, Circle) 216 | 217 | assert "ICircle" in str(context.value) 218 | 219 | 220 | def test_raises_for_circular_dependency(): 221 | container = Container() 222 | container.add_transient(ICircle, Circle) 223 | 224 | with pytest.raises(CircularDependencyException) as context: 225 | container.build_provider() 226 | 227 | assert "Circle" in str(context.value) 228 | 229 | 230 | def test_raises_for_circular_dependency_class_annotation(): 231 | container = Container() 232 | container.add_transient(ICircle, Circle2) 233 | 234 | with pytest.raises(CircularDependencyException) as context: 235 | container.build_provider() 236 | 237 | assert "Circle" in str(context.value) 238 | 239 | 240 | def test_raises_for_circular_dependency_with_dynamic_resolver(): 241 | container = Container() 242 | container._add_exact_transient(Jing) 243 | container._add_exact_transient(Jang) 244 | 245 | with pytest.raises(CircularDependencyException): 246 | container.build_provider() 247 | 248 | 249 | def test_raises_for_deep_circular_dependency_with_dynamic_resolver(): 250 | container = Container() 251 | container._add_exact_transient(W) 252 | container._add_exact_transient(X) 253 | container._add_exact_transient(Y) 254 | container._add_exact_transient(Z) 255 | 256 | with pytest.raises(CircularDependencyException): 257 | container.build_provider() 258 | 259 | 260 | def test_does_not_raise_for_deep_circular_dependency_with_one_factory(): 261 | container = Container() 262 | container._add_exact_transient(W) 263 | container._add_exact_transient(X) 264 | container._add_exact_transient(Y) 265 | 266 | def z_factory(_) -> Z: 267 | return Z(None) 268 | 269 | container.add_transient_by_factory(z_factory) 270 | 271 | provider = container.build_provider() 272 | 273 | w = provider.get(W) 274 | 275 | assert isinstance(w, W) 276 | assert isinstance(w.x, X) 277 | assert isinstance(w.x.y, Y) 278 | assert isinstance(w.x.y.z, Z) 279 | assert w.x.y.z.w is None 280 | 281 | 282 | def test_circular_dependency_is_supported_by_factory(): 283 | def get_jang(_) -> Jang: 284 | return Jang(None) 285 | 286 | container = Container() 287 | container._add_exact_transient(Jing) 288 | container.add_transient_by_factory(get_jang) 289 | 290 | provider = container.build_provider() 291 | 292 | jing = provider.get(Jing) 293 | assert jing is not None 294 | assert isinstance(jing.jang, Jang) 295 | assert jing.jang.jing is None 296 | 297 | 298 | def test_add_instance_allows_for_circular_classes(): 299 | container = Container() 300 | container.add_instance(Circle(Circle(None))) 301 | 302 | # NB: in this example, Shape requires a Circle 303 | container._add_exact_transient(Shape) 304 | provider = container.build_provider() 305 | 306 | circle = provider.get(Circle) 307 | assert isinstance(circle, Circle) 308 | 309 | shape = provider.get(Shape) 310 | 311 | assert isinstance(shape, Shape) 312 | assert shape.circle is circle 313 | 314 | 315 | def test_add_instance_with_declared_type(): 316 | container = Container() 317 | container.add_instance(Circle(Circle(None)), declared_class=ICircle) 318 | provider = container.build_provider() 319 | 320 | icircle = provider.get(ICircle) 321 | assert isinstance(icircle, Circle) 322 | 323 | 324 | def test_optional_parameter(): 325 | container = Container() 326 | container.add_transient(Optional[Foo], Foo) # type: ignore 327 | container.add_transient(TypeWithOptional) 328 | 329 | a = container.resolve(TypeWithOptional) 330 | assert isinstance(a.foo, Foo) 331 | 332 | 333 | def test_raises_for_nested_circular_dependency(): 334 | container = Container() 335 | container.add_transient(ICircle, Circle) 336 | container._add_exact_transient(TrickyCircle) 337 | 338 | with pytest.raises(CircularDependencyException) as context: 339 | container.build_provider() 340 | 341 | assert "Circle" in str(context.value) 342 | 343 | 344 | def test_interdependencies(): 345 | container = Container() 346 | container._add_exact_transient(A) 347 | container._add_exact_transient(B) 348 | container._add_exact_transient(C) 349 | container._add_exact_transient(IdGetter) 350 | provider = container.build_provider() 351 | 352 | c = provider.get(C) 353 | 354 | assert isinstance(c, C) 355 | assert isinstance(c.a, A) 356 | assert isinstance(c.b, B) 357 | assert isinstance(c.b.a, A) 358 | 359 | 360 | def test_transient_service(): 361 | container = Container() 362 | container.add_transient(ICatsRepository, InMemoryCatsRepository) 363 | provider = container.build_provider() 364 | 365 | cats_repo = provider.get(ICatsRepository) 366 | assert isinstance(cats_repo, InMemoryCatsRepository) 367 | 368 | other_cats_repo = provider.get(ICatsRepository) 369 | assert cats_repo is not other_cats_repo 370 | 371 | 372 | def test_singleton_services(): 373 | container = Container() 374 | container._add_exact_singleton(IdGetter) 375 | provider = container.build_provider() 376 | 377 | with ActivationScope() as context: 378 | a = provider.get(IdGetter, context) 379 | b = provider.get(IdGetter, context) 380 | c = provider.get(IdGetter, context) 381 | d = provider.get(IdGetter) 382 | 383 | assert a is b 384 | assert a is c 385 | assert b is c 386 | assert d is a 387 | 388 | 389 | def test_scoped_services_context_used_more_than_once(): 390 | container = Container() 391 | 392 | @inject() 393 | class C: 394 | def __init__(self): 395 | pass 396 | 397 | @inject() 398 | class B2: 399 | def __init__(self, c: C): 400 | self.c = c 401 | 402 | @inject() 403 | class B1: 404 | def __init__(self, c: C): 405 | self.c = c 406 | 407 | @inject() 408 | class A: 409 | def __init__(self, b1: B1, b2: B2): 410 | self.b1 = b1 411 | self.b2 = b2 412 | 413 | container._add_exact_scoped(C) 414 | container._add_exact_transient(B1) 415 | container._add_exact_transient(B2) 416 | container._add_exact_transient(A) 417 | 418 | provider = container.build_provider() 419 | 420 | context = ActivationScope(provider) 421 | 422 | with context: 423 | a = provider.get(A, context) 424 | first_c = provider.get(C) 425 | a.b1.c is first_c 426 | a.b2.c is first_c 427 | 428 | with context: 429 | a = provider.get(A, context) 430 | second_c = provider.get(C) 431 | a.b1.c is second_c 432 | a.b2.c is second_c 433 | 434 | assert first_c is not None 435 | assert second_c is not None 436 | assert first_c is not second_c 437 | 438 | 439 | def test_scoped_services_context_used_more_than_once_manual_dispose(): 440 | container = Container() 441 | 442 | container.add_instance("value") 443 | 444 | provider = container.build_provider() 445 | context = ActivationScope(provider) 446 | 447 | context.dispose() 448 | assert context.provider is None 449 | 450 | 451 | def test_transient_services(): 452 | container = Container() 453 | container._add_exact_transient(IdGetter) 454 | provider = container.build_provider() 455 | 456 | with ActivationScope() as context: 457 | a = provider.get(IdGetter, context) 458 | b = provider.get(IdGetter, context) 459 | c = provider.get(IdGetter, context) 460 | d = provider.get(IdGetter) 461 | 462 | assert a is not b 463 | assert a is not c 464 | assert b is not c 465 | assert d is not a 466 | assert d is not b 467 | assert d is not c 468 | 469 | 470 | def test_scoped_services_use_scope_context_by_default(): 471 | container = Container() 472 | container._add_exact_scoped(IdGetter) 473 | provider = container.build_provider() 474 | 475 | with ActivationScope(provider) as scoped_provider: 476 | a = scoped_provider.get(IdGetter) 477 | b = scoped_provider.get(IdGetter) 478 | c = scoped_provider.get(IdGetter) 479 | d = provider.get(IdGetter) 480 | 481 | assert a is b 482 | assert b is c 483 | assert a is not d 484 | assert b is not d 485 | 486 | 487 | def test_scoped_services_use_correct_scope_context_by_default_with_multiple_scopes(): 488 | container = Container() 489 | container._add_exact_scoped(IdGetter) 490 | provider = container.build_provider() 491 | 492 | with ActivationScope(provider) as scoped_provider_1: 493 | a = scoped_provider_1.get(IdGetter) 494 | b = scoped_provider_1.get(IdGetter) 495 | with ActivationScope(provider) as scoped_provider_2: 496 | c = scoped_provider_2.get(IdGetter) 497 | d = scoped_provider_2.get(IdGetter) 498 | e = scoped_provider_2.get(IdGetter, scoped_provider_1) 499 | f = provider.get(IdGetter) 500 | 501 | assert a is b 502 | assert b is e 503 | assert c is d 504 | assert a is not c 505 | assert a is not f 506 | assert c is not f 507 | 508 | 509 | def test_scoped_services_works_with_str_keys(): 510 | container = Container() 511 | container.add_singleton("Id", IdGetter) 512 | provider = container.build_provider() 513 | 514 | with ActivationScope(provider) as scoped_provider: 515 | a = scoped_provider.get("Id") 516 | b = provider.get("id") 517 | 518 | assert a is b 519 | 520 | 521 | def test_scoped_services(): 522 | container = Container() 523 | container._add_exact_scoped(IdGetter) 524 | provider = container.build_provider() 525 | 526 | with ActivationScope() as context: 527 | a = provider.get(IdGetter, context) 528 | b = provider.get(IdGetter, context) 529 | c = provider.get(IdGetter, context) 530 | d = provider.get(IdGetter) 531 | 532 | assert a is b 533 | assert b is c 534 | assert a is not d 535 | assert b is not d 536 | 537 | 538 | def test_scoped_service_from_scoped_services(): 539 | container = Container() 540 | provider = container.build_provider() 541 | 542 | scoped_service = IdGetter() 543 | 544 | with ActivationScope( 545 | provider, 546 | { 547 | IdGetter: scoped_service, 548 | }, 549 | ) as context: 550 | a = provider.get(IdGetter, context) 551 | b = provider.get(IdGetter, default=None) 552 | c = provider.get(IdGetter, default=None) 553 | 554 | with ActivationScope( 555 | scoped_services={ 556 | IdGetter: scoped_service, 557 | } 558 | ) as context: 559 | d = provider.get(IdGetter, context) 560 | e = provider.get(IdGetter, default=None) 561 | 562 | assert a is scoped_service 563 | assert b is None 564 | assert c is None 565 | assert d is scoped_service 566 | assert e is None 567 | 568 | 569 | def test_scoped_services_with_shortcut(): 570 | container = Container() 571 | container.add_scoped(IdGetter) 572 | provider = container.build_provider() 573 | 574 | with ActivationScope() as context: 575 | a = provider.get(IdGetter, context) 576 | b = provider.get(IdGetter, context) 577 | c = provider.get(IdGetter, context) 578 | d = provider.get(IdGetter) 579 | 580 | assert a is b 581 | assert b is c 582 | assert a is not d 583 | assert b is not d 584 | 585 | 586 | def test_resolution_by_parameter_name(): 587 | container = Container() 588 | container.add_transient(ICatsRepository, InMemoryCatsRepository) 589 | container._add_exact_transient(ResolveThisByParameterName) 590 | 591 | provider = container.build_provider() 592 | resolved = provider.get(ResolveThisByParameterName) 593 | 594 | assert resolved is not None 595 | 596 | assert isinstance(resolved, ResolveThisByParameterName) 597 | assert isinstance(resolved.cats_repository, InMemoryCatsRepository) 598 | 599 | 600 | def test_resolve_singleton_by_parameter_name(): 601 | container = Container() 602 | container.add_transient(IByParamName, FooByParamName) 603 | 604 | singleton = Foo() 605 | container.add_instance(singleton) 606 | 607 | provider = container.build_provider() 608 | resolved = provider.get(IByParamName) 609 | 610 | assert resolved is not None 611 | 612 | assert isinstance(resolved, FooByParamName) 613 | assert resolved.foo is singleton 614 | 615 | 616 | def test_service_collection_contains(): 617 | container = Container() 618 | container._add_exact_transient(Foo) 619 | 620 | assert Foo in container 621 | assert Cat not in container 622 | 623 | 624 | def test_service_provider_contains(): 625 | container = Container() 626 | container.add_transient(IdGetter) 627 | 628 | provider = container.build_provider() 629 | 630 | assert Foo not in provider 631 | assert IdGetter in provider 632 | 633 | 634 | def test_exact_alias(): 635 | container = arrange_cats_example() 636 | 637 | class UsingAlias: 638 | def __init__(self, example): 639 | self.cats_controller = example 640 | 641 | container.add_transient(UsingAlias) 642 | 643 | # arrange an exact alias for UsingAlias class init parameter: 644 | container.set_alias("example", CatsController) 645 | 646 | provider = container.build_provider() 647 | u = provider.get(UsingAlias) 648 | 649 | assert isinstance(u, UsingAlias) 650 | assert isinstance(u.cats_controller, CatsController) 651 | assert isinstance(u.cats_controller.cat_request_handler, GetCatRequestHandler) 652 | 653 | 654 | def test_additional_alias(): 655 | container = arrange_cats_example() 656 | 657 | class UsingAlias: 658 | def __init__(self, example, settings): 659 | self.cats_controller = example 660 | self.settings = settings 661 | 662 | class AnotherUsingAlias: 663 | def __init__(self, cats_controller, service_settings): 664 | self.cats_controller = cats_controller 665 | self.settings = service_settings 666 | 667 | container._add_exact_transient(UsingAlias) 668 | container._add_exact_transient(AnotherUsingAlias) 669 | 670 | # arrange an exact alias for UsingAlias class init parameter: 671 | container.add_alias("example", CatsController) 672 | container.add_alias("settings", ServiceSettings) 673 | 674 | provider = container.build_provider() 675 | u = provider.get(UsingAlias) 676 | 677 | assert isinstance(u, UsingAlias) 678 | assert isinstance(u.settings, ServiceSettings) 679 | assert isinstance(u.cats_controller, CatsController) 680 | assert isinstance(u.cats_controller.cat_request_handler, GetCatRequestHandler) 681 | 682 | u = provider.get(AnotherUsingAlias) 683 | 684 | assert isinstance(u, AnotherUsingAlias) 685 | assert isinstance(u.settings, ServiceSettings) 686 | assert isinstance(u.cats_controller, CatsController) 687 | assert isinstance(u.cats_controller.cat_request_handler, GetCatRequestHandler) 688 | 689 | 690 | def test_get_service_by_name_or_alias(): 691 | container = arrange_cats_example() 692 | container.add_alias("k", CatsController) 693 | 694 | provider = container.build_provider() 695 | 696 | for name in {"CatsController", "cats_controller", "k"}: 697 | service = provider.get(name) 698 | 699 | assert isinstance(service, CatsController) 700 | assert isinstance(service.cat_request_handler, GetCatRequestHandler) 701 | assert isinstance(service.cat_request_handler.repo, FooDBCatsRepository) 702 | 703 | 704 | def test_missing_service_raises_exception(): 705 | container = Container() 706 | provider = container.build_provider() 707 | 708 | with pytest.raises(CannotResolveTypeException): 709 | provider.get("not_existing") 710 | 711 | 712 | def test_missing_service_can_return_default(): 713 | container = Container() 714 | provider = container.build_provider() 715 | 716 | service = provider.get("not_existing", default=None) 717 | assert service is None 718 | 719 | 720 | def test_by_factory_type_annotation_simple(): 721 | container = Container() 722 | 723 | def factory() -> Cat: 724 | return Cat("Celine") 725 | 726 | container.add_transient_by_factory(factory) 727 | provider = container.build_provider() 728 | 729 | cat = provider.get(Cat) 730 | assert isinstance(cat, Cat) 731 | assert cat.name == "Celine" 732 | 733 | 734 | def test_by_factory_type_annotation_simple_local(): 735 | container = Container() 736 | 737 | @dataclass 738 | class LocalCat: 739 | name: str 740 | 741 | @inject() 742 | def service_factory() -> LocalCat: 743 | return LocalCat("Celine") 744 | 745 | container.add_transient_by_factory(service_factory) 746 | provider = container.build_provider() 747 | 748 | cat = provider.get(LocalCat) 749 | assert isinstance(cat, LocalCat) 750 | assert cat.name == "Celine" 751 | 752 | 753 | @pytest.mark.parametrize( 754 | "method_name", 755 | ["add_singleton_by_factory", "add_transient_by_factory", "add_scoped_by_factory"], 756 | ) 757 | def test_by_factory_type_annotation(method_name): 758 | container = Container() 759 | 760 | def factory(_) -> Cat: 761 | return Cat("Celine") 762 | 763 | method = getattr(container, method_name) 764 | 765 | method(factory) 766 | 767 | provider = container.build_provider() 768 | 769 | cat = provider.get(Cat) 770 | 771 | assert cat is not None 772 | assert cat.name == "Celine" 773 | 774 | if method_name == "add_singleton_by_factory": 775 | cat_2 = provider.get(Cat) 776 | assert cat_2 is cat 777 | 778 | if method_name == "add_transient_by_factory": 779 | assert provider.get(Cat) is not cat 780 | assert provider.get(Cat) is not cat 781 | assert provider.get(Cat) is not cat 782 | 783 | if method_name == "add_scoped_by_factory": 784 | with ActivationScope() as context: 785 | cat_2 = provider.get(Cat, context) 786 | assert cat_2 is not cat 787 | 788 | assert provider.get(Cat, context) is cat_2 789 | assert provider.get(Cat, context) is cat_2 790 | assert provider.get(Cat, context) is cat_2 791 | 792 | 793 | @pytest.mark.parametrize( 794 | "method_name", 795 | ["add_singleton_by_factory", "add_transient_by_factory", "add_scoped_by_factory"], 796 | ) 797 | def test_invalid_factory_too_many_arguments_throws(method_name): 798 | container = Container() 799 | method = getattr(container, method_name) 800 | 801 | def factory(context, activating_type, extra_argument_mistake): 802 | return Cat("Celine") 803 | 804 | with raises(InvalidFactory): 805 | method(factory, Cat) 806 | 807 | def factory(context, activating_type, extra_argument_mistake, two): 808 | return Cat("Celine") 809 | 810 | with raises(InvalidFactory): 811 | method(factory, Cat) 812 | 813 | def factory(context, activating_type, extra_argument_mistake, two, three): 814 | return Cat("Celine") 815 | 816 | with raises(InvalidFactory): 817 | method(factory, Cat) 818 | 819 | 820 | @pytest.mark.parametrize( 821 | "method_name", 822 | ["add_singleton_by_factory", "add_transient_by_factory", "add_scoped_by_factory"], 823 | ) 824 | def test_add_singleton_by_factory_given_type(method_name): 825 | container = Container() 826 | 827 | def factory(a): 828 | return Cat("Celine") 829 | 830 | method = getattr(container, method_name) 831 | 832 | method(factory, Cat) 833 | 834 | provider = container.build_provider() 835 | 836 | cat = provider.get(Cat) 837 | 838 | assert cat is not None 839 | assert cat.name == "Celine" 840 | 841 | if method_name == "add_singleton_by_factory": 842 | cat_2 = provider.get(Cat) 843 | assert cat_2 is cat 844 | 845 | if method_name == "add_transient_by_factory": 846 | assert provider.get(Cat) is not cat 847 | assert provider.get(Cat) is not cat 848 | assert provider.get(Cat) is not cat 849 | 850 | if method_name == "add_scoped_by_factory": 851 | with ActivationScope() as context: 852 | cat_2 = provider.get(Cat, context) 853 | assert cat_2 is not cat 854 | 855 | assert provider.get(Cat, context) is cat_2 856 | assert provider.get(Cat, context) is cat_2 857 | assert provider.get(Cat, context) is cat_2 858 | 859 | 860 | @pytest.mark.parametrize( 861 | "method_name", 862 | ["add_singleton_by_factory", "add_transient_by_factory", "add_scoped_by_factory"], 863 | ) 864 | def test_add_singleton_by_factory_raises_for_missing_type(method_name): 865 | container = Container() 866 | 867 | def factory(_): 868 | return Cat("Celine") 869 | 870 | method = getattr(container, method_name) 871 | 872 | with pytest.raises(MissingTypeException): 873 | method(factory) 874 | 875 | 876 | def test_singleton_by_provider(): 877 | container = Container() 878 | container._add_exact_singleton(P) 879 | container._add_exact_transient(R) 880 | 881 | provider = container.build_provider() 882 | 883 | p = provider.get(P) 884 | r = provider.get(R) 885 | 886 | assert p is not None 887 | assert r is not None 888 | assert r.p is p 889 | 890 | 891 | def test_singleton_by_provider_with_shortcut(): 892 | container = Container() 893 | container.add_singleton(P) 894 | container.add_transient(R) 895 | 896 | provider = container.build_provider() 897 | 898 | p = provider.get(P) 899 | r = provider.get(R) 900 | 901 | assert p is not None 902 | assert r is not None 903 | assert r.p is p 904 | 905 | 906 | def test_singleton_by_provider_both_singletons(): 907 | container = Container() 908 | container._add_exact_singleton(P) 909 | container._add_exact_singleton(R) 910 | 911 | provider = container.build_provider() 912 | 913 | p = provider.get(P) 914 | r = provider.get(R) 915 | 916 | assert p is not None 917 | assert r is not None 918 | assert r.p is p 919 | 920 | r_2 = provider.get(R) 921 | assert r_2 is r 922 | 923 | 924 | def test_type_hints_precedence(): 925 | container = Container() 926 | container._add_exact_transient(PrecedenceOfTypeHintsOverNames) 927 | container._add_exact_transient(Foo) 928 | container._add_exact_transient(Q) 929 | container._add_exact_transient(P) 930 | container._add_exact_transient(Ko) 931 | container._add_exact_transient(Ok) 932 | 933 | provider = container.build_provider() 934 | 935 | service = provider.get(PrecedenceOfTypeHintsOverNames) 936 | 937 | assert isinstance(service, PrecedenceOfTypeHintsOverNames) 938 | assert isinstance(service.q, Q) 939 | assert isinstance(service.p, P) 940 | 941 | 942 | def test_type_hints_precedence_with_shortcuts(): 943 | container = Container() 944 | container.add_transient(PrecedenceOfTypeHintsOverNames) 945 | container.add_transient(Foo) 946 | container.add_transient(Q) 947 | container.add_transient(P) 948 | container.add_transient(Ko) 949 | container.add_transient(Ok) 950 | 951 | provider = container.build_provider() 952 | 953 | service = provider.get(PrecedenceOfTypeHintsOverNames) 954 | 955 | assert isinstance(service, PrecedenceOfTypeHintsOverNames) 956 | assert isinstance(service.q, Q) 957 | assert isinstance(service.p, P) 958 | 959 | 960 | def test_proper_handling_of_inheritance(): 961 | container = Container() 962 | container._add_exact_transient(UfoOne) 963 | container._add_exact_transient(UfoTwo) 964 | container._add_exact_transient(UfoThree) 965 | container._add_exact_transient(UfoFour) 966 | container._add_exact_transient(Foo) 967 | 968 | provider = container.build_provider() 969 | 970 | ufo_one = provider.get(UfoOne) 971 | ufo_two = provider.get(UfoTwo) 972 | ufo_three = provider.get(UfoThree) 973 | ufo_four = provider.get(UfoFour) 974 | 975 | assert isinstance(ufo_one, UfoOne) 976 | assert isinstance(ufo_two, UfoTwo) 977 | assert isinstance(ufo_three, UfoThree) 978 | assert isinstance(ufo_four, UfoFour) 979 | 980 | 981 | def cat_factory_no_args() -> Cat: 982 | return Cat("Celine") 983 | 984 | 985 | def cat_factory_with_context(context) -> Cat: 986 | assert isinstance(context, ActivationScope) 987 | return Cat("Celine") 988 | 989 | 990 | def cat_factory_with_context_and_activating_type(context, activating_type) -> Cat: 991 | assert isinstance(context, ActivationScope) 992 | assert activating_type is Cat 993 | return Cat("Celine") 994 | 995 | 996 | @pytest.mark.parametrize( 997 | "method_name,factory", 998 | [ 999 | (name, method) 1000 | for name in [ 1001 | "add_singleton_by_factory", 1002 | "add_transient_by_factory", 1003 | "add_scoped_by_factory", 1004 | ] 1005 | for method in [ 1006 | cat_factory_no_args, 1007 | cat_factory_with_context, 1008 | cat_factory_with_context_and_activating_type, 1009 | ] 1010 | ], 1011 | ) 1012 | def test_by_factory_with_different_parameters(method_name, factory): 1013 | container = Container() 1014 | 1015 | method = getattr(container, method_name) 1016 | method(factory) 1017 | 1018 | provider = container.build_provider() 1019 | 1020 | cat = provider.get(Cat) 1021 | 1022 | assert cat is not None 1023 | assert cat.name == "Celine" 1024 | 1025 | 1026 | @pytest.mark.parametrize( 1027 | "method_name", ["add_transient_by_factory", "add_scoped_by_factory"] 1028 | ) 1029 | def test_factory_can_receive_activating_type_as_parameter(method_name): 1030 | @inject() 1031 | class Logger: 1032 | def __init__(self, name): 1033 | self.name = name 1034 | 1035 | @inject() 1036 | class HelpController: 1037 | def __init__(self, logger: Logger): 1038 | self.logger = logger 1039 | 1040 | @inject() 1041 | class HomeController: 1042 | def __init__(self, logger: Logger): 1043 | self.logger = logger 1044 | 1045 | @inject() 1046 | class FooController: 1047 | def __init__(self, foo: Foo, logger: Logger): 1048 | self.foo = foo 1049 | self.logger = logger 1050 | 1051 | container = Container() 1052 | container._add_exact_transient(Foo) 1053 | 1054 | @inject() 1055 | def factory(_, activating_type) -> Logger: 1056 | return Logger(activating_type.__module__ + "." + activating_type.__name__) 1057 | 1058 | method = getattr(container, method_name) 1059 | method(factory) 1060 | 1061 | container._add_exact_transient(HelpController)._add_exact_transient( 1062 | HomeController 1063 | )._add_exact_transient(FooController) 1064 | 1065 | provider = container.build_provider() 1066 | 1067 | help_controller = provider.get(HelpController) 1068 | 1069 | assert help_controller is not None 1070 | assert help_controller.logger is not None 1071 | assert help_controller.logger.name == "tests.test_services.HelpController" 1072 | 1073 | home_controller = provider.get(HomeController) 1074 | 1075 | assert home_controller is not None 1076 | assert home_controller.logger is not None 1077 | assert home_controller.logger.name == "tests.test_services.HomeController" 1078 | 1079 | foo_controller = provider.get(FooController) 1080 | 1081 | assert foo_controller is not None 1082 | assert foo_controller.logger is not None 1083 | assert foo_controller.logger.name == "tests.test_services.FooController" 1084 | 1085 | 1086 | def test_factory_can_receive_activating_type_as_parameter_nested_resolution(): 1087 | # NB: this scenario can only work when a class is registered as transient service 1088 | 1089 | class Logger: 1090 | def __init__(self, name): 1091 | self.name = name 1092 | 1093 | @inject() 1094 | class HelpRepo: 1095 | def __init__(self, logger: Logger): 1096 | self.logger = logger 1097 | 1098 | @inject() 1099 | class HelpHandler: 1100 | def __init__(self, help_repo: HelpRepo): 1101 | self.repo = help_repo 1102 | 1103 | @inject() 1104 | class HelpController: 1105 | def __init__(self, logger: Logger, handler: HelpHandler): 1106 | self.logger = logger 1107 | self.handler = handler 1108 | 1109 | container = Container() 1110 | 1111 | @inject() 1112 | def factory(_, activating_type) -> Logger: 1113 | # NB: this scenario is tested for rolog library 1114 | return Logger(activating_type.__module__ + "." + activating_type.__name__) 1115 | 1116 | container.add_transient_by_factory(factory) 1117 | 1118 | for service_type in {HelpRepo, HelpHandler, HelpController}: 1119 | container._add_exact_transient(service_type) 1120 | 1121 | provider = container.build_provider() 1122 | 1123 | help_controller = provider.get(HelpController) 1124 | 1125 | assert help_controller is not None 1126 | assert help_controller.logger is not None 1127 | assert help_controller.logger.name == "tests.test_services.HelpController" 1128 | assert help_controller.handler.repo.logger.name == "tests.test_services.HelpRepo" 1129 | 1130 | 1131 | def test_factory_can_receive_activating_type_as_parameter_nested_resolution_many(): 1132 | # NB: this scenario can only work when a class is registered as transient service 1133 | 1134 | class Logger: 1135 | def __init__(self, name): 1136 | self.name = name 1137 | 1138 | @inject() 1139 | class HelpRepo: 1140 | def __init__(self, db_context: FooDBContext, logger: Logger): 1141 | self.db_context = db_context 1142 | self.logger = logger 1143 | 1144 | @inject() 1145 | class HelpHandler: 1146 | def __init__(self, help_repo: HelpRepo): 1147 | self.repo = help_repo 1148 | 1149 | @inject() 1150 | class AnotherPathTwo: 1151 | def __init__(self, logger: Logger): 1152 | self.logger = logger 1153 | 1154 | @inject() 1155 | class AnotherPath: 1156 | def __init__(self, another_path_2: AnotherPathTwo): 1157 | self.child = another_path_2 1158 | 1159 | @inject() 1160 | class HelpController: 1161 | def __init__( 1162 | self, handler: HelpHandler, another_path: AnotherPath, logger: Logger 1163 | ): 1164 | self.logger = logger 1165 | self.handler = handler 1166 | self.other = another_path 1167 | 1168 | container = Container() 1169 | 1170 | @inject() 1171 | def factory(_, activating_type) -> Logger: 1172 | # NB: this scenario is tested for rolog library 1173 | return Logger(activating_type.__module__ + "." + activating_type.__name__) 1174 | 1175 | container.add_transient_by_factory(factory) 1176 | container.add_instance(ServiceSettings("foo:foo")) 1177 | 1178 | for service_type in { 1179 | HelpRepo, 1180 | HelpHandler, 1181 | HelpController, 1182 | AnotherPath, 1183 | AnotherPathTwo, 1184 | Foo, 1185 | FooDBContext, 1186 | }: 1187 | container._add_exact_transient(service_type) 1188 | 1189 | provider = container.build_provider() 1190 | 1191 | help_controller = provider.get(HelpController) 1192 | 1193 | assert help_controller is not None 1194 | assert help_controller.logger is not None 1195 | assert help_controller.logger.name == "tests.test_services.HelpController" 1196 | assert help_controller.handler.repo.logger.name == "tests.test_services.HelpRepo" 1197 | assert ( 1198 | help_controller.other.child.logger.name == "tests.test_services." 1199 | "AnotherPathTwo" 1200 | ) 1201 | 1202 | 1203 | def test_service_provider_supports_set_by_class(): 1204 | provider = Services() 1205 | 1206 | singleton_cat = Cat("Celine") 1207 | 1208 | provider.set(Cat, singleton_cat) 1209 | 1210 | cat = provider.get(Cat) 1211 | 1212 | assert cat is not None 1213 | assert cat.name == "Celine" 1214 | 1215 | cat = provider.get("Cat") 1216 | 1217 | assert cat is not None 1218 | assert cat.name == "Celine" 1219 | 1220 | 1221 | def test_service_provider_supports_set_by_name(): 1222 | provider = Services() 1223 | 1224 | singleton_cat = Cat("Celine") 1225 | 1226 | provider.set("my_cat", singleton_cat) 1227 | 1228 | cat = provider.get("my_cat") 1229 | 1230 | assert cat is not None 1231 | assert cat.name == "Celine" 1232 | 1233 | 1234 | def test_service_provider_supports_set_and_get_item_by_class(): 1235 | provider = Services() 1236 | 1237 | singleton_cat = Cat("Celine") 1238 | 1239 | provider[Cat] = singleton_cat 1240 | 1241 | cat = provider[Cat] 1242 | 1243 | assert cat is not None 1244 | assert cat.name == "Celine" 1245 | 1246 | cat = provider["Cat"] 1247 | 1248 | assert cat is not None 1249 | assert cat.name == "Celine" 1250 | 1251 | 1252 | def test_service_provider_supports_set_and_get_item_by_name(): 1253 | provider = Services() 1254 | 1255 | singleton_cat = Cat("Celine") 1256 | 1257 | provider["my_cat"] = singleton_cat 1258 | 1259 | cat = provider["my_cat"] 1260 | 1261 | assert cat is not None 1262 | assert cat.name == "Celine" 1263 | 1264 | 1265 | def test_service_provider_supports_set_simple_values(): 1266 | provider = Services() 1267 | 1268 | provider["one"] = 10 1269 | provider["two"] = 12 1270 | provider["three"] = 16 1271 | 1272 | assert provider["one"] == 10 1273 | assert provider["two"] == 12 1274 | assert provider["three"] == 16 1275 | 1276 | 1277 | def test_container_handles_class_without_init(): 1278 | container = Container() 1279 | 1280 | class WithoutInit: 1281 | pass 1282 | 1283 | container._add_exact_singleton(WithoutInit) 1284 | provider = container.build_provider() 1285 | 1286 | instance = provider.get(WithoutInit) 1287 | assert isinstance(instance, WithoutInit) 1288 | 1289 | 1290 | def test_raises_invalid_factory_for_non_callable(): 1291 | container = Container() 1292 | 1293 | with raises(InvalidFactory): 1294 | container.register_factory("Not a factory", Cat, ServiceLifeStyle.SINGLETON) 1295 | 1296 | 1297 | def test_set_alias_raises_in_strict_mode(): 1298 | container = Container(strict=True) 1299 | 1300 | with raises(InvalidOperationInStrictMode): 1301 | container.set_alias("something", Cat) 1302 | 1303 | 1304 | def test_set_alias_raises_if_alias_is_defined(): 1305 | container = Container() 1306 | 1307 | container.set_alias("something", Cat) 1308 | 1309 | with raises(AliasAlreadyDefined): 1310 | container.set_alias("something", Foo) 1311 | 1312 | 1313 | def test_set_alias_requires_configured_type(): 1314 | container = Container() 1315 | 1316 | container.set_alias("something", Cat) 1317 | 1318 | with raises(AliasConfigurationError): 1319 | container.build_provider() 1320 | 1321 | 1322 | def test_set_aliases(): 1323 | container = Container() 1324 | 1325 | container.add_instance(Cat("Celine")) 1326 | container.add_instance(Foo()) 1327 | 1328 | container.set_aliases({"a": Cat, "b": Foo}) 1329 | 1330 | provider = container.build_provider() 1331 | 1332 | x = provider.get("a") 1333 | 1334 | assert isinstance(x, Cat) 1335 | assert x.name == "Celine" 1336 | 1337 | assert isinstance(provider.get("b"), Foo) 1338 | 1339 | 1340 | def test_add_alias_raises_in_strict_mode(): 1341 | container = Container(strict=True) 1342 | 1343 | with raises(InvalidOperationInStrictMode): 1344 | container.add_alias("something", Cat) 1345 | 1346 | 1347 | def test_add_alias_raises_if_alias_is_defined(): 1348 | container = Container() 1349 | 1350 | container.add_alias("something", Cat) 1351 | 1352 | with raises(AliasAlreadyDefined): 1353 | container.add_alias("something", Foo) 1354 | 1355 | 1356 | def test_add_aliases(): 1357 | container = Container() 1358 | 1359 | container.add_instance(Cat("Celine")) 1360 | container.add_instance(Foo()) 1361 | 1362 | container.add_aliases({"a": Cat, "b": Foo}) 1363 | 1364 | container.add_aliases({"c": Cat, "d": Foo}) 1365 | 1366 | provider = container.build_provider() 1367 | 1368 | for alias in {"a", "c"}: 1369 | x = provider.get(alias) 1370 | 1371 | assert isinstance(x, Cat) 1372 | assert x.name == "Celine" 1373 | 1374 | for alias in {"b", "d"}: 1375 | assert isinstance(provider.get(alias), Foo) 1376 | 1377 | 1378 | def test_add_alias_requires_configured_type(): 1379 | container = Container() 1380 | 1381 | container.add_alias("something", Cat) 1382 | 1383 | with raises(AliasConfigurationError): 1384 | container.build_provider() 1385 | 1386 | 1387 | def test_build_provider_raises_for_missing_transient_parameter(): 1388 | container = Container() 1389 | 1390 | container._add_exact_transient(CatsController) 1391 | 1392 | with raises(CannotResolveParameterException): 1393 | container.build_provider() 1394 | 1395 | 1396 | def test_build_provider_raises_for_missing_scoped_parameter(): 1397 | container = Container() 1398 | 1399 | container._add_exact_scoped(CatsController) 1400 | 1401 | with raises(CannotResolveParameterException): 1402 | container.build_provider() 1403 | 1404 | 1405 | def test_build_provider_raises_for_missing_singleton_parameter(): 1406 | container = Container() 1407 | 1408 | container._add_exact_singleton(CatsController) 1409 | 1410 | with raises(CannotResolveParameterException): 1411 | container.build_provider() 1412 | 1413 | 1414 | def test_overriding_alias_from_class_name_throws(): 1415 | container = Container() 1416 | 1417 | class A: 1418 | def __init__(self, b): 1419 | self.b = b 1420 | 1421 | class B: 1422 | def __init__(self, c): 1423 | self.c = c 1424 | 1425 | class C: 1426 | def __init__(self): 1427 | pass 1428 | 1429 | container._add_exact_transient(A) 1430 | container._add_exact_transient(B) 1431 | container._add_exact_transient(C) 1432 | 1433 | with raises(AliasAlreadyDefined): 1434 | container.add_alias("b", C) # <-- ambiguity 1435 | 1436 | 1437 | def test_cannot_resolve_parameter_in_strict_mode_throws(): 1438 | container = Container(strict=True) 1439 | 1440 | class A: 1441 | def __init__(self, b): 1442 | self.b = b 1443 | 1444 | class B: 1445 | def __init__(self, c): 1446 | self.c = c 1447 | 1448 | container._add_exact_transient(A) 1449 | container._add_exact_transient(B) 1450 | 1451 | with raises(CannotResolveParameterException): 1452 | container.build_provider() 1453 | 1454 | 1455 | def test_services_set_throws_if_service_is_already_defined(): 1456 | services = Services() 1457 | 1458 | services.set("example", {}) 1459 | 1460 | with raises(OverridingServiceException): 1461 | services.set("example", []) 1462 | 1463 | 1464 | def test_scoped_services_exact(): 1465 | container = Container() 1466 | 1467 | class A: 1468 | def __init__(self, b): 1469 | self.b = b 1470 | 1471 | class B: 1472 | def __init__(self, c): 1473 | self.c = c 1474 | 1475 | class C: 1476 | def __init__(self): 1477 | pass 1478 | 1479 | container._add_exact_scoped(A) 1480 | container._add_exact_scoped(B) 1481 | container._add_exact_scoped(C) 1482 | 1483 | provider = container.build_provider() 1484 | context = ActivationScope(provider) 1485 | 1486 | a = provider.get(A, context) 1487 | assert isinstance(a, A) 1488 | assert isinstance(a.b, B) 1489 | assert isinstance(a.b.c, C) 1490 | 1491 | a2 = provider.get(A, context) 1492 | assert a is a2 1493 | assert a.b is a2.b 1494 | assert a.b.c is a2.b.c 1495 | 1496 | 1497 | def test_scoped_services_abstract(): 1498 | container = Container() 1499 | 1500 | class ABase(ABC): 1501 | pass 1502 | 1503 | class BBase(ABC): 1504 | pass 1505 | 1506 | class CBase(ABC): 1507 | pass 1508 | 1509 | @inject() 1510 | class A(ABase): 1511 | def __init__(self, b: BBase): 1512 | self.b = b 1513 | 1514 | @inject() 1515 | class B(BBase): 1516 | def __init__(self, c: CBase): 1517 | self.c = c 1518 | 1519 | class C(CBase): 1520 | def __init__(self): 1521 | pass 1522 | 1523 | container.add_scoped(ABase, A) 1524 | container.add_scoped(BBase, B) 1525 | container.add_scoped(CBase, C) 1526 | 1527 | provider = container.build_provider() 1528 | context = ActivationScope(provider) 1529 | 1530 | a = provider.get(ABase, context) 1531 | assert isinstance(a, A) 1532 | assert isinstance(a.b, B) 1533 | assert isinstance(a.b.c, C) 1534 | 1535 | a2 = provider.get(ABase, context) 1536 | assert a is a2 1537 | assert a.b is a2.b 1538 | assert a.b.c is a2.b.c 1539 | 1540 | 1541 | def test_instance_resolver_representation(): 1542 | singleton = Foo() 1543 | resolver = InstanceResolver(singleton) 1544 | 1545 | representation = repr(resolver) 1546 | assert representation.startswith(" BBase: 1574 | assert isinstance(context, ActivationScope) 1575 | assert activating_type is A 1576 | return B() 1577 | 1578 | container.add_transient(ABase, A) 1579 | 1580 | method = getattr(container, method_name) 1581 | method(bbase_factory) 1582 | 1583 | provider = container.build_provider() 1584 | context = ActivationScope(provider) 1585 | 1586 | a = provider.get(ABase, context) 1587 | assert isinstance(a, A) 1588 | assert isinstance(a.b, B) 1589 | 1590 | 1591 | @pytest.mark.parametrize( 1592 | "method_name", 1593 | ["add_transient_by_factory", "add_scoped_by_factory", "add_singleton_by_factory"], 1594 | ) 1595 | def test_factories_activating_scoped_type_consistency(method_name): 1596 | container = Container() 1597 | 1598 | class ABase(ABC): 1599 | pass 1600 | 1601 | class BBase(ABC): 1602 | pass 1603 | 1604 | @inject() 1605 | class A(ABase): 1606 | def __init__(self, b: BBase): 1607 | self.b = b 1608 | 1609 | class B(BBase): 1610 | def __init__(self): 1611 | pass 1612 | 1613 | @inject() 1614 | def bbase_factory(context: ActivationScope, activating_type: Type) -> BBase: 1615 | assert isinstance(context, ActivationScope) 1616 | assert activating_type is A 1617 | return B() 1618 | 1619 | container.add_scoped(ABase, A) 1620 | 1621 | method = getattr(container, method_name) 1622 | method(bbase_factory) 1623 | 1624 | provider = container.build_provider() 1625 | context = ActivationScope(provider) 1626 | 1627 | a = provider.get(ABase, context) 1628 | assert isinstance(a, A) 1629 | assert isinstance(a.b, B) 1630 | 1631 | 1632 | @pytest.mark.parametrize( 1633 | "method_name", 1634 | ["add_transient_by_factory", "add_scoped_by_factory", "add_singleton_by_factory"], 1635 | ) 1636 | def test_factories_activating_singleton_type_consistency(method_name): 1637 | container = Container() 1638 | 1639 | class ABase(ABC): 1640 | pass 1641 | 1642 | class BBase(ABC): 1643 | pass 1644 | 1645 | @inject() 1646 | class A(ABase): 1647 | def __init__(self, b: BBase): 1648 | self.b = b 1649 | 1650 | class B(BBase): 1651 | def __init__(self): 1652 | pass 1653 | 1654 | @inject() 1655 | def bbase_factory(context: ActivationScope, activating_type: Type) -> BBase: 1656 | assert isinstance(context, ActivationScope) 1657 | assert activating_type is A 1658 | return B() 1659 | 1660 | container.add_singleton(ABase, A) 1661 | 1662 | method = getattr(container, method_name) 1663 | method(bbase_factory) 1664 | 1665 | provider = container.build_provider() 1666 | context = ActivationScope(provider) 1667 | 1668 | a = provider.get(ABase, context) 1669 | assert isinstance(a, A) 1670 | assert isinstance(a.b, B) 1671 | 1672 | 1673 | @pytest.mark.parametrize( 1674 | "method_name", 1675 | ["add_transient_by_factory", "add_scoped_by_factory", "add_singleton_by_factory"], 1676 | ) 1677 | def test_factories_type_transient_consistency_nested(method_name): 1678 | container = Container() 1679 | 1680 | class ABase(ABC): 1681 | pass 1682 | 1683 | class BBase(ABC): 1684 | pass 1685 | 1686 | class CBase(ABC): 1687 | pass 1688 | 1689 | @inject() 1690 | class A(ABase): 1691 | def __init__(self, b: BBase): 1692 | self.b = b 1693 | 1694 | @inject() 1695 | class B(BBase): 1696 | def __init__(self, c: CBase): 1697 | self.c = c 1698 | 1699 | class C(CBase): 1700 | def __init__(self): 1701 | pass 1702 | 1703 | @inject() 1704 | def cbase_factory(context: ActivationScope, activating_type: Type) -> CBase: 1705 | assert isinstance(context, ActivationScope) 1706 | assert activating_type is B 1707 | return C() 1708 | 1709 | container.add_transient(ABase, A) 1710 | container.add_transient(BBase, B) 1711 | 1712 | method = getattr(container, method_name) 1713 | method(cbase_factory) 1714 | 1715 | provider = container.build_provider() 1716 | context = ActivationScope(provider) 1717 | 1718 | a = provider.get(ABase, context) 1719 | assert isinstance(a, A) 1720 | assert isinstance(a.b, B) 1721 | assert isinstance(a.b.c, C) 1722 | 1723 | 1724 | @pytest.mark.parametrize( 1725 | "method_name", 1726 | ["add_transient_by_factory", "add_scoped_by_factory", "add_singleton_by_factory"], 1727 | ) 1728 | def test_factories_type_scoped_consistency_nested(method_name): 1729 | container = Container() 1730 | 1731 | class ABase(ABC): 1732 | pass 1733 | 1734 | class BBase(ABC): 1735 | pass 1736 | 1737 | class CBase(ABC): 1738 | pass 1739 | 1740 | @inject() 1741 | class A(ABase): 1742 | def __init__(self, b: BBase): 1743 | self.b = b 1744 | 1745 | @inject() 1746 | class B(BBase): 1747 | def __init__(self, c: CBase): 1748 | self.c = c 1749 | 1750 | class C(CBase): 1751 | def __init__(self): 1752 | pass 1753 | 1754 | @inject() 1755 | def cbase_factory(context: ActivationScope, activating_type: Type) -> CBase: 1756 | assert isinstance(context, ActivationScope) 1757 | assert activating_type is B 1758 | return C() 1759 | 1760 | container.add_scoped(ABase, A) 1761 | container.add_scoped(BBase, B) 1762 | 1763 | method = getattr(container, method_name) 1764 | method(cbase_factory) 1765 | 1766 | provider = container.build_provider() 1767 | context = ActivationScope(provider) 1768 | 1769 | a = provider.get(ABase, context) 1770 | assert isinstance(a, A) 1771 | assert isinstance(a.b, B) 1772 | assert isinstance(a.b.c, C) 1773 | 1774 | 1775 | @pytest.mark.parametrize( 1776 | "method_name", 1777 | ["add_transient_by_factory", "add_scoped_by_factory", "add_singleton_by_factory"], 1778 | ) 1779 | def test_factories_type_singleton_consistency_nested(method_name): 1780 | container = Container() 1781 | 1782 | class ABase(ABC): 1783 | pass 1784 | 1785 | class BBase(ABC): 1786 | pass 1787 | 1788 | class CBase(ABC): 1789 | pass 1790 | 1791 | @inject() 1792 | class A(ABase): 1793 | def __init__(self, b: BBase): 1794 | self.b = b 1795 | 1796 | @inject() 1797 | class B(BBase): 1798 | def __init__(self, c: CBase): 1799 | self.c = c 1800 | 1801 | class C(CBase): 1802 | def __init__(self): 1803 | pass 1804 | 1805 | @inject() 1806 | def cbase_factory(context: ActivationScope, activating_type: Type) -> CBase: 1807 | assert isinstance(context, ActivationScope) 1808 | assert activating_type is B 1809 | return C() 1810 | 1811 | container.add_singleton(ABase, A) 1812 | container.add_singleton(BBase, B) 1813 | 1814 | method = getattr(container, method_name) 1815 | method(cbase_factory) 1816 | 1817 | provider = container.build_provider() 1818 | context = ActivationScope(provider) 1819 | 1820 | a = provider.get(ABase, context) 1821 | assert isinstance(a, A) 1822 | assert isinstance(a.b, B) 1823 | assert isinstance(a.b.c, C) 1824 | 1825 | 1826 | def test_annotation_resolution(): 1827 | class B: 1828 | pass 1829 | 1830 | @inject() 1831 | class A: 1832 | dep: B 1833 | 1834 | container = Container() 1835 | 1836 | b_singleton = B() 1837 | container.add_instance(b_singleton) 1838 | container._add_exact_scoped(A) 1839 | 1840 | provider = container.build_provider() 1841 | 1842 | instance = provider.get(A) 1843 | 1844 | assert isinstance(instance, A) 1845 | assert instance.dep is not None 1846 | assert instance.dep is b_singleton 1847 | 1848 | 1849 | def test_annotation_resolution_scoped(): 1850 | class B: 1851 | pass 1852 | 1853 | @inject() 1854 | class A: 1855 | dep: B 1856 | 1857 | container = Container() 1858 | 1859 | b_singleton = B() 1860 | container.add_instance(b_singleton) 1861 | container._add_exact_scoped(A) 1862 | 1863 | provider = container.build_provider() 1864 | 1865 | with ActivationScope() as context: 1866 | instance = provider.get(A, context) 1867 | 1868 | assert isinstance(instance, A) 1869 | assert instance.dep is not None 1870 | assert instance.dep is b_singleton 1871 | 1872 | second = provider.get(A, context) 1873 | assert instance is second 1874 | 1875 | third = provider.get(A) 1876 | assert third is not instance 1877 | 1878 | 1879 | def test_annotation_nested_resolution_1(): 1880 | class D: 1881 | pass 1882 | 1883 | class C: 1884 | pass 1885 | 1886 | @inject() 1887 | class B: 1888 | dep_1: C 1889 | dep_2: D 1890 | 1891 | @inject() 1892 | class A: 1893 | dep: B 1894 | 1895 | container = Container() 1896 | 1897 | container.add_instance(C()) 1898 | container.add_instance(D()) 1899 | container._add_exact_transient(B) 1900 | container._add_exact_scoped(A) 1901 | 1902 | provider = container.build_provider() 1903 | 1904 | with ActivationScope(provider) as context: 1905 | instance = provider.get(A, context) 1906 | 1907 | assert isinstance(instance, A) 1908 | assert isinstance(instance.dep, B) 1909 | assert isinstance(instance.dep.dep_1, C) 1910 | assert isinstance(instance.dep.dep_2, D) 1911 | 1912 | second = provider.get(A, context) 1913 | assert instance is second 1914 | 1915 | third = provider.get(A) 1916 | assert third is not instance 1917 | 1918 | 1919 | def test_annotation_nested_resolution_2(): 1920 | class E: 1921 | pass 1922 | 1923 | @inject() 1924 | class D: 1925 | dep: E 1926 | 1927 | @inject() 1928 | class C: 1929 | dep: E 1930 | 1931 | @inject() 1932 | class B: 1933 | dep_1: C 1934 | dep_2: D 1935 | 1936 | @inject() 1937 | class A: 1938 | dep: B 1939 | 1940 | container = Container() 1941 | 1942 | container.add_scoped_by_factory(E, E) 1943 | container._add_exact_scoped(C) 1944 | container._add_exact_scoped(D) 1945 | container._add_exact_transient(B) 1946 | container._add_exact_scoped(A) 1947 | 1948 | provider = container.build_provider() 1949 | 1950 | with ActivationScope(provider) as context: 1951 | instance = provider.get(A, context) 1952 | 1953 | assert isinstance(instance, A) 1954 | assert isinstance(instance.dep, B) 1955 | assert isinstance(instance.dep.dep_1, C) 1956 | assert isinstance(instance.dep.dep_2, D) 1957 | assert isinstance(instance.dep.dep_1.dep, E) 1958 | assert isinstance(instance.dep.dep_2.dep, E) 1959 | assert instance.dep.dep_1.dep is instance.dep.dep_2.dep 1960 | 1961 | second = provider.get(A, context) 1962 | assert instance is second 1963 | assert instance.dep.dep_1.dep is second.dep.dep_1.dep 1964 | 1965 | third = provider.get(A) 1966 | assert third is not instance 1967 | 1968 | 1969 | def test_annotation_resolution_singleton(): 1970 | class B: 1971 | pass 1972 | 1973 | @inject() 1974 | class A: 1975 | dep: B 1976 | 1977 | container = Container() 1978 | 1979 | b_singleton = B() 1980 | container.add_instance(b_singleton) 1981 | container._add_exact_singleton(A) 1982 | 1983 | provider = container.build_provider() 1984 | 1985 | instance = provider.get(A) 1986 | 1987 | assert isinstance(instance, A) 1988 | assert instance.dep is not None 1989 | assert instance.dep is b_singleton 1990 | 1991 | second = provider.get(A) 1992 | assert instance is second 1993 | 1994 | 1995 | def test_annotation_resolution_transient(): 1996 | class B: 1997 | pass 1998 | 1999 | @inject() 2000 | class A: 2001 | dep: B 2002 | 2003 | container = Container() 2004 | 2005 | b_singleton = B() 2006 | container.add_instance(b_singleton) 2007 | container.add_transient(A) 2008 | 2009 | provider = container.build_provider() 2010 | 2011 | with ActivationScope() as context: 2012 | instance = provider.get(A, context) 2013 | 2014 | assert isinstance(instance, A) 2015 | assert instance.dep is not None 2016 | assert instance.dep is b_singleton 2017 | 2018 | second = provider.get(A, context) 2019 | assert instance is not second 2020 | 2021 | assert isinstance(second, A) 2022 | assert second.dep is not None 2023 | assert second.dep is b_singleton 2024 | 2025 | 2026 | def test_annotations_abstract_type_transient_service(): 2027 | class FooCatsRepository(ICatsRepository): 2028 | def get_by_id(self, _id) -> Cat: 2029 | return Cat("foo") 2030 | 2031 | class GetCatRequestHandler: 2032 | cats_repository: ICatsRepository 2033 | 2034 | def get_cat(self, _id): 2035 | cat = self.cats_repository.get_by_id(_id) 2036 | return cat 2037 | 2038 | container = Container() 2039 | container.add_transient(ICatsRepository, FooCatsRepository) 2040 | container._add_exact_transient(GetCatRequestHandler) 2041 | provider = container.build_provider() 2042 | 2043 | cats_repo = provider.get(ICatsRepository) 2044 | assert isinstance(cats_repo, FooCatsRepository) 2045 | 2046 | other_cats_repo = provider.get(ICatsRepository) 2047 | assert cats_repo is not other_cats_repo 2048 | 2049 | get_cat_handler = provider.get(GetCatRequestHandler) 2050 | assert isinstance(get_cat_handler, GetCatRequestHandler) 2051 | assert isinstance(get_cat_handler.cats_repository, FooCatsRepository) 2052 | 2053 | 2054 | def test_support_for_dataclasses(): 2055 | @dataclass 2056 | class Settings: 2057 | region: str 2058 | 2059 | @inject() 2060 | @dataclass 2061 | class GetInfoHandler: 2062 | service_settings: Settings 2063 | 2064 | def handle_request(self): 2065 | return {"service_region": self.service_settings.region} 2066 | 2067 | container = Container() 2068 | container.add_instance(Settings(region="Western Europe")) 2069 | container._add_exact_scoped(GetInfoHandler) 2070 | 2071 | provider = container.build_provider() 2072 | 2073 | info_handler = provider.get(GetInfoHandler) 2074 | 2075 | assert isinstance(info_handler, GetInfoHandler) 2076 | assert isinstance(info_handler.service_settings, Settings) 2077 | 2078 | 2079 | def test_list(): 2080 | container = Container() 2081 | 2082 | class Foo: 2083 | items: list 2084 | 2085 | container.add_instance(["one", "two", "three"]) 2086 | 2087 | container._add_exact_scoped(Foo) 2088 | 2089 | provider = container.build_provider() 2090 | 2091 | instance = provider.get(Foo) 2092 | 2093 | assert instance.items == ["one", "two", "three"] 2094 | 2095 | 2096 | def test_list_generic_alias(): 2097 | container = Container() 2098 | 2099 | def list_int_factory() -> List[int]: 2100 | return [1, 2, 3] 2101 | 2102 | def list_str_factory() -> List[str]: 2103 | return ["a", "b"] 2104 | 2105 | class C: 2106 | a: List[int] 2107 | b: List[str] 2108 | 2109 | container.add_scoped_by_factory(list_int_factory) 2110 | container.add_scoped_by_factory(list_str_factory) 2111 | container.add_scoped(C) 2112 | 2113 | provider = container.build_provider() 2114 | 2115 | instance = provider.get(C) 2116 | 2117 | assert instance.a == list_int_factory() 2118 | assert instance.b == list_str_factory() 2119 | 2120 | 2121 | def test_mapping_generic_alias(): 2122 | container = Container() 2123 | 2124 | def mapping_int_factory() -> Mapping[int, int]: 2125 | return {1: 1, 2: 2, 3: 3} 2126 | 2127 | def mapping_str_factory() -> Mapping[str, int]: 2128 | return {"a": 1, "b": 2, "c": 3} 2129 | 2130 | class C: 2131 | a: Mapping[int, int] 2132 | b: Mapping[str, int] 2133 | 2134 | container.add_scoped_by_factory(mapping_int_factory) 2135 | container.add_scoped_by_factory(mapping_str_factory) 2136 | container.add_scoped(C) 2137 | 2138 | provider = container.build_provider() 2139 | 2140 | instance = provider.get(C) 2141 | 2142 | assert instance.a == mapping_int_factory() 2143 | assert instance.b == mapping_str_factory() 2144 | 2145 | 2146 | def test_dict_generic_alias(): 2147 | container = Container() 2148 | 2149 | def mapping_int_factory() -> Dict[int, int]: 2150 | return {1: 1, 2: 2, 3: 3} 2151 | 2152 | def mapping_str_factory() -> Dict[str, int]: 2153 | return {"a": 1, "b": 2, "c": 3} 2154 | 2155 | class C: 2156 | a: Dict[int, int] 2157 | b: Dict[str, int] 2158 | 2159 | container.add_scoped_by_factory(mapping_int_factory) 2160 | container.add_scoped_by_factory(mapping_str_factory) 2161 | container.add_scoped(C) 2162 | 2163 | provider = container.build_provider() 2164 | 2165 | instance = provider.get(C) 2166 | 2167 | assert instance.a == mapping_int_factory() 2168 | assert instance.b == mapping_str_factory() 2169 | 2170 | 2171 | @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires Python 3.9") 2172 | def test_list_generic_alias_list(): 2173 | container = Container() 2174 | 2175 | def list_int_factory() -> list[int]: 2176 | return [1, 2, 3] 2177 | 2178 | def list_str_factory() -> list[str]: 2179 | return ["a", "b"] 2180 | 2181 | class C: 2182 | a: list[int] 2183 | b: list[str] 2184 | 2185 | container.add_scoped_by_factory(list_int_factory) 2186 | container.add_scoped_by_factory(list_str_factory) 2187 | container.add_scoped(C) 2188 | 2189 | provider = container.build_provider() 2190 | 2191 | instance = provider.get(C) 2192 | 2193 | assert instance.a == list_int_factory() 2194 | assert instance.b == list_str_factory() 2195 | 2196 | 2197 | @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires Python 3.9") 2198 | def test_dict_generic_alias_dict(): 2199 | container = Container() 2200 | 2201 | def mapping_int_factory() -> dict[int, int]: 2202 | return {1: 1, 2: 2, 3: 3} 2203 | 2204 | def mapping_str_factory() -> dict[str, int]: 2205 | return {"a": 1, "b": 2, "c": 3} 2206 | 2207 | class C: 2208 | a: dict[int, int] 2209 | b: dict[str, int] 2210 | 2211 | container.add_scoped_by_factory(mapping_int_factory) 2212 | container.add_scoped_by_factory(mapping_str_factory) 2213 | container.add_scoped(C) 2214 | 2215 | provider = container.build_provider() 2216 | 2217 | instance = provider.get(C) 2218 | 2219 | assert instance.a == mapping_int_factory() 2220 | assert instance.b == mapping_str_factory() 2221 | 2222 | 2223 | def test_generic(): 2224 | container = Container() 2225 | 2226 | class A(LoggedVar[int]): 2227 | def __init__(self) -> None: 2228 | super().__init__(10, "example") 2229 | 2230 | class B(LoggedVar[str]): 2231 | def __init__(self) -> None: 2232 | super().__init__("Foo", "example") 2233 | 2234 | class C: 2235 | a: LoggedVar[int] 2236 | b: LoggedVar[str] 2237 | 2238 | container.add_scoped(LoggedVar[int], A) 2239 | container.add_scoped(LoggedVar[str], B) 2240 | container.add_scoped(C) 2241 | 2242 | provider = container.build_provider() 2243 | 2244 | instance = provider.get(C) 2245 | 2246 | assert isinstance(instance.a, A) 2247 | assert isinstance(instance.b, B) 2248 | 2249 | 2250 | ITERABLES = [ 2251 | ( 2252 | Iterable[LoggedVar[int]], 2253 | [LoggedVar(1, "a"), LoggedVar(2, "b"), LoggedVar(3, "c")], 2254 | ), 2255 | (Iterable[str], ["one", "two", "three"]), 2256 | (List[str], ["one", "two", "three"]), 2257 | (Tuple[str, ...], ["one", "two", "three"]), 2258 | (Sequence[str], ["one", "two", "three"]), 2259 | (List[Cat], [Cat("A"), Cat("B"), Cat("C")]), 2260 | ] 2261 | 2262 | 2263 | @pytest.mark.parametrize("annotation,value", ITERABLES) 2264 | def test_iterables_annotations_singleton(annotation, value): 2265 | container = Container() 2266 | 2267 | @inject() 2268 | class Foo: 2269 | items: annotation 2270 | 2271 | container.add_instance(value, declared_class=annotation) 2272 | 2273 | container._add_exact_scoped(Foo) 2274 | 2275 | provider = container.build_provider() 2276 | 2277 | instance = provider.get(Foo) 2278 | 2279 | assert instance.items == value 2280 | 2281 | 2282 | @pytest.mark.parametrize("annotation,value", ITERABLES) 2283 | def test_iterables_annotations_scoped_factory(annotation, value): 2284 | container = Container() 2285 | 2286 | @inject() 2287 | class Foo: 2288 | items: annotation 2289 | 2290 | @inject() 2291 | def factory() -> annotation: 2292 | return value 2293 | 2294 | container.add_scoped_by_factory(factory).add_scoped(Foo) 2295 | 2296 | provider = container.build_provider() 2297 | 2298 | instance = provider.get(Foo) 2299 | 2300 | assert instance.items == value 2301 | 2302 | 2303 | @pytest.mark.parametrize("annotation,value", ITERABLES) 2304 | def test_iterables_annotations_transient_factory(annotation, value): 2305 | container = Container() 2306 | 2307 | @inject() 2308 | class Foo: 2309 | items: annotation 2310 | 2311 | @inject() 2312 | def factory() -> annotation: 2313 | return value 2314 | 2315 | container.add_transient_by_factory(factory).add_scoped(Foo) 2316 | 2317 | provider = container.build_provider() 2318 | 2319 | instance = provider.get(Foo) 2320 | 2321 | assert instance.items == value 2322 | 2323 | 2324 | def test_factory_without_locals_raises(): 2325 | def factory_without_context() -> None: ... 2326 | 2327 | with pytest.raises(FactoryMissingContextException): 2328 | _get_factory_annotations_or_throw(factory_without_context) 2329 | 2330 | 2331 | def test_factory_with_locals_get_annotations(): 2332 | @inject() 2333 | def factory_without_context() -> "Cat": ... 2334 | 2335 | annotations = _get_factory_annotations_or_throw(factory_without_context) 2336 | 2337 | assert annotations["return"] is Cat 2338 | 2339 | 2340 | def test_deps_github_scenario(): 2341 | """ 2342 | CLAHandler 2343 | ├── CommentsService --> GitHubCommentsAPI . 2344 | └── ChecksService --> GitHubChecksAPI . 2345 | ├── GitHubAuthHandler - GitHubSettings 2346 | ├── GitHubAuthHandler - GitHubSettings 2347 | └── HTTPClient 2348 | """ 2349 | 2350 | class HTTPClient: ... 2351 | 2352 | class CommentsService: ... 2353 | 2354 | class ChecksService: ... 2355 | 2356 | class CLAHandler: 2357 | comments_service: CommentsService 2358 | checks_service: ChecksService 2359 | 2360 | class GitHubSettings: ... 2361 | 2362 | class GitHubAuthHandler: 2363 | settings: GitHubSettings 2364 | http_client: HTTPClient 2365 | 2366 | class GitHubCommentsAPI(CommentsService): 2367 | auth_handler: GitHubAuthHandler 2368 | http_client: HTTPClient 2369 | 2370 | class GitHubChecksAPI(ChecksService): 2371 | auth_handler: GitHubAuthHandler 2372 | http_client: HTTPClient 2373 | 2374 | container = Container() 2375 | 2376 | container.add_singleton(HTTPClient) 2377 | container.add_singleton(GitHubSettings) 2378 | container.add_singleton(GitHubAuthHandler) 2379 | container.add_singleton(CommentsService, GitHubCommentsAPI) 2380 | container.add_singleton(ChecksService, GitHubChecksAPI) 2381 | container.add_singleton(CLAHandler) 2382 | 2383 | provider = container.build_provider() 2384 | 2385 | cla_handler = provider.get(CLAHandler) 2386 | assert isinstance(cla_handler, CLAHandler) 2387 | assert isinstance(cla_handler.comments_service, GitHubCommentsAPI) 2388 | assert isinstance(cla_handler.checks_service, GitHubChecksAPI) 2389 | assert ( 2390 | cla_handler.comments_service.auth_handler 2391 | is cla_handler.checks_service.auth_handler 2392 | ) 2393 | 2394 | 2395 | def test_container_protocol(): 2396 | container: ContainerProtocol = arrange_cats_example() 2397 | 2398 | class UsingAlias: 2399 | def __init__(self, example): 2400 | self.cats_controller = example 2401 | 2402 | container.register(UsingAlias) 2403 | 2404 | # arrange an exact alias for UsingAlias class init parameter: 2405 | container.set_alias("example", CatsController) 2406 | 2407 | u = container.resolve(UsingAlias) 2408 | 2409 | assert isinstance(u, UsingAlias) 2410 | assert isinstance(u.cats_controller, CatsController) 2411 | assert isinstance(u.cats_controller.cat_request_handler, GetCatRequestHandler) 2412 | 2413 | 2414 | def test_container_protocol_register(): 2415 | container: ContainerProtocol = Container() 2416 | 2417 | class BaseA: 2418 | pass 2419 | 2420 | class A(BaseA): 2421 | pass 2422 | 2423 | container.register(BaseA, A) 2424 | a = container.resolve(BaseA) 2425 | 2426 | assert isinstance(a, A) 2427 | 2428 | 2429 | def test_container_protocol_any_argument(): 2430 | container: ContainerProtocol = Container() 2431 | 2432 | class A: 2433 | pass 2434 | 2435 | container.register(A, None, None, 1, noop="foo") 2436 | a = container.resolve(A, None, None, 1, noop="foo") 2437 | 2438 | assert isinstance(a, A) 2439 | 2440 | 2441 | def test_container_register_instance(): 2442 | container: ContainerProtocol = Container() 2443 | 2444 | singleton = FooDBCatsRepository(FooDBContext(ServiceSettings("example"))) 2445 | 2446 | container.register(ICatsRepository, instance=singleton) 2447 | 2448 | assert container.resolve(ICatsRepository) is singleton 2449 | 2450 | 2451 | def test_import_version(): 2452 | from rodi.__about__ import __version__ # noqa 2453 | 2454 | 2455 | def test_container_iter(): 2456 | container = Container() 2457 | 2458 | class A: 2459 | pass 2460 | 2461 | class B: 2462 | pass 2463 | 2464 | container.register(A) 2465 | container.register(B) 2466 | 2467 | for key, value in container: 2468 | assert key is A or key is B 2469 | assert isinstance(value, DynamicResolver) 2470 | 2471 | 2472 | def test_provide_protocol_with_attribute_dependency() -> None: 2473 | class P(Protocol): 2474 | def foo(self) -> Any: ... 2475 | 2476 | class Dependency: 2477 | pass 2478 | 2479 | class Impl(P): 2480 | # attribute dependency 2481 | dependency: Dependency 2482 | 2483 | def foo(self) -> Any: 2484 | pass 2485 | 2486 | container = Container() 2487 | container.register(Dependency) 2488 | container.register(Impl) 2489 | 2490 | try: 2491 | resolved = container.resolve(Impl) 2492 | except CannotResolveParameterException as e: 2493 | pytest.fail(str(e)) 2494 | 2495 | assert isinstance(resolved, Impl) 2496 | assert isinstance(resolved.dependency, Dependency) 2497 | 2498 | 2499 | def test_provide_protocol_with_init_dependency() -> None: 2500 | class P(Protocol): 2501 | def foo(self) -> Any: ... 2502 | 2503 | class Dependency: 2504 | pass 2505 | 2506 | class Impl(P): 2507 | def __init__(self, dependency: Dependency) -> None: 2508 | self.dependency = dependency 2509 | 2510 | def foo(self) -> Any: 2511 | pass 2512 | 2513 | container = Container() 2514 | container.register(Dependency) 2515 | container.register(Impl) 2516 | 2517 | try: 2518 | resolved = container.resolve(Impl) 2519 | except CannotResolveParameterException as e: 2520 | pytest.fail(str(e)) 2521 | 2522 | assert isinstance(resolved, Impl) 2523 | assert isinstance(resolved.dependency, Dependency) 2524 | 2525 | 2526 | def test_provide_protocol_generic() -> None: 2527 | T = TypeVar("T") 2528 | 2529 | class P(Protocol[T]): 2530 | def foo(self, t: T) -> T: ... 2531 | 2532 | class A: ... 2533 | 2534 | class Impl(P[A]): 2535 | def foo(self, t: A) -> A: 2536 | return t 2537 | 2538 | container = Container() 2539 | 2540 | container.register(Impl) 2541 | 2542 | try: 2543 | resolved = container.resolve(Impl) 2544 | except CannotResolveParameterException as e: 2545 | pytest.fail(str(e)) 2546 | 2547 | assert isinstance(resolved, Impl) 2548 | 2549 | 2550 | def test_provide_protocol_generic_with_inner_dependency() -> None: 2551 | T = TypeVar("T") 2552 | 2553 | class P(Protocol[T]): 2554 | def foo(self, t: T) -> T: ... 2555 | 2556 | class A: ... 2557 | 2558 | class Dependency: 2559 | pass 2560 | 2561 | class Impl(P[A]): 2562 | dependency: Dependency 2563 | 2564 | def foo(self, t: A) -> A: 2565 | return t 2566 | 2567 | container = Container() 2568 | 2569 | container.register(Impl) 2570 | container.register(Dependency) 2571 | 2572 | try: 2573 | resolved = container.resolve(Impl) 2574 | except CannotResolveParameterException as e: 2575 | pytest.fail(str(e)) 2576 | 2577 | assert isinstance(resolved, Impl) 2578 | assert isinstance(resolved.dependency, Dependency) 2579 | 2580 | 2581 | def test_ignore_class_var(): 2582 | """ 2583 | ClassVar attributes must be ignored, because they are not instance attributes. 2584 | """ 2585 | 2586 | class A: 2587 | foo: ClassVar[str] = "foo" 2588 | 2589 | class B: 2590 | example: ClassVar[str] = "example" 2591 | dependency: A 2592 | 2593 | container = Container() 2594 | 2595 | container.register(A) 2596 | container.register(B) 2597 | 2598 | b = container.resolve(B) 2599 | 2600 | assert isinstance(b, B) 2601 | assert b.example == "example" 2602 | assert b.dependency.foo == "foo" 2603 | 2604 | 2605 | def test_ignore_subclass_class_var(): 2606 | """ 2607 | Class attributes must be ignored in implementations. 2608 | """ 2609 | 2610 | class A: 2611 | foo = "foo" 2612 | 2613 | container = Container() 2614 | 2615 | container.register(A) 2616 | 2617 | a = container.resolve(A) 2618 | 2619 | assert a.foo == "foo" 2620 | 2621 | 2622 | def test_singleton_register_order_last(): 2623 | """ 2624 | The registration order of singletons should not matter. 2625 | Check that singletons are not registered twice when they are registered 2626 | after their dependents. 2627 | """ 2628 | 2629 | class Bar: 2630 | foo: Foo 2631 | 2632 | class Bar2: 2633 | foo: Foo 2634 | 2635 | container = Container() 2636 | container.register(Bar) 2637 | container.register(Bar2) 2638 | container._add_exact_singleton(Foo) 2639 | 2640 | bar = container.resolve(Bar) 2641 | bar2 = container.resolve(Bar2) 2642 | foo = container.resolve(Foo) 2643 | 2644 | # check that singletons are always the same instance 2645 | assert bar.foo is bar2.foo is foo 2646 | 2647 | 2648 | def test_singleton_register_order_first(): 2649 | """ 2650 | The registration order of singletons should not matter. 2651 | Check that singletons are not registered twice when they are registered 2652 | before their dependents. 2653 | """ 2654 | 2655 | class Bar: 2656 | foo: Foo 2657 | 2658 | class Bar2: 2659 | foo: Foo 2660 | 2661 | container = Container() 2662 | container._add_exact_singleton(Foo) 2663 | container.register(Bar) 2664 | container.register(Bar2) 2665 | 2666 | bar = container.resolve(Bar) 2667 | bar2 = container.resolve(Bar2) 2668 | foo = container.resolve(Foo) 2669 | 2670 | # check that singletons are always the same instance 2671 | assert bar.foo is bar2.foo is foo 2672 | 2673 | 2674 | def test_ignore_class_variable_if_already_initialized(): 2675 | """ 2676 | if a class variable is already initialized, it should not be overridden by 2677 | resolving a new instance nor fail if rodi can't resolve it. 2678 | """ 2679 | 2680 | foo_instance = Foo() 2681 | 2682 | class A: 2683 | foo: Foo = foo_instance 2684 | 2685 | class B: 2686 | example: ClassVar[str] = "example" 2687 | dependency: A 2688 | 2689 | container = Container() 2690 | 2691 | container.register(A) 2692 | container.register(B) 2693 | container._add_exact_singleton(Foo) 2694 | 2695 | b = container.resolve(B) 2696 | a = container.resolve(A) 2697 | foo = container.resolve(Foo) 2698 | 2699 | assert isinstance(a, A) 2700 | assert isinstance(a.foo, Foo) 2701 | assert foo_instance is a.foo 2702 | 2703 | assert isinstance(b, B) 2704 | assert b.example == "example" 2705 | assert b.dependency.foo is foo_instance 2706 | 2707 | # check that is not being overridden by resolving a new instance 2708 | assert foo is not a.foo 2709 | 2710 | 2711 | def test_nested_scope_1(): 2712 | container = Container(scope_cls=TrackingActivationScope) 2713 | container.add_scoped(Ok) 2714 | provider = container.build_provider() 2715 | 2716 | with provider.create_scope() as context_1: 2717 | a = provider.get(Ok, context_1) 2718 | 2719 | with provider.create_scope() as context_2: 2720 | b = provider.get(Ok, context_2) 2721 | 2722 | assert a is b 2723 | 2724 | 2725 | def test_nested_scope_2(): 2726 | container = Container(scope_cls=TrackingActivationScope) 2727 | container.add_scoped(Ok) 2728 | provider = container.build_provider() 2729 | 2730 | with provider.create_scope(): 2731 | with provider.create_scope() as context: 2732 | a = provider.get(Ok, context) 2733 | 2734 | with provider.create_scope() as context: 2735 | b = provider.get(Ok, context) 2736 | 2737 | assert a is not b 2738 | 2739 | 2740 | async def nested_scope_async(): 2741 | container = Container(scope_cls=TrackingActivationScope) 2742 | container.add_scoped(Ok) 2743 | provider = container.build_provider() 2744 | 2745 | with provider.create_scope() as context_1: 2746 | a = provider.get(Ok, context_1) 2747 | 2748 | await asyncio.sleep(0.01) 2749 | with provider.create_scope() as context_2: 2750 | b = provider.get(Ok, context_2) 2751 | 2752 | assert a is b 2753 | 2754 | 2755 | @pytest.mark.asyncio 2756 | async def test_nested_scope_async_1(): 2757 | await asyncio.gather( 2758 | nested_scope_async(), 2759 | nested_scope_async(), 2760 | nested_scope_async(), 2761 | nested_scope_async(), 2762 | ) 2763 | --------------------------------------------------------------------------------