├── limbus_config ├── __init__.py └── config.py ├── environment.yml ├── limbus ├── __init__.py ├── widgets │ ├── __init__.py │ ├── viz.py │ ├── widget_component.py │ └── types.py └── core │ ├── __init__.py │ ├── app.py │ ├── states.py │ ├── async_utils.py │ ├── params.py │ ├── component.py │ ├── pipeline.py │ └── param.py ├── tests ├── core │ ├── conftest.py │ ├── test_app.py │ ├── test_params.py │ ├── test_component.py │ ├── test_pipeline.py │ └── test_param.py ├── widgets │ ├── test_widget_component.py │ └── test_viz.py └── config │ └── test_config.py ├── path.bash.inc ├── .github ├── ISSUE_TEMPLATE │ ├── feature-request.md │ └── bug-report.md ├── workflows │ ├── release.yml │ └── ci.yml └── pull-request-template.md ├── setup.cfg ├── setup.py ├── setup_dev_env.sh ├── examples ├── default_cmps.py ├── webcam_app.py └── defining_cmps.py ├── .gitignore ├── README.md └── LICENSE /limbus_config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: limbus 2 | 3 | dependencies: 4 | - pip 5 | - python 6 | -------------------------------------------------------------------------------- /limbus/__init__.py: -------------------------------------------------------------------------------- 1 | from limbus import widgets 2 | from limbus.core import Pipeline, Component 3 | -------------------------------------------------------------------------------- /limbus_config/config.py: -------------------------------------------------------------------------------- 1 | """Configuration file for Limbus. 2 | 3 | Usage: 4 | Before importing any other Limbus module, set the COMPONENT_TYPE variable as: 5 | > from limbus_config import config 6 | > config.COMPONENT_TYPE = "torch" 7 | 8 | 9 | """ 10 | 11 | COMPONENT_TYPE = "generic" # generic or torch 12 | -------------------------------------------------------------------------------- /tests/core/conftest.py: -------------------------------------------------------------------------------- 1 | """Pytest fictures.""" 2 | import pytest 3 | import asyncio 4 | 5 | from limbus.core import async_utils 6 | 7 | 8 | @pytest.fixture 9 | def event_loop_instance(): 10 | """Ensure there is an event loop running.""" 11 | if async_utils.loop.is_closed(): 12 | async_utils.reset_loop() 13 | asyncio.set_event_loop(async_utils.loop) 14 | yield async_utils.loop 15 | async_utils.loop.close() 16 | -------------------------------------------------------------------------------- /limbus/widgets/__init__.py: -------------------------------------------------------------------------------- 1 | from limbus.widgets.types import Viz, Visdom, Console 2 | from limbus.widgets.viz import get, delete, set_type 3 | from limbus.widgets.widget_component import WidgetComponent, BaseWidgetComponent, is_disabled, WidgetState 4 | 5 | __all__ = [ 6 | "is_disabled", 7 | "get", 8 | "delete", 9 | "set_type", 10 | "WidgetComponent", 11 | "BaseWidgetComponent", 12 | "WidgetState", 13 | "Viz", 14 | "Visdom", 15 | "Console", 16 | ] 17 | -------------------------------------------------------------------------------- /tests/core/test_app.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from limbus.core import App 4 | from limbus_components.base import Constant, Printer 5 | 6 | 7 | @pytest.mark.usefixtures("event_loop_instance") 8 | class TestApp: 9 | def test_app(self): 10 | class MyApp(App): 11 | def create_components(self): # noqa: D102 12 | self._constant = Constant("constant", "xyz") # type: ignore 13 | self._print = Printer("print") # type: ignore 14 | 15 | def connect_components(self): # noqa: D102 16 | self._constant.outputs.out >> self._print.inputs.inp 17 | 18 | app = MyApp() 19 | app.run(1) 20 | -------------------------------------------------------------------------------- /tests/widgets/test_widget_component.py: -------------------------------------------------------------------------------- 1 | from limbus.widgets import WidgetState, WidgetComponent 2 | 3 | 4 | class TestWidgetComponent: 5 | def test_smoke(self): 6 | cmp = WidgetComponent("yuhu") 7 | assert cmp.name == "yuhu" 8 | assert cmp.inputs is not None 9 | assert cmp.outputs is not None 10 | assert cmp.properties is not None 11 | assert cmp.widget_state == WidgetState.ENABLED 12 | 13 | def test_widget_state(self): 14 | cmp = WidgetComponent("yuhu") 15 | assert cmp.widget_state == WidgetState.ENABLED 16 | cmp.widget_state = WidgetState.DISABLED 17 | assert cmp.widget_state == WidgetState.DISABLED 18 | -------------------------------------------------------------------------------- /limbus/core/__init__.py: -------------------------------------------------------------------------------- 1 | from limbus.core.component import Component, executions_manager 2 | from limbus.core.states import ComponentState, PipelineState, VerboseMode 3 | from limbus.core.param import NoValue, Reference, InputParam, OutputParam, PropertyParam 4 | from limbus.core.params import PropertyParams, InputParams, OutputParams 5 | from limbus.core.pipeline import Pipeline 6 | from limbus.core.app import App 7 | 8 | 9 | __all__ = [ 10 | "App", 11 | "Pipeline", 12 | "PipelineState", 13 | "VerboseMode", 14 | "Component", 15 | "executions_manager", 16 | "ComponentState", 17 | "Reference", 18 | "PropertyParams", 19 | "InputParams", 20 | "OutputParams", 21 | "PropertyParam", 22 | "InputParam", 23 | "OutputParam", 24 | "NoValue"] 25 | -------------------------------------------------------------------------------- /path.bash.inc: -------------------------------------------------------------------------------- 1 | # The purpose of this script is simplify running scripts inside of our 2 | # dev_env docker container. It mounts the workspace and the 3 | # workspace/../build directory inside of the container, and executes 4 | # any arguments passed to the dev_env.sh 5 | script_link="$( readlink "$BASH_SOURCE" )" || script_link="$BASH_SOURCE" 6 | apparent_sdk_dir="${script_link%/*}" 7 | if [ "$apparent_sdk_dir" = "$script_link" ]; then 8 | apparent_sdk_dir=. 9 | fi 10 | sdk_dir="$(cd -P "$apparent_sdk_dir" > /dev/null && pwd -P )" 11 | if [ ! -e $sdk_dir/.dev_env/bin/conda ]; then 12 | $sdk_dir/setup_dev_env.sh 13 | fi 14 | 15 | cmd="source $sdk_dir/.dev_env/bin/activate $sdk_dir/.dev_env/envs/limbus" 16 | if [ -z $CI ]; then 17 | eval $cmd 18 | else 19 | echo $cmd >> "$BASH_ENV" 20 | fi 21 | 22 | export PYTHONPATH=$PYTHONPATH:$sdk_dir 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680Feature Request" 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | ## Motivation 13 | 14 | 15 | 16 | ## Pitch 17 | 18 | 19 | 20 | ## Alternatives 21 | 22 | 23 | 24 | ## Additional context 25 | 26 | 27 | -------------------------------------------------------------------------------- /tests/config/test_config.py: -------------------------------------------------------------------------------- 1 | """Config tests.""" 2 | import sys 3 | 4 | from torch import nn 5 | 6 | 7 | def remove_limbus_imports(): 8 | """Remove limbus dependencies from sys.modules.""" 9 | for key in list(sys.modules.keys()): 10 | if key.startswith("limbus"): 11 | del sys.modules[key] 12 | 13 | 14 | def test_torch_base_class(): 15 | remove_limbus_imports() 16 | from limbus_config import config 17 | config.COMPONENT_TYPE = "torch" 18 | import limbus 19 | mro = limbus.Component.__mro__ 20 | remove_limbus_imports() 21 | assert len(mro) == 3 22 | assert nn.Module in mro 23 | 24 | 25 | def test_generic_base_class(): 26 | remove_limbus_imports() 27 | import limbus 28 | mro = limbus.Component.__mro__ 29 | remove_limbus_imports() 30 | assert len(mro) == 2 31 | assert nn.Module not in mro 32 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: [published] 7 | 8 | jobs: 9 | pypi: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@master 13 | - uses: actions/setup-python@v2 14 | - name: Install dependencies 15 | run: python3 -m pip install --upgrade setuptools wheel 16 | - name: Compile project 17 | run: python3 -m pip install -e . 18 | - name: Build distribution package 19 | run: python3 setup.py sdist bdist_wheel 20 | # - name: Publish package to 📦 Test PyPI 21 | # uses: pypa/gh-action-pypi-publish@release/v1 22 | # with: 23 | # password: ${{ secrets.PYPI_TEST_PASSWORD_LIMBUS }} 24 | # repository_url: https://test.pypi.org/legacy/ 25 | - name: Publish distribution 📦 to PyPI 26 | uses: pypa/gh-action-pypi-publish@release/v1 27 | with: 28 | user: __token__ 29 | password: ${{ secrets.PYPI_PASSWORD_LIMBUS }} 30 | -------------------------------------------------------------------------------- /.github/pull-request-template.md: -------------------------------------------------------------------------------- 1 | #### Changes 2 | 3 | 4 | 5 | 6 | Fixes # (issue) 7 | 8 | 9 | #### Type of change 10 | 11 | - [ ] 📚 Documentation Update 12 | - [ ] 🧪 Tests Cases 13 | - [ ] 🐞 Bug fix (non-breaking change which fixes an issue) 14 | - [ ] 🔬 New feature (non-breaking change which adds functionality) 15 | - [ ] 🚨 Breaking change (fix or feature that would cause existing functionality to not work as expected) 16 | - [ ] 📝 This change requires a documentation update 17 | 18 | 19 | #### Checklist 20 | 21 | - [ ] My code follows the style guidelines of this project 22 | - [ ] I have performed a self-review of my own code 23 | - [ ] I have commented my code, particularly in hard-to-understand areas 24 | - [ ] I have made corresponding changes to the documentation 25 | - [ ] My changes generate no new warnings 26 | - [ ] Did you update CHANGELOG in case of a major change? 27 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = 6 | -v --color=yes 7 | --pydocstyle --mypy --flake8 8 | --cov=limbus --cov-report=term-missing --cov-config=setup.cfg 9 | norecursedirs = 10 | .git 11 | .github 12 | dist 13 | build 14 | asyncio_mode=auto 15 | 16 | [bdist_wheel] 17 | universal=1 18 | 19 | [metadata] 20 | license_file = LICENSE 21 | 22 | [flake8] 23 | max-line-length = 120 24 | ignore = 25 | W504 26 | E722 27 | F401 28 | F541 29 | 30 | exclude = docs/src 31 | 32 | [mypy] 33 | files = examples, limbus, tests, limbus_config 34 | show_error_codes = True 35 | ignore_missing_imports = True 36 | 37 | [pydocstyle] 38 | match=(?!test_|__|setup).*\.py 39 | ignore = D105, D107, D203, D204, D213, D406, D407 40 | 41 | [coverage:report] 42 | exclude_lines = 43 | pragma: no cover 44 | def __repr__ 45 | if self.debug: 46 | raise 47 | if 0: 48 | if __name__ == .__main__.: 49 | 50 | 51 | [isort] 52 | line_length = 120 53 | known_first_party = 54 | examples 55 | limbus 56 | tests 57 | order_by_type = False 58 | # 3 - Vertical Hanging Indent 59 | multi_line_output = 3 60 | include_trailing_comma = True 61 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41BBug report" 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug, help wanted 6 | assignees: '' 7 | --- 8 | 9 | Thanks for taking the time to fill out this bug report! 10 | 11 | ## 🐛 Describe the bug 12 | 13 | 14 | ## Reproduction steps 15 | 21 | 22 | ## Expected behavior 23 | 24 | 25 | ## Environment 26 | 41 | 42 | ## Additional context 43 | 44 | -------------------------------------------------------------------------------- /limbus/core/app.py: -------------------------------------------------------------------------------- 1 | """High level template to create apps.""" 2 | from __future__ import annotations 3 | from abc import abstractmethod 4 | 5 | from limbus.core import Pipeline, VerboseMode, Component 6 | 7 | 8 | class App: 9 | """High level template to create an app.""" 10 | def __init__(self): 11 | self.create_components() 12 | self.connect_components() 13 | # Create the pipeline 14 | self._pipeline = Pipeline() 15 | self._pipeline.add_nodes(self._get_component_attrs()) 16 | self._pipeline.set_verbose_mode(VerboseMode.DISABLED) 17 | 18 | def _get_component_attrs(self) -> list[Component]: 19 | """Get the component attribute by name.""" 20 | return [getattr(self, attr) for attr in dir(self) if isinstance(getattr(self, attr), Component)] 21 | 22 | @abstractmethod 23 | def create_components(self): 24 | """Create the components of the app.""" 25 | pass 26 | 27 | @abstractmethod 28 | def connect_components(self): 29 | """Connect the components of the app.""" 30 | pass 31 | 32 | def run(self, iters: int = 0): 33 | """Run the app. 34 | 35 | Args: 36 | iters (optional): number of iters to be run. By default (0) all of them are run. 37 | 38 | """ 39 | self._pipeline.run(iters) 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='limbus', 4 | version='0.1.6', 5 | description='High level interface to create Pytorch Graphs.', 6 | long_description=open('README.md').read(), 7 | long_description_content_type='text/markdown', 8 | author='Luis Ferraz', 9 | url='https://github.com/kornia/limbus', 10 | install_requires=[ 11 | 'typeguard<3.0.0', 12 | ], 13 | extras_require={ 14 | 'dev': [ 15 | 'pytest', 16 | 'pytest-flake8', 17 | 'pytest-cov', 18 | 'pytest-mypy', 19 | 'pytest-pydocstyle', 20 | 'pytest-asyncio', 21 | 'mypy', # TODO: check if we can remove the deps without pytest-* 22 | 'pydocstyle', 23 | 'flake8<5.0.0', # last versions of flake8 are not compatible with pytest-flake8==1.1.1 (lastest version) 24 | 'pep8-naming', 25 | ], 26 | 'components': [ 27 | 'limbus-components' 28 | ], 29 | 'widgets': [ 30 | 'kornia', 31 | 'torch', 32 | 'numpy', 33 | 'visdom', 34 | 'opencv-python', 35 | ] 36 | }, 37 | packages=find_packages(where='.'), 38 | package_dir={'': '.'}, 39 | package_data={'': ['*.yml']}, 40 | include_package_data=True 41 | ) 42 | -------------------------------------------------------------------------------- /setup_dev_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | script_link="$( readlink "$BASH_SOURCE" )" || script_link="$BASH_SOURCE" 3 | apparent_sdk_dir="${script_link%/*}" 4 | if [ "$apparent_sdk_dir" == "$script_link" ]; then 5 | apparent_sdk_dir=. 6 | fi 7 | sdk_dir="$( command cd -P "$apparent_sdk_dir" > /dev/null && pwd -P )" 8 | 9 | # create root directory to install miniconda 10 | dev_env_dir=$sdk_dir/.dev_env 11 | mkdir -p $dev_env_dir 12 | 13 | # define miniconda paths 14 | conda_bin_dir=$dev_env_dir/bin 15 | conda_bin=$conda_bin_dir/conda 16 | 17 | # download and install miniconda 18 | # check the operating system: Mac or Linux 19 | platform=`uname` 20 | if [[ "$platform" == "Darwin" ]]; 21 | then 22 | download_link=https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh 23 | else 24 | download_link=https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 25 | fi 26 | 27 | if [ ! -e $dev_env_dir/miniconda.sh ]; then 28 | curl -o $dev_env_dir/miniconda.sh \ 29 | -O "$download_link" 30 | chmod +x $dev_env_dir/miniconda.sh 31 | fi 32 | if [ ! -e $conda_bin ]; then 33 | $dev_env_dir/miniconda.sh -b -u -p $dev_env_dir 34 | fi 35 | 36 | # create the environment 37 | $conda_bin update -n base -c defaults conda -y 38 | source $conda_bin_dir/activate $dev_env_dir 39 | $conda_bin env create 40 | # $conda_bin clean -ya 41 | 42 | # activate local virtual environment 43 | source $conda_bin_dir/activate $dev_env_dir/envs/limbus 44 | 45 | # install dev requirements 46 | pip install -e .[dev,components,widgets] 47 | # note that limbus-components is not installed in editable mode 48 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | # Trigger the workflow on push or 3 | # pull request, but only for the 4 | # master branch. 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | branches: 10 | - main 11 | # Weekly run to account for 12 | # changed dependencies. 13 | schedule: 14 | - cron: '17 03 * * 0' 15 | 16 | name: CI 17 | jobs: 18 | build: 19 | name: Build and test 20 | runs-on: ${{ matrix.os }} 21 | strategy: 22 | matrix: 23 | os: [ubuntu-latest] 24 | python-version: 25 | - '3.8' 26 | - '3.9' 27 | include: 28 | - os: ubuntu-20.04 29 | python-version: '3.8' 30 | installTyping: ${{ true }} 31 | fail-fast: true 32 | 33 | steps: 34 | - name: Checkout 35 | uses: actions/checkout@v2 36 | 37 | - name: Set up Python ${{ matrix.python-version }} 38 | uses: actions/setup-python@v2 39 | with: 40 | python-version: ${{ matrix.python-version }} 41 | 42 | - name: Cache pip 43 | uses: actions/cache@v2 44 | with: 45 | # This path is specific to Ubuntu 46 | path: ~/.cache/pip 47 | # Look to see if there is a cache hit for the corresponding requirements file 48 | key: v1-pip-${{ runner.os }}-${{ matrix.python-version }} 49 | restore-keys: | 50 | v1-pip-${{ runner.os }} 51 | v1-pip- 52 | 53 | #- name: Lint 54 | # run: python -m flake8 55 | 56 | - name: Build the code and Install dependencies 57 | run: pip install -e .[dev,components,widgets] 58 | 59 | - name: Run tests 60 | run: python -m pytest -v -------------------------------------------------------------------------------- /examples/default_cmps.py: -------------------------------------------------------------------------------- 1 | """Basic example with predefined cmps.""" 2 | import asyncio 3 | from sys import version_info 4 | 5 | import torch 6 | 7 | from limbus.core.pipeline import Pipeline 8 | try: 9 | import limbus_components as components 10 | except ImportError: 11 | raise ImportError("limbus-components is required to run this script." 12 | "Install the package with: " 13 | "'pip install limbus-components@git+https://git@github.com/kornia/limbus-components.git'") 14 | 15 | 16 | # define your components 17 | c1 = components.base.Constant("c1", 0) # type: ignore 18 | t1 = components.base.Constant("t1", torch.ones(1, 3)) # type: ignore 19 | t2 = components.base.Constant("t2", torch.ones(1, 3) * 2) # type: ignore 20 | stack = components.torch.Stack("stack") # type: ignore 21 | show = components.base.Printer("print") # type: ignore 22 | 23 | # connect the components 24 | c1.outputs.out >> stack.inputs.dim 25 | t1.outputs.out >> stack.inputs.tensors.select(0) 26 | t2.outputs.out >> stack.inputs.tensors.select(1) 27 | stack.outputs.out >> show.inputs.inp 28 | 29 | USING_PIPELINE = True 30 | if USING_PIPELINE: 31 | # run your pipeline (only one iteration, note that this pipeline can run forever) 32 | print("Run with pipeline:") 33 | # create the pipeline and add its nodes 34 | pipeline = Pipeline() 35 | pipeline.add_nodes([c1, t1, t2, stack, show]) 36 | pipeline.run(1) 37 | # You can rerun the pipeline as many times as you want and will continue from the last iteration 38 | pipeline.run(1) 39 | else: 40 | # run 1 iteration using the asyncio loop 41 | print("Run with loop:") 42 | 43 | async def f(): # noqa: D103 44 | await asyncio.gather(c1(), t1(), t2(), stack(), show()) 45 | 46 | if version_info.minor < 10: 47 | # for python <3.10 the loop must be run in this way to avoid creating a new loop. 48 | loop = asyncio.get_event_loop() 49 | loop.run_until_complete(f()) 50 | elif version_info.minor >= 10: 51 | # for python >=3.10 the loop should be run in this way. 52 | asyncio.run(f()) 53 | -------------------------------------------------------------------------------- /limbus/core/states.py: -------------------------------------------------------------------------------- 1 | """Define the states for components/pipelines.""" 2 | from __future__ import annotations 3 | from enum import Enum 4 | 5 | 6 | class ComponentStoppedError(Exception): 7 | """Raised when trying to interact with a stopped component. 8 | 9 | Properties: 10 | state: state of the component when the error was raised. 11 | message: explanation of the error. 12 | 13 | """ 14 | def __init__(self, state: "ComponentState", message: None | str = None): 15 | self.state: ComponentState = state 16 | self.message: None | str = message 17 | super().__init__() 18 | 19 | 20 | class VerboseMode(Enum): 21 | """Possible states for the verbose in the pipeline objects.""" 22 | DISABLED = 0 23 | PIPELINE = 1 24 | COMPONENT = 2 25 | 26 | 27 | class ComponentState(Enum): 28 | """Possible states for the components.""" 29 | STOPPED = 0 # when the stop is because of the component stops internally 30 | PAUSED = 1 # when the component is paused because the user requires it 31 | OK = 2 # when the iteration is executed normaly 32 | ERROR = 3 # when the stop is because of an error 33 | DISABLED = 4 # when the component is disabled for some reason (e.g. viz cannot be done) 34 | FORCED_STOP = 5 # when the stop is because the user requires it 35 | INITIALIZED = 6 # whe it is created 36 | RUNNING = 7 37 | RECEIVING_PARAMS = 8 38 | SENDING_PARAMS = 9 39 | STOPPED_AT_ITER = 10 # when the stop is because of the iteration number 40 | READY = 11 # when the component is ready to be executed at the beginning of each iteration 41 | STOPPED_BY_COMPONENT = 12 # when the stop is because another component forces it 42 | 43 | 44 | class PipelineState(Enum): 45 | """Possible states for the pipeline.""" 46 | STARTED = 0 47 | ENDED = 1 48 | PAUSED = 2 49 | ERROR = 3 50 | EMPTY = 4 51 | RUNNING = 5 52 | INITIALIZING = 6 53 | FORCED_STOP = 7 54 | 55 | 56 | class IterationState(Enum): 57 | """Internal state to control the pipeline iterations.""" 58 | COMPONENT_EXECUTED = 0 59 | COMPONENT_NOT_EXECUTED = 1 60 | COMPONENT_IN_EXECUTION = 2 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # vscode 132 | .vscode/ 133 | 134 | .dev_env 135 | components.json -------------------------------------------------------------------------------- /limbus/widgets/viz.py: -------------------------------------------------------------------------------- 1 | """Module to manage the visualization tools in limbus.""" 2 | from __future__ import annotations 3 | from typing import Type 4 | import inspect 5 | 6 | from limbus.widgets import types 7 | 8 | 9 | # global var to store the visualization backend. We want a single instance. 10 | _viz: None | types.Viz = None 11 | # global var to store the type used for the visualization backend. 12 | _viz_cls: Type[types.Viz] = types.Console # Default value is Console. 13 | 14 | 15 | def set_type(viz_cls: Type[types.Viz] | str) -> None: 16 | """Set the visualization class that will be used to create the visualization object. 17 | 18 | Args: 19 | viz_cls: The visualization class that will be used to create the visualization object. 20 | It can be a string with the name of the class (defined in limbus/viz/types.py) or the class itself. 21 | 22 | """ 23 | global _viz_cls 24 | if isinstance(viz_cls, str): 25 | # the param viz_cls must be the name of the class. 26 | # So, we go through all the posible classes and find the one with the name. 27 | found = False 28 | for cls in inspect.getmembers(types, inspect.isclass): 29 | if (issubclass(cls[1], types.Viz) and 30 | cls[1].__name__ != types.Viz.__name__ and 31 | cls[0].lower() == viz_cls.lower()): 32 | _viz_cls = cls[1] 33 | found = True 34 | break 35 | if not found: 36 | raise ValueError(f"Unknown visualization type: {viz_cls}.") 37 | elif issubclass(viz_cls, types.Viz): 38 | if viz_cls == types.Viz: 39 | raise ValueError(f"Invalid visualization type. The Viz base class cannot be setted.") 40 | _viz_cls = viz_cls 41 | else: 42 | raise ValueError(f"Invalid visualization type: {viz_cls}. Must be a subclass of Viz.") 43 | # delete and recreate the current viz object with the new type 44 | delete() 45 | get() 46 | 47 | 48 | def get(reconnect: bool = True) -> types.Viz: 49 | """Get the visualization object. 50 | 51 | Args: 52 | reconnect (optional): If True, tries to reconnect the visualization object if it is not connected. 53 | Default: True. 54 | 55 | Returns: 56 | The visualization object. 57 | 58 | """ 59 | global _viz 60 | if _viz is None or not isinstance(_viz, _viz_cls): 61 | _viz = _viz_cls() 62 | if reconnect: 63 | _viz.check_status() 64 | return _viz 65 | 66 | 67 | def delete() -> None: 68 | """Remove the visualization object.""" 69 | global _viz 70 | _viz = None 71 | -------------------------------------------------------------------------------- /limbus/core/async_utils.py: -------------------------------------------------------------------------------- 1 | """Async utils for limbus.""" 2 | from __future__ import annotations 3 | import asyncio 4 | import inspect 5 | from typing import Coroutine, TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | from limbus.core.component import Component 9 | 10 | # Get the loop that is going to run the pipeline. Doing it in this way allows to rerun the pipeline. 11 | loop = asyncio.new_event_loop() 12 | 13 | 14 | def reset_loop() -> asyncio.AbstractEventLoop: 15 | """Reset the loop.""" 16 | global loop 17 | loop = asyncio.new_event_loop() 18 | return loop 19 | 20 | 21 | def run_coroutine(coro: Coroutine) -> None: 22 | """Run a coroutine in an event loop. 23 | 24 | Args: 25 | coro: coroutine to run. 26 | 27 | """ 28 | global loop 29 | if loop.is_closed(): 30 | loop = reset_loop() 31 | loop.run_until_complete(coro) 32 | 33 | 34 | def get_task_if_exists(component: Component) -> None | asyncio.Task: 35 | """Get the task associated to a given component if it exists. 36 | 37 | Args: 38 | component: component to check. 39 | 40 | Returns: 41 | None | asyncio.Task: task associated to the component if it exists, None otherwise. 42 | 43 | """ 44 | task: asyncio.Task 45 | for task in asyncio.all_tasks(): 46 | coro = task.get_coro() 47 | assert isinstance(coro, Coroutine) # added to avoid mypy issues 48 | cr_locals = inspect.getcoroutinelocals(coro) 49 | # check if the coroutine of the component object already exists in the tasks list 50 | if "self" in cr_locals and cr_locals["self"] is component: 51 | return task 52 | return None 53 | 54 | 55 | def check_if_task_exists(component: Component) -> bool: 56 | """Check if the coroutine of the parent object already exists in the tasks list. 57 | 58 | Args: 59 | component: parent component object to check. 60 | 61 | Returns: 62 | True if the coroutine of the component object already exists in the tasks list, False otherwise. 63 | 64 | """ 65 | if get_task_if_exists(component) is not None: 66 | return True 67 | return False 68 | 69 | 70 | def create_task_if_needed(ref_component: Component, component: Component) -> None: 71 | """Create the task for the component if it is not created yet. 72 | 73 | Args: 74 | ref_component: reference component. 75 | component: component to create the task. 76 | 77 | """ 78 | if not check_if_task_exists(component): 79 | # start the execution of the component if it is not started yet 80 | component.init_from_component(ref_component) 81 | asyncio.create_task(component()) 82 | -------------------------------------------------------------------------------- /limbus/widgets/widget_component.py: -------------------------------------------------------------------------------- 1 | """Module containing the base component for visualization components.""" 2 | import functools 3 | from abc import abstractmethod 4 | from typing import Callable 5 | from enum import Enum 6 | 7 | from limbus import widgets 8 | from limbus.core import Component, ComponentState, PropertyParams 9 | 10 | 11 | class WidgetState(Enum): 12 | """Possible states for the viz.""" 13 | DISABLED = 0 # viz is disabled but can be enabled. 14 | ENABLED = 1 # viz is enabled. 15 | NO = 2 # viz cannot be used. 16 | 17 | 18 | # this is a decorator that will return ComponentState.DISABLED if the visualization is not enabled. 19 | def is_disabled(func: Callable) -> Callable: 20 | """Return ComponentState.DISABLED if viz is not enabled.""" 21 | @functools.wraps(func) 22 | async def wrapper_check_component_disabled(self, *args, **kwargs): 23 | vz = widgets.get(False) 24 | if vz is None or not vz.enabled: 25 | return ComponentState.DISABLED 26 | return await func(self, *args, **kwargs) 27 | return wrapper_check_component_disabled 28 | 29 | 30 | class WidgetComponent(Component): 31 | """Allow to use widgets in Limbus Components. 32 | 33 | Args: 34 | name (str): component name. 35 | 36 | """ 37 | # By default the components do not have viz. 38 | # Default WIDGET_STATE must be static because we need to get access when the class is not instantiated. 39 | # To change the widget state, use the widget_state property. 40 | WIDGET_STATE: WidgetState = WidgetState.ENABLED 41 | 42 | def __init__(self, name: str): 43 | super().__init__(name) 44 | self._widget_state: WidgetState = self.__class__.WIDGET_STATE 45 | 46 | @property 47 | def widget_state(self) -> WidgetState: 48 | """Get the viz state for this component.""" 49 | return self._widget_state 50 | 51 | @widget_state.setter 52 | def widget_state(self, state: WidgetState) -> None: 53 | """Set the viz state for this component.""" 54 | self._widget_state = state 55 | 56 | 57 | class BaseWidgetComponent(WidgetComponent): 58 | """Base class for only visualization components. 59 | 60 | Args: 61 | name (str): component name. 62 | 63 | """ 64 | # by default Widget Components have the viz enabled, to disable it use the widget_state property. 65 | WIDGET_STATE: WidgetState = WidgetState.ENABLED 66 | 67 | @staticmethod 68 | def register_properties(properties: PropertyParams) -> None: 69 | """Register the properties. 70 | 71 | Args: 72 | properties: object to register the properties. 73 | 74 | """ 75 | # this line is like super() but for static methods. 76 | Component.register_properties(properties) 77 | properties.declare("title", str, "") 78 | 79 | @abstractmethod 80 | async def _show(self, title: str) -> None: 81 | """Show the data. 82 | 83 | Args: 84 | title: same as self._properties[]"title"].value. 85 | 86 | """ 87 | raise NotImplementedError 88 | 89 | @is_disabled 90 | async def forward(self) -> ComponentState: # noqa: D102 91 | await self._show(self._properties["title"].value) 92 | return ComponentState.OK 93 | -------------------------------------------------------------------------------- /examples/webcam_app.py: -------------------------------------------------------------------------------- 1 | """Example with an app managing the pipeline.""" 2 | import asyncio 3 | 4 | try: 5 | import aioconsole 6 | except: 7 | raise ImportError("aioconsole is required to run this script. Install it with: pip install aioconsole") 8 | 9 | from limbus.core import VerboseMode, App, async_utils 10 | import limbus.widgets 11 | try: 12 | import limbus_components as components 13 | except ImportError: 14 | raise ImportError("limbus-components is required to run this script." 15 | "Install the package with: " 16 | "'pip install limbus-components@git+https://git@github.com/kornia/limbus-components.git'") 17 | 18 | # Init the widgets backend 19 | limbus.widgets.set_type("OpenCV") 20 | 21 | 22 | class WebcamApp(App): 23 | """Example with an app managing a pipeline using a webcam.""" 24 | def create_components(self): # noqa: D102 25 | self._webcam = components.base.Webcam(name="webcam", batch_size=1) # type: ignore 26 | self._show = components.base.ImageShow(name="show") # type: ignore 27 | self._accum = components.base.Accumulator(name="acc", elements=2) # type: ignore 28 | self._cat = components.torch.Cat(name="stack") # type: ignore 29 | 30 | def connect_components(self): # noqa: D102 31 | self._webcam.outputs.image >> self._accum.inputs.inp 32 | self._accum.outputs.out >> self._cat.inputs.tensors 33 | self._cat.outputs.out >> self._show.inputs.image 34 | 35 | def run(self, iters: int = 0): # noqa: D102 36 | self._pipeline.set_verbose_mode(VerboseMode.PIPELINE) 37 | # self.pipeline.run(iters) 38 | async_utils.run_coroutine(self._app(self._pipeline)) 39 | 40 | @staticmethod 41 | def _print_help() -> None: 42 | """Print the help message.""" 43 | print( 44 | '\n\nOPTIONS MENU:\n' 45 | 'Press "o" to run one pipeline iteration.\n' 46 | 'Press "f" to run the pipeline forever.\n' 47 | 'Press "r" to resume the pipeline.\n' 48 | 'Press "p" to pause the pipeline.\n' 49 | 'Press "vc" COMPONENT verbose state.\n' 50 | 'Press "vp" PIPELINE verbose state.\n' 51 | 'Press "vd" DISABLED verbose state.\n' 52 | 'Press "q" to stop and quit.') 53 | 54 | async def _app(self, pipeline) -> None: 55 | """Run the interface.""" 56 | while True: 57 | self._print_help() 58 | key_in = await aioconsole.ainput('Option:') 59 | if key_in == 'o': 60 | asyncio.create_task(pipeline.async_run(1)) 61 | elif key_in == 'f': 62 | asyncio.create_task(pipeline.async_run()) 63 | elif key_in == 'r': 64 | pipeline.resume() 65 | elif key_in == 'p': 66 | pipeline.pause() 67 | elif key_in == 'vc': 68 | pipeline.set_verbose_mode(VerboseMode.COMPONENT) 69 | elif key_in == 'vp': 70 | pipeline.set_verbose_mode(VerboseMode.PIPELINE) 71 | elif key_in == 'vd': 72 | pipeline.set_verbose_mode(VerboseMode.DISABLED) 73 | elif key_in == 'q': 74 | pipeline.stop() 75 | break 76 | 77 | 78 | WebcamApp().run() 79 | -------------------------------------------------------------------------------- /tests/core/test_params.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | 5 | from limbus.core import PropertyParams, NoValue, InputParams, OutputParams 6 | from limbus.core.param import Param, InputParam, OutputParam, PropertyParam 7 | 8 | 9 | class TParams(InputParams): 10 | """Test class to test Params class with all teh posible params for Param. 11 | 12 | NOTE: Inherits from InputParams because it is the only one that allows to use all the args in Param. 13 | 14 | """ 15 | pass 16 | 17 | 18 | class TestParams: 19 | def test_smoke(self): 20 | p = TParams() 21 | assert p is not None 22 | 23 | def test_declare(self): 24 | p = TParams() 25 | p.declare("x") 26 | assert isinstance(p.x.value, NoValue) 27 | assert isinstance(p["x"].value, NoValue) 28 | 29 | p.declare("y", float, 1.) 30 | assert p.y.value == 1. 31 | assert p["y"].value == 1. 32 | assert isinstance(p["y"], Param) 33 | assert isinstance(p.y, Param) 34 | assert isinstance(p["y"].value, float) 35 | assert p["y"].type == float 36 | assert p["y"].name == "y" 37 | assert p["y"].arg is None 38 | assert p.y.arg is None 39 | 40 | def test_tensor(self): 41 | p1 = TParams() 42 | p2 = TParams() 43 | 44 | p1.declare("x", torch.Tensor, torch.tensor(1.)) 45 | assert isinstance(p1["x"].value, torch.Tensor) 46 | 47 | p2.declare("y", torch.Tensor, p1.x) 48 | assert p1.x.value == p2.y.value 49 | 50 | def test_get_params(self): 51 | p = TParams() 52 | p.declare("x") 53 | p.declare("y", float, 1.) 54 | assert len(p) == 2 55 | assert p.get_params() == ["x", "y"] 56 | assert isinstance(p.x.value, NoValue) 57 | assert p.y.value == 1. 58 | p.x.value = "xyz" 59 | assert p.x.value == "xyz" 60 | 61 | def test_wrong_set_param_type(self): 62 | p = TParams() 63 | with pytest.raises(TypeError): 64 | p.declare("x", int, 1.) 65 | p.declare("x", int) 66 | with pytest.raises(TypeError): 67 | p.x.value = "xyz" 68 | 69 | 70 | class TestInputParams: 71 | def test_declare(self): 72 | p = InputParams() 73 | p.declare("x", float, 1.) 74 | assert isinstance(p.x, InputParam) 75 | 76 | def test_declare_with_param(self): 77 | p = InputParams() 78 | p0 = Param("x", float, 1.) 79 | p.declare("x", float, p0) 80 | assert p.x.value == p0.value 81 | assert p.z is None # Intellisense asumes p.z exist as an InputParams 82 | 83 | 84 | class TestOutputParams: 85 | def test_declare(self): 86 | p = OutputParams() 87 | p.declare("x", float) 88 | assert isinstance(p.x, OutputParam) 89 | assert p.z is None # Intellisense asumes p.z exist as an OutputParam 90 | 91 | 92 | class TestPropertyParams: 93 | def test_declare(self): 94 | p = PropertyParams() 95 | p.declare("x", float, 1.) 96 | assert isinstance(p.x, PropertyParam) 97 | assert p.z is None # Intellisense asumes p.z exist as an PropParams 98 | 99 | def test_declare_with_param(self): 100 | p = PropertyParams() 101 | p0 = Param("x", float, 1.) 102 | p.declare("x", float, p0) 103 | assert p.x.value == p0.value 104 | -------------------------------------------------------------------------------- /tests/widgets/test_viz.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import asyncio 3 | 4 | from limbus.core.component import ComponentState 5 | from limbus.widgets import viz 6 | from limbus.widgets import is_disabled 7 | 8 | 9 | @pytest.fixture 10 | def class_with_viz(): 11 | class A: 12 | @is_disabled 13 | async def f(self) -> ComponentState: 14 | return ComponentState.OK 15 | return A 16 | 17 | 18 | def _init_viz_module_params(viz_cls): 19 | viz._viz = None 20 | viz._viz_cls = viz_cls 21 | 22 | 23 | VALID_TYPES = [viz.types.Console, viz.types.Visdom] 24 | 25 | 26 | @pytest.mark.parametrize("value", (VALID_TYPES)) 27 | def test_get(value): 28 | _init_viz_module_params(value) 29 | 30 | # create viz object 31 | v = viz.get() 32 | assert v == viz._viz 33 | assert isinstance(v, viz._viz_cls) 34 | 35 | # validate that viz object is a singleton 36 | v2 = viz.get() 37 | assert v == v2 38 | 39 | 40 | def test_get_with_change_of_cls(): 41 | _init_viz_module_params(viz.types.Console) 42 | 43 | # create viz object 44 | v = viz.get() 45 | assert v == viz._viz 46 | assert isinstance(v, viz._viz_cls) 47 | 48 | viz._viz_cls = viz.types.Visdom 49 | 50 | # validate that viz object is a singleton 51 | v2 = viz.get() 52 | assert v != v2 53 | assert v2 == viz._viz 54 | assert isinstance(v2, viz._viz_cls) 55 | 56 | 57 | @pytest.mark.parametrize("value", (VALID_TYPES)) 58 | def test_delete(value): 59 | _init_viz_module_params(value) 60 | 61 | # create init viz module if it is not created 62 | viz.get() 63 | 64 | # create viz object 65 | viz.delete() 66 | assert viz._viz is None 67 | 68 | 69 | @pytest.mark.parametrize("value", (VALID_TYPES)) 70 | def test_set_types_class(value): 71 | viz.set_type(value) 72 | assert viz._viz_cls == value 73 | assert isinstance(viz._viz, value) 74 | 75 | 76 | @pytest.mark.parametrize("value, viz_type", ([("Console", viz.types.Console), ("Visdom", viz.types.Visdom)])) 77 | def test_set_types_str(value, viz_type): 78 | viz.set_type(value) 79 | assert viz._viz_cls == viz_type 80 | assert isinstance(viz._viz, viz_type) 81 | 82 | 83 | def test_set_types_new_viz_class(): 84 | class A(viz.types.Viz): 85 | def check_status(self): 86 | pass 87 | 88 | def show_image(self): 89 | pass 90 | 91 | def show_images(self): 92 | pass 93 | 94 | def show_text(self): 95 | pass 96 | 97 | viz.set_type(A) 98 | assert viz._viz_cls == A 99 | assert isinstance(viz._viz, A) 100 | 101 | 102 | def test_invalid_viz_type(): 103 | # invalid string 104 | with pytest.raises(ValueError): 105 | viz.set_type("InvalidVizType") 106 | 107 | # invalid class type 108 | class A: 109 | pass 110 | with pytest.raises(ValueError): 111 | viz.set_type(A) 112 | 113 | # base class cannot be passed 114 | with pytest.raises(ValueError): 115 | viz.set_type(viz.types.Viz) 116 | 117 | 118 | # THE NEXT TESTS ARE FOR CODE IN base_component.py 119 | 120 | @pytest.mark.parametrize("value", (VALID_TYPES)) 121 | def test_decorator_is_disabled_with_viz_disabled(class_with_viz, value): 122 | _init_viz_module_params(value) 123 | viz.get() 124 | # by default, the viz is always enabled in Console class 125 | viz._viz._enabled = False 126 | assert asyncio.run(class_with_viz().f()) == ComponentState.DISABLED 127 | 128 | 129 | @pytest.mark.parametrize("value", (VALID_TYPES)) 130 | def test_decorator_is_disabled_with_valid_viz(class_with_viz, value): 131 | _init_viz_module_params(value) 132 | if value == viz.types.Visdom: 133 | # visdom server is not enabled so the state is DISABLED 134 | assert asyncio.run(class_with_viz().f()) == ComponentState.DISABLED 135 | else: 136 | assert asyncio.run(class_with_viz().f()) == ComponentState.OK 137 | -------------------------------------------------------------------------------- /limbus/core/params.py: -------------------------------------------------------------------------------- 1 | """Classes to define set of parameters.""" 2 | from __future__ import annotations 3 | from typing import Any, Iterator, Iterable, Callable 4 | from abc import ABC, abstractmethod 5 | 6 | # Note that Component class cannot be imported to avoid circular dependencies. 7 | # Since it is only used for type hints we import the module and use "component.Component" for typing. 8 | from limbus.core import component 9 | from limbus.core.param import Param, NoValue, InputParam, OutputParam, PropertyParam 10 | 11 | 12 | class Params(Iterable, ABC): 13 | """Class to store parameters.""" 14 | 15 | def __init__(self, parent_component: None | "component.Component" = None): 16 | super().__init__() 17 | self._parent = parent_component 18 | 19 | @abstractmethod 20 | def declare(self, *args, **kwargs) -> None: 21 | """Add or modify a param.""" 22 | raise NotImplementedError 23 | 24 | def get_params(self, only_connected: bool = False) -> list[str]: 25 | """Return the name of all the params. 26 | 27 | Args: 28 | only_connected: If True, only return the params that are connected. 29 | 30 | """ 31 | params = [] 32 | for name in self.__dict__: 33 | param = getattr(self, name) 34 | if isinstance(param, Param) and (not only_connected or param.ref_counter()): 35 | params.append(name) 36 | return params 37 | 38 | def __len__(self) -> int: 39 | return len(self.get_params()) 40 | 41 | def __getitem__(self, name: str) -> Param: 42 | return getattr(self, name) 43 | 44 | def __iter__(self) -> Iterator[Param]: 45 | for name in self.__dict__: 46 | attr = getattr(self, name) 47 | if isinstance(attr, Param): 48 | yield attr 49 | 50 | def __repr__(self) -> str: 51 | return ''.join( 52 | ( 53 | f'{type(self).__name__}(', 54 | ', '.join( 55 | f'{name}={getattr(self, name).value}' for name in sorted(self.__dict__) if not name.startswith('_') 56 | ), 57 | ')', 58 | ) 59 | ) 60 | 61 | 62 | class InputParams(Params): 63 | """Class to manage input parameters.""" 64 | 65 | def declare(self, name: str, tp: Any = Any, value: Any = NoValue(), callback: Callable | None = None) -> None: 66 | """Add or modify a param. 67 | 68 | Args: 69 | name: name of the parameter. 70 | tp: type (e.g. str, int, list, str | int,...). Default: typing.Any 71 | value (optional): value for the parameter. Default: NoValue(). 72 | callback (optional): async callback function to be called when the parameter value changes. 73 | Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` 74 | - MUST return the value to be finally used. 75 | Default: None. 76 | 77 | """ 78 | if isinstance(value, Param): 79 | value = value.value 80 | setattr(self, name, InputParam(name, tp, value, None, self._parent, callback)) 81 | 82 | def __getattr__(self, name: str) -> InputParam: # type: ignore # it should return an InitParam 83 | """Trick to avoid mypy issues with dinamyc attributes.""" 84 | ... 85 | 86 | 87 | class PropertyParams(Params): 88 | """Class to manage property parameters.""" 89 | 90 | def declare(self, name: str, tp: Any = Any, value: Any = NoValue(), callback: Callable | None = None) -> None: 91 | """Add or modify a param. 92 | 93 | Args: 94 | name: name of the parameter. 95 | tp: type (e.g. str, int, list, str | int,...). Default: typing.Any 96 | value (optional): value for the parameter. Default: NoValue(). 97 | callback (optional): async callback function to be called when the parameter value changes. 98 | Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` 99 | - MUST return the value to be finally used. 100 | Default: None. 101 | 102 | """ 103 | if isinstance(value, Param): 104 | value = value.value 105 | setattr(self, name, PropertyParam(name, tp, value, None, self._parent, callback)) 106 | 107 | def __getattr__(self, name: str) -> PropertyParam: # type: ignore # it should return an PropParam 108 | """Trick to avoid mypy issues with dinamyc attributes.""" 109 | ... 110 | 111 | 112 | class OutputParams(Params): 113 | """Class to manage output parameters.""" 114 | 115 | def declare(self, name: str, tp: Any = Any, arg: None | str = None, callback: Callable | None = None) -> None: 116 | """Add or modify a param. 117 | 118 | Args: 119 | name: name of the parameter. 120 | tp: type (e.g. str, int, list, str | int,...). Default: typing.Any 121 | arg (optional): Component argument directly related with the value of the parameter. Default: None. 122 | E.g. this is useful to propagate datatypes and values from a pin with a default value to an argument 123 | in a Component (GUI). 124 | callback (optional): async callback function to be called when the parameter value changes. 125 | Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` 126 | - MUST return the value to be finally used. 127 | Default: None. 128 | 129 | """ 130 | setattr(self, name, OutputParam(name, tp, NoValue(), arg, self._parent, callback)) 131 | 132 | def __getattr__(self, name: str) -> OutputParam: # type: ignore # it should return an OutputParam 133 | """Trick to avoid mypy issues with dinamyc attributes.""" 134 | ... 135 | -------------------------------------------------------------------------------- /examples/defining_cmps.py: -------------------------------------------------------------------------------- 1 | """Basic example defining components and connecting them.""" 2 | from typing import List, Any 3 | import asyncio 4 | 5 | # If you want to change the limbus config you need to do it before importing any limbus module!!! 6 | from limbus_config import config 7 | config.COMPONENT_TYPE = "torch" 8 | 9 | from limbus.core import (Component, InputParams, OutputParams, PropertyParams, Pipeline, VerboseMode, # noqa: E402 10 | ComponentState, OutputParam, InputParam, async_utils) # noqa: E402 11 | 12 | 13 | # define the components 14 | # --------------------- 15 | class Add(Component): 16 | """Add two numbers.""" 17 | # NOTE: type definition is optional, but it helps with the intellisense. ;) 18 | class InputsTyping(OutputParams): # noqa: D106 19 | a: InputParam 20 | b: InputParam 21 | 22 | class OutputsTyping(OutputParams): # noqa: D106 23 | out: OutputParam 24 | 25 | inputs: InputsTyping # type: ignore 26 | outputs: OutputsTyping # type: ignore 27 | 28 | async def val_rec_a(self, value: Any) -> Any: # noqa: D102 29 | print(f"CALLBACK: Add.a: {value}.") 30 | return value 31 | 32 | async def val_rec_b(self, value: Any) -> Any: # noqa: D102 33 | print(f"CALLBACK: Add.b: {value}.") 34 | return value 35 | 36 | async def val_sent(self, value: Any) -> Any: # noqa: D102 37 | print(f"CALLBACK: Add.out: {value}.") 38 | return value 39 | 40 | @staticmethod 41 | def register_inputs(inputs: InputParams) -> None: # noqa: D102 42 | inputs.declare("a", int, callback=Add.val_rec_a) 43 | inputs.declare("b", int, callback=Add.val_rec_b) 44 | 45 | @staticmethod 46 | def register_outputs(outputs: OutputParams) -> None: # noqa: D102 47 | outputs.declare("out", int, callback=Add.val_sent) 48 | 49 | async def forward(self) -> ComponentState: # noqa: D102 50 | a, b = await asyncio.gather(self._inputs.a.receive(), self._inputs.b.receive()) 51 | print(f"Add: {a} + {b}") 52 | await self._outputs.out.send(a + b) 53 | return ComponentState.OK 54 | 55 | 56 | class Printer(Component): 57 | """Prints the input to the console.""" 58 | # NOTE: type definition is optional, but it helps with the intellisense. ;) 59 | class InputsTyping(OutputParams): # noqa: D106 60 | inp: InputParam 61 | 62 | inputs: InputsTyping # type: ignore 63 | 64 | async def val_changed(self, value: Any) -> Any: # noqa: D102 65 | print(f"CALLBACK: Printer.inp: {value}.") 66 | return value 67 | 68 | @staticmethod 69 | def register_inputs(inputs: InputParams) -> None: # noqa: D102 70 | inputs.declare("inp", Any, callback=Printer.val_changed) 71 | 72 | async def forward(self) -> ComponentState: # noqa: D102 73 | value = await self._inputs.inp.receive() 74 | print(f"Printer: {value}") 75 | return ComponentState.OK 76 | 77 | 78 | class Data(Component): 79 | """Data source of inf numbers.""" 80 | # NOTE: type definition is optional, but it helps with the intellisense. ;) 81 | class OutputsTyping(OutputParams): # noqa: D106 82 | out: OutputParam 83 | 84 | outputs: OutputsTyping # type: ignore 85 | 86 | def __init__(self, name: str, initial_value: int = 0): 87 | super().__init__(name) 88 | self._initial_value: int = initial_value 89 | 90 | @staticmethod 91 | def register_outputs(outputs: OutputParams) -> None: # noqa: D102 92 | outputs.declare("out", int) 93 | 94 | async def forward(self) -> ComponentState: # noqa: D102 95 | print(f"Read: {self._initial_value}") 96 | await self._outputs.out.send(self._initial_value) 97 | self._initial_value += 1 98 | return ComponentState.OK 99 | 100 | 101 | class Acc(Component): 102 | """Accumulate data in a list.""" 103 | # NOTE: type definition is optional, but it helps with the intellisense. ;) 104 | class InputsTyping(OutputParams): # noqa: D106 105 | inp: InputParam 106 | 107 | class OutputsTyping(OutputParams): # noqa: D106 108 | out: OutputParam 109 | 110 | inputs: InputsTyping # type: ignore 111 | outputs: OutputsTyping # type: ignore 112 | 113 | def __init__(self, name: str, elements: int = 1): 114 | super().__init__(name) 115 | self._elements: int = elements 116 | 117 | async def set_elements(self, value: int) -> int: # noqa: D102 118 | print(f"CALLBACK: Acc.elements: {value}.") 119 | # this is a bir tricky since the value is stored in 2 places the property and the variable. 120 | # Since the acc uses the _elements variable in the forward method we need to update it here 121 | # as well. Thanks to the callback we do not need to worry about both sources. 122 | self._elements = value 123 | return value 124 | 125 | @staticmethod 126 | def register_properties(properties: PropertyParams) -> None: # noqa: D102 127 | properties.declare("elements", int, callback=Acc.set_elements) 128 | 129 | @staticmethod 130 | def register_inputs(inputs: InputParams) -> None: # noqa: D102 131 | inputs.declare("inp", int) 132 | 133 | @staticmethod 134 | def register_outputs(outputs: OutputParams) -> None: # noqa: D102 135 | outputs.declare("out", List[int]) 136 | 137 | async def forward(self) -> ComponentState: # noqa: D102 138 | res: List[int] = [] 139 | while len(res) < self._elements: 140 | res.append(await self._inputs.inp.receive()) 141 | print(f"Acc {len(res)}: {res}") 142 | 143 | print(f"Acc: {res}") 144 | await self._outputs.out.send(res) 145 | return ComponentState.OK 146 | 147 | 148 | # create the components 149 | # --------------------- 150 | data0 = Data("data0", 0) 151 | data10 = Data("data10", 10) 152 | add = Add("add") 153 | acc = Acc(name="acc", elements=2) 154 | printer0 = Printer("printer0") 155 | printer1 = Printer("printer1") 156 | printer2 = Printer("printer2") 157 | 158 | # connect the components 159 | # ---------------------- 160 | data0.outputs.out >> add.inputs.a 161 | data10.outputs.out >> add.inputs.b 162 | add.outputs.out >> acc.inputs.inp 163 | acc.outputs.out >> printer2.inputs.inp # print the accumulated values once all are received 164 | data0.outputs.out >> printer0.inputs.inp # print the first value (data0) 165 | add.outputs.out >> printer1.inputs.inp # print the sum of the values (data10 + data0) 166 | 167 | # create and run the pipeline 168 | # --------------------------- 169 | engine: Pipeline = Pipeline() 170 | # at least we need to add one node, the others are added automatically 171 | engine.add_nodes([add, printer0]) 172 | # there are several states for each component, with this verbose mode we can see them 173 | engine.set_verbose_mode(VerboseMode.COMPONENT) 174 | # run all the components at least 2 times (since there is an accumulator, some components will be run more than once) 175 | 176 | 177 | async def run() -> None: # noqa: D103 178 | await engine.async_run(1) 179 | await acc.properties.elements.set_property(3) # change the number of elements to accumulate 180 | await engine.async_run(1) 181 | 182 | async_utils.run_coroutine(run()) 183 | -------------------------------------------------------------------------------- /tests/core/test_component.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import logging 3 | 4 | from limbus.core import Component, ComponentState, Pipeline 5 | from limbus.core.component import _ComponentState 6 | 7 | 8 | class TestState: 9 | def test_smoke(self): 10 | cmp = Component("test") 11 | state = _ComponentState(cmp, ComponentState.RUNNING, True) 12 | assert state.state == [ComponentState.RUNNING] 13 | assert state.message(ComponentState.RUNNING) is None 14 | assert state.verbose is True 15 | assert state._component == cmp 16 | 17 | def test_call_no_params(self): 18 | cmp = Component("test") 19 | state = _ComponentState(cmp, ComponentState.RUNNING) 20 | assert state() == [ComponentState.RUNNING] 21 | assert state.message(ComponentState.RUNNING) is None 22 | assert state.verbose is False 23 | 24 | @pytest.mark.parametrize("verbose", [True, False]) 25 | def test_call(self, caplog, verbose): 26 | cmp = Component("test") 27 | state = _ComponentState(cmp, ComponentState.RUNNING) 28 | state.verbose = verbose 29 | assert state.verbose == verbose 30 | with caplog.at_level(logging.INFO): 31 | assert state(ComponentState.DISABLED, "message") == [ComponentState.DISABLED] 32 | assert state.message(ComponentState.DISABLED) == "message" 33 | if verbose: 34 | assert len(caplog.records) == 1 35 | assert "message" in caplog.text 36 | assert caplog.records[0].levelname == "INFO" 37 | else: 38 | assert len(caplog.records) == 0 39 | 40 | 41 | class TestComponent: 42 | def test_smoke(self): 43 | cmp = Component("yuhu") 44 | assert cmp.name == "yuhu" 45 | assert cmp.inputs is not None 46 | assert cmp.outputs is not None 47 | assert cmp.properties is not None 48 | 49 | def test_set_state(self): 50 | cmp = Component("yuhu") 51 | cmp.set_state(ComponentState.PAUSED) 52 | cmp.state == ComponentState.PAUSED 53 | cmp.state_message(ComponentState.PAUSED) is None 54 | cmp.state_message(ComponentState.FORCED_STOP) is None 55 | cmp.set_state(ComponentState.ERROR, "error") 56 | cmp.state == ComponentState.ERROR 57 | cmp.state_message(ComponentState.ERROR) == "error" 58 | cmp.state_message(ComponentState.FORCED_STOP) is None 59 | 60 | def test_register_properties(self): 61 | class A(Component): 62 | @staticmethod 63 | def register_properties(properties): 64 | Component.register_properties(properties) 65 | properties.declare("a", float, 1.) 66 | properties.declare("b", float, 2.) 67 | 68 | cmp = A("yuhu") 69 | assert len(cmp.properties) == 2 70 | assert len(cmp.inputs) == 0 71 | assert len(cmp.outputs) == 0 72 | assert cmp.properties.a.value == 1. 73 | assert cmp.properties.b.value == 2. 74 | 75 | def test_register_inputs(self): 76 | class A(Component): 77 | @staticmethod 78 | def register_inputs(inputs): 79 | inputs.declare("a", float, 1.) 80 | inputs.declare("b", float, 2.) 81 | 82 | cmp = A("yuhu") 83 | assert len(cmp.properties) == 0 84 | assert len(cmp.outputs) == 0 85 | assert len(cmp.inputs) == 2 86 | assert cmp.inputs.a.value == 1. 87 | assert cmp.inputs.b.value == 2. 88 | 89 | def test_register_outputs(self): 90 | class A(Component): 91 | @staticmethod 92 | def register_outputs(outputs): 93 | outputs.declare("a", float) 94 | outputs.declare("b", float) 95 | 96 | cmp = A("yuhu") 97 | assert len(cmp.properties) == 0 98 | assert len(cmp.inputs) == 0 99 | assert len(cmp.outputs) == 2 100 | assert cmp.outputs.a.type is float 101 | assert cmp.outputs.b.type is float 102 | 103 | def test_init_from_component(self): 104 | class A(Component): 105 | pass 106 | 107 | cmp = A("yuhu") 108 | cmp.verbose is True 109 | cmp2 = A("yuhu2") 110 | assert cmp2.verbose is False 111 | cmp2.init_from_component(cmp) 112 | assert cmp2.verbose == cmp.verbose 113 | assert cmp2.pipeline == cmp.pipeline # None 114 | 115 | 116 | @pytest.mark.usefixtures("event_loop_instance") 117 | class TestComponentWithPipeline: 118 | def test_init_from_component_with_pipeline(self): 119 | class A(Component): 120 | @staticmethod 121 | def register_outputs(outputs): 122 | outputs.declare("out", int) 123 | 124 | async def forward(): 125 | self._outputs.out.send(1) 126 | return ComponentState.OK 127 | 128 | class B(Component): 129 | @staticmethod 130 | def register_inputs(inputs): 131 | inputs.declare("inp", int) 132 | 133 | async def forward(): 134 | self._inputs.inp.receive() 135 | return ComponentState.OK 136 | 137 | a = A("a") 138 | b = B("b") 139 | a.outputs.out >> b.inputs.inp 140 | pipeline = Pipeline() 141 | pipeline.add_nodes(a) 142 | assert len(pipeline._nodes) == 1 143 | assert a in pipeline._nodes 144 | a.set_pipeline(pipeline) 145 | a.verbose = True 146 | b.init_from_component(a) 147 | assert a.verbose == b.verbose 148 | assert a.pipeline == b.pipeline # None 149 | assert len(pipeline._nodes) == 2 150 | assert a in pipeline._nodes 151 | assert b in pipeline._nodes 152 | 153 | @pytest.mark.parametrize("iters", [0, 1, 2]) 154 | def test_get_stopping_iteration(self, iters): 155 | class A(Component): 156 | async def forward(self): 157 | return ComponentState.STOPPED 158 | 159 | cmp = A("yuhu") 160 | pipeline = Pipeline() 161 | pipeline.add_nodes(cmp) 162 | pipeline.run(iters) 163 | assert cmp.executions_counter == 1 164 | assert pipeline.min_iteration_in_progress == 1 165 | assert cmp.stopping_execution == iters 166 | 167 | def test_stop_after_exception(self): 168 | class A(Component): 169 | async def forward(self): 170 | raise Exception("test") 171 | 172 | cmp = A("yuhu") 173 | pipeline = Pipeline() 174 | pipeline.add_nodes(cmp) 175 | pipeline.run(2) 176 | assert cmp.executions_counter == 1 177 | assert pipeline.min_iteration_in_progress == 1 178 | assert cmp.state == [ComponentState.ERROR] 179 | assert cmp.state_message(ComponentState.ERROR) == "Exception - test" 180 | 181 | def test_stop_after_stop(self): 182 | class A(Component): 183 | async def forward(self): 184 | return ComponentState.STOPPED 185 | 186 | cmp = A("yuhu") 187 | pipeline = Pipeline() 188 | pipeline.add_nodes(cmp) 189 | pipeline.run(2) 190 | assert cmp.executions_counter == 1 191 | assert pipeline.min_iteration_in_progress == 1 192 | assert cmp.state == [ComponentState.STOPPED] 193 | assert cmp.state_message(ComponentState.STOPPED) is None 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Limbus: Computer Vision pipelining for PyTorch 2 | 3 | [![CI](https://github.com/kornia/limbus/actions/workflows/ci.yml/badge.svg)](https://github.com/kornia/limbus/actions/workflows/ci.yml) 4 | [![PyPI version](https://badge.fury.io/py/limbus.svg)](https://pypi.org/project/limbus) 5 | 6 | Similar to the eye [*corneal limbus*](https://en.wikipedia.org/wiki/Corneal_limbus) - **Limbus** is a framework to create Computer Vision pipelines within the context of Deep Learning and writen in terms of differentiable tensors message passing on top of Kornia and PyTorch. 7 | 8 | ## Overview 9 | 10 | You can create pipelines using `limbus.Component`s as follows: 11 | 12 | ```python 13 | # define your components 14 | c1 = Constant("c1", 1.) 15 | c2 = Constant("c2", torch.ones(1, 3)) 16 | add = Adder("add") 17 | show = Printer("print") 18 | 19 | # connect the components 20 | c1.outputs.out >> add.inputs.a 21 | c2.outputs.out >> add.inputs.b 22 | add.outputs.out >> show.inputs.inp 23 | 24 | # create the pipeline and add its nodes 25 | pipeline = Pipeline() 26 | pipeline.add_nodes([c1, c2, add, show]) 27 | 28 | # run your pipeline 29 | pipeline.run(1) 30 | 31 | torch.allclose(add.outputs.out.value, torch.ones(1, 3) * 2.) 32 | ``` 33 | 34 | Example using the `stack` torch method: 35 | 36 | ```python 37 | # define your components 38 | c1 = Constant("c1", 0) 39 | t1 = Constant("t1", torch.ones(1, 3)) 40 | t2 = Constant("t2", torch.ones(1, 3) * 2) 41 | stack = Stack("stack") 42 | show = Printer("print") 43 | 44 | # connect the components 45 | c1.outputs.out >> stack.inputs.dim 46 | t1.outputs.out >> stack.inputs.tensors.select(0) 47 | t2.outputs.out >> stack.inputs.tensors.select(1) 48 | stack.outputs.out >> show.inputs.inp 49 | 50 | # create the pipeline and add its nodes 51 | pipeline = Pipeline() 52 | pipeline.add_nodes([c1, t1, t2, stack, show]) 53 | 54 | # run your pipeline 55 | pipeline.run(1) 56 | 57 | torch.allclose(stack.outputs.out.value, torch.tensor([[1., 1., 1.],[2., 2., 2.]])) 58 | ``` 59 | 60 | Remember that the components can be run without the `Pipeline`, e.g in the last example you can also run: 61 | 62 | ```python 63 | asyncio.run(asyncio.gather(c1(), t1(), t2(), stack(), show())) 64 | ``` 65 | 66 | Basically, `Pipeline` objects allow you to control the execution flow, e.g. you can stop, pause, resume the execution, determine the number of executions to be run... 67 | 68 | A higher level API on top of `Pipeline` is `App` allowing to encapsulate some code. E.g.: 69 | 70 | ```python 71 | class MyApp(App): 72 | def create_components(self): 73 | self.c1 = Constant("c1", 0) 74 | self.t1 = Constant("t1", torch.ones(1, 3)) 75 | self.t2 = Constant("t2", torch.ones(1, 3) * 2) 76 | self.stack = stack("stack") 77 | self.show = Printer("print") 78 | 79 | def connect_components(self): 80 | self.c1.outputs.out >> self.stack.inputs.dim 81 | self.t1.outputs.out >> self.stack.inputs.tensors.select(0) 82 | self.t2.outputs.out >> self.stack.inputs.tensors.select(1) 83 | self.stack.outputs.out >> self.show.inputs.inp 84 | 85 | MyApp().run(1) 86 | ``` 87 | 88 | ## Component definition 89 | 90 | Creating your own components is pretty easy, you just need to inherit from `limbus.Component` and implement some methods (see some examples in `examples/defining_cmps.py`). 91 | 92 | The `Component` class has the next main methods: 93 | - `__init__`: where you can add class parameters to your component. 94 | - `register_inputs`: where you need to declare the input pins of your component. 95 | - `register_outputs`: where you need to declare the output pins of your component. 96 | - `register_properties`: where you can declare properties that can be changed during the execution. 97 | - `forward`: where you must define the logic of your component (mandatory). 98 | 99 | For a detailed list of `Component` methods and attributes, please check `limbus/core/component.py`. 100 | 101 | **Note** that if you want intellisense (at least in `VSCode` you will need to define the `input` and `output` types). 102 | 103 | Let's see a very simple example that sums 2 integers: 104 | 105 | ```python 106 | class Add(Component): 107 | """Add two numbers.""" 108 | # NOTE: type definition is optional, but it helps with the intellisense. ;) 109 | class InputsTyping(InputParams): 110 | a: InputParam 111 | b: InputParam 112 | 113 | class OutputsTyping(OutputParams): 114 | out: OutputParam 115 | 116 | inputs: InputsTyping 117 | outputs: OutputsTyping 118 | 119 | @staticmethod 120 | def register_inputs(inputs: InputParams) -> None: 121 | # Here you need to declare the input parameters and their default values (if they have). 122 | inputs.declare("a", int) 123 | inputs.declare("b", int) 124 | 125 | @staticmethod 126 | def register_outputs(outputs: OutputParams) -> None: 127 | # Here you need to declare the output parameters. 128 | outputs.declare("out", int) 129 | 130 | async def forward(self) -> ComponentState: 131 | # Here you must to define the logic of your component. 132 | a, b = await asyncio.gather( 133 | self.inputs.a.receive(), 134 | self.inputs.b.receive() 135 | ) 136 | await self.outputs.out.send(a + b) 137 | return ComponentState.OK 138 | ``` 139 | 140 | **Note** that `Component` can inherint from `nn.Module`. By default inherints from `object`. 141 | 142 | To change the inheritance, before importing any other `limbus` module, set the `COMPONENT_TYPE` variable as: 143 | 144 | ```python 145 | from limbus_config import config 146 | config.COMPONENT_TYPE = "torch" 147 | ``` 148 | 149 | ## Ecosystem 150 | 151 | Limbus is a core technology to easily build different components and create generic pipelines. In the following list, you can find different examples 152 | about how to use Limbus with some first/third party projects containing components: 153 | 154 | - Official examples: 155 | - Basic pipeline generation: https://github.com/kornia/limbus/blob/main/examples/default_cmps.py 156 | - Define custom components: https://github.com/kornia/limbus/blob/main/examples/defining_cmps.py 157 | - Create a web camera application: https://github.com/kornia/limbus/blob/main/examples/defining_cmps.py 158 | - Official repository with a set of basic components: https://github.com/kornia/limbus-components 159 | - Example combining limbus and the farm-ng Amiga: https://github.com/edgarriba/amiga-limbus-examples 160 | - Example implementing a Kornia face detection pipeline: https://github.com/edgarriba/limbus-face-detector 161 | 162 | ## Installation 163 | 164 | ### from PyPI: 165 | ```bash 166 | pip install limbus # limbus alone 167 | # or 168 | pip install limbus[components] # limbus + some predefined components 169 | ``` 170 | 171 | Note that to use widgets you need to install their dependencies: 172 | ```bash 173 | pip install limbus[widgets] 174 | ``` 175 | 176 | ### from the repository: 177 | 178 | ```bash 179 | pip install limbus@git+https://git@github.com/kornia/limbus.git # limbus alone 180 | # or 181 | pip install limbus[components]@git+https://git@github.com/kornia/limbus.git # limbus + some predefined components 182 | ``` 183 | 184 | ### for development 185 | 186 | you can install the environment with the following commands: 187 | 188 | ```bash 189 | git clone https://github.com/kornia/limbus 190 | cd limbus 191 | source path.bash.inc 192 | ``` 193 | 194 | In order to regenerate the development environment: 195 | ```bash 196 | cd limbus 197 | rm -rf .dev_env 198 | source path.bash.inc 199 | ``` 200 | 201 | ## Testing 202 | 203 | Run `pytest` and automatically will test: `cov`, `pydocstyle`, `mypy` and `flake8` 204 | -------------------------------------------------------------------------------- /tests/core/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import logging 3 | import asyncio 4 | 5 | import torch 6 | 7 | from limbus.core import ( 8 | Pipeline, PipelineState, VerboseMode, ComponentState, Component, OutputParams, OutputParam, InputParam) 9 | from limbus_components.base import Constant, Printer, Adder 10 | from limbus_components.torch import Unbind 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | 15 | # TODO: test in detail the functions 16 | @pytest.mark.usefixtures("event_loop_instance") 17 | class TestPipeline: 18 | def test_smoke(self): 19 | man = Pipeline() 20 | man is not None 21 | 22 | def test_pipeline(self): 23 | c1 = Constant("c1", 2 * torch.ones(1, 3)) 24 | c2 = Constant("c2", torch.ones(1, 3)) 25 | add = Adder("add") 26 | show = Printer("print") 27 | 28 | c1.outputs.out >> add.inputs.a 29 | c2.outputs.out >> add.inputs.b 30 | add.outputs.out >> show.inputs.inp 31 | 32 | pipeline = Pipeline() 33 | pipeline.add_nodes([c1, c2, add, show]) 34 | out = pipeline.run(1) 35 | assert isinstance(out, PipelineState) 36 | 37 | torch.allclose(add.outputs.out.value, torch.ones(1, 3) * 3.) 38 | 39 | def test_pipeline_simple_graph(self): 40 | c1 = Constant("c1", torch.rand(2, 3)) 41 | show0 = Printer("print0") 42 | c1.outputs.out.connect(show0.inputs.inp) 43 | pipeline = Pipeline() 44 | pipeline.add_nodes([c1, show0]) 45 | out = pipeline.run(1) 46 | assert isinstance(out, PipelineState) 47 | 48 | def test_pipeline_disconnected_components(self): 49 | c1 = Constant("c1", torch.rand(2, 3)) 50 | show0 = Printer("print0") 51 | c1.outputs.out.connect(show0.inputs.inp) 52 | c1.outputs.out.disconnect(show0.inputs.inp) 53 | pipeline = Pipeline() 54 | pipeline.add_nodes([c1, show0]) 55 | out = pipeline.run(1) 56 | assert isinstance(out, PipelineState) 57 | 58 | def test_pipeline_iterable(self): 59 | c1 = Constant("c1", torch.rand(2, 3)) 60 | c2 = Constant("c2", 0) 61 | unbind = Unbind("unbind") 62 | show0 = Printer("print0") 63 | c1.outputs.out.connect(unbind.inputs.input) 64 | c2.outputs.out.connect(unbind.inputs.dim) 65 | unbind.outputs.out.select(0).connect(show0.inputs.inp) 66 | pipeline = Pipeline() 67 | pipeline.add_nodes([c1, c2, unbind, show0]) 68 | out = pipeline.run(1) 69 | assert isinstance(out, PipelineState) 70 | 71 | def test_pipeline_counter(self): 72 | c1 = Constant("c1", torch.rand(2, 3)) 73 | show0 = Printer("print0") 74 | c1.outputs.out.connect(show0.inputs.inp) 75 | pipeline = Pipeline() 76 | pipeline.add_nodes([c1, show0]) 77 | out = pipeline.run(2) 78 | assert isinstance(out, PipelineState) 79 | assert pipeline.min_iteration_in_progress == 2 80 | 81 | async def test_pipeline_flow(self): 82 | c1 = Constant("c1", torch.rand(2, 3)) 83 | show0 = Printer("print0") 84 | c1.outputs.out.connect(show0.inputs.inp) 85 | pipeline = Pipeline() 86 | pipeline.add_nodes([c1, show0]) 87 | # tests before running 88 | assert pipeline._resume_event.is_set() is False 89 | pipeline.pause() 90 | assert pipeline._resume_event.is_set() is False 91 | pipeline.resume() 92 | assert pipeline._resume_event.is_set() is True 93 | pipeline.pause() 94 | 95 | # tests while running 96 | async def task(): 97 | t = asyncio.create_task(pipeline.async_run()) 98 | # wait for the pipeline to start (requires at least 2 iterations) 99 | await asyncio.sleep(0) 100 | assert pipeline._resume_event.is_set() is True 101 | assert pipeline._state.state == PipelineState.RUNNING 102 | pipeline.pause() 103 | assert pipeline._state.state == PipelineState.PAUSED 104 | assert pipeline._resume_event.is_set() is False 105 | assert pipeline._stop_event.is_set() is False 106 | await asyncio.sleep(0) 107 | pipeline.resume() 108 | assert pipeline._resume_event.is_set() is True 109 | # add some awaits to allow the pipeline to execute some components 110 | await asyncio.sleep(0) 111 | await asyncio.sleep(0) 112 | await asyncio.sleep(0) 113 | await asyncio.sleep(0) 114 | pipeline.stop() 115 | assert pipeline._resume_event.is_set() is True 116 | assert pipeline._stop_event.is_set() is True 117 | assert pipeline.state == PipelineState.FORCED_STOP 118 | await asyncio.gather(t) 119 | assert len(c1.state) > 1 120 | assert ComponentState.FORCED_STOP in c1.state 121 | # Could be up to 3 states: 122 | # ComponentState.STOPPED_BY_COMPONENT 123 | # ComponentState.FORCED_STOP 124 | # E.g. ComponentState.OK 125 | assert c1.state_message(ComponentState.FORCED_STOP) is None 126 | assert len(show0.state) > 1 127 | assert ComponentState.FORCED_STOP in show0.state 128 | assert show0.state_message(ComponentState.FORCED_STOP) is None 129 | await task() 130 | assert pipeline.min_iteration_in_progress > 0 131 | assert pipeline.min_iteration_in_progress < 5 132 | 133 | def test_pipeline_verbose(self): 134 | c1 = Constant("c1", torch.rand(2, 3)) 135 | show0 = Printer("print0") 136 | c1.outputs.out.connect(show0.inputs.inp) 137 | pipeline = Pipeline() 138 | pipeline.add_nodes([c1, show0]) 139 | assert pipeline._state.verbose == VerboseMode.DISABLED 140 | assert c1.verbose is False 141 | assert show0.verbose is False 142 | pipeline.set_verbose_mode(VerboseMode.COMPONENT) 143 | assert pipeline._state.verbose == VerboseMode.COMPONENT 144 | assert c1.verbose is True 145 | assert show0.verbose is True 146 | pipeline.set_verbose_mode(VerboseMode.PIPELINE) 147 | assert pipeline._state.verbose == VerboseMode.PIPELINE 148 | assert c1.verbose is False 149 | assert show0.verbose is False 150 | 151 | def my_testing_pipeline(self): 152 | class C(Component): 153 | @staticmethod 154 | def register_outputs(outputs: OutputParams) -> None: # noqa: D102 155 | outputs.declare("out", int, arg="value") 156 | 157 | async def forward(self) -> ComponentState: # noqa: D102 158 | if self.executions_counter == 2: 159 | return ComponentState.STOPPED 160 | await self._outputs.out.send(1) 161 | return ComponentState.OK 162 | 163 | c1 = C("c1") 164 | show0 = Printer("print0") 165 | c1.outputs.out.connect(show0.inputs.inp) 166 | pipeline = Pipeline() 167 | pipeline.add_nodes([c1, show0]) 168 | return pipeline 169 | 170 | def test_before_pipeline_user_hook(self, caplog): 171 | async def pipeline_hook(state: PipelineState): 172 | log.info(f"state: {state}") 173 | pipeline = self.my_testing_pipeline() 174 | pipeline.set_before_pipeline_user_hook(pipeline_hook) 175 | with caplog.at_level(logging.INFO): 176 | pipeline.run(1) 177 | assert "state: PipelineState.STARTED" in caplog.text 178 | 179 | def test_after_pipeline_user_hook(self, caplog): 180 | async def pipeline_hook(state: PipelineState): 181 | log.info(f"state: {state}") 182 | pipeline = self.my_testing_pipeline() 183 | pipeline.set_after_pipeline_user_hook(pipeline_hook) 184 | with caplog.at_level(logging.INFO): 185 | pipeline.run(3) 186 | assert "state: PipelineState.ENDED" in caplog.text 187 | 188 | def test_before_iteration_user_hook(self, caplog): 189 | async def iteration_hook(iter: int, state: PipelineState): 190 | log.info(f"iteration: {iter} ({state})") 191 | pipeline = self.my_testing_pipeline() 192 | pipeline.set_before_iteration_user_hook(iteration_hook) 193 | with caplog.at_level(logging.INFO): 194 | pipeline.run(1) 195 | assert "iteration: 1 (PipelineState.RUNNING)" in caplog.text 196 | 197 | def test_after_iteration_user_hook(self, caplog): 198 | async def iteration_hook(state: PipelineState): 199 | log.info(f"state: {state}") 200 | pipeline = self.my_testing_pipeline() 201 | pipeline.set_after_iteration_user_hook(iteration_hook) 202 | with caplog.at_level(logging.INFO): 203 | pipeline.run(1) 204 | assert "state: PipelineState.RUNNING" in caplog.text 205 | 206 | def test_before_component_user_hook(self, caplog): 207 | async def component_hook(cmp: Component): 208 | log.info(f"before component: {cmp.name}") 209 | pipeline = self.my_testing_pipeline() 210 | pipeline.set_before_component_user_hook(component_hook) 211 | with caplog.at_level(logging.INFO): 212 | pipeline.run(1) 213 | assert "before component" in caplog.text 214 | 215 | def test_after_component_user_hook(self, caplog): 216 | async def component_hook(cmp: Component): 217 | log.info(f"after component: {cmp.name}") 218 | pipeline = self.my_testing_pipeline() 219 | pipeline.set_after_component_user_hook(component_hook) 220 | with caplog.at_level(logging.INFO): 221 | pipeline.run(1) 222 | assert "after component" in caplog.text 223 | 224 | def test_param_received_user_hook(self, caplog): 225 | async def param_hook(param: InputParam): 226 | log.info(f"param: {param.name}") 227 | pipeline = self.my_testing_pipeline() 228 | pipeline.set_param_received_user_hook(param_hook) 229 | with caplog.at_level(logging.INFO): 230 | pipeline.run(1) 231 | assert "param: inp" in caplog.text 232 | 233 | def test_param_sent_user_hook(self, caplog): 234 | async def param_hook(param: OutputParam): 235 | log.info(f"param: {param.name}") 236 | pipeline = self.my_testing_pipeline() 237 | pipeline.set_param_sent_user_hook(param_hook) 238 | with caplog.at_level(logging.INFO): 239 | pipeline.run(1) 240 | assert "param: out" in caplog.text 241 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /limbus/core/component.py: -------------------------------------------------------------------------------- 1 | """Component definition.""" 2 | from __future__ import annotations 3 | from abc import abstractmethod 4 | from functools import partial 5 | from typing import TYPE_CHECKING, Callable, Type, Any, Coroutine 6 | import logging 7 | import asyncio 8 | import traceback 9 | import functools 10 | 11 | try: 12 | import torch.nn as nn 13 | except ImportError: 14 | pass 15 | 16 | from limbus_config import config 17 | from limbus.core.params import InputParams, OutputParams, PropertyParams 18 | from limbus.core.states import ComponentState, ComponentStoppedError 19 | # Note that Pipeline class cannot be imported to avoid circular dependencies. 20 | if TYPE_CHECKING: 21 | from limbus.core.pipeline import Pipeline 22 | 23 | log = logging.getLogger(__name__) 24 | 25 | 26 | base_class: Type = object 27 | if config.COMPONENT_TYPE == "generic": 28 | pass 29 | elif config.COMPONENT_TYPE == "torch": 30 | try: 31 | base_class = nn.Module 32 | except NameError: 33 | log.error("Torch not installed. Using generic base class.") 34 | else: 35 | log.error("Invalid component type. Using generic base class.") 36 | 37 | 38 | # this is a decorator that will determine how many iterations must be run 39 | def executions_manager(func: Callable) -> Callable: 40 | """Update the last execution to be run by the component.""" 41 | @functools.wraps(func) 42 | async def wrapper_set_iteration(self, *args, **kwargs): 43 | if self.pipeline is not None: 44 | self.stopping_execution = self.pipeline.get_component_stopping_iteration(self) 45 | return await func(self, *args, **kwargs) 46 | return wrapper_set_iteration 47 | 48 | 49 | class _ComponentState(): 50 | """Manage the state of the component. 51 | 52 | Note that the state can be multiple. 53 | The user interactions are the ones allowing simultaneous states, concretelly: 54 | - ComponentState.STOPPED_AT_ITER 55 | - ComponentState.STOPPED_BY_COMPONENT (it is generated by the STOPPED_AT_ITER state in other components) 56 | - ComponentState.FORCED_STOP 57 | E.g., a component can be properly executed and at the same time stopped by the user. 58 | 59 | Args: 60 | component (Component): component to manage. 61 | state (ComponentState): initial state. 62 | verbose (bool, optional): verbose state. Default: False. 63 | 64 | """ 65 | def __init__(self, component: Component, state: ComponentState, verbose: bool = False): 66 | self._states: list[ComponentState] = [state] 67 | self._messages: dict[ComponentState, None | str] = {state: None} 68 | self._component: Component = component 69 | self._verbose: bool = verbose 70 | 71 | def __call__(self, state: None | ComponentState = None, msg: None | str = None, add: bool = False 72 | ) -> list[ComponentState]: 73 | """Set or add a new state for the component. 74 | 75 | If no args are passed, it returns the current state. 76 | 77 | Args: 78 | state (optional): state to set. Default: None. 79 | msg (optional): message to log. Default: None. 80 | add (optional): if True, the state is added to the list of states. Default: False. 81 | 82 | Returns: 83 | The state or states of the component. 84 | 85 | """ 86 | if state is not None: 87 | if add: 88 | self._states.append(state) 89 | self._messages[state] = msg 90 | else: 91 | self._states = [state] 92 | self._messages = {state: msg} 93 | self._logger() 94 | return self._states 95 | 96 | def _logger(self) -> None: 97 | """Log the message with the component name, iteration number and state.""" 98 | if self._verbose: 99 | num_states = len(self._states) 100 | for idx, state in enumerate(self._states): 101 | msg = self._messages.get(state, None) 102 | msg_str: str = f"" 103 | if num_states > 1: 104 | msg_str = f" {idx + 1}/{num_states}" 105 | msg_str = f"{msg_str} {self._component.name}({self._component.executions_counter}): {state.name}" 106 | if msg is not None: 107 | # concat the message 108 | msg_str = f"{msg_str} ({msg})" 109 | log.info(msg_str) 110 | 111 | def message(self, state: ComponentState) -> None | str: 112 | """Get the message associated to the state. If state is not found, returns None.""" 113 | return self._messages.get(state, None) 114 | 115 | @property 116 | def state(self) -> list[ComponentState]: 117 | """Get the state of the component.""" 118 | return self._states 119 | 120 | @property 121 | def verbose(self) -> bool: 122 | """Get the verbose state.""" 123 | return self._verbose 124 | 125 | @verbose.setter 126 | def verbose(self, value: bool) -> None: 127 | """Set the verbose state.""" 128 | self._verbose = value 129 | 130 | 131 | class Component(base_class): 132 | """Base class to define a Limbus Component. 133 | 134 | Args: 135 | name (str): component name. 136 | 137 | """ 138 | 139 | def __init__(self, name: str): 140 | super().__init__() 141 | self._name = name 142 | self._inputs = InputParams(self) 143 | self.__class__.register_inputs(self._inputs) 144 | self._outputs = OutputParams(self) 145 | self.__class__.register_outputs(self._outputs) 146 | self._properties = PropertyParams(self) 147 | self.__class__.register_properties(self._properties) 148 | self.__state: _ComponentState = _ComponentState(self, ComponentState.INITIALIZED) 149 | self.__pipeline: None | Pipeline = None 150 | self.__exec_counter: int = 0 # Counter of executions. 151 | # Last execution to be run in the __call__ loop. 152 | self.__stopping_execution: int = 0 # 0 means run forever 153 | self.__num_params_waiting_to_receive: int = 0 # updated from InputParam 154 | 155 | # method called in __run_with_hooks to execute the component forward method 156 | self.__run_forward: Callable[..., Coroutine[Any, Any, ComponentState]] = self.forward 157 | try: 158 | if nn.Module in Component.__mro__: 159 | # If the component inherits from nn.Module, the forward method is called by the __call__ method 160 | self.__run_forward = partial(nn.Module.__call__, self) 161 | except NameError: 162 | pass 163 | 164 | def __del__(self): 165 | self.release() 166 | 167 | def release(self) -> None: 168 | """Event executed when the component ends its execution.""" 169 | pass 170 | 171 | def init_from_component(self, ref_component: Component) -> None: 172 | """Init basic execution params from another component. 173 | 174 | Args: 175 | ref_component: reference component. 176 | 177 | """ 178 | self.__pipeline = ref_component.__pipeline 179 | if self.__pipeline is not None: 180 | self.__pipeline.add_nodes(self) 181 | self.verbose = ref_component.verbose 182 | 183 | @property 184 | def executions_counter(self) -> int: 185 | """Get the executions counter.""" 186 | return self.__exec_counter 187 | 188 | @property 189 | def stopping_execution(self) -> int: 190 | """Get the last execution to be run by the component in the __call__ loop. 191 | 192 | Note that extra executions can be forced by other components to be able to run their executions. 193 | 194 | """ 195 | return self.__stopping_execution 196 | 197 | @stopping_execution.setter 198 | def stopping_execution(self, value: int) -> None: 199 | """Set the last execution to be run by the component in the __call__ loop. 200 | 201 | Note that extra executions can be forced by other components to be able to run their executions. 202 | 203 | """ 204 | self.__stopping_execution = value 205 | 206 | @property 207 | def state(self) -> list[ComponentState]: 208 | """Get the current state/s of the component.""" 209 | return self.__state.state 210 | 211 | def state_message(self, state: ComponentState) -> None | str: 212 | """Get the message associated a given current state of the component.""" 213 | return self.__state.message(state) 214 | 215 | def set_state(self, state: ComponentState, msg: None | str = None, add: bool = False) -> None: 216 | """Set the state of the component. 217 | 218 | Args: 219 | state: state to set. 220 | msg (optional): message to log. Default: None. 221 | add (optional): if True, the state is added to the list of states. Default: False. 222 | 223 | """ 224 | self.__state(state, msg, add) 225 | 226 | @property 227 | def verbose(self) -> bool: 228 | """Get the verbose state.""" 229 | return self.__state.verbose 230 | 231 | @verbose.setter 232 | def verbose(self, value: bool) -> None: 233 | """Set the verbose state.""" 234 | self.__state.verbose = value 235 | 236 | @property 237 | def name(self) -> str: 238 | """Name of the component.""" 239 | return self._name 240 | 241 | @property 242 | def inputs(self) -> InputParams: 243 | """Get the set of component inputs.""" 244 | return self._inputs 245 | 246 | @property 247 | def outputs(self) -> OutputParams: 248 | """Get the set of component outputs.""" 249 | return self._outputs 250 | 251 | @property 252 | def properties(self) -> PropertyParams: 253 | """Get the set of properties for this component.""" 254 | return self._properties 255 | 256 | @staticmethod 257 | def register_inputs(inputs: InputParams) -> None: 258 | """Register the input params. 259 | 260 | Args: 261 | inputs: object to register the inputs. 262 | 263 | """ 264 | pass 265 | 266 | @staticmethod 267 | def register_outputs(outputs: OutputParams) -> None: 268 | """Register the output params. 269 | 270 | Args: 271 | outputs: object to register the outputs. 272 | 273 | """ 274 | pass 275 | 276 | @staticmethod 277 | def register_properties(properties: PropertyParams) -> None: 278 | """Register the properties. 279 | 280 | These params are optional. 281 | 282 | Args: 283 | properties: object to register the properties. 284 | 285 | """ 286 | pass 287 | 288 | @property 289 | def pipeline(self) -> None | Pipeline: 290 | """Get the pipeline object.""" 291 | return self.__pipeline 292 | 293 | def set_pipeline(self, pipeline: None | Pipeline) -> None: 294 | """Set the pipeline running the component.""" 295 | self.__pipeline = pipeline 296 | 297 | def __stop_component(self) -> None: 298 | """Prepare the component to be stopped.""" 299 | for input in self._inputs.get_params(): 300 | for ref in self._inputs[input].references: 301 | assert ref.sent is not None 302 | assert ref.consumed is not None 303 | # unblock the events 304 | ref.sent.set() 305 | ref.consumed.set() 306 | for output in self._outputs.get_params(): 307 | for ref in self._outputs[output].references: 308 | assert ref.sent is not None 309 | assert ref.consumed is not None 310 | # unblock the events 311 | ref.sent.set() 312 | ref.consumed.set() 313 | 314 | @executions_manager 315 | async def __call__(self) -> None: 316 | """Execute the forward method. 317 | 318 | If the component is executed in a pipeline, the component runs forever. However, 319 | if the component is run alone it will run only once. 320 | 321 | NOTE 1: If you want to use `async for...` instead of `while True` this method must be overridden. 322 | E.g.: 323 | async for x in xyz: 324 | if await self.__run_with_hooks(x): 325 | break 326 | 327 | Note that in this example the forward method will require 1 parameter. 328 | 329 | NOTE 2: if you override this method you must add the `executions_manager` decorator. 330 | 331 | """ 332 | while True: 333 | if await self.__run_with_hooks(): 334 | break 335 | 336 | def is_stopped(self) -> bool: 337 | """Check if the component is stopped or is going to be stopped.""" 338 | if len(set(self.state).intersection(set([ComponentState.STOPPED, ComponentState.STOPPED_AT_ITER, 339 | ComponentState.ERROR, ComponentState.FORCED_STOP, 340 | ComponentState.STOPPED_BY_COMPONENT]))) > 0: 341 | return True 342 | return False 343 | 344 | def __stop_if_needed(self) -> bool: 345 | """Stop the component if it is required.""" 346 | if self.is_stopped(): 347 | if ComponentState.STOPPED_AT_ITER not in self.state: 348 | # in this case we need to force the stop of the component. When it is stopped at a given iter 349 | # the pipeline ends without forcing anything. 350 | self.__stop_component() 351 | return True 352 | return False 353 | 354 | async def __run_with_hooks(self, *args, **kwargs) -> bool: 355 | self.__exec_counter += 1 356 | if self.__pipeline is not None: 357 | await self.__pipeline.before_component_hook(self) 358 | if self.__pipeline.before_component_user_hook: 359 | await self.__pipeline.before_component_user_hook(self) 360 | if self.__stop_if_needed(): # just in case the component state is changed in the before_component_hook 361 | return True 362 | # run the component 363 | try: 364 | if len(self._inputs) == 0: 365 | # RUNNING state is set once the input params are received, if there are not inputs the state is set here 366 | self.set_state(ComponentState.RUNNING) 367 | self.set_state(await self.__run_forward(*args, **kwargs)) 368 | except ComponentStoppedError as e: 369 | self.set_state(e.state, e.message, add=True) 370 | except Exception as e: 371 | self.set_state(ComponentState.ERROR, f"{type(e).__name__} - {str(e)}") 372 | log.error(f"Error in component {self.name}.\n" 373 | f"{''.join(traceback.format_exception(None, e, e.__traceback__))}") 374 | if self.__pipeline is not None: 375 | # after component hook 376 | await self.__pipeline.after_component_hook(self) 377 | if self.__pipeline.after_component_user_hook: 378 | await self.__pipeline.after_component_user_hook(self) 379 | if self.__stop_if_needed(): 380 | return True 381 | return False 382 | # if there is not a pipeline, the component is executed only once 383 | return True 384 | 385 | @abstractmethod 386 | async def forward(self, *args, **kwargs) -> ComponentState: 387 | """Run the component, this method shouldn't be called, instead call __call__.""" 388 | raise NotImplementedError 389 | -------------------------------------------------------------------------------- /limbus/widgets/types.py: -------------------------------------------------------------------------------- 1 | """Module containing the visualization interfaces.""" 2 | from __future__ import annotations 3 | from abc import abstractmethod 4 | import math 5 | from typing import Callable, Any 6 | import logging 7 | import functools 8 | 9 | try: 10 | # NOTE: we import these modules here to avoid having it as a dependency 11 | # for the whole project. 12 | import cv2 13 | import visdom 14 | import torch 15 | import kornia 16 | import numpy as np 17 | except ImportError: 18 | pass 19 | 20 | 21 | from limbus.core import Component 22 | from limbus import widgets 23 | 24 | log = logging.getLogger(__name__) 25 | 26 | 27 | def _get_component_from_args(*args, **kwargs) -> Component: 28 | # NOTE: this is a hack to get the component from the args. We know that the first argument is the component in all 29 | # the methods. 30 | if len(args) > 0: 31 | return args[0] 32 | elif "component" in kwargs: 33 | return kwargs["component"] 34 | else: 35 | raise ValueError("No component found in args or kwargs.") 36 | 37 | 38 | def _get_title_from_args(*args, **kwargs) -> str: 39 | # NOTE: this is a hack to get the title from the args. We know that the second argument is the title in all the 40 | # methods. 41 | if len(args) > 1: 42 | return args[1] 43 | elif "title" in kwargs: 44 | return kwargs["title"] 45 | else: 46 | raise ValueError("No title found in args or kwargs.") 47 | 48 | 49 | def _set_title_in_args(title: str, args: tuple[Any, ...], kwargs: dict[Any, Any] 50 | ) -> tuple[tuple[Any, ...], dict[Any, Any]]: 51 | # NOTE: this is a hack to update the title from the args. We know that the second argument is the title in all the 52 | # methods. 53 | if len(args) > 1: 54 | new_args: list[Any] = list(args) 55 | new_args[1] = title 56 | return (tuple(new_args), kwargs) 57 | elif "title" in kwargs: 58 | kwargs.update({"title": title}) 59 | return (args, kwargs) 60 | 61 | 62 | # This is a decorator that will disable the method if the visualization is not enabled. 63 | def is_enabled(func: Callable) -> Callable: 64 | """Return None if viz is not enabled.""" 65 | @functools.wraps(func) 66 | def wrapper_check_component_disabled(self, *args, **kwargs) -> Any: 67 | vz = widgets.get(False) 68 | if vz is None or not vz.enabled: 69 | return None 70 | if not vz.force_viz and _get_component_from_args(*args, **kwargs).widget_state != widgets.WidgetState.ENABLED: 71 | return None 72 | return func(self, *args, **kwargs) 73 | return wrapper_check_component_disabled 74 | 75 | 76 | def set_title(func: Callable) -> Callable: 77 | """Set the title to be used in the viz if title param is empty.""" 78 | @functools.wraps(func) 79 | def wrapper_set_title(self, *args, **kwargs) -> Any: 80 | comp_name: str = _get_component_from_args(*args, **kwargs).name 81 | title: str = _get_title_from_args(*args, **kwargs) 82 | if title == "": 83 | title = comp_name 84 | args, kwargs = _set_title_in_args(title, args, kwargs) 85 | return func(self, *args, **kwargs) 86 | return wrapper_set_title 87 | 88 | 89 | class Viz: 90 | """Base class containing the method definitions for the visualization backends. 91 | 92 | IMPORTANT NOTE to create or add new visualization backends/methods: 93 | All the methods showing data should be decorated with @is_enabled and @set_title 94 | All the methods showing data must have as first argument the "component" and as second argument the "title". 95 | With those argument names since the decorators use them. 96 | 97 | """ 98 | def __init__(self) -> None: 99 | self._enabled: bool = False 100 | # by default the components control if they will viz or not. However, in some cases with this parameter we can 101 | # force to viz even if the viz is not enabled. 102 | self._force_viz: bool = False 103 | 104 | @property 105 | def enabled(self) -> bool: 106 | """Whether the visualization is enabled.""" 107 | return self._enabled 108 | 109 | @property 110 | def force_viz(self) -> bool: 111 | """Whether the visualization is forced for all the components.""" 112 | return self._force_viz 113 | 114 | @force_viz.setter 115 | def force_viz(self, force_viz: bool) -> None: 116 | """Force viz for all the components.""" 117 | self._force_viz = force_viz 118 | 119 | @abstractmethod 120 | def check_status(self) -> bool: 121 | """Check if the connection is alive and try to reconnect if connection is lost.""" 122 | raise NotImplementedError 123 | 124 | @abstractmethod 125 | def show_image(self, component: Component, title: str, image: "torch.Tensor"): 126 | """Show an image. 127 | 128 | Args: 129 | component: component that calls this method. 130 | title: Title of the window. 131 | image: Tensor with shape ([1, 3] x H x W) or (H x W). Values can be float in [0, 1] or uint8 in [0, 255]. 132 | 133 | """ 134 | raise NotImplementedError 135 | 136 | @abstractmethod 137 | def show_images(self, component: Component, title: str, 138 | images: "torch.Tensor" | list["torch.Tensor"], 139 | nrow: None | int = None 140 | ) -> None: 141 | """Show a batch of images. 142 | 143 | Args: 144 | component: component that calls this method. 145 | title: Title of the window. 146 | images: 4D Tensor with shape (B x [1, 3] x H x W) or a list of tensors with the same shape 147 | ([1, 3] x H x W) or (H x W). 148 | nrow (optional): Number of images in each row. Default: None -> sqrt(len(images)). 149 | 150 | """ 151 | raise NotImplementedError 152 | 153 | @abstractmethod 154 | def show_text(self, component: Component, title: str, text: str, append: bool = False): 155 | """Show text. 156 | 157 | Args: 158 | component: component that calls this method. 159 | title: Title of the window. 160 | text: Text to be displayed. 161 | append (optional): If True, the text is appended to the previous text. Default: False. 162 | 163 | """ 164 | raise NotImplementedError 165 | 166 | 167 | class Visdom(Viz): 168 | """Visdom visualization backend.""" 169 | VISDOM_PORT = 8097 170 | 171 | def __init__(self) -> None: 172 | super().__init__() 173 | try: 174 | import visdom 175 | except: 176 | raise ImportError("To use Visdom as backend install the widgets extras: " 177 | "pip install limbus[widgets]") 178 | self._vis: None | visdom.Visdom = None 179 | self._try_init() 180 | 181 | def _try_init(self) -> None: 182 | try: 183 | self._vis = visdom.Visdom(port=Visdom.VISDOM_PORT, raise_exceptions=True) 184 | self._enabled = True 185 | except: 186 | self._enabled = False 187 | 188 | if not self._enabled: 189 | log.warning("Visualization is disabled!!!") 190 | return 191 | 192 | assert self._vis is not None, "Visdom is not initialized." 193 | if not self._vis.check_connection(): 194 | self._enabled = False 195 | log.warning("Error connecting with the visdom server.") 196 | 197 | def check_status(self) -> bool: 198 | """Check if the connection is alive and try to reconnect if connection is lost.""" 199 | if self._vis is None: 200 | self._try_init() 201 | else: 202 | self._enabled = self._vis.check_connection() 203 | return self._enabled 204 | 205 | @is_enabled 206 | @set_title 207 | def show_image(self, component: Component, title: str, image: "torch.Tensor") -> None: 208 | """Show an image. 209 | 210 | Args: 211 | component: component that calls this method. 212 | title: Title of the window. 213 | image: Tensor with shape ([1, 3] x H x W) or (H x W). Values can be float in [0, 1] or uint8 in [0, 255]. 214 | 215 | """ 216 | opts = {"title": title} 217 | assert self._vis is not None, "Visdom is not initialized." 218 | self._vis.image(image, win=title, opts=opts) 219 | 220 | @is_enabled 221 | @set_title 222 | def show_images(self, component: Component, title: str, 223 | images: "torch.Tensor" | list["torch.Tensor"], 224 | nrow: None | int = None 225 | ) -> None: 226 | """Show a batch of images. 227 | 228 | Args: 229 | component: component that calls this method. 230 | title: Title of the window. 231 | images: 4D Tensor with shape (B x [1, 3] x H x W) or a list of tensors with the same shape 232 | ([1, 3] x H x W) or (H x W). 233 | nrow (optional): Number of images in each row. Default: None -> sqrt(len(images)). 234 | 235 | """ 236 | opts = {"title": title} 237 | if nrow is None: 238 | l: int = images.shape[0] if isinstance(images, torch.Tensor) else len(images) 239 | nrow = math.ceil(math.sqrt(l)) 240 | assert self._vis is not None, "Visdom is not initialized." 241 | self._vis.images(images, win=title, opts=opts, nrow=nrow) 242 | 243 | @is_enabled 244 | @set_title 245 | def show_text(self, component: Component, title: str, text: str, append: bool = False): 246 | """Show text. 247 | 248 | Args: 249 | component: component that calls this method. 250 | title: Title of the window. 251 | text: Text to be displayed. 252 | append (optional): If True, the text is appended to the previous text. Default: False. 253 | 254 | """ 255 | assert self._vis is not None, "Visdom is not initialized." 256 | opts = {"title": title} 257 | self._vis.text(text, win=title, append=append, opts=opts) 258 | 259 | 260 | class Console(Viz): 261 | """COnsole visualization backend.""" 262 | 263 | def __init__(self) -> None: 264 | super().__init__() 265 | self._enabled = True 266 | 267 | def check_status(self) -> bool: 268 | """Check if the connection is alive and try to reconnect if connection is lost.""" 269 | return self._enabled 270 | 271 | @is_enabled 272 | @set_title 273 | def show_image(self, component: Component, title: str, image: "torch.Tensor"): 274 | """Show an image. 275 | 276 | Args: 277 | component: component that calls this method. 278 | title: Title of the window. 279 | image: Tensor with shape ([1, 3] x H x W) or (H x W). Values can be float in [0, 1] or uint8 in [0, 255]. 280 | 281 | """ 282 | log.warning("Console visualization does not show images.") 283 | 284 | @is_enabled 285 | @set_title 286 | def show_images(self, component: Component, title: str, 287 | images: "torch.Tensor" | list["torch.Tensor"], 288 | nrow: None | int = None 289 | ) -> None: 290 | """Show a batch of images. 291 | 292 | Args: 293 | component: component that calls this method. 294 | title: Title of the window. 295 | images: 4D Tensor with shape (B x [1, 3] x H x W) or a list of tensors with the same shape 296 | ([1, 3] x H x W) or (H x W). 297 | nrow (optional): Number of images in each row. Default: None -> sqrt(len(images)). 298 | 299 | """ 300 | log.warning("Console visualization does not show images.") 301 | 302 | @is_enabled 303 | @set_title 304 | def show_text(self, component: Component, title: str, text: str, append: bool = False): 305 | """Show text. 306 | 307 | Args: 308 | component: component that calls this method. 309 | title: Title of the window. 310 | text: Text to be displayed. 311 | append (optional): If True, the text is appended to the previous text. Default: False. 312 | 313 | """ 314 | log.info(f" {title}: {text}") 315 | 316 | 317 | class OpenCV(Console): 318 | """Console visualization backend + openCV for images.""" 319 | 320 | def __init__(self) -> None: 321 | super().__init__() 322 | try: 323 | import cv2 324 | except: 325 | raise ImportError("To use OpenCV as backend install the widgets extras: " 326 | "pip install limbus[widgets]") 327 | 328 | @is_enabled 329 | @set_title 330 | def show_image(self, component: Component, title: str, image: "torch.Tensor"): 331 | """Show an image. 332 | 333 | Args: 334 | component: component that calls this method. 335 | title: Title of the window. 336 | image: Tensor with shape ([1, 3] x H x W) or (H x W). Values can be float in [0, 1] or uint8 in [0, 255]. 337 | 338 | """ 339 | # inspired by the image() function in visdom.__init__.py. 340 | # convert image type to uint8 [0, 255] 341 | if image.dtype in [torch.float, torch.float32, torch.float64]: 342 | if image.max() <= 1: 343 | image = image * 255.0 344 | image = image.byte() 345 | 346 | if image.ndim == 3 and image.shape[0] == 1: 347 | image = image.repeat(3, 1, 1) 348 | 349 | # convert image shape to 3xHxW 350 | if image.ndim == 2: 351 | image = image.unsqueeze(0) 352 | image = image.repeat(3, 1, 1) 353 | 354 | np_img: np.ndarray = kornia.tensor_to_image(image) 355 | cv2.imshow(title, cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)) 356 | cv2.waitKey(1) 357 | 358 | @is_enabled 359 | @set_title 360 | def show_images(self, component: Component, title: str, 361 | images: "torch.Tensor" | list["torch.Tensor"], 362 | nrow: None | int = None 363 | ) -> None: 364 | """Show a batch of images. 365 | 366 | Args: 367 | component: component that calls this method. 368 | title: Title of the window. 369 | images: 4D Tensor with shape (B x [1, 3] x H x W) or a list of tensors with the same shape 370 | ([1, 3] x H x W) or (H x W). 371 | nrow (optional): Number of images in each row. Default: None -> sqrt(len(images)). 372 | 373 | """ 374 | # inspired by the images() function in visdom.__init__.py. 375 | if isinstance(images, list): 376 | # NOTE that stack adds a dim to the result even if there was a single image. 377 | # so, if images is [(Hx W)] then the stack shape is (1 x H x W) 378 | images = torch.stack(images, 0) 379 | if images.ndim == 3: 380 | # if images are (H x W) convert tensor to (B x 1 x H x W) 381 | images = images.unsqueeze(1) 382 | else: 383 | # in this case we will assume a single image was passed. 384 | if images.ndim == 3: 385 | images = images.unsqueeze(0) 386 | 387 | # convert tensor shape from (B x C x H x W) to a grid with shape (C x H' x W') 388 | if nrow is None: 389 | nrow = math.ceil(math.sqrt(images.shape[0])) 390 | elif nrow > images.shape[0]: 391 | nrow = images.shape[0] 392 | ncol: int = math.ceil(images.shape[0] / nrow) 393 | padding: int = 4 394 | grid: torch.Tensor = torch.zeros((images.shape[1], 395 | ncol * (images.shape[2] + padding) - padding, 396 | nrow * (images.shape[3] + padding) - padding)).to(images) 397 | j: int = 0 398 | h: int = images.shape[2] + padding 399 | w: int = images.shape[3] + padding 400 | for idx in range(images.shape[0]): 401 | i = idx % nrow 402 | j = idx // nrow 403 | grid[:, j * h:j * h + images.shape[2], i * w:i * w + images.shape[3]] = images[idx] 404 | self.show_image(component, title, grid) 405 | -------------------------------------------------------------------------------- /limbus/core/pipeline.py: -------------------------------------------------------------------------------- 1 | """Components manager to connect, traverse and execute pipelines.""" 2 | from __future__ import annotations 3 | from typing import Coroutine, Any, Callable 4 | import logging 5 | import asyncio 6 | 7 | from limbus.core.component import Component, ComponentState 8 | from limbus.core.states import PipelineState, VerboseMode, IterationState 9 | from limbus.core import async_utils 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | log = logging.getLogger(__name__) 13 | 14 | 15 | class _PipelineState(): 16 | """Manage the state of the pipeline.""" 17 | def __init__(self, state: PipelineState, verbose: VerboseMode = VerboseMode.DISABLED): 18 | self._state: PipelineState = state 19 | self._verbose: VerboseMode = verbose 20 | 21 | def __call__(self, state: PipelineState, msg: None | str = None) -> None: 22 | """Set the state of the pipeline. 23 | 24 | Args: 25 | state: state to set. 26 | msg (optional): message to log. Default: None. 27 | 28 | """ 29 | self._state = state 30 | self._logger(self._state, msg) 31 | 32 | def _logger(self, state: PipelineState, msg: None | str) -> None: 33 | """Log the message with the pipeline state.""" 34 | if self._verbose != VerboseMode.DISABLED: 35 | if msg is None: 36 | log.info(f" {state.name}") 37 | else: 38 | log.info(f" {state.name}: {msg}") 39 | 40 | @property 41 | def state(self) -> PipelineState: 42 | """Get the state of the pipeline.""" 43 | return self._state 44 | 45 | @property 46 | def verbose(self) -> VerboseMode: 47 | """Get the verbose mode.""" 48 | return self._verbose 49 | 50 | @verbose.setter 51 | def verbose(self, value: VerboseMode) -> None: 52 | """Set the verbose mode.""" 53 | self._verbose = value 54 | 55 | 56 | class Pipeline: 57 | """Class to create and execute a pipeline of Limbus Components.""" 58 | def __init__(self) -> None: 59 | self._nodes: set[Component] = set() # note that it does not contain all the nodes 60 | self._resume_event: asyncio.Event = asyncio.Event() 61 | self._stop_event: asyncio.Event = asyncio.Event() 62 | self._state: _PipelineState = _PipelineState(PipelineState.INITIALIZING) 63 | # flag x component denoting if it was run in the iteration 64 | self._iteration_component_state: dict[Component, tuple[IterationState, int]] = {} 65 | # number of iterations executed in the pipeline (== component with less executions). 66 | self._min_iteration_in_progress: int = 0 67 | # Number of times each component will be run at least. 68 | # This feature should be mainly used for debugging purposes. It can make the processing a bit slower and 69 | # depending on the graph to be executed it can require to recreate tasks (e.g. when a given component requires 70 | # several runs from a previous one). 71 | self._min_number_of_iters_to_run: int = 0 # 0 means until the end of the pipeline 72 | # user defined hooks 73 | self._before_component_user_hook: None | Callable = None 74 | self._after_component_user_hook: None | Callable = None 75 | self._before_iteration_user_hook: None | Callable = None 76 | self._after_iteration_user_hook: None | Callable = None 77 | self._before_pipeline_user_hook: None | Callable = None 78 | self._after_pipeline_user_hook: None | Callable = None 79 | self._param_sent_user_hook: None | Callable = None 80 | self._param_sent_and_consumed_user_hook: None | Callable = None 81 | self._param_received_user_hook: None | Callable = None 82 | self._pipeline_updates_from_component_lock = asyncio.Lock() 83 | 84 | def set_param_received_user_hook(self, hook: None | Callable) -> None: 85 | """Set a hook to be executed once a parameter is received. 86 | 87 | This callable must have a single parameter which is the parameter received. 88 | Moreover it must be async. 89 | 90 | Prototype: async def hook_name(param: InputParam). 91 | """ 92 | self._param_received_user_hook = hook 93 | 94 | @property 95 | def param_received_user_hook(self) -> None | Callable: 96 | """Get the param received user hook.""" 97 | return self._param_received_user_hook 98 | 99 | def set_param_sent_user_hook(self, hook: None | Callable) -> None: 100 | """Set a hook to be executed once a parameter is sent. 101 | 102 | This callable must have a single parameter which is the parameter sent. 103 | Moreover it must be async. 104 | 105 | Prototype: async def hook_name(param: OutputParam). 106 | """ 107 | self._param_sent_user_hook = hook 108 | 109 | @property 110 | def param_sent_user_hook(self) -> None | Callable: 111 | """Get the param sent user hook.""" 112 | return self._param_sent_user_hook 113 | 114 | def set_before_pipeline_user_hook(self, hook: None | Callable) -> None: 115 | """Set a hook to be executed before the pipeline execution. 116 | 117 | This callable must have a single parameter which is the state of the pipeline at the begining of the pipeline. 118 | Moreover it must be async. 119 | 120 | Prototype: async def hook_name(state: PipelineState). 121 | """ 122 | self._before_pipeline_user_hook = hook 123 | 124 | @property 125 | def before_pipeline_user_hook(self) -> None | Callable: 126 | """Get the before pipeline user hook.""" 127 | return self._before_pipeline_user_hook 128 | 129 | def set_after_pipeline_user_hook(self, hook: None | Callable) -> None: 130 | """Set a hook to be executed after the pipeline execution. 131 | 132 | This callable must have a single parameter which is the state of the pipeline at the end of the pipeline. 133 | Moreover it must be async. 134 | 135 | Prototype: async def hook_name(state: PipelineState). 136 | """ 137 | self._after_pipeline_user_hook = hook 138 | 139 | @property 140 | def after_pipeline_user_hook(self) -> None | Callable: 141 | """Get the after pipeline user hook.""" 142 | return self._after_pipeline_user_hook 143 | 144 | def set_before_iteration_user_hook(self, hook: None | Callable) -> None: 145 | """Set a hook to be executed before each iteration. 146 | 147 | This callable must have a single parameter which is an int denoting the iter being executed. 148 | Moreover it must be async. 149 | 150 | Prototype: async def hook_name(counter: int, state: PipelineState). 151 | """ 152 | self._before_iteration_user_hook = hook 153 | 154 | @property 155 | def before_iteration_user_hook(self) -> None | Callable: 156 | """Get the before iteration user hook.""" 157 | return self._before_iteration_user_hook 158 | 159 | def set_after_iteration_user_hook(self, hook: None | Callable) -> None: 160 | """Set a hook to be executed after each iteration (next iter can be already in execution). 161 | 162 | This callable must have a single parameter which is the state of the pipeline at the end of the iteration. 163 | Moreover it must be async. 164 | 165 | Prototype: async def hook_name(state: PipelineState). 166 | """ 167 | self._after_iteration_user_hook = hook 168 | 169 | @property 170 | def after_iteration_user_hook(self) -> None | Callable: 171 | """Get the after iteration user hook.""" 172 | return self._after_iteration_user_hook 173 | 174 | def set_before_component_user_hook(self, hook: None | Callable) -> None: 175 | """Set a hook to be executed before each component. 176 | 177 | This callable must have a single parameter which is the component being executed. 178 | Moreover it must be async. 179 | 180 | Prototype: async def hook_name(obj: Componet). 181 | """ 182 | self._before_component_user_hook = hook 183 | 184 | @property 185 | def before_component_user_hook(self) -> None | Callable: 186 | """Get the before component user hook.""" 187 | return self._before_component_user_hook 188 | 189 | def set_after_component_user_hook(self, hook: None | Callable) -> None: 190 | """Set a hook to be executed after each component. 191 | 192 | This callable must have a single parameter which is the component being executed. 193 | Moreover it must be async. 194 | 195 | Prototype: async def hook_name(obj: Componet). 196 | """ 197 | self._after_component_user_hook = hook 198 | 199 | @property 200 | def after_component_user_hook(self) -> None | Callable: 201 | """Get the after component user hook.""" 202 | return self._after_component_user_hook 203 | 204 | def get_component_stopping_iteration(self, component: Component) -> int: 205 | """Compute the iteration where the __call__ loop of the component will be stopped. 206 | 207 | Args: 208 | component: component to be run. 209 | 210 | Returns: 211 | int denoting the iteration where the _call__ loop will be stopped. 212 | 0 means that it will run forever. 213 | 214 | """ 215 | if self._min_number_of_iters_to_run > 0: 216 | return component.executions_counter + self._min_number_of_iters_to_run 217 | return 0 218 | 219 | def _get_iteration_status(self) -> tuple[bool, bool | None]: 220 | """Get the status of the iteration. 221 | 222 | Returns: 223 | bool: True if all the components have finished the current iteration. 224 | bool | None: True if the next iteration has started. None if it is undertemined. 225 | Next iter status cannot be determined until the previous iter is finished. 226 | 227 | """ 228 | values = list(self._iteration_component_state.values()) 229 | # get the state of the iteration 230 | prev_iteration_status = [ 231 | # this cond is not always correct x component but it is correct for the pipeline. 232 | # state[1] can be > self._min_iteration_in_progress but remaining in the same iter however the min 233 | # exec_counter of all the components will be equal to the pipeline iter. 234 | # NOTE: this will not be true once we allow adding components to the pipeline during the execution. 235 | (state[0] == IterationState.COMPONENT_EXECUTED or state[1] > self._min_iteration_in_progress 236 | ) for state in values 237 | ] 238 | prev_status = sum(prev_iteration_status) == len(prev_iteration_status) 239 | if prev_status: 240 | next_iteration_status = [state[0] == IterationState.COMPONENT_IN_EXECUTION for state in values] 241 | return prev_status, sum(next_iteration_status) > 0 242 | return prev_status, None 243 | 244 | async def before_component_hook(self, component: Component) -> None: 245 | """Run before the execution of each component. 246 | 247 | Args: 248 | component: component to be executed. 249 | 250 | """ 251 | # just in case in the future several components run in parallel (not now) 252 | await self._pipeline_updates_from_component_lock.acquire() 253 | try: 254 | if not self._resume_event.is_set(): 255 | component.set_state(ComponentState.PAUSED) 256 | await self._resume_event.wait() 257 | component.set_state(ComponentState.READY) 258 | # state of the iteration 259 | is_prev_iter_finished, _ = self._get_iteration_status() 260 | 261 | # denote that this component is being executed in the current iteration 262 | self._iteration_component_state[component] = (IterationState.COMPONENT_IN_EXECUTION, 263 | component.executions_counter) 264 | if self._min_iteration_in_progress == 0: 265 | # denote that the first iteration is starting 266 | self._min_iteration_in_progress = 1 267 | 268 | if is_prev_iter_finished: 269 | # previous iteration has finished but the current one started before or just now 270 | self._min_iteration_in_progress += 1 271 | finally: 272 | self._pipeline_updates_from_component_lock.release() 273 | 274 | async def after_component_hook(self, component: Component) -> None: 275 | """Run after the execution of each component. 276 | 277 | Args: 278 | component: executed component. 279 | 280 | """ 281 | # just in case in the future several components run in parallel (not now) 282 | await self._pipeline_updates_from_component_lock.acquire() 283 | try: 284 | # determine when the component must be stopped 285 | # when the pipeline claims that it must be stopped... 286 | if self._stop_event.is_set(): 287 | component.set_state(ComponentState.FORCED_STOP, add=True) 288 | # denote that this component was already executed in the current iteration 289 | self._iteration_component_state[component] = (IterationState.COMPONENT_EXECUTED, 290 | component.executions_counter) 291 | 292 | # NEXT CODE is disabled because we cannot know when an iteration starts. 293 | # get the state of the iteration 294 | # is_prev_iter_finished, _ = self._get_iteration_status() 295 | # if is_prev_iter_finished: 296 | # if self._after_iteration_user_hook is not None: 297 | # # Since the last component being executed changes its state in this method the 298 | # # min iteration in progress is correct. 299 | # await self._after_iteration_user_hook(self.state, self._min_iteration_in_progress) 300 | 301 | # when the number of iters to run is reached... 302 | # NOTE: component could be stopped before finishing the number of iterations since execution != iteration. 303 | # In that case the other components will force rerunning this one to run the required iterations. 304 | if (not component.is_stopped() and 305 | self._min_number_of_iters_to_run != 0 and 306 | component.executions_counter >= component.stopping_execution): 307 | component.set_state(ComponentState.STOPPED_AT_ITER, add=True) 308 | finally: 309 | self._pipeline_updates_from_component_lock.release() 310 | 311 | @property 312 | def min_iteration_in_progress(self) -> int: 313 | """Get the number of the oldest iteration still being executed.""" 314 | return self._min_iteration_in_progress 315 | 316 | @property 317 | def state(self) -> PipelineState: 318 | """Get the state of the pipeline.""" 319 | return self._state.state 320 | 321 | def add_nodes(self, components: Component | list[Component]) -> None: 322 | """Add components to the pipeline. 323 | 324 | Note: At least one component per graph must be added to be able to run the pipeline. The pipeline will 325 | automatically add the nodes that are missing at the begining. 326 | 327 | Args: 328 | components: Component or list of components to be added. 329 | 330 | """ 331 | if isinstance(components, Component): 332 | components = [components] 333 | for component in components: 334 | self._nodes.add(component) 335 | 336 | def pause(self) -> None: 337 | """Pause the execution of the pipeline. 338 | 339 | Note: Components will be paused as soon as posible, if the pipeline is running will be done inmediatelly after 340 | sending the outputs. Some components waiting for inputs will remain in that state since the previous components 341 | can be paused. 342 | """ 343 | if self._resume_event.is_set(): 344 | self._state(PipelineState.PAUSED) 345 | self._resume_event.clear() 346 | 347 | def stop(self) -> None: 348 | """Force the stop of the pipeline.""" 349 | self.resume() # if the pipeline is paused it is blocked 350 | self._stop_event.set() # stop the forever loop inside each component 351 | self._state(PipelineState.FORCED_STOP) 352 | 353 | def resume(self) -> None: 354 | """Resume the execution of the pipeline.""" 355 | if not self._resume_event.is_set(): 356 | self._state(PipelineState.RUNNING) 357 | self._resume_event.set() 358 | 359 | def set_verbose_mode(self, state: VerboseMode) -> None: 360 | """Set the verbose mode. 361 | 362 | Args: 363 | state: verbose mode to be set. 364 | 365 | """ 366 | if self._state.verbose == state: 367 | return 368 | self._state.verbose = state 369 | for node in self._nodes: 370 | node.verbose = self._state.verbose == VerboseMode.COMPONENT 371 | 372 | async def async_run(self, iters: int = 0) -> PipelineState: 373 | """Run the components graph. 374 | 375 | Args: 376 | iters (optional): number of iters to be run. By default (0) all of them are run. 377 | 378 | Returns: 379 | PipelineState with the current pipeline status. 380 | 381 | """ 382 | self._iteration_component_state = {} 383 | self._stop_event.clear() 384 | 385 | async def start() -> None: 386 | tasks: list[Coroutine[Any, Any, None]] = [] 387 | for node in self._nodes: 388 | node.set_pipeline(self) 389 | tasks.append(node()) 390 | # set the initial state of the components if they are not already set 391 | if self._iteration_component_state.get(node, None) is None: 392 | self._iteration_component_state[node] = (IterationState.COMPONENT_NOT_EXECUTED, 0) 393 | await asyncio.gather(*tasks) 394 | # check if there are pending tasks 395 | pending_tasks: list = [] 396 | for node in self._nodes: 397 | t = async_utils.get_task_if_exists(node) 398 | if t is not None: 399 | pending_tasks.append(t) 400 | await asyncio.gather(*pending_tasks) 401 | 402 | # if it was previously run then the state is not changed to STARTED 403 | if self._state.state == PipelineState.INITIALIZING: 404 | self._state(PipelineState.STARTED) 405 | if len(self._nodes) == 0: 406 | self._state(PipelineState.EMPTY, "No components added to the pipeline") 407 | 408 | if self.before_pipeline_user_hook is not None: 409 | # even if the pipeline is empty we run the hook 410 | await self.before_pipeline_user_hook(self.state) 411 | 412 | if self._state.state == PipelineState.EMPTY: 413 | # if it is empty we do not run the pipeline 414 | return self._state.state 415 | 416 | # NOTE about limitting the number of iterations and using iteration hooks. 417 | # In order to achieve both we need to block asyncio execution. The selected mechanism is using a loop in this 418 | # method. This means that it is not efficient and this feature should be mainly used for debugging purposes. 419 | # It can make the processing a bit slower and depending on the graph to be executed it can require to recreate 420 | # tasks (e.g. when a given component requires several runs to finish one iteration). 421 | # Even if you do not limit the number of iterations but you want to run the pipeline forever using 422 | # iteration hooks then the iterations must be run independently to be able to know when each iteration starts 423 | # and ends. NOTE that after_iteration_user_hook() could be run in after_component_hook() but without control on 424 | # when the next iteration starts, so we disabled. 425 | # ATTENTION: We recommend to use this feature only for debugging!!! 426 | self._min_number_of_iters_to_run = 0 427 | # if there are hooks then iters must be run one by one forever 428 | if self._before_iteration_user_hook is not None or self._after_iteration_user_hook is not None: 429 | self._min_number_of_iters_to_run = 1 430 | 431 | # If there are no hooks but there is a limit in the number of iters we only set 1 iters but running all the 432 | # required iters. 433 | if self._min_number_of_iters_to_run == 0: 434 | self._min_number_of_iters_to_run = iters 435 | iters = 1 436 | 437 | # run the pipeline as independent iterations. The loop is only run ince if there are no hooks. 438 | self.resume() # change the state to running 439 | forever = iters == 0 440 | while forever or iters > 0: # run until the pipeline is completed or there are no iters to run 441 | iters -= 1 if not forever else 0 442 | if self._before_iteration_user_hook is not None: 443 | # If there are iteration hooks the min_iteration_in_progress is the last iteration that 444 | # was run, so we need to add 1 to get the next iteration. 445 | await self._before_iteration_user_hook(self._min_iteration_in_progress + 1, self.state) 446 | await start() 447 | if self._after_iteration_user_hook is not None: 448 | await self._after_iteration_user_hook(self.state) 449 | 450 | states = [] 451 | for component in self._nodes: 452 | states.extend(component.state) 453 | if ComponentState.STOPPED in states: 454 | self._state(PipelineState.ENDED) 455 | elif ComponentState.ERROR in states: 456 | self._state(PipelineState.ERROR) 457 | 458 | if self._state.state in [PipelineState.FORCED_STOP, PipelineState.ERROR, PipelineState.ENDED]: 459 | break 460 | 461 | if not forever and self._state.state == PipelineState.RUNNING: 462 | self.pause() 463 | 464 | if self._state.state in [PipelineState.FORCED_STOP, PipelineState.ERROR, PipelineState.ENDED]: 465 | for node in self._nodes: 466 | node.release() 467 | if self.after_pipeline_user_hook is not None: 468 | await self.after_pipeline_user_hook(self.state) 469 | return self._state.state 470 | 471 | def run(self, iters: int = 0) -> PipelineState: 472 | """Run the components graph. 473 | 474 | Args: 475 | iters (optional): number of iters to be run. By default (0) all of them are run. 476 | 477 | Returns: 478 | PipelineState with the current pipeline status. 479 | 480 | """ 481 | async_utils.run_coroutine(self.async_run(iters)) 482 | return self._state.state 483 | -------------------------------------------------------------------------------- /tests/core/test_param.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from typing import Any, List, Sequence, Iterable, Tuple 3 | import asyncio 4 | import logging 5 | 6 | import torch 7 | 8 | import limbus.core.param 9 | from limbus.core import NoValue, Component, ComponentState 10 | from limbus.core.param import (Container, Param, InputParam, OutputParam, PropertyParam, 11 | IterableContainer, IterableInputContainers, IterableParam, Reference) 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | 16 | class TestContainer: 17 | def test_smoke(self): 18 | v = Container(None) 19 | assert isinstance(v, Container) 20 | assert v.value is None 21 | 22 | 23 | class TestIterableContainer: 24 | def test_smoke(self): 25 | c = IterableContainer(Container(None), 0) 26 | assert isinstance(c, IterableContainer) 27 | assert isinstance(c.container, Container) 28 | assert c.index == 0 29 | assert c.container.value is None 30 | # None is not a valid value inside the container. This is not controlled by the container. 31 | with pytest.raises(TypeError): 32 | c.value 33 | 34 | def test_value(self): 35 | c = IterableContainer(Container([1, 2]), 0) 36 | assert c.value == 1 37 | assert c.index == 0 38 | assert c.container.value == [1, 2] 39 | 40 | def test_value_with_iterable(self): 41 | c = IterableContainer(IterableContainer(Container([1, 2]), 1), 0) 42 | assert c.value == 2 43 | assert c.index == 0 44 | assert c.container.value == 2 45 | 46 | 47 | class TestIterableInputContainers: 48 | def test_smoke(self): 49 | c = IterableInputContainers() 50 | assert isinstance(c, IterableInputContainers) 51 | assert c._containers == [] 52 | 53 | def test_init_with_container(self): 54 | ic = IterableContainer(Container(None), 0) 55 | c = IterableInputContainers(ic) 56 | assert c._containers[0] is ic 57 | assert len(c) == 1 58 | 59 | def test_add(self): 60 | ic0 = IterableContainer(Container(None), 0) 61 | ic1 = IterableContainer(Container(None), 1) 62 | c = IterableInputContainers() 63 | c.add(ic0) 64 | c.add(ic1) 65 | assert c._containers[0] is ic0 66 | assert c._containers[1] is ic1 67 | assert len(c) == 2 68 | 69 | def test_remove(self): 70 | ic = IterableContainer(Container(None), 0) 71 | c = IterableInputContainers() 72 | c.add(ic) 73 | assert len(c) == 1 74 | c.remove(0) 75 | assert len(c) == 0 76 | 77 | def test_remove_not_found(self): 78 | ic = IterableContainer(Container(None), 0) 79 | c = IterableInputContainers() 80 | c.add(ic) 81 | assert len(c) == 1 82 | c.remove(1) 83 | assert len(c) == 1 84 | 85 | def test_get_ordered(self): 86 | ic0 = IterableContainer(Container(1), 0) 87 | ic1 = IterableContainer(Container(2), 2) 88 | c = IterableInputContainers() 89 | c.add(ic1) 90 | c.add(ic0) 91 | 92 | res = c.get_ordered() 93 | assert len(res) == 2 94 | assert res[0] == 1 95 | assert res[1] == 2 96 | 97 | def test_get_ordered_iterable(self): 98 | ic0 = IterableContainer(IterableContainer(Container([1, 2]), 1), 0) 99 | ic1 = IterableContainer(Container(3), 1) 100 | c = IterableInputContainers() 101 | c.add(ic1) 102 | c.add(ic0) 103 | 104 | res = c.get_ordered() 105 | assert len(res) == 2 106 | assert res[0] == 2 107 | assert res[1] == 3 108 | 109 | 110 | @pytest.mark.usefixtures("event_loop_instance") 111 | class TestParam: 112 | def test_smoke(self): 113 | p = Param("a") 114 | assert isinstance(p, Param) 115 | assert p.name == "a" 116 | assert p.parent is None 117 | assert p.type is Any 118 | assert isinstance(p.value, NoValue) 119 | assert p.references == set() 120 | assert p.arg is None 121 | assert p._is_subscriptable is False 122 | assert p.is_subscriptable is False 123 | 124 | def test_subcriptability(self): 125 | p = Param("a", List[torch.Tensor], value=[torch.tensor(1), torch.tensor(1)]) 126 | assert p._is_subscriptable 127 | assert p.is_subscriptable 128 | p.set_as_non_subscriptable() 129 | assert not p._is_subscriptable 130 | p.reset_is_subscriptable() 131 | assert p._is_subscriptable 132 | 133 | def test_init_with_type(self): 134 | p = Param("a", tp=int) 135 | assert p.type is int 136 | assert isinstance(p.value, NoValue) 137 | 138 | def test_init_with_value(self): 139 | p = Param("a", tp=int, value=1) 140 | assert p.value == 1 141 | 142 | def test_init_with_invalid_value_raise_error(self): 143 | with pytest.raises(TypeError): 144 | Param("a", tp=int, value=1.) 145 | 146 | def test_init_with_arg(self): 147 | p = Param("a", arg="arg0") 148 | assert p.arg == "arg0" 149 | 150 | def test_set_value(self): 151 | p = Param("a", tp=int, value=1) 152 | p.value = 2 153 | assert p.value == 2 154 | with pytest.raises(TypeError): 155 | p.value = "a" 156 | 157 | def test_get_iterable_value(self): 158 | # only torch.Tensor can be iterable at this moment 159 | iter_container = IterableContainer(Container([torch.tensor(1), torch.tensor(2)]), 0) 160 | p = Param("a", value=iter_container) 161 | assert isinstance(p.container, Container) 162 | assert p.container.value is iter_container 163 | assert p.value == torch.tensor(1) 164 | 165 | def test_get_iterable_input_value(self): 166 | # only torch.Tensor can be iterable at this moment 167 | p = Param("a", tp=List[torch.Tensor]) 168 | p.container = IterableInputContainers(IterableContainer(Container(torch.tensor(1)), 1)) 169 | p.container.add(IterableContainer(Container(torch.tensor(2)), 0)) 170 | assert p.value == [torch.tensor(2), torch.tensor(1)] 171 | 172 | def test_select(self): 173 | p = Param("a", List[torch.Tensor], value=[torch.tensor(1), torch.tensor(1)]) 174 | assert p._is_subscriptable 175 | assert p.is_subscriptable 176 | iter_param = p.select(0) 177 | assert isinstance(iter_param, IterableParam) 178 | assert isinstance(iter_param.iter_container, IterableContainer) 179 | assert iter_param.iter_container.index == 0 180 | assert iter_param.iter_container.value == 1 181 | 182 | def test_subscriptable(self): 183 | p = Param("a", List[torch.Tensor], value=[torch.tensor(1), torch.tensor(1)]) 184 | assert p.is_subscriptable 185 | p.set_as_non_subscriptable() 186 | assert p.is_subscriptable is False 187 | p.reset_is_subscriptable() 188 | assert p.is_subscriptable 189 | 190 | def test_connect_iterparam_param_no_select_raise_error(self): 191 | p0 = Param("a", List[torch.Tensor]) 192 | p1 = Param("b") 193 | # TODO: this check is temporary disabled because we should allow connect iterparam with iterparam 194 | # with pytest.raises(ValueError): 195 | # # mandatory connect with an index 196 | # p0.connect(p1) 197 | p0.select(0).connect(p1) 198 | 199 | def test_connect_param_iterparam_no_select_raise_error(self): 200 | p0 = Param("a", tp=List[torch.Tensor]) 201 | p1 = Param("b", value=torch.Tensor(2)) 202 | # TODO: this check is temporary disabled because we should allow connect iterparam with iterparam 203 | # with pytest.raises(ValueError): 204 | # # mandatory connect with an index 205 | # p1.connect(p0) 206 | p1.connect(p0.select(0)) 207 | 208 | def test_connect_param_iterparam_no_valid_type_raise_error(self): 209 | p0 = Param("a", tp=List[torch.Tensor]) 210 | p1 = Param("b", value=2) 211 | with pytest.raises(TypeError): 212 | p1.connect(p0.select(0)) 213 | 214 | def test_connect_param_param_no_valid_type_raise_error(self): 215 | p0 = Param("a", tp=int) 216 | p1 = Param("b", value=2.) 217 | with pytest.raises(TypeError): 218 | p1.connect(p0) 219 | 220 | def test_connect_param_param(self): 221 | p0 = Param("a", value=1) 222 | p1 = Param("b") 223 | p0.connect(p1) 224 | assert isinstance(p1.container, Container) 225 | assert p1.value == 1 226 | assert p0.references == {Reference(p1, ori_param=p0)} 227 | assert list(p0._refs.keys()) == [None] 228 | assert list(p0._refs[None])[0] == Reference(p1, ori_param=p0) 229 | assert p1.references == {Reference(p0, ori_param=p1)} 230 | assert list(p1._refs.keys()) == [None] 231 | assert list(p1._refs[None])[0] == Reference(p0, ori_param=p1) 232 | 233 | def test_disconnect_param_param(self): 234 | p0 = Param("a", value=1) 235 | p1 = Param("b") 236 | p0.connect(p1) 237 | p0.disconnect(p1) 238 | assert isinstance(p1.container, Container) 239 | assert isinstance(p1.value, NoValue) 240 | assert p0.references == set() 241 | assert list(p0._refs.keys()) == [None] 242 | assert list(p0._refs[None]) == [] 243 | assert p1.references == set() 244 | assert list(p1._refs.keys()) == [None] 245 | assert list(p1._refs[None]) == [] 246 | 247 | def test_connect_disconnect_iterparam_param(self): 248 | p0 = Param("a", tp=List[torch.Tensor], value=[torch.tensor(1), torch.tensor(2)]) 249 | p1 = Param("b") 250 | p0.select(1).connect(p1) 251 | assert isinstance(p1.container, Container) 252 | assert p1.value == torch.tensor(2) 253 | assert list(p0._refs.keys()) == [1] 254 | assert list(p0._refs[1])[0] == Reference(p1, ori_param=p0, ori_index=1) 255 | assert list(p1._refs.keys()) == [None] 256 | assert list(p1._refs[None])[0] == Reference(p0, p1, 1) 257 | 258 | def test_disconnect_iterparam_param(self): 259 | p0 = Param("a", tp=List[torch.Tensor], value=[torch.tensor(1), torch.tensor(2)]) 260 | p1 = Param("b") 261 | p0.select(1).connect(p1) 262 | p0.select(1).disconnect(p1) 263 | assert isinstance(p1.container, Container) 264 | assert isinstance(p1.value, NoValue) 265 | assert list(p0._refs.keys()) == [1] 266 | assert list(p0._refs[1]) == [] 267 | assert list(p1._refs.keys()) == [None] 268 | assert list(p1._refs[None]) == [] 269 | 270 | def test_connect_param_iterparam(self): 271 | p0 = Param("a", tp=List[torch.Tensor]) 272 | p1 = Param("b", value=torch.tensor(1)) 273 | p2 = Param("c", value=torch.tensor(2)) 274 | p1.connect(p0.select(1)) 275 | p2.connect(p0.select(0)) 276 | assert isinstance(p0.container, IterableInputContainers) 277 | assert p0.value == [torch.tensor(2), torch.tensor(1)] 278 | assert sorted(list(p0._refs.keys())) == sorted([0, 1]) 279 | assert list(p0._refs[0])[0] == Reference(p2, ori_param=p0, ori_index=0) 280 | assert list(p0._refs[1])[0] == Reference(p1, ori_param=p0, ori_index=1) 281 | assert list(p1._refs.keys()) == [None] 282 | assert list(p1._refs[None])[0] == Reference(p0, p1, 1) 283 | assert list(p2._refs.keys()) == [None] 284 | assert list(p2._refs[None])[0] == Reference(p0, p2, 0) 285 | 286 | def test_disconnect_param_iterparam(self): 287 | p0 = Param("a", tp=List[torch.Tensor]) 288 | p1 = Param("b", value=torch.tensor(1)) 289 | p2 = Param("c", value=torch.tensor(2)) 290 | p1.connect(p0.select(1)) 291 | p2.connect(p0.select(0)) 292 | p1.disconnect(p0.select(1)) 293 | p2.disconnect(p0.select(0)) 294 | assert isinstance(p0.container, Container) 295 | assert isinstance(p0.value, NoValue) 296 | assert sorted(list(p0._refs.keys())) == sorted([0, 1]) 297 | assert list(p0._refs[0]) == [] 298 | assert list(p0._refs[1]) == [] 299 | assert list(p1._refs.keys()) == [None] 300 | assert list(p1._refs[None]) == [] 301 | assert list(p2._refs.keys()) == [None] 302 | assert list(p2._refs[None]) == [] 303 | 304 | def test_connect_iterparam_iterparam(self): 305 | p0 = Param("a", tp=List[torch.Tensor], value=[torch.tensor(1), torch.tensor(2)]) 306 | p1 = Param("b", tp=List[torch.Tensor]) 307 | p0.select(0).connect(p1.select(1)) 308 | p0.select(1).connect(p1.select(0)) 309 | assert isinstance(p1.container, IterableInputContainers) 310 | assert p1.value == [torch.tensor(2), torch.tensor(1)] 311 | assert sorted(list(p0._refs.keys())) == sorted([0, 1]) 312 | assert list(p0._refs[0])[0] == Reference(p1, p0, 1, 0) 313 | assert list(p0._refs[1])[0] == Reference(p1, p0, 0, 1) 314 | assert sorted(list(p1._refs.keys())) == sorted([0, 1]) 315 | assert list(p1._refs[0])[0] == Reference(p0, p1, 1, 0) 316 | assert list(p1._refs[1])[0] == Reference(p0, p1, 0, 1) 317 | 318 | def test_disconnect_iterparam_iterparam(self): 319 | p0 = Param("a", tp=List[torch.Tensor], value=[torch.tensor(1), torch.tensor(2)]) 320 | p1 = Param("b", tp=List[torch.Tensor]) 321 | p0.select(0).connect(p1.select(1)) 322 | p0.select(1).connect(p1.select(0)) 323 | p0.select(0).disconnect(p1.select(1)) 324 | p0.select(1).disconnect(p1.select(0)) 325 | assert isinstance(p1.container, Container) 326 | assert isinstance(p1.value, NoValue) 327 | assert sorted(list(p0._refs.keys())) == sorted([0, 1]) 328 | assert list(p0._refs[0]) == [] 329 | assert list(p0._refs[1]) == [] 330 | assert sorted(list(p1._refs.keys())) == sorted([0, 1]) 331 | assert list(p1._refs[0]) == [] 332 | assert list(p1._refs[1]) == [] 333 | 334 | def test_ref_count(self): 335 | p0 = Param("a") 336 | p1 = Param("b") 337 | p2 = Param("c") 338 | p0.connect(p1) 339 | p0.connect(p2) 340 | assert p1.ref_counter(None) == 1 341 | assert p0.ref_counter(None) == 2 342 | assert p2.ref_counter(None) == 1 343 | 344 | def test_ref_counter(self): 345 | p0 = Param("a") 346 | p1 = Param("b") 347 | p2 = Param("c") 348 | p0.connect(p1) 349 | p0.connect(p2) 350 | assert p1.ref_counter() == 1 351 | assert p0.ref_counter() == 2 352 | assert p2.ref_counter() == 1 353 | assert p1.ref_counter(None) == 1 354 | assert p0.ref_counter(None) == 2 355 | assert p2.ref_counter(None) == 1 356 | 357 | def test_ref_counter_iterable(self): 358 | p0 = Param("a", List[torch.Tensor], value=[torch.tensor(1), torch.tensor(2)]) 359 | p1 = Param("b") 360 | p2 = Param("c") 361 | p0.select(0).connect(p1) 362 | p0.select(1).connect(p2) 363 | assert p1.ref_counter() == 1 364 | assert p2.ref_counter() == 1 365 | assert p0.ref_counter() == 2 366 | assert p0.ref_counter(0) == 1 367 | assert p0.ref_counter(1) == 1 368 | 369 | 370 | @pytest.mark.usefixtures("event_loop_instance") 371 | class TestIterableParam: 372 | def test_smoke(self): 373 | p = Param("a", tp=List[int], value=[1, 2]) 374 | ip = IterableParam(p, 0) 375 | assert ip.param is p 376 | assert ip.index == 0 377 | assert ip.value == 1 378 | assert isinstance(ip.iter_container, IterableContainer) 379 | assert ip.iter_container.index == 0 380 | assert ip.iter_container.value == 1 381 | 382 | def test_ref_counter(self): 383 | p0 = Param("a", tp=List[int], value=[1, 2]) 384 | ip = IterableParam(p0, 0) 385 | assert ip.ref_counter() == 0 386 | p1 = Param("b", tp=int) 387 | ip.connect(p1) 388 | assert ip.ref_counter() == 1 389 | 390 | 391 | def test_check_subscriptable(): 392 | # only torch.Tensor is subscriptable 393 | assert not limbus.core.param._check_subscriptable(Sequence[int]) 394 | assert not limbus.core.param._check_subscriptable(Iterable[int]) 395 | assert not limbus.core.param._check_subscriptable(List[int]) 396 | assert not limbus.core.param._check_subscriptable(Tuple[int]) 397 | assert limbus.core.param._check_subscriptable(Sequence[torch.Tensor]) 398 | assert limbus.core.param._check_subscriptable(Iterable[torch.Tensor]) 399 | assert limbus.core.param._check_subscriptable(List[torch.Tensor]) 400 | assert limbus.core.param._check_subscriptable(Tuple[torch.Tensor]) 401 | assert not limbus.core.param._check_subscriptable(Tuple[torch.Tensor, torch.Tensor]) 402 | assert not limbus.core.param._check_subscriptable(int) 403 | assert not limbus.core.param._check_subscriptable(torch.Tensor) 404 | assert not limbus.core.param._check_subscriptable(str) 405 | 406 | 407 | class A(Component): 408 | async def forward(self) -> ComponentState: 409 | return ComponentState.OK 410 | 411 | 412 | class B(Component): 413 | def __init__(self, name): 414 | super().__init__(name) 415 | self.stopping_execution = 1 416 | 417 | async def forward(self) -> ComponentState: 418 | return ComponentState.OK 419 | 420 | 421 | class TestInputParam: 422 | def test_smoke(self): 423 | p = InputParam("a") 424 | assert isinstance(p, Param) 425 | 426 | async def test_receive_without_parent(self): 427 | p = InputParam("a") 428 | with pytest.raises(AssertionError): 429 | await p.receive() 430 | 431 | async def test_receive_without_refs(self): 432 | p = InputParam("a", value=1, parent=A("a")) 433 | assert await p.receive() == 1 434 | 435 | async def test_receive_with_refs(self): 436 | po = OutputParam("b", parent=A("b")) 437 | pi = InputParam("a", parent=A("a")) 438 | po >> pi 439 | res = await asyncio.gather(po.send(1), pi.receive()) 440 | assert res == [None, 1] 441 | assert [ref.consumed.is_set() for ref in pi.references] == [True] 442 | assert [ref.sent.is_set() for ref in po.references] == [False] 443 | 444 | async def test_receive_with_refs_and_stopping_iteration(self): 445 | po = OutputParam("b", parent=B("b")) 446 | pi = InputParam("a", parent=B("a")) 447 | po >> pi 448 | res = await asyncio.gather(po.send(1), pi.receive()) 449 | assert res == [None, 1] 450 | assert [ref.consumed.is_set() for ref in pi.references] == [True] 451 | assert [ref.sent.is_set() for ref in po.references] == [False] 452 | assert pi.parent.executions_counter == 1 453 | 454 | async def test_receive_from_iterable_param(self): 455 | po0 = OutputParam("b", torch.Tensor, parent=A("b")) 456 | po1 = OutputParam("c", torch.Tensor, parent=A("c")) 457 | pi = InputParam("a", List[torch.Tensor], parent=A("a")) 458 | po0 >> pi.select(0) 459 | po1 >> pi.select(1) 460 | t0 = asyncio.create_task(pi.receive()) 461 | await asyncio.sleep(0) # exec t0 without blocking 462 | assert list(pi._refs[0])[0].consumed.is_set() is False 463 | assert list(pi._refs[1])[0].consumed.is_set() is False 464 | assert list(pi._refs[0])[0].sent.is_set() is False 465 | assert list(pi._refs[1])[0].sent.is_set() is False 466 | t1 = asyncio.create_task(po1.send(torch.tensor(1))) 467 | await asyncio.sleep(0) # exec t1 without blocking 468 | assert list(pi._refs[0])[0].consumed.is_set() is False 469 | assert list(pi._refs[1])[0].consumed.is_set() is False 470 | assert list(pi._refs[0])[0].sent.is_set() is False 471 | assert list(pi._refs[1])[0].sent.is_set() is True 472 | t2 = asyncio.create_task(po0.send(torch.tensor(2))) 473 | await asyncio.sleep(0) # exec t2 without blocking 474 | assert list(pi._refs[0])[0].consumed.is_set() is False 475 | assert list(pi._refs[1])[0].consumed.is_set() is False 476 | assert list(pi._refs[0])[0].sent.is_set() is True 477 | assert list(pi._refs[1])[0].sent.is_set() is True 478 | await asyncio.gather(t0, t1, t2) 479 | assert list(pi._refs[0])[0].consumed.is_set() is True 480 | assert list(pi._refs[1])[0].consumed.is_set() is True 481 | assert list(pi._refs[0])[0].sent.is_set() is False 482 | assert list(pi._refs[1])[0].sent.is_set() is False 483 | assert pi.value == [torch.tensor(2), torch.tensor(1)] 484 | 485 | async def test_receive_with_callback(self, caplog): 486 | async def callback(self, value): 487 | assert self.name == "a" 488 | log.info(f"callback: {value}") 489 | return 2 490 | po = OutputParam("b", parent=A("b")) 491 | pi = InputParam("a", parent=A("a"), callback=callback) 492 | po >> pi 493 | with caplog.at_level(logging.INFO): 494 | res = await asyncio.gather(po.send(1), pi.receive()) 495 | assert pi.value == 1 # the callback does not change the internal param value 496 | assert res[1] == 2 # onl changes the return value 497 | assert "callback: 1" in caplog.text 498 | 499 | 500 | class TestOutputParam: 501 | def test_smoke(self): 502 | p = OutputParam("a") 503 | assert isinstance(p, Param) 504 | 505 | async def test_send_without_parent(self): 506 | p = OutputParam("a") 507 | with pytest.raises(AssertionError): 508 | await p.send(1) 509 | 510 | async def test_send_without_refs(self): 511 | p = OutputParam("a", value=1, parent=A("a")) 512 | await p.send(1) 513 | 514 | async def test_send_with_refs(self): 515 | po = OutputParam("b", parent=A("b")) 516 | pi = InputParam("a", parent=A("a")) 517 | po >> pi 518 | assert [ref.sent.is_set() for ref in po.references] == [False] 519 | asyncio.create_task(po.send(1)) 520 | await asyncio.sleep(0) 521 | assert [ref.sent.is_set() for ref in po.references] == [True] 522 | assert [ref.consumed.is_set() for ref in po.references] == [False] 523 | await pi.receive() 524 | await asyncio.sleep(0) 525 | 526 | async def test_send_from_iterable_param(self): 527 | po = OutputParam("c", List[torch.Tensor], parent=A("c")) 528 | pi0 = InputParam("a", torch.Tensor, parent=A("a")) 529 | pi1 = InputParam("b", torch.Tensor, parent=A("b")) 530 | po.select(0) >> pi0 531 | po.select(1) >> pi1 532 | t0 = asyncio.create_task(pi0.receive()) 533 | await asyncio.sleep(0) # exec t0 without blocking 534 | assert list(po._refs[0])[0].consumed.is_set() is False 535 | assert list(po._refs[1])[0].consumed.is_set() is False 536 | assert list(po._refs[0])[0].sent.is_set() is False 537 | assert list(po._refs[1])[0].sent.is_set() is False 538 | t1 = asyncio.create_task(po.send([torch.tensor(1), torch.tensor(2)])) 539 | await asyncio.sleep(0) # exec t1 without blocking 540 | assert list(po._refs[0])[0].consumed.is_set() is False 541 | assert list(po._refs[1])[0].consumed.is_set() is False 542 | assert list(po._refs[0])[0].sent.is_set() is True 543 | assert list(po._refs[1])[0].sent.is_set() is True 544 | t2 = asyncio.create_task(pi1.receive()) 545 | await asyncio.gather(t0, t1, t2) 546 | assert list(po._refs[0])[0].consumed.is_set() is True 547 | assert list(po._refs[1])[0].consumed.is_set() is True 548 | assert list(po._refs[0])[0].sent.is_set() is False 549 | assert list(po._refs[1])[0].sent.is_set() is False 550 | assert pi0.value == torch.tensor(1) 551 | assert pi1.value == torch.tensor(2) 552 | 553 | async def test_send_with_callback(self, caplog): 554 | async def callback(self, value): 555 | assert self.name == "b" 556 | log.info(f"callback: {value}") 557 | return 2 558 | po = OutputParam("b", parent=A("b"), callback=callback) 559 | pi = InputParam("a", parent=A("a")) 560 | po >> pi 561 | with caplog.at_level(logging.INFO): 562 | await asyncio.gather(po.send(1), pi.receive()) 563 | assert pi.value == 2 564 | assert "callback: 1" in caplog.text 565 | 566 | 567 | class TestPropertyParam: 568 | def test_smoke(self): 569 | p = PropertyParam("a") 570 | assert isinstance(p, Param) 571 | 572 | def test_init_without_parent(self): 573 | p = PropertyParam("a") 574 | p.init_property(1) 575 | assert p.value == 1 576 | 577 | def test_init_with_parent(self): 578 | p = PropertyParam("a", parent=A("b")) 579 | p.init_property(1) 580 | assert p.value == 1 581 | 582 | async def test_set_without_parent(self): 583 | p = PropertyParam("a") 584 | with pytest.raises(AssertionError): 585 | await p.set_property(1) 586 | 587 | async def test_set_property_with_parent(self): 588 | p = PropertyParam("b", parent=A("b")) 589 | await p.set_property(1) 590 | assert p.value == 1 591 | 592 | async def test_set_property_with_callback(self, caplog): 593 | async def callback(self, value): 594 | assert self.name == "b" 595 | log.info(f"callback: {value}") 596 | return 2 597 | p = PropertyParam("b", parent=A("b"), callback=callback) 598 | with caplog.at_level(logging.INFO): 599 | await p.set_property(1) 600 | assert p.value == 2 601 | assert "callback: 1" in caplog.text 602 | -------------------------------------------------------------------------------- /limbus/core/param.py: -------------------------------------------------------------------------------- 1 | """Classes to define parameters.""" 2 | from __future__ import annotations 3 | from dataclasses import dataclass 4 | from collections import defaultdict 5 | import typing 6 | from typing import Any, TYPE_CHECKING, Callable 7 | import inspect 8 | import collections 9 | import asyncio 10 | import contextlib 11 | from abc import ABC 12 | 13 | import typeguard 14 | 15 | from limbus.core.states import ComponentState, ComponentStoppedError 16 | from limbus.core import async_utils 17 | # Note that Component class cannot be imported to avoid circular dependencies. 18 | if TYPE_CHECKING: 19 | from limbus.core.component import Component 20 | 21 | SUBSCRIPTABLE_TYPES: list[type] = [] 22 | try: 23 | import torch 24 | SUBSCRIPTABLE_TYPES.append(torch.Tensor) 25 | except ImportError: 26 | pass 27 | 28 | try: 29 | import numpy as np 30 | SUBSCRIPTABLE_TYPES.append(np.ndarray) 31 | except ImportError: 32 | pass 33 | 34 | 35 | class NoValue: 36 | """Denote that a param does not have a value.""" 37 | pass 38 | 39 | 40 | @dataclass 41 | class Container: 42 | """Denote that a param has a value.""" 43 | value: Any 44 | 45 | 46 | @dataclass 47 | class IterableContainer: 48 | """Denote that a param has an indexed value. 49 | 50 | Note: In our use case the maximum number of nested IterableContainers is 2. 51 | This number is not explicitly controlled. It is implicitly controlled in the Param class. 52 | 53 | """ 54 | container: Container | "IterableContainer" 55 | index: int 56 | 57 | @property 58 | def value(self) -> Any: 59 | """Get the value of the container.""" 60 | if isinstance(self.container, Container): 61 | # return the value of the container at the index 62 | return self.container.value[self.index] 63 | else: 64 | # look for the container value. 65 | # If it is an IterableContainer means that the final value is nested. 66 | assert isinstance(self.container, IterableContainer) 67 | return self.container.value 68 | 69 | 70 | class IterableInputContainers: 71 | """Denote that an input param is a sequence of Containers.""" 72 | def __init__(self, container: None | IterableContainer = None): 73 | containers = [] 74 | if container is not None: 75 | containers = [container] 76 | self._containers: list[IterableContainer] = containers 77 | 78 | def __len__(self) -> int: 79 | return len(self._containers) 80 | 81 | def add(self, container: IterableContainer) -> None: 82 | """Add an IterableValue to the list of values.""" 83 | self._containers.append(container) 84 | 85 | def remove(self, index: int) -> None: 86 | """Remove an IterableValue from the list of values.""" 87 | for container in self._containers: 88 | if container.index == index: 89 | self._containers.remove(container) 90 | return 91 | 92 | def get_ordered(self) -> list[Any]: 93 | """Return a list with the values in the order denoted by the index in the IterableValue.""" 94 | indices: list[int] = [] 95 | for container in self._containers: 96 | assert isinstance(container, IterableContainer) 97 | indices.append(container.index) 98 | 99 | containers: list[Any] = [] 100 | for pos_idx in sorted(range(len(indices)), key=indices.__getitem__): # argsort 101 | obj: Container | IterableContainer = self._containers[pos_idx].container 102 | if isinstance(obj, IterableContainer): 103 | obj = obj.container.value[obj.index] # type: ignore # Iterable[Any] is not indexable [index] 104 | else: 105 | assert isinstance(obj, Container) 106 | obj = obj.value 107 | containers.append(obj) 108 | return containers 109 | 110 | 111 | def _check_subscriptable(datatype: type) -> bool: 112 | """Checf if datatype is subscriptable with tensors inside. 113 | 114 | Args: 115 | datatype (type): type to be analised. 116 | 117 | Returns: 118 | bool: True if datatype is a subscriptable with tensors, False otherwise. 119 | 120 | """ 121 | # we need to know if it is a variable size datatype, we assume that all the sequences are variable size 122 | # if they contain tensors. E.g. list[Tensor], tuple[Tensor], Sequence[Tensor]. 123 | # Note that e.g. for the case tuple[Tensor, Tensor] we don't assume it is variable since the size is known. 124 | origin = typing.get_origin(datatype) 125 | if origin is None: # discard datatypes that are not typing expressions 126 | return False 127 | datatype_args: tuple = typing.get_args(datatype) 128 | if inspect.isclass(origin): 129 | is_abstract: bool = inspect.isabstract(origin) 130 | is_abstract_seq: bool = origin is collections.abc.Sequence or origin is collections.abc.Iterable 131 | # mypy complaints in the case origin is NoneType 132 | if is_abstract_seq or (not is_abstract and isinstance(origin(), typing.Iterable)): # type: ignore 133 | if (len(datatype_args) == 1 or (len(datatype_args) == 2 and Ellipsis in datatype_args)): 134 | if datatype_args[0] in SUBSCRIPTABLE_TYPES: 135 | return True 136 | return False 137 | 138 | 139 | class IterableParam: 140 | """Temporal class to manage indexing inside a parameter.""" 141 | def __init__(self, param: "Param", index: int) -> None: 142 | self._param: Param = param 143 | # TODO: validate that _iter_container can be an IterableInputContainers, I feel it cannot!! 144 | self._iter_container: IterableContainer | IterableInputContainers 145 | if isinstance(param.container, Container): 146 | self._iter_container = IterableContainer(param.container, index) 147 | elif isinstance(param.container, IterableInputContainers): 148 | # since it is an input, the pointer to the value is not relevant at this stage 149 | self._iter_container = IterableContainer(Container(None), index) 150 | 151 | @property 152 | def param(self) -> "Param": 153 | """Return the base parameter.""" 154 | return self._param 155 | 156 | @property 157 | def index(self) -> int: 158 | """Return the selected index in the sequence.""" 159 | if isinstance(self._iter_container, IterableInputContainers): 160 | raise TypeError("Cannot get the index of a list of input containers.") 161 | return self._iter_container.index 162 | 163 | @property 164 | def value(self) -> Any | list[Any]: 165 | """Get the value of the parameter. 166 | 167 | It can be a list of values if the parameter is an IterableInputContainers. 168 | 169 | """ 170 | if isinstance(self._iter_container, IterableContainer): 171 | return self._iter_container.value 172 | else: 173 | assert isinstance(self._iter_container, IterableInputContainers) 174 | return self._iter_container.get_ordered() 175 | 176 | @property 177 | def iter_container(self) -> IterableContainer | IterableInputContainers: 178 | """Get the container of the parameter.""" 179 | return self._iter_container 180 | 181 | def ref_counter(self) -> int: 182 | """Return the number of references for this parameter.""" 183 | if isinstance(self._iter_container, IterableInputContainers): 184 | raise TypeError("At this moment the number of references for IterableInputContainers cannot be retrieved.") 185 | return self._param.ref_counter(self._iter_container.index) 186 | 187 | def connect(self, dst: "Param" | "IterableParam") -> None: 188 | """Connect this parameter (output) with the dst (input) parameter.""" 189 | self._param._connect(self, dst) 190 | 191 | def __rshift__(self, rvalue: "Param" | "IterableParam"): 192 | """Allow to connect params using the >> operator.""" 193 | self.connect(rvalue) 194 | 195 | def disconnect(self, dst: "Param" | "IterableParam") -> None: 196 | """Disconnect this parameter (output) with the dst (input) parameter.""" 197 | self._param._disconnect(self, dst) 198 | 199 | 200 | @dataclass 201 | class Reference: 202 | """Reference to a parameter. 203 | 204 | It is used to keep track of the references to a parameter. 205 | 206 | """ 207 | param: "Param" 208 | ori_param: "Param" # added to avoid duplicated references, it is rare but it could happen. 209 | index: None | int = None 210 | ori_index: None | int = None # added to avoid duplicated references, it is rare but it could happen. 211 | # allow to know if there is a new value for the parameter 212 | sent: None | asyncio.Event = None 213 | # allow to know if the value has been consumed 214 | consumed: None | asyncio.Event = None 215 | 216 | def __hash__(self) -> int: 217 | # this method is required to be able to use Reference in a set. 218 | # Note that we don't use the consumed attribute in the hash since it is dynamic. 219 | return hash((self.param, self.index, self.ori_param, self.ori_index)) 220 | 221 | def __eq__(self, other: Any) -> bool: 222 | # this method is required to be able to use Reference in a set. 223 | # Note that we don't use the consumed attribute in the hash since it is dynamic. 224 | if isinstance(other, Reference): 225 | return (self.param == other.param and self.index == other.index and 226 | self.ori_param == other.ori_param and self.ori_index == other.ori_index) 227 | return False 228 | 229 | 230 | class Param(ABC): 231 | """Base class to store data for each parameter. 232 | 233 | Args: 234 | name: name of the parameter. 235 | tp (optional): type of the parameter. Madnatory for subscriptable params. Default: Any. 236 | value (optional): value of the parameter. Default: NoValue(). 237 | arg (optional): name of the argument in the component constructor related with this param. Default: None. 238 | parent (optional): parent component. Default: None. 239 | callback (optional): async callback to be called when the value of the parameter changes. 240 | Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` 241 | - MUST return the value to be finally used. 242 | Default: None. 243 | 244 | """ 245 | def __init__(self, name: str, tp: Any = Any, value: Any = NoValue(), arg: None | str = None, 246 | parent: None | Component = None, callback: Callable | None = None) -> None: 247 | # validate that the type is coherent with the value 248 | if not isinstance(value, NoValue): 249 | typeguard.check_type(name, value, tp) 250 | 251 | self._name: str = name 252 | self._type: Any = tp 253 | self._arg: None | str = arg 254 | # We store all the references for each param. 255 | # The key is the slicing for the current param. 256 | self._refs: dict[Any, set[Reference]] = defaultdict(set) 257 | self._value: Container | IterableContainer | IterableInputContainers = Container(value) 258 | # only sequences with tensors inside are subscriptable 259 | self._is_subscriptable = _check_subscriptable(tp) 260 | self._parent: None | Component = parent 261 | self._callback: None | Callable = callback 262 | 263 | @property 264 | def is_subscriptable(self) -> bool: 265 | """Return if the parameter is subscriptable.""" 266 | return self._is_subscriptable 267 | 268 | def reset_is_subscriptable(self) -> None: 269 | """Reset the subscriptable flag.""" 270 | self._is_subscriptable = _check_subscriptable(self._type) 271 | 272 | def set_as_non_subscriptable(self) -> None: 273 | """Set the subscriptable flag to False.""" 274 | self._is_subscriptable = False 275 | 276 | @property 277 | def parent(self) -> None | Component: 278 | """Get the parent component.""" 279 | return self._parent 280 | 281 | @property 282 | def arg(self) -> None | str: 283 | """Get the argument in the Component constructor related with this param. 284 | 285 | This is a trick to pass a value and type of an argument in the Component constructor to this parameter. 286 | 287 | """ 288 | return self._arg 289 | 290 | @property 291 | def type(self) -> Any: 292 | """Return the type of the parameter.""" 293 | return self._type 294 | 295 | @property 296 | def name(self) -> str: 297 | """Get the name of the parameter.""" 298 | return self._name 299 | 300 | @property 301 | def references(self) -> set[Reference]: 302 | """Get all the references for the parameter.""" 303 | refs: set[Reference] = set() 304 | for ref_set in self._refs.values(): 305 | refs = refs.union(ref_set) 306 | return refs 307 | 308 | @property 309 | def value(self) -> Any: 310 | """Get the value of the parameter.""" 311 | if isinstance(self._value, Container): 312 | if isinstance(self._value.value, IterableContainer): 313 | # mypy error: Iterable[Any] is not indexable [index] 314 | return self._value.value.container.value[self._value.value.index] # type: ignore 315 | else: 316 | return self._value.value 317 | elif isinstance(self._value, IterableInputContainers): 318 | assert self._is_subscriptable 319 | origin = typing.get_origin(self._type) 320 | assert origin is not None 321 | res_value: list[Any] = self._value.get_ordered() 322 | return origin(res_value) 323 | 324 | @value.setter 325 | def value(self, value: Any) -> None: 326 | """Set the value of the parameter. 327 | 328 | Args: 329 | value (Any): The value to set. 330 | 331 | """ 332 | self._set_value(value) 333 | 334 | def _set_value(self, value: Any) -> None: 335 | # trick to easily override the setter of the value property 336 | if isinstance(value, Param): 337 | value = value.value 338 | if not isinstance(self._value, Container): 339 | raise TypeError(f"Param '{self.name}' cannot be assigned.") 340 | if isinstance(value, (Container, IterableContainer, set)): 341 | raise TypeError( 342 | f"The type of the value to be assigned to param '{self.name}' cannot have a 'value' attribute.") 343 | typeguard.check_type(self._name, value, self._type) 344 | self._value.value = value 345 | 346 | @property 347 | def container(self) -> Container | IterableContainer | IterableInputContainers: 348 | """Get the container for this parameter.""" 349 | return self._value 350 | 351 | @container.setter 352 | def container(self, value: Container | IterableContainer | IterableInputContainers) -> None: 353 | """Set the container for this parameter. 354 | 355 | Args: 356 | value (Container, IterableContainer or IterableInputContainers): The container to set. 357 | 358 | """ 359 | self._value = value 360 | 361 | def ref_counter(self, index: None | int = None) -> int: 362 | """Return the number of references for this parameter.""" 363 | if index is not None: 364 | return len(self._refs[index]) 365 | else: 366 | return len(self.references) 367 | 368 | def select(self, index: int) -> IterableParam: 369 | """Select a slice of the parameter. 370 | 371 | Args: 372 | index (int): The index of the slice. 373 | 374 | Returns: 375 | Param: The selected slice. 376 | 377 | """ 378 | if not self._is_subscriptable: 379 | raise ValueError(f"The param '{self.name}' is not subscriptable (it must be a sequence of tensors).") 380 | # NOTE: we cannot check if the index is valid because it is not known at this point the len of the sequence 381 | # create a new param with the selected slice inside the param 382 | return IterableParam(self, index) 383 | 384 | def _connect(self, ori: "Param" | IterableParam, dst: "Param" | IterableParam) -> None: 385 | """Connect this parameter (output) with the dst (input) parameter.""" 386 | # Disable this check until a better solution is found to connect 2 lists. 387 | # if isinstance(ori, Param) and ori._is_subscriptable: 388 | # raise ValueError(f"The param '{ori.name}' must be connected using indexes.") 389 | 390 | # if isinstance(dst, Param) and dst._is_subscriptable: 391 | # raise ValueError(f"The param '{dst.name}' must be connected using indexes.") 392 | 393 | # NOTE that there is not type validation, we will trust in the user to connect params. 394 | # We only check when there is an explicit value in the ori param. 395 | if isinstance(ori, Param) and not isinstance(ori.value, NoValue): 396 | if isinstance(dst, Param): 397 | typeguard.check_type(self._name, ori.value, dst.type) 398 | else: 399 | typeguard.check_type(self._name, ori.value, typing.get_args(dst.param.type)[0]) 400 | 401 | # TODO: check that dst param is an input param 402 | # TODO: check type compatibility 403 | if (isinstance(dst, Param) and dst.ref_counter() > 0): 404 | raise ValueError(f"An input parameter can only be connected to 1 param. " 405 | f"Dst param '{dst.name}' is connected to {dst._refs}.") 406 | 407 | if isinstance(dst, IterableParam) and dst.param.ref_counter(dst.index) > 0: 408 | raise ValueError(f"An input parameter can only be connected to 1 param. " 409 | f"Dst param '{dst.param.name}' is connected to {dst.param._refs}.") 410 | 411 | # connect the param to the dst param 412 | if isinstance(dst, Param) and isinstance(ori, Param): 413 | assert isinstance(dst.container, Container) 414 | assert isinstance(ori.container, Container) 415 | dst.container = ori.container 416 | elif isinstance(dst, IterableParam) and isinstance(ori, Param): 417 | assert isinstance(dst.iter_container, IterableContainer) 418 | assert isinstance(ori.container, Container) 419 | dst.iter_container.container = ori.container 420 | elif isinstance(dst, Param) and isinstance(ori, IterableParam): 421 | assert isinstance(dst.container, Container) 422 | assert isinstance(ori.iter_container, IterableContainer) 423 | dst.container.value = ori.iter_container 424 | else: 425 | assert isinstance(dst, IterableParam) 426 | assert isinstance(ori, IterableParam) 427 | assert isinstance(dst.iter_container, IterableContainer) 428 | assert isinstance(ori.iter_container, IterableContainer) 429 | dst.iter_container.container = ori.iter_container 430 | 431 | # if dest is an IterableParam means that several ori params can be connected to different dest indexes 432 | # so they are stored as a list of params 433 | if isinstance(dst, IterableParam): 434 | assert isinstance(dst.iter_container, IterableContainer) 435 | if isinstance(dst.param.container, IterableInputContainers): 436 | dst.param.container.add(dst.iter_container) 437 | else: 438 | dst.param.container = IterableInputContainers(dst.iter_container) 439 | 440 | self._update_references('add', ori, dst) 441 | 442 | def connect(self, dst: "Param" | IterableParam) -> None: 443 | """Connect this parameter (output) with the dst (input) parameter.""" 444 | self._connect(self, dst) 445 | 446 | def __rshift__(self, rvalue: "Param" | IterableParam): 447 | """Allow to connect params using the >> operator.""" 448 | self.connect(rvalue) 449 | 450 | def _disconnect(self, ori: "Param" | IterableParam, dst: "Param" | IterableParam) -> None: 451 | """Disconnect this parameter from the dst parameter.""" 452 | if isinstance(dst, Param): 453 | assert isinstance(dst.container, Container) 454 | dst.container = Container(NoValue()) 455 | elif isinstance(dst, IterableParam): 456 | if isinstance(dst.param.container, IterableInputContainers): 457 | assert isinstance(dst.iter_container, IterableContainer) 458 | dst.param.container.remove(dst.iter_container.index) 459 | if len(dst.param.container) == 0: 460 | dst.param.container = Container(NoValue()) 461 | else: 462 | dst.param.container = Container(NoValue()) 463 | 464 | self._update_references('remove', ori, dst) 465 | 466 | def _update_references(self, type: str, ori: "Param" | IterableParam, dst: "Param" | IterableParam 467 | ) -> None: 468 | # assign references 469 | ori_idx = None 470 | dst_idx = None 471 | if isinstance(ori, IterableParam): 472 | ori_idx = ori.index 473 | ori = ori.param 474 | if isinstance(dst, IterableParam): 475 | dst_idx = dst.index 476 | dst = dst.param 477 | if type == 'add': 478 | # Set events denoting that the param is sent/consumed. Note that the same events are set in the 479 | # references of both params. 480 | consumed_event = asyncio.Event() 481 | sent_event = asyncio.Event() 482 | ori._refs[ori_idx].add(Reference(dst, ori, dst_idx, ori_idx, sent_event, consumed_event)) 483 | dst._refs[dst_idx].add(Reference(ori, dst, ori_idx, dst_idx, sent_event, consumed_event)) 484 | elif type == 'remove': 485 | ori._refs[ori_idx].remove(Reference(dst, ori, dst_idx, ori_idx)) 486 | dst._refs[dst_idx].remove(Reference(ori, dst, ori_idx, dst_idx)) 487 | 488 | def disconnect(self, dst: "Param" | IterableParam) -> None: 489 | """Disconnect this parameter (output) from the dst (input) parameter.""" 490 | self._disconnect(self, dst) 491 | 492 | 493 | class PropertyParam(Param): 494 | """Class to manage the comunication for each property parameter.""" 495 | 496 | def init_property(self, value: Any) -> None: 497 | """Initialize the property with the given value. 498 | 499 | This method should be called before running the component to init the property. 500 | So, it is not running the callback function. 501 | 502 | """ 503 | # ComponentState.INITIALIZED means that the component was just created 504 | if self._parent is not None and ComponentState.INITIALIZED not in self._parent.state: 505 | raise RuntimeError("The property can only be initialized before running the component.") 506 | self.value = value 507 | 508 | async def set_property(self, value: Any) -> None: 509 | """Set the value of the property. 510 | 511 | Note: using this method is the only way to run the callback function. 512 | 513 | """ 514 | assert self._parent is not None 515 | if self._callback is None: 516 | self.value = value 517 | else: 518 | self.value = await self._callback(self._parent, value) 519 | 520 | 521 | class InputParam(Param): 522 | """Class to manage the comunication for each input parameter.""" 523 | 524 | async def receive(self) -> Any: 525 | """Wait until the input param receives a value from the connected output param. 526 | 527 | Note that using this metohd will run the callback function as soon as a new value is received. 528 | Note tha the callback changes teh result returned by the received method, not the value inside the 529 | param (Param.value). This is in this way because the param can be shared between several input params, 530 | so each callback call could change its value. 531 | 532 | """ 533 | assert self._parent is not None 534 | self._parent._Component__num_params_waiting_to_receive += 1 535 | if self.references: 536 | for ref in self.references: 537 | # NOTE: each input param can be connected to 0 or 1 output param (N output params if it is iterable). 538 | # ref: Reference = next(iter(self.references)) 539 | # ensure the component related with the output param exists 540 | assert ref.param is not None 541 | ori_param: Param = ref.param 542 | assert isinstance(ori_param, OutputParam) # they must be of type OutputParam 543 | assert ori_param.parent is not None 544 | self._parent.set_state(ComponentState.RECEIVING_PARAMS, 545 | f"{ori_param.parent.name}.{ori_param.name} -> {self._parent.name}.{self.name}") 546 | async_utils.create_task_if_needed(self._parent, ori_param.parent) 547 | 548 | if self._parent.stopping_execution == 0: 549 | # fast way, in contrast with the while loop below, to wait for the input param. 550 | # wait until all the output params send the values 551 | await asyncio.gather(*[ref.sent.wait() for ref in self.references if ref.sent is not None]) 552 | else: 553 | sent: int = 0 554 | while sent < len(self.references): 555 | for ref in self.references: 556 | # Trick to avoid issues due to setting a concrete number of iters to be executed. E.g.: if at 557 | # least 1 iter is requested from each component and there is a component requesting 2 iters 558 | # from a previous one this trick will recreate the tasks. 559 | # This is mainly useful for debugging purposes since it slowdown the execution. 560 | assert ref.param is not None 561 | assert ref.param.parent is not None 562 | async_utils.create_task_if_needed(self._parent, ref.param.parent) 563 | assert isinstance(ref.sent, asyncio.Event) 564 | with contextlib.suppress(asyncio.TimeoutError): 565 | await asyncio.wait_for(ref.sent.wait(), timeout=0.1) 566 | sent = sum([ref.sent.is_set() for ref in self.references if ref.sent is not None]) 567 | 568 | for ref in self.references: 569 | assert ref.param is not None 570 | assert ref.param.parent is not None 571 | # if we want to stop at a given min iter then it is posible to require more iters 572 | if ComponentState.STOPPED_AT_ITER not in ref.param.parent.state and ref.param.parent.is_stopped(): 573 | raise ComponentStoppedError(ComponentState.STOPPED_BY_COMPONENT) 574 | 575 | for ref in self.references: 576 | # NOTE: depending on how the value is consumed we should apply a copy here. 577 | # - We assume components do not modify the value. (this can happen) 578 | # - When the value is setted reusing the same memory, instead of creating a new var, then 579 | # the changes will also be propagated to components consuming the previous value. (in theory 580 | # this cannot happen) 581 | # TODO: add a flag to allow to determine if we want to copy the value. 582 | value = self.value # get the value before allowing to send again 583 | assert isinstance(ref.sent, asyncio.Event) 584 | assert isinstance(ref.consumed, asyncio.Event) 585 | ref.consumed.set() # denote that the param is consumed 586 | ref.sent.clear() # allow to know to the sender that it can send again 587 | else: 588 | value = self.value 589 | if self._callback is not None: 590 | # specific callback for this param 591 | value = await self._callback(self._parent, value) 592 | await self._are_all_waiting_params_received() 593 | if self._parent.pipeline and self._parent.pipeline.param_received_user_hook: 594 | # hook from the pipeline, all the components and input params run the same code 595 | await self._parent.pipeline.param_received_user_hook(self) 596 | return value 597 | 598 | async def _are_all_waiting_params_received(self) -> None: 599 | """Check if the component is waiting for other params before changing the component state.""" 600 | assert self._parent is not None 601 | self._parent._Component__num_params_waiting_to_receive -= 1 602 | if self._parent._Component__num_params_waiting_to_receive == 0: 603 | self._parent.set_state(ComponentState.RUNNING) 604 | 605 | 606 | class OutputParam(Param): 607 | """Class to manage the comunication for each output parameter.""" 608 | 609 | async def send(self, value: Any) -> None: 610 | """Send the value of this param to the connected input params. 611 | 612 | Note that using this metohd will run the callback function as soon as a new value is received. 613 | 614 | """ 615 | assert self._parent is not None 616 | if self._callback is None: 617 | self.value = value # set the value for the param 618 | else: 619 | self.value = await self._callback(self._parent, value) 620 | 621 | for ref in self.references: 622 | assert isinstance(ref.sent, asyncio.Event) 623 | assert isinstance(ref.consumed, asyncio.Event) 624 | ref.consumed.clear() # init the state of the event 625 | ref.sent.set() # denote that the param is ready to be consumed 626 | 627 | # ensure the component related with the input param exists 628 | assert ref.param is not None 629 | dst_param: Param = ref.param 630 | assert isinstance(dst_param, InputParam) # they must be of type InputParam 631 | assert dst_param.parent is not None 632 | self._parent.set_state(ComponentState.SENDING_PARAMS, 633 | f"{self._parent.name}.{self.name} -> {dst_param.parent.name}.{dst_param.name}") 634 | async_utils.create_task_if_needed(self._parent, dst_param.parent) 635 | 636 | if self._parent.pipeline and self._parent.pipeline.param_sent_user_hook: 637 | await self._parent.pipeline.param_sent_user_hook(self) 638 | 639 | # wait until all the input params read the value 640 | await asyncio.gather(*[ref.consumed.wait() for ref in self.references if ref.consumed is not None]) 641 | for ref in self.references: 642 | assert ref.param is not None 643 | assert ref.param.parent is not None 644 | # if we want to stop at a given min iter then it is posible to require more iters 645 | if ComponentState.STOPPED_AT_ITER not in ref.param.parent.state and ref.param.parent.is_stopped(): 646 | raise ComponentStoppedError(ComponentState.STOPPED_BY_COMPONENT) 647 | --------------------------------------------------------------------------------