├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── README.md ├── aws_runtime_gateway.py ├── aws_runtime_kinesis.py ├── bin ├── flink-kafka-1.8.jar ├── flink-kafka-bytes-serializer.jar └── flink-sql-connector-kafka_2.11-1.13.0.jar ├── demo_client.py ├── demo_client_universalis.py ├── demo_common.py ├── demo_runtime.py ├── demo_runtime_universalis.py ├── demo_runtime_universalis_ycsb.py ├── demo_ycsb.py ├── deployment ├── docker-compose.yml └── statefun │ ├── docker-compose.yml │ └── module.yaml ├── fastapi_client.py ├── img ├── fun_address.svg └── stateflow_to_stream.svg ├── package.json ├── requirements.txt ├── runtime_aws.py ├── serverless.yml ├── serverless_gateway.yml ├── serverless_kinesis.yml ├── setup.py ├── stateflow ├── __init__.py ├── analysis │ ├── __init__.py │ ├── ast_utils.py │ ├── extract_class_descriptor.py │ └── extract_method_descriptor.py ├── client │ ├── __init__.py │ ├── aws_client.py │ ├── aws_gateway_client.py │ ├── class_ref.py │ ├── fastapi │ │ ├── __init__.py │ │ ├── aws_gateway.py │ │ ├── aws_lambda.py │ │ ├── fastapi.py │ │ └── kafka.py │ ├── future.py │ ├── kafka_client.py │ ├── stateflow_client.py │ └── universalis_client.py ├── core.py ├── dataflow │ ├── __init__.py │ ├── address.py │ ├── args.py │ ├── dataflow.py │ ├── event.py │ ├── event_flow.py │ ├── state.py │ └── stateful_operator.py ├── descriptors │ ├── __init__.py │ ├── class_descriptor.py │ └── method_descriptor.py ├── runtime │ ├── KafkaConsumer.py │ ├── __init__.py │ ├── aws │ │ ├── __init__.py │ │ ├── abstract_lambda.py │ │ ├── gateway_lambda.py │ │ └── kinesis_lambda.py │ ├── beam_runtime.py │ ├── cloudburst │ │ ├── __init__.py │ │ └── cloudburst.py │ ├── dataflow │ │ ├── __init__.py │ │ └── remote_lambda.py │ ├── flink │ │ ├── __init__.py │ │ ├── pyflink.py │ │ └── statefun.py │ ├── runtime.py │ └── universalis │ │ ├── __init__.py │ │ └── universalis_runtime.py ├── serialization │ ├── __init__.py │ ├── cloudpickle_serializer.py │ ├── json_serde.py │ ├── pickle_serializer.py │ ├── proto │ │ ├── __init__.py │ │ ├── event.proto │ │ ├── event_pb2.py │ │ └── proto_serde.py │ └── serde.py ├── split │ ├── __init__.py │ ├── conditional_block.py │ ├── execution_plan_merging.py │ ├── for_block.py │ ├── split_analyze.py │ ├── split_block.py │ └── split_transform.py ├── util │ ├── __init__.py │ ├── dataflow_operator_generator.py │ ├── dataflow_visualizer.py │ ├── local_runtime.py │ ├── stateflow_test.py │ └── statefun_module_generator.py └── wrappers │ ├── __init__.py │ ├── class_wrapper.py │ └── meta_wrapper.py ├── tests ├── __init__.py ├── analysis │ ├── ast_utils_test.py │ └── extract_stateful_test.py ├── client │ ├── class_ref_test.py │ ├── future_test.py │ └── kafka_client_test.py ├── common │ └── common_classes.py ├── context.py ├── dataflow │ ├── arguments_test.py │ └── stateful_operator_test.py ├── kafka │ └── KafkaImage.py ├── local_runtime_test.py ├── runtime │ ├── aws_runtime_test.py │ └── beam_runtime_test.py ├── serialization │ └── proto_serializer_test.py ├── split │ └── split_test.py ├── stateflow_test.py └── wrapper │ ├── class_wrapper_test.py │ └── meta_wrapper_test.py └── zipfian_generator.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: build 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.8 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.8 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install flake8 pytest 27 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 28 | - name: Lint with flake8 29 | run: | 30 | # stop the build if there are Python syntax errors or undefined names 31 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 32 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 33 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 34 | - name: Test with pytest 35 | run: | 36 | coverage run --source=stateflow -m pytest tests 37 | - name: Generate coverage report 38 | run: | 39 | coverage xml 40 | - name: "Upload coverage to Codecov" 41 | uses: codecov/codecov-action@v1 42 | with: 43 | fail_ci_if_error: true 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | .idea 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | *.serverless/* 132 | node_modules/* 133 | 134 | benchmark/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StateFlow | Object Oriented Code to Distributed Stateful Dataflows 2 | [![CI](https://github.com/wzorgdrager/stateful_dataflows/actions/workflows/python-app.yml/badge.svg)](https://github.com/wzorgdrager/stateful_dataflows/actions/workflows/python-app.yml) 3 | [![codecov](https://codecov.io/gh/delftdata/stateflow/branch/main/graph/badge.svg?token=AUL4CXQQJX)](https://codecov.io/gh/delftdata/stateflow) 4 | [![Python 3.8](https://img.shields.io/badge/python-3.8-blue.svg)](https://www.python.org/downloads/release/python-380/) 5 | 6 | StateFlow is a framework which compiles object oriented Python code to distributed stateful dataflows. 7 | These dataflows can be executed on different target systems. At the moment, we support the following runtime systems: 8 | 9 | | **Runtime** | **Local execution** | **Cluster execution** | **Notes** | 10 | |:--------------:|:-------------------:|:---------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------:| 11 | | PyFlink | :white_check_mark: | :white_check_mark: | - | 12 | | Stateflow (Universalis) | :white_check_mark: | :white_check_mark: | - | 13 | | Apache Beam | :white_check_mark: | :x: | Beam [suffers a bug with Kafka](https://issues.apache.org/jira/browse/BEAM-11998), which can be bypassed locally. Deployment in a Dataflow runner does not work. | 14 | | Flink Statefun | :white_check_mark: | :white_check_mark: | - | 15 | | AWS Lambda | :white_check_mark: | :white_check_mark: | - | 16 | | CloudBurst | :x: | :x: | CloudBurst has never been officially released. Due to missing Docker files and documentation, execution does not work. | 17 | 18 | An evaluation of StateFlow can be found at [delftdata/stateflow-evaluation](https://github.com/delftdata/stateflow-evaluation). 19 | 20 | The stateflow runtime system can be found at [delftdata/stateflow-runtime](https://github.com/delftdata/stateflow-runtime). 21 | ## Features 22 | - Analysis and transformation of Python classes to distributed stateful dataflows. These dataflows can be ported to cloud services and dataflow systems. 23 | - Due to the nature of dataflow systems, stateful entities cannot directly interact with each other. Therefore, direct calls to other objects, as done in object-oriented code, does not work in stateful dataflows. StateFlow splits such functions at the AST level to get rid of the remote call. 24 | Instead, StateFlow splits a function into several parts such that the dataflow system can move back and forth between the different stateful entities (e.g. dataflow operators). 25 | - Support for compilation to several (cloud) services including: AWS Lambda, Apache Beam, Flink Statefun and PyFlink. 26 | - Support for several client-side connectivity services including: Apache Kafka, AWS Kinesis, AWS Gateway. Depending on the runtime system, a compatible client has to be used. 27 | A developer can either use StateFlow futures or asyncio to interact with the remote stateful entities. 28 | - Integration with FastAPI: each class function will be covered by an HTTP endpoint. Developers can easily add their own. 29 | 30 | ## Walkthrough 31 | To work with StateFlow, a developer annotates its classes with the `@stateflow` decorator. 32 | ```python 33 | from typing import List 34 | import stateflow 35 | 36 | @stateflow.stateflow 37 | class Item: 38 | def __init__(self, item_name: str, price: int): 39 | self.item_name: str = item_name 40 | self.stock: int = 0 41 | self.price: int = price 42 | 43 | def set_stock(self, amount: int): 44 | self.stock = amount 45 | 46 | def __key__(self): 47 | return self.item_name 48 | 49 | 50 | @stateflow.stateflow 51 | class User: 52 | def __init__(self, username: str): 53 | self.username: str = username 54 | self.balance: int = 1 55 | 56 | def update_balance(self, x: int): 57 | self.balance += x 58 | 59 | def buy_item(self, amount: int, item: Item) -> bool: 60 | total_price = amount * item.price 61 | 62 | if self.balance < total_price: 63 | return False 64 | 65 | # Decrease the stock. 66 | decrease_stock = item.update_stock(-amount) 67 | 68 | if not decrease_stock: 69 | return False # For some reason, stock couldn't be decreased. 70 | 71 | self.balance -= total_price 72 | return True 73 | 74 | def __key__(self): 75 | return self.username 76 | ``` 77 | Each stateful entities has to implement the `key` method to define the static partitioning key. This key cannot change during execution 78 | and ensures the entity is addressable in the distributed runtime. Think of this key as the primary key in databases. 79 | 80 | To deploy this to, for example, as a Flink job simply import the annotated classes and initialize stateflow. 81 | ```python 82 | from demo_common import User, Item, stateflow 83 | from stateflow.runtime.flink.pyflink import FlinkRuntime, Runtime 84 | 85 | # Initialize stateflow 86 | flow = stateflow.init() 87 | 88 | runtime: FlinkRuntime = FlinkRuntime(flow) 89 | runtime.run() 90 | ``` 91 | This code will generate a `streaming dataflow graph` compatible with Apache Flink. 92 | Finally, to interact with these stateful entities: 93 | ```python 94 | from demo_common import User, Item, stateflow 95 | from stateflow.client.kafka_client import StateflowKafkaClient, StateflowClient, StateflowFuture 96 | 97 | # Initialize stateflow 98 | flow = stateflow.init() 99 | client: StateflowClient = StateflowKafkaClient( 100 | flow, brokers="localhost:9092", statefun_mode=False 101 | ) 102 | 103 | future_user: StateflowFuture[User] = User("new-user") 104 | user: User = future_user.get() 105 | 106 | user.update_balance(10).get() 107 | ``` 108 | 109 | ## Demo 110 | To run a (full) demo: 111 | 1. Launch a Kafka cluster 112 | ``` 113 | cd deployment 114 | docker-compose up 115 | ``` 116 | 2. Run `demo_client.py`, this will start a client being able to interact with stateful entities. 117 | This will also create the appropriate Kafka topics `client_request`, `client_reply`, `internal`. 118 | 3. Run `demo_runtime.py`, this will deploy the stateful dataflow on Apache Beam. The stateful entities are defined in `demo_common.py`. 119 | 120 | ## Demo (with FastAPI) 121 | 1. Launch a Kafka cluster 122 | ``` 123 | cd deployment 124 | docker-compose up 125 | ``` 126 | 2. Run ` uvicorn fastapi_client:app`, this will start a FastAPI client on http://localhost:8000 127 | being able to interact with stateful entities using Kafka. To find all (generated) endpoints visit http://localhost:8000/docs. 128 | New endpoints can be added in `fastapi_client.py`. 129 | 3. Run `demo_runtime.py`, this will deploy the stateful dataflow on Apache Beam. The stateful entities are defined in `demo_common.py`. 130 | 131 | ## Credits 132 | This repository is part of the research conducted at the [Delft Data Management Lab](http://www.wis.ewi.tudelft.nl/data-management.html). 133 | Contributors: 134 | - [Wouter Zorgdrager](https://github.com/wzorgdrager) 135 | - [All contributors](https://github.com/delftdata/stateflow/graphs/contributors) 136 | -------------------------------------------------------------------------------- /aws_runtime_gateway.py: -------------------------------------------------------------------------------- 1 | from stateflow.runtime.aws.gateway_lambda import AWSGatewayLambdaRuntime 2 | from demo_common import stateflow 3 | 4 | 5 | flow = stateflow.init() 6 | print("Called init code!") 7 | 8 | runtime, handler = AWSGatewayLambdaRuntime.get_handler(flow, gateway=False) 9 | -------------------------------------------------------------------------------- /aws_runtime_kinesis.py: -------------------------------------------------------------------------------- 1 | from stateflow.runtime.aws.kinesis_lambda import AWSKinesisLambdaRuntime 2 | from demo_common import stateflow 3 | 4 | 5 | flow = stateflow.init() 6 | print("Called init code!") 7 | 8 | runtime, handler = AWSKinesisLambdaRuntime.get_handler(flow) 9 | -------------------------------------------------------------------------------- /bin/flink-kafka-1.8.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/bin/flink-kafka-1.8.jar -------------------------------------------------------------------------------- /bin/flink-kafka-bytes-serializer.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/bin/flink-kafka-bytes-serializer.jar -------------------------------------------------------------------------------- /bin/flink-sql-connector-kafka_2.11-1.13.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/bin/flink-sql-connector-kafka_2.11-1.13.0.jar -------------------------------------------------------------------------------- /demo_client.py: -------------------------------------------------------------------------------- 1 | from demo_common import User, Item, stateflow 2 | from stateflow.client.kafka_client import StateflowClient 3 | from stateflow.util.local_runtime import LocalRuntime 4 | from stateflow.client.future import StateflowFuture, StateflowFailure 5 | from stateflow.client.aws_gateway_client import AWSGatewayClient 6 | from stateflow.client.kafka_client import StateflowKafkaClient 7 | import time 8 | import datetime 9 | from stateflow.util import statefun_module_generator 10 | 11 | flow = stateflow.init() 12 | client: StateflowClient = StateflowKafkaClient( 13 | flow, brokers="localhost:9092", statefun_mode=False 14 | ) 15 | client.create_all_topics() 16 | client.wait_until_healthy(timeout=10) 17 | 18 | 19 | print("Creating a user: ") 20 | start = datetime.datetime.now() 21 | future_user: StateflowFuture[User] = User("wouter-user") 22 | 23 | try: 24 | user: User = future_user.get() 25 | except StateflowFailure: 26 | user: User = client.find(User, "wouter-user").get() 27 | 28 | end = datetime.datetime.now() 29 | delta = end - start 30 | print(f"Creating user took {delta.total_seconds() * 1000}ms") 31 | 32 | print("Creating another user") 33 | start = datetime.datetime.now() 34 | future_user2: StateflowFuture[User] = User("wouter-user2") 35 | 36 | try: 37 | user2: User = future_user2.get() 38 | except StateflowFailure: 39 | user2: User = client.find(User, "wouter-user2").get() 40 | end = datetime.datetime.now() 41 | delta = end - start 42 | print(f"Creating another user took {delta.total_seconds() * 1000}ms") 43 | 44 | print("Done!") 45 | start = datetime.datetime.now() 46 | for_loop: int = user.simple_for_loop([user, user2]).get(timeout=5) 47 | end = datetime.datetime.now() 48 | delta = end - start 49 | print(f"Simple for loop took {delta.total_seconds() * 1000}ms") 50 | 51 | print(user.balance.get()) 52 | print(user2.balance.get()) 53 | # print(for_loop) 54 | # print("") 55 | print("Creating an item: ") 56 | start = datetime.datetime.now() 57 | future_item: StateflowFuture[Item] = Item("coke", 10) 58 | 59 | try: 60 | item: Item = future_item.get() 61 | except StateflowFailure: 62 | item: Item = client.find(Item, "coke").get() 63 | end = datetime.datetime.now() 64 | delta = end - start 65 | print(f"Creating coke took {delta.total_seconds() * 1000}ms") 66 | 67 | 68 | start = datetime.datetime.now() 69 | future_item2: StateflowFuture[Item] = Item("pepsi", 10) 70 | 71 | try: 72 | item2: Item = future_item2.get() 73 | except StateflowFailure: 74 | item2: Item = client.find(Item, "pepsi").get() 75 | end = datetime.datetime.now() 76 | delta = end - start 77 | print(f"Creating another pepsi took {delta.total_seconds() * 1000}ms") 78 | 79 | 80 | start = datetime.datetime.now() 81 | hi = user.state_requests([item, item2]).get(timeout=10) 82 | print(hi) 83 | end = datetime.datetime.now() 84 | delta = end - start 85 | print(f"State requests took {delta.total_seconds() * 1000}ms") 86 | item.stock = 5 87 | user.balance = 10 88 | 89 | 90 | start_time = time.time() 91 | print(f"User balance: {user.balance.get()}") 92 | print(f"Item stock: {item.stock.get()} and price {item.price.get()}") 93 | 94 | print() 95 | # This is impossible. 96 | start = datetime.datetime.now() 97 | print(f"Let's try to buy 100 coke's of 10EU?: {user.buy_item(100, item).get()}") 98 | end = datetime.datetime.now() 99 | delta = end - start 100 | print(f"Buy item took {delta.total_seconds() * 1000}ms") 101 | 102 | # Jeej we buy one, user will end up with 0 balance and there is 4 left in stock. 103 | start = datetime.datetime.now() 104 | print(f"Lets' try to buy 1 coke's of 10EU?: {user.buy_item(1, item).get()}") 105 | end = datetime.datetime.now() 106 | delta = end - start 107 | print(f"Another buy item took {delta.total_seconds() * 1000}ms") 108 | 109 | print() 110 | # user balance 0, stock 4. 111 | print(f"Final user balance: {user.balance.get()}") 112 | print(f"Final item stock: {item.stock.get()}") 113 | end_time = time.time() 114 | diff = (end_time - start_time) * 1000 115 | 116 | print(f"\nThat took {diff}ms") 117 | -------------------------------------------------------------------------------- /demo_client_universalis.py: -------------------------------------------------------------------------------- 1 | from demo_common import User, Item, stateflow 2 | from stateflow.client.future import StateflowFuture, StateflowFailure 3 | from stateflow.client.kafka_client import StateflowKafkaClient 4 | import time 5 | import datetime 6 | 7 | from stateflow.client.universalis_client import UniversalisClient 8 | 9 | flow = stateflow.init() 10 | client: UniversalisClient = UniversalisClient(flow, 11 | brokers="localhost:9092", 12 | statefun_mode=False) 13 | client.wait_until_healthy(timeout=10) 14 | 15 | 16 | # print("Creating a user: ") 17 | # start = datetime.datetime.now() 18 | # future_user: StateflowFuture[User] = User("wouter-user") 19 | # 20 | # try: 21 | # user: User = future_user.get() 22 | # except StateflowFailure: 23 | # user: User = client.find(User, "wouter-user").get() 24 | # 25 | # end = datetime.datetime.now() 26 | # delta = end - start 27 | # print(f"Creating user took {delta.total_seconds() * 1000}ms") 28 | # 29 | # print("Creating another user") 30 | # start = datetime.datetime.now() 31 | # future_user2: StateflowFuture[User] = User("wouter-user2") 32 | # 33 | # try: 34 | # user2: User = future_user2.get() 35 | # except StateflowFailure: 36 | # user2: User = client.find(User, "wouter-user2").get() 37 | # end = datetime.datetime.now() 38 | # delta = end - start 39 | # print(f"Creating another user took {delta.total_seconds() * 1000}ms") 40 | # 41 | # print("Done!") 42 | # start = datetime.datetime.now() 43 | # for_loop: int = user.simple_for_loop([user, user2]).get(timeout=5) 44 | # end = datetime.datetime.now() 45 | # delta = end - start 46 | # print(f"Simple for loop took {delta.total_seconds() * 1000}ms") 47 | # 48 | # print(user.balance.get()) 49 | # print(user2.balance.get()) 50 | # # print(for_loop) 51 | # # print("") 52 | # print("Creating an item: ") 53 | # start = datetime.datetime.now() 54 | # future_item: StateflowFuture[Item] = Item("coke", 10) 55 | # 56 | # try: 57 | # item: Item = future_item.get() 58 | # except StateflowFailure: 59 | # item: Item = client.find(Item, "coke").get() 60 | # end = datetime.datetime.now() 61 | # delta = end - start 62 | # print(f"Creating coke took {delta.total_seconds() * 1000}ms") 63 | # 64 | # 65 | # start = datetime.datetime.now() 66 | # future_item2: StateflowFuture[Item] = Item("pepsi", 10) 67 | # 68 | # try: 69 | # item2: Item = future_item2.get() 70 | # except StateflowFailure: 71 | # item2: Item = client.find(Item, "pepsi").get() 72 | # end = datetime.datetime.now() 73 | # delta = end - start 74 | # print(f"Creating another pepsi took {delta.total_seconds() * 1000}ms") 75 | # 76 | # 77 | # start = datetime.datetime.now() 78 | # hi = user.state_requests([item, item2]).get(timeout=10) 79 | # print(hi) 80 | # end = datetime.datetime.now() 81 | # delta = end - start 82 | # print(f"State requests took {delta.total_seconds() * 1000}ms") 83 | # item.stock = 5 84 | # user.balance = 10 85 | # 86 | # 87 | # start_time = time.time() 88 | # print(f"User balance: {user.balance.get()}") 89 | # print(f"Item stock: {item.stock.get()} and price {item.price.get()}") 90 | # 91 | # print() 92 | # # This is impossible. 93 | # start = datetime.datetime.now() 94 | # print(f"Let's try to buy 100 coke's of 10EU?: {user.buy_item(100, item).get()}") 95 | # end = datetime.datetime.now() 96 | # delta = end - start 97 | # print(f"Buy item took {delta.total_seconds() * 1000}ms") 98 | # 99 | # # Jeej we buy one, user will end up with 0 balance and there is 4 left in stock. 100 | # start = datetime.datetime.now() 101 | # print(f"Lets' try to buy 1 coke's of 10EU?: {user.buy_item(1, item).get()}") 102 | # end = datetime.datetime.now() 103 | # delta = end - start 104 | # print(f"Another buy item took {delta.total_seconds() * 1000}ms") 105 | # 106 | # print() 107 | # # user balance 0, stock 4. 108 | # print(f"Final user balance: {user.balance.get()}") 109 | # print(f"Final item stock: {item.stock.get()}") 110 | # end_time = time.time() 111 | # diff = (end_time - start_time) * 1000 112 | # 113 | # print(f"\nThat took {diff}ms") 114 | -------------------------------------------------------------------------------- /demo_common.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import List 3 | import stateflow 4 | 5 | 6 | @stateflow.stateflow 7 | class Item: 8 | def __init__(self, item_name: str, price: int): 9 | self.item_name: str = item_name 10 | self.stock: int = 0 11 | self.price: int = price 12 | 13 | def update_stock(self, amount: int) -> bool: 14 | if (self.stock + amount) < 0: # We can't get a stock < 0. 15 | return False 16 | 17 | self.stock += amount 18 | return True 19 | 20 | def set_stock(self, amount: int): 21 | self.stock = amount 22 | 23 | def __key__(self): 24 | return self.item_name 25 | 26 | 27 | @stateflow.stateflow 28 | class User: 29 | def __init__(self, username: str): 30 | self.username: str = username 31 | self.balance: int = 1 32 | self.items: List[Item] = [] 33 | 34 | def update_balance(self, x: int): 35 | self.balance += x 36 | 37 | def get_balance(self): 38 | return self.balance 39 | 40 | def buy_item(self, amount: int, item: Item) -> bool: 41 | total_price = amount * item.price 42 | 43 | if self.balance < total_price: 44 | return False 45 | 46 | # Decrease the stock. 47 | decrease_stock = item.update_stock(-amount) 48 | 49 | if not decrease_stock: 50 | return False # For some reason, stock couldn't be decreased. 51 | 52 | self.balance -= total_price 53 | return True 54 | 55 | def simple_for_loop(self, users: List["User"]): 56 | i = 0 57 | for user in users: 58 | if i > 0: 59 | user.update_balance(9) 60 | else: 61 | user.update_balance(4) 62 | i += 1 63 | 64 | return i 65 | 66 | def state_requests(self, items: List[Item]): 67 | total: int = 0 68 | first_item: Item = items[0] 69 | print(f"Total is now {total}.") 70 | total += first_item.stock # Total = 0 71 | first_item.set_stock(10) 72 | total += first_item.stock # total = 10 73 | first_item.set_stock(0) 74 | for x in items: 75 | total += x.stock # total = 10 76 | x.set_stock(5) 77 | total += x.stock # total = 10 + 5 + 5 = 20 78 | 79 | print(f"Total is now {total}.") 80 | total += first_item.stock # total = 25 81 | if total > 0: 82 | first_item.set_stock(1) 83 | 84 | print(f"Total is now {total}.") 85 | 86 | first_item: Item = first_item 87 | total += first_item.stock # total = 26 88 | return total 89 | 90 | def __key__(self): 91 | return self.username 92 | -------------------------------------------------------------------------------- /demo_runtime.py: -------------------------------------------------------------------------------- 1 | from demo_common import stateflow 2 | from stateflow.runtime.beam_runtime import BeamRuntime, Runtime 3 | from stateflow.serialization.pickle_serializer import PickleSerializer 4 | 5 | # Initialize stateflow 6 | flow = stateflow.init() 7 | 8 | runtime: BeamRuntime = BeamRuntime(flow, serializer=PickleSerializer()) 9 | runtime.run() 10 | -------------------------------------------------------------------------------- /demo_runtime_universalis.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import datetime 3 | import time 4 | 5 | from universalis.common.stateflow_ingress import IngressTypes 6 | from universalis.universalis import Universalis 7 | import demo_common 8 | from demo_common import User, Item, stateflow 9 | from stateflow.client.future import StateflowFuture, StateflowFailure 10 | from stateflow.client.universalis_client import UniversalisClient 11 | from stateflow.runtime.universalis.universalis_runtime import UniversalisRuntime 12 | 13 | 14 | UNIVERSALIS_HOST: str = 'localhost' 15 | UNIVERSALIS_PORT: int = 8886 16 | KAFKA_URL = 'localhost:9093' 17 | 18 | 19 | async def main(): 20 | # Initialize stateflow 21 | flow = stateflow.init() 22 | 23 | universalis_interface = Universalis(UNIVERSALIS_HOST, UNIVERSALIS_PORT, 24 | ingress_type=IngressTypes.KAFKA, 25 | kafka_url=KAFKA_URL) 26 | 27 | runtime: UniversalisRuntime = UniversalisRuntime(flow, 28 | universalis_interface, 29 | "Stateflow", 30 | n_partitions=6) 31 | 32 | universalis_operators = await runtime.run((demo_common, )) 33 | 34 | print(universalis_operators.keys()) 35 | 36 | flow = stateflow.init() 37 | client: UniversalisClient = UniversalisClient(flow=flow, 38 | universalis_client=universalis_interface, 39 | kafka_url=KAFKA_URL, 40 | operators=universalis_operators) 41 | 42 | client.wait_until_healthy(timeout=1) 43 | 44 | print("Creating a user: ") 45 | start = datetime.datetime.now() 46 | future_user: StateflowFuture[User] = User("wouter-user") 47 | 48 | try: 49 | user: User = future_user.get() 50 | except StateflowFailure: 51 | user: User = client.find(User, "wouter-user").get() 52 | 53 | end = datetime.datetime.now() 54 | delta = end - start 55 | print(f"Creating user took {delta.total_seconds() * 1000}ms") 56 | 57 | print("Creating another user") 58 | start = datetime.datetime.now() 59 | future_user2: StateflowFuture[User] = User("wouter-user2") 60 | 61 | try: 62 | user2: User = future_user2.get() 63 | except StateflowFailure: 64 | user2: User = client.find(User, "wouter-user2").get() 65 | end = datetime.datetime.now() 66 | delta = end - start 67 | print(f"Creating another user took {delta.total_seconds() * 1000}ms") 68 | 69 | print("Done!") 70 | 71 | start = datetime.datetime.now() 72 | print('Running for loop') 73 | for_loop: int = user.simple_for_loop([user, user2]).get() 74 | print(f'For loop result: {for_loop}') 75 | end = datetime.datetime.now() 76 | delta = end - start 77 | print(f"Simple for loop took {delta.total_seconds() * 1000}ms") 78 | 79 | if user.balance.get() != 5: 80 | print(f"\nThis should have been 5 but it is: {user.balance.get()}\n") 81 | exit() 82 | 83 | if user2.balance.get() != 10: 84 | print(f"\nThis should have been 10 but it is: {user2.balance.get()}\n") 85 | exit() 86 | 87 | print("Creating an item: ") 88 | start = datetime.datetime.now() 89 | future_item: StateflowFuture[Item] = Item("coke", 10) 90 | 91 | try: 92 | item: Item = future_item.get() 93 | except StateflowFailure: 94 | item: Item = client.find(Item, "coke").get() 95 | end = datetime.datetime.now() 96 | delta = end - start 97 | print(f"Creating coke took {delta.total_seconds() * 1000}ms") 98 | 99 | start = datetime.datetime.now() 100 | future_item2: StateflowFuture[Item] = Item("pepsi", 10) 101 | 102 | try: 103 | item2: Item = future_item2.get() 104 | except StateflowFailure: 105 | item2: Item = client.find(Item, "pepsi").get() 106 | end = datetime.datetime.now() 107 | delta = end - start 108 | print(f"Creating another pepsi took {delta.total_seconds() * 1000}ms") 109 | print(item.item_name.get()) 110 | print(item2.item_name.get()) 111 | 112 | print('-------------------------------------------------') 113 | print('Starting complex logic') 114 | print('-------------------------------------------------') 115 | start = datetime.datetime.now() 116 | hi = user.state_requests([item, item2]).get(timeout=10) 117 | print(hi) 118 | end = datetime.datetime.now() 119 | delta = end - start 120 | print(f"State requests took {delta.total_seconds() * 1000}ms") 121 | item.stock = 5 122 | user.balance = 10 123 | 124 | start_time = time.time() 125 | print(f"User balance: {user.balance.get()}") 126 | print(f"Item stock: {item.stock.get()} and price {item.price.get()}") 127 | 128 | print() 129 | # This is impossible. 130 | start = datetime.datetime.now() 131 | print(f"Let's try to buy 100 coke's of 10EU?: {user.buy_item(100, item).get()}") 132 | end = datetime.datetime.now() 133 | delta = end - start 134 | print(f"Buy item took {delta.total_seconds() * 1000}ms") 135 | 136 | # Jeej we buy one, user will end up with 0 balance and there is 4 left in stock. 137 | start = datetime.datetime.now() 138 | print(f"Lets' try to buy 1 coke's of 10EU?: {user.buy_item(1, item).get()}") 139 | end = datetime.datetime.now() 140 | delta = end - start 141 | print(f"Another buy item took {delta.total_seconds() * 1000}ms") 142 | 143 | print() 144 | # user balance 0, stock 4. 145 | print(f"Final user balance: {user.balance.get()}") 146 | print(f"Final item stock: {item.stock.get()}") 147 | end_time = time.time() 148 | diff = (end_time - start_time) * 1000 149 | 150 | print(f"\nThat took {diff}ms") 151 | 152 | 153 | asyncio.run(main()) 154 | -------------------------------------------------------------------------------- /demo_runtime_universalis_ycsb.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | import random 4 | 5 | from universalis.common.stateflow_ingress import IngressTypes 6 | from universalis.universalis import Universalis 7 | import demo_ycsb 8 | from demo_ycsb import YCSBEntity, stateflow 9 | from stateflow.client.universalis_client import UniversalisClient 10 | from stateflow.runtime.universalis.universalis_runtime import UniversalisRuntime 11 | from zipfian_generator import ZipfGenerator 12 | 13 | UNIVERSALIS_HOST: str = 'localhost' 14 | UNIVERSALIS_PORT: int = 8886 15 | KAFKA_URL = 'localhost:9093' 16 | N_ENTITIES = 100 17 | keys: list[int] = list(range(N_ENTITIES)) 18 | STARTING_AMOUNT = 100 19 | N_TASKS = 1000 20 | WORKLOAD = 'a' 21 | 22 | 23 | async def main(): 24 | 25 | # Initialize the zipf generator 26 | zipf_gen = ZipfGenerator(items=N_ENTITIES) 27 | operations = ["r", "u", "t"] 28 | operation_mix_a = [0.5, 0.5, 0.0] 29 | operation_mix_b = [0.95, 0.05, 0.0] 30 | operation_mix_t = [0.0, 0.0, 1.0] 31 | 32 | if WORKLOAD == 'a': 33 | operation_mix = operation_mix_a 34 | elif WORKLOAD == 'b': 35 | operation_mix = operation_mix_b 36 | else: 37 | operation_mix = operation_mix_t 38 | 39 | # Initialize stateflow 40 | flow = stateflow.init() 41 | 42 | universalis_interface = Universalis(UNIVERSALIS_HOST, UNIVERSALIS_PORT, 43 | ingress_type=IngressTypes.KAFKA, 44 | kafka_url=KAFKA_URL) 45 | 46 | runtime: UniversalisRuntime = UniversalisRuntime(flow, 47 | universalis_interface, 48 | "Stateflow", 49 | n_partitions=6) 50 | 51 | universalis_operators = await runtime.run((demo_ycsb, )) 52 | 53 | print(universalis_operators.keys()) 54 | 55 | flow = stateflow.init() 56 | client: UniversalisClient = UniversalisClient(flow=flow, 57 | universalis_client=universalis_interface, 58 | kafka_url=KAFKA_URL, 59 | operators=universalis_operators) 60 | 61 | time.sleep(1) 62 | 63 | client.wait_until_healthy(timeout=1) 64 | 65 | entities: dict[int, YCSBEntity] = {} 66 | print("Creating the entities...") 67 | for i in keys: 68 | print(f'Creating: {i}') 69 | entities[i] = YCSBEntity(str(i), STARTING_AMOUNT).get() 70 | 71 | client.stop_consumer_thread() 72 | 73 | operation_counts: dict[str, int] = {"r": 0, "u": 0, "t": 0} 74 | time.sleep(10) 75 | client.start_result_consumer_process() 76 | time.sleep(10) 77 | 78 | for i in range(N_TASKS): 79 | key = keys[next(zipf_gen)] 80 | op = random.choices(operations, weights=operation_mix, k=1)[0] 81 | operation_counts[op] += 1 82 | if op == "r": 83 | entities[key].read() 84 | elif op == "u": 85 | entities[key].update(STARTING_AMOUNT) 86 | else: 87 | key2 = keys[next(zipf_gen)] 88 | while key2 == key: 89 | key2 = keys[next(zipf_gen)] 90 | entities[key].transfer(1, entities[key2]) 91 | 92 | client.store_request_csv() 93 | print(operation_counts) 94 | time.sleep(10) 95 | print("Stopping") 96 | client.stop_result_consumer_process() 97 | print("Done") 98 | 99 | asyncio.run(main()) 100 | -------------------------------------------------------------------------------- /demo_ycsb.py: -------------------------------------------------------------------------------- 1 | import stateflow 2 | 3 | 4 | @stateflow.stateflow 5 | class YCSBEntity: 6 | 7 | def __init__(self, key: str, value: int): 8 | # insert 9 | self.key: str = key 10 | self.value: int = value 11 | 12 | def read(self): 13 | return self.key, self.value 14 | 15 | def update(self, new_value: int): 16 | self.value = new_value 17 | 18 | def add_funds(self, transfer_amount: int): 19 | self.value += transfer_amount 20 | 21 | def transfer(self, transfer_amount: int, other_entity: "YCSBEntity") -> bool: 22 | new_value: int = self.value - transfer_amount 23 | if new_value < 0: 24 | return False 25 | self.value = new_value 26 | other_entity.add_funds(transfer_amount) 27 | return True 28 | 29 | def __key__(self): 30 | return self.key 31 | -------------------------------------------------------------------------------- /deployment/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | services: 3 | zookeeper: 4 | image: wurstmeister/zookeeper:latest 5 | ports: 6 | - "2181:2181" 7 | kafka-broker: 8 | image: wurstmeister/kafka:2.12-2.2.0 9 | ports: 10 | - "9092:9092" 11 | expose: 12 | - "9093" 13 | environment: 14 | KAFKA_ADVERTISED_LISTENERS: INSIDE://kafka-broker:9093,OUTSIDE://localhost:9092 15 | KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: INSIDE:PLAINTEXT,OUTSIDE:PLAINTEXT 16 | KAFKA_LISTENERS: INSIDE://0.0.0.0:9093,OUTSIDE://0.0.0.0:9092 17 | KAFKA_INTER_BROKER_LISTENER_NAME: INSIDE 18 | KAFKA_ZOOKEEPER_CONNECT: zookeeper:2181 -------------------------------------------------------------------------------- /deployment/statefun/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "2.1" 2 | 3 | services: 4 | 5 | ############################################################### 6 | # StateFun runtime 7 | ############################################################### 8 | 9 | statefun-manager: 10 | image: apache/flink-statefun:3.0.0 11 | expose: 12 | - "6123" 13 | ports: 14 | - "8081:8081" 15 | environment: 16 | ROLE: master 17 | MASTER_HOST: statefun-manager 18 | volumes: 19 | - ./module.yaml:/opt/statefun/modules/stateflow/module.yaml 20 | 21 | statefun-worker: 22 | image: apache/flink-statefun:3.0.0 23 | expose: 24 | - "6121" 25 | - "6122" 26 | environment: 27 | ROLE: worker 28 | MASTER_HOST: statefun-manager 29 | volumes: 30 | - ./module.yaml:/opt/statefun/modules/stateflow/module.yaml 31 | 32 | host-machine: 33 | image: qoomon/docker-host@sha256:e0f021dd77c7c26d37b825ab2cbf73cd0a77ca993417da80a14192cb041937b0 34 | cap_add: ['NET_ADMIN', 'NET_RAW'] 35 | mem_limit: 8M 36 | restart: on-failure 37 | environment: 38 | PORTS: 8000 -------------------------------------------------------------------------------- /deployment/statefun/module.yaml: -------------------------------------------------------------------------------- 1 | version: '3.0' 2 | module: 3 | meta: 4 | type: remote 5 | spec: 6 | egresses: 7 | - egress: 8 | meta: 9 | id: stateflow/kafka-egress 10 | type: io.statefun.kafka/egress 11 | spec: 12 | address: kafka-broker:9093 13 | endpoints: 14 | - endpoint: 15 | meta: 16 | kind: http 17 | spec: 18 | functions: globals/ping 19 | urlPathTemplate: http://host-machine:8000/statefun 20 | - endpoint: 21 | meta: 22 | kind: http 23 | spec: 24 | functions: global/YCSBEntity 25 | urlPathTemplate: http://host-machine:8000/statefun 26 | - endpoint: 27 | meta: 28 | kind: http 29 | spec: 30 | functions: global/YCSBEntity_create 31 | urlPathTemplate: http://host-machine:8000/statefun 32 | ingresses: 33 | - ingress: 34 | meta: 35 | id: stateflow/kafka-ingress 36 | type: io.statefun.kafka/ingress 37 | spec: 38 | address: kafka-broker:9093 39 | consumerGroupId: stateflow-statefun-consumer 40 | topics: 41 | - topic: globals_ping 42 | targets: 43 | - globals/ping 44 | valueType: stateflow/byte_type 45 | - topic: global_YCSBEntity 46 | targets: 47 | - global/YCSBEntity 48 | valueType: stateflow/byte_type 49 | - topic: global_YCSBEntity_create 50 | targets: 51 | - global/YCSBEntity_create 52 | valueType: stateflow/byte_type -------------------------------------------------------------------------------- /fastapi_client.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from stateflow.client.fastapi.kafka import KafkaFastAPIClient, StateflowFailure 4 | from stateflow.client.fastapi.aws_lambda import AWSLambdaFastAPIClient 5 | from demo_common import stateflow, User 6 | 7 | client = KafkaFastAPIClient(stateflow.init()) 8 | app = client.get_app() 9 | -------------------------------------------------------------------------------- /img/fun_address.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
Function Address
Function Addr...
Key
Key
Function Type
Function Type
Namespace
Namespace
Name
Name
Stateful
Stateful
Viewer does not support full SVG 1.1
-------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "stateflow", 3 | "description": "", 4 | "version": "0.1.0", 5 | "dependencies": {}, 6 | "devDependencies": { 7 | "serverless-python-requirements": "^5.1.1" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiokafka==0.7.2 2 | anyio==3.6.1 3 | apache-beam==2.40.0 4 | apache-flink==1.9.3 5 | asn1crypto==1.5.1 6 | attrs==21.4.0 7 | beam-nuggets==0.18.1 8 | boto3==1.24.22 9 | botocore==1.27.22 10 | charset-normalizer==2.1.0 11 | click==8.1.3 12 | cloudpickle==2.1.0 13 | confluent-kafka==1.9.0 14 | crcmod==1.7 15 | dill==0.3.1.1 16 | docopt==0.6.2 17 | fastapi==0.78.0 18 | fastavro==1.5.2 19 | graphviz==0.20 20 | greenlet==1.1.2 21 | grpcio==1.47.0 22 | h11==0.13.0 23 | hdfs==2.7.0 24 | httplib2==0.20.4 25 | idna==3.3 26 | iniconfig==1.1.1 27 | jmespath==1.0.1 28 | kafka-python==2.0.2 29 | libcst==0.4.5 30 | mypy-extensions==0.4.3 31 | numpy==1.22.4 32 | orjson==3.7.5 33 | packaging==21.3 34 | pg8000==1.29.1 35 | pluggy==1.0.0 36 | proto-plus==1.20.6 37 | protobuf==3.20.1 38 | py==1.11.0 39 | py4j==0.10.8.1 40 | pyarrow==7.0.0 41 | pydantic==1.9.1 42 | pydot==1.4.2 43 | pymongo==3.12.3 44 | PyMySQL==1.0.2 45 | pynamodb==5.2.1 46 | pyparsing==3.0.9 47 | pytest==7.1.2 48 | python-dateutil==2.8.2 49 | pytz==2022.1 50 | PyYAML==6.0 51 | requests==2.28.1 52 | s3transfer==0.6.0 53 | scramp==1.4.1 54 | six==1.16.0 55 | sniffio==1.2.0 56 | SQLAlchemy==1.4.39 57 | SQLAlchemy-Utils==0.38.2 58 | starlette==0.19.1 59 | tomli==2.0.1 60 | typing-inspect==0.7.1 61 | typing_extensions==4.3.0 62 | ujson==5.4.0 63 | urllib3==1.26.9 64 | uvicorn==0.18.2 65 | -------------------------------------------------------------------------------- /runtime_aws.py: -------------------------------------------------------------------------------- 1 | from benchmark.hotel.user import User 2 | from benchmark.hotel.search import Search, Geo, Rate 3 | from benchmark.hotel.reservation import Reservation 4 | from benchmark.hotel.profile import Profile, HotelProfile 5 | from benchmark.hotel.recommend import RecommendType, Recommend, stateflow 6 | from stateflow.runtime.aws.gateway_lambda import AWSGatewayLambdaRuntime 7 | 8 | # Initialize stateflow 9 | flow = stateflow.init() 10 | 11 | runtime, handler = AWSGatewayLambdaRuntime.get_handler(flow, gateway=True) 12 | -------------------------------------------------------------------------------- /serverless.yml: -------------------------------------------------------------------------------- 1 | service: stateflow 2 | provider: 3 | name: aws 4 | region: eu-west-2 5 | runtime: python3.8 6 | iam: 7 | role: arn:aws:iam::958167380706:role/stateflow-dev-eu-west-1-lambdaRole 8 | plugins: 9 | - serverless-python-requirements 10 | package: 11 | exclude: 12 | - 'venv/**' 13 | custom: 14 | pythonRequirements: 15 | usePipenv: false 16 | slim: true 17 | noDeploy: 18 | - Flask 19 | - flake8 20 | - apache-beam 21 | - coverage 22 | - confluent-kafka 23 | - beam-nuggets 24 | - httplib2 25 | - google-api-python-client 26 | - pytest-mock 27 | - PyHamcrest 28 | - pytest-docker-fixtures 29 | - pytest-timeout 30 | - apache-flink 31 | - boto3 32 | - apache-flinl-statefun 33 | - pyflink 34 | - pandas 35 | functions: 36 | stateflow: 37 | handler: runtime_aws.handler 38 | memorySize: 1024 # optional, in MB, default is 1024 39 | events: 40 | - http: 41 | path: stateflow 42 | method: post -------------------------------------------------------------------------------- /serverless_gateway.yml: -------------------------------------------------------------------------------- 1 | service: stateflow 2 | provider: 3 | name: aws 4 | region: eu-west-1 5 | runtime: python3.8 6 | iam: 7 | role: arn:aws:iam::958167380706:role/stateflow-dev-eu-west-1-lambdaRole 8 | plugins: 9 | - serverless-python-requirements 10 | package: 11 | exclude: 12 | - 'venv/**' 13 | custom: 14 | pythonRequirements: 15 | usePipenv: false 16 | slim: true 17 | noDeploy: 18 | - pytest 19 | - Flask 20 | - flake8 21 | - apache-beam 22 | - coverage 23 | - confluent-kafka 24 | - beam-nuggets 25 | - httplib2 26 | - google-api-python-client 27 | - pytest-mock 28 | - PyHamcrest 29 | - pytest-docker-fixtures 30 | - pytest-timeout 31 | - apache-flink 32 | - boto3 33 | functions: 34 | stateflow: 35 | handler: aws_runtime_gateway.handler 36 | memorySize: 1024 # optional, in MB, default is 1024 37 | events: 38 | - http: 39 | path: stateflow 40 | method: post -------------------------------------------------------------------------------- /serverless_kinesis.yml: -------------------------------------------------------------------------------- 1 | service: stateflow 2 | provider: 3 | name: aws 4 | region: eu-west-1 5 | runtime: python3.8 6 | iam: 7 | role: arn:aws:iam::958167380706:role/stateflow-dev-eu-west-1-lambdaRole 8 | plugins: 9 | - serverless-python-requirements 10 | package: 11 | exclude: 12 | - 'venv/**' 13 | custom: 14 | pythonRequirements: 15 | usePipenv: false 16 | slim: true 17 | noDeploy: 18 | - pytest 19 | - Flask 20 | - flake8 21 | - apache-beam 22 | - coverage 23 | - confluent-kafka 24 | - beam-nuggets 25 | - httplib2 26 | - google-api-python-client 27 | - pytest-mock 28 | - PyHamcrest 29 | - pytest-docker-fixtures 30 | - pytest-timeout 31 | - apache-flink 32 | - boto3 33 | functions: 34 | stateflow: 35 | handler: aws_runtime_kinesis.handler 36 | memorySize: 1024 # optional, in MB, default is 1024 37 | events: 38 | - stream: 39 | arn: arn:aws:kinesis:eu-west-1:958167380706:stream/stateflow-request 40 | batchSize: 1 41 | startingPosition: LATEST 42 | maximumRetryAttempts: 1 43 | batchWindow: 1 44 | enabled: true -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="stateflow", 5 | version="0.0.1", 6 | author="Wouter Zorgdrager", 7 | author_email="zorgdragerw@gmail.com", 8 | python_requires=">=3.6", 9 | packages=find_packages(exclude=("tests")), 10 | install_requires=[ 11 | "graphviz", 12 | "libcst", 13 | "apache-beam", 14 | "ujson", 15 | "confluent-kafka", 16 | "apache-flink", 17 | "pynamodb", 18 | "boto3", 19 | "fastapi", 20 | "uvicorn", 21 | "aiokafka", 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /stateflow/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import stateflow, init, service_by_id 2 | from .util.stateflow_test import stateflow_test 3 | -------------------------------------------------------------------------------- /stateflow/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/analysis/__init__.py -------------------------------------------------------------------------------- /stateflow/analysis/ast_utils.py: -------------------------------------------------------------------------------- 1 | import libcst as cst 2 | import libcst.matchers as m 3 | from typing import Any, List 4 | 5 | 6 | def is_self(node: cst.CSTNode) -> bool: 7 | """ 8 | Verifies if an Attribute node, attributes to 'self'. 9 | This is used to find the self attributes of a class. 10 | 11 | :param node: the Attribute node to verify. 12 | :return: True if nod is an Attribute and the value is 'self'. 13 | """ 14 | if m.matches(node, m.Attribute(value=m.Name(value="self"))): 15 | return True 16 | 17 | return False 18 | 19 | 20 | def extract_types( 21 | module_node: cst.CSTNode, node: cst.Annotation, unpack: bool = False 22 | ) -> Any: 23 | """ 24 | Extracts the (evaluated) type of an Annotation object. 25 | 26 | We always remove quotations from the final type. If you define: 27 | fun() -> "Item": 28 | The type is actually Item and not "Item". However, the quotation marks are necessary for forward declaration. 29 | 30 | :param module_node: the 'context' of this annotation: i.e. the module in which it was declared. 31 | :param node: the actual annotation object. 32 | :param unpack: if the annotation holds a tuple, its unpacked and each element is individually evaluated. 33 | :return: the evaluated type annotation. 34 | """ 35 | if unpack and m.matches(node.annotation, m.Subscript(value=m.Name(value="Tuple"))): 36 | types: List[Any] = [] 37 | 38 | # Unpacks the Tuple as: 39 | # Subscript(value=Name(value="Tuple"), slice=[SubscriptElement(slice=Index(value=Name)),..]) 40 | for tuple_element in node.annotation.slice: 41 | if m.matches( 42 | tuple_element, m.SubscriptElement(slice=m.Index(value=m.Name())) 43 | ): 44 | types.append(tuple_element.slice.value.value.replace('"', "")) 45 | elif m.matches(tuple_element, m.SubscriptElement(slice=m.Index())): 46 | types.append( 47 | module_node.code_for_node(tuple_element.slice.value).replace( 48 | '"', "" 49 | ) 50 | ) 51 | 52 | return types 53 | 54 | return module_node.code_for_node(node.annotation).replace('"', "") 55 | -------------------------------------------------------------------------------- /stateflow/analysis/extract_class_descriptor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Any, Optional, Dict 2 | import libcst as cst 3 | from stateflow.dataflow.stateful_operator import NoType 4 | from stateflow.dataflow.state import StateDescriptor 5 | from stateflow.analysis.extract_method_descriptor import ExtractMethodDescriptor 6 | from stateflow.descriptors.method_descriptor import MethodDescriptor 7 | from stateflow.descriptors.class_descriptor import ClassDescriptor 8 | import libcst.helpers as helpers 9 | import libcst.matchers as m 10 | 11 | 12 | class ExtractClassDescriptor(cst.CSTVisitor): 13 | """Visits a ClassDefinition and extracts information to create a StatefulFunction.""" 14 | 15 | def __init__( 16 | self, module_node: cst.CSTNode, decorated_class_name: str, expression_provider 17 | ): 18 | self.module_node = module_node 19 | self.decorated_class_name = decorated_class_name 20 | 21 | # This maps an AST node to it's expression context (i.e. LOAD, STORE, DEL), we will use this in downstream tasks 22 | # especially for splitting methods. 23 | self.expression_provider = expression_provider 24 | 25 | # Name of the class and if it is already defined. 26 | self.is_defined: bool = False 27 | self.class_name: str = None 28 | self.class_node: cst.ClassDef = None 29 | 30 | # Used to extract state. 31 | self.self_attributes: List[Tuple[str, Any]] = [] 32 | 33 | # Keep track of all extracted methods. 34 | self.method_descriptors: List[MethodDescriptor] = [] 35 | 36 | # We maintain a stack to keep track 'in' which classes we are currently. 37 | self.class_stack = [] 38 | 39 | def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: 40 | """Visits a function definition and analyze it. 41 | 42 | Extracts the following properties of a function: 43 | 1. The declared self variables (i.e. state). 44 | 2. The input variables of a function. 45 | 3. The output variables of a function. 46 | 4. If a function is read-only. 47 | 48 | :param node: the node to analyze. 49 | :return: always returns False. 50 | """ 51 | if m.matches(node.asynchronous, m.Asynchronous()): 52 | raise AttributeError( 53 | "Function within a stateful function cannot be defined asynchronous." 54 | ) 55 | 56 | method_extractor: ExtractMethodDescriptor = ExtractMethodDescriptor( 57 | self.module_node, node 58 | ) 59 | node.visit(method_extractor) 60 | 61 | # Get self attributes of the function and add to the attributes list of the class. 62 | self.self_attributes.extend(method_extractor.self_attributes) 63 | 64 | duplicates: List[MethodDescriptor] = [ 65 | method 66 | for method in self.method_descriptors 67 | if method.method_name == node.name.value 68 | ] 69 | 70 | if ( 71 | len(duplicates) > 0 72 | ): # We remove duplicate method definitions, and keep the last. 73 | self.method_descriptors.remove(duplicates[0]) 74 | 75 | # Create a wrapper for this analyzed class method. 76 | self.method_descriptors.append( 77 | ExtractMethodDescriptor.create_method_descriptor(method_extractor) 78 | ) 79 | 80 | # We don't need to visit the FunctionDefs, we already analyze them in ExtractStatefulFun 81 | return False 82 | 83 | def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: 84 | """Visits a class and extracts useful information. 85 | 86 | This retrieves the name of the class and ensures that no nested classes are defined. 87 | 88 | :param node: the class definition to analyze. 89 | """ 90 | 91 | # If original class is defined and we're nested, we throw an exception. 92 | if len(self.class_stack) and self.is_defined: # We don't allow nested classes. 93 | raise AttributeError("Nested classes are not allowed.") 94 | 95 | if self.decorated_class_name != helpers.get_full_name_for_node(node): 96 | return False 97 | 98 | self.is_defined = True 99 | self.class_name = helpers.get_full_name_for_node(node) 100 | self.class_node = node 101 | self.class_stack.append(self.class_name) 102 | 103 | return True 104 | 105 | def leave_ClassDef(self, node: cst.ClassDef): 106 | if len(self.class_stack) > 0: 107 | self.class_stack.pop() 108 | 109 | def merge_self_attributes(self) -> Dict[str, any]: 110 | """Merges all self attributes. 111 | 112 | Merges all collected declarations attributing to 'self' into a dictionary. A key can only exist once. 113 | Type hints are stored as value for the key. Conflicting type hints for the same key will throw an error. 114 | Keys without type hints are valued as 'NoType'. Example: 115 | ``` 116 | self.x : int = 3 117 | self.y, self.z = 4 118 | ``` 119 | will be stored as: `{"x": "int", "y": "NoType", "z": "NoType"}` 120 | 121 | :return: the merged attributes. 122 | """ 123 | attributes = {} 124 | 125 | for var_name, typ in self.self_attributes: 126 | if var_name in attributes: 127 | if typ == NoType: # Skip NoTypes. 128 | continue 129 | elif ( 130 | attributes[var_name] == "NoType" 131 | ): # If current type is NoType, update to an actual type. 132 | attributes[var_name] = typ 133 | elif ( 134 | typ != attributes[var_name] 135 | ): # Throw error when type hints conflict. 136 | raise AttributeError( 137 | f"Stateful Function {self.class_name} has two declarations of {var_name} with different types {typ} != {attributes[var_name]}." 138 | ) 139 | 140 | else: 141 | if typ == NoType: 142 | typ = "NoType" # Rename NoType to a proper str. 143 | 144 | attributes[var_name] = typ 145 | 146 | return attributes 147 | 148 | @staticmethod 149 | def create_class_descriptor( 150 | analyzed_visitor: "ExtractClassDescriptor", 151 | ) -> ClassDescriptor: 152 | """Creates a Stateful function. 153 | 154 | Leverages the analyzed visitor to create a Stateful Function. 155 | 156 | :param analyzed_visitor: the visitor that walked the ClassDef tree. 157 | :return: a Stateful Function object. 158 | """ 159 | state_desc: StateDescriptor = StateDescriptor( 160 | analyzed_visitor.merge_self_attributes() 161 | ) 162 | return ClassDescriptor( 163 | class_name=analyzed_visitor.class_name, 164 | module_node=analyzed_visitor.module_node, 165 | class_node=analyzed_visitor.class_node, 166 | state_desc=state_desc, 167 | methods_dec=analyzed_visitor.method_descriptors, 168 | expression_provider=analyzed_visitor.expression_provider, 169 | ) 170 | -------------------------------------------------------------------------------- /stateflow/client/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/client/__init__.py -------------------------------------------------------------------------------- /stateflow/client/aws_client.py: -------------------------------------------------------------------------------- 1 | from stateflow.client.stateflow_client import StateflowClient 2 | from stateflow.serialization.pickle_serializer import PickleSerializer, SerDe 3 | from stateflow.dataflow.dataflow import Dataflow 4 | from stateflow.dataflow.event import Event, EventType 5 | from stateflow.dataflow.address import FunctionType, FunctionAddress 6 | from stateflow.client.future import StateflowFuture, T 7 | from typing import Dict, Optional, Any 8 | import boto3 9 | import time 10 | import threading 11 | import uuid 12 | 13 | 14 | class AWSKinesisClient(StateflowClient): 15 | def __init__( 16 | self, 17 | flow: Dataflow, 18 | request_stream: str = "stateflow-request", 19 | reply_stream: str = "stateflow-reply", 20 | serializer: SerDe = PickleSerializer(), 21 | ): 22 | self.flow = flow 23 | self.request_stream = request_stream 24 | self.reply_stream = reply_stream 25 | self.serializer = serializer 26 | 27 | self.kinesis = boto3.client("kinesis") 28 | self.request_stream: str = request_stream 29 | self.reply_stream: str = reply_stream 30 | 31 | # The futures still to complete. 32 | self.futures: Dict[str, StateflowFuture] = {} 33 | 34 | # Set the wrapper. 35 | [op.meta_wrapper.set_client(self) for op in flow.operators] 36 | 37 | self.running = True 38 | consume_thread = threading.Thread(target=self.consume) 39 | consume_thread.start() 40 | 41 | def consume(self): 42 | iterator = self.kinesis.get_shard_iterator( 43 | StreamName=self.reply_stream, ShardId="0", ShardIteratorType="LATEST" 44 | )["ShardIterator"] 45 | 46 | while self.running: 47 | response = self.kinesis.get_records(ShardIterator=iterator) 48 | iterator = response["NextShardIterator"] 49 | for msg in response["Records"]: 50 | event_serialized = msg["Data"] 51 | event = self.serializer.deserialize_event(event_serialized) 52 | key = event.event_id 53 | 54 | # print(f"{key} -> Received message") 55 | if key in self.futures.keys(): 56 | if not event: 57 | event = self.serializer.deserialize_event(msg.value()) 58 | self.futures[key].complete(event) 59 | del self.futures[key] 60 | 61 | time.sleep(0.01) 62 | 63 | def find(self, clasz, key: str) -> StateflowFuture[Optional[Any]]: 64 | event_id = str(uuid.uuid4()) 65 | event_type = EventType.Request.FindClass 66 | fun_address = FunctionAddress(FunctionType.create(clasz.descriptor), key) 67 | payload = {} 68 | 69 | return self.send(Event(event_id, fun_address, event_type, payload), clasz) 70 | 71 | def send(self, event: Event, return_type: T = None) -> StateflowFuture[T]: 72 | self.kinesis.put_record( 73 | StreamName=self.request_stream, 74 | Data=self.serializer.serialize_event(event), 75 | PartitionKey=event.event_id, 76 | ) 77 | 78 | future = StateflowFuture( 79 | event.event_id, time.time(), event.fun_address, return_type 80 | ) 81 | 82 | self.futures[event.event_id] = future 83 | 84 | return future 85 | -------------------------------------------------------------------------------- /stateflow/client/aws_gateway_client.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | import boto3 4 | 5 | from stateflow.client.stateflow_client import StateflowClient 6 | from stateflow.dataflow.dataflow import Dataflow 7 | from stateflow.serialization.pickle_serializer import SerDe, PickleSerializer 8 | from stateflow.dataflow.event import Event 9 | from stateflow.client.future import StateflowFuture, T 10 | import time 11 | import requests 12 | import json 13 | 14 | 15 | class AWSGatewayClient(StateflowClient): 16 | def __init__( 17 | self, 18 | flow: Dataflow, 19 | api_gateway_url: str, 20 | serde: SerDe = PickleSerializer(), 21 | ): 22 | super().__init__(flow, serde) 23 | 24 | # Set the wrapper. 25 | [op.meta_wrapper.set_client(self) for op in flow.operators] 26 | 27 | self.api_gateway_url = api_gateway_url 28 | 29 | def send(self, event: Event, return_type: T = None) -> StateflowFuture[T]: 30 | event_serialized: bytes = self.serializer.serialize_event(event) 31 | event_encoded = base64.b64encode(event_serialized).decode() 32 | 33 | result = requests.post(self.api_gateway_url, json={"event": event_encoded}) 34 | result_json = result.json() 35 | result_event = base64.b64decode(result_json["event"]) 36 | 37 | fut = StateflowFuture( 38 | event.event_id, time.time(), event.fun_address, return_type 39 | ) 40 | 41 | fut.complete(self.serializer.deserialize_event(result_event)) 42 | 43 | return fut 44 | -------------------------------------------------------------------------------- /stateflow/client/fastapi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/client/fastapi/__init__.py -------------------------------------------------------------------------------- /stateflow/client/fastapi/aws_gateway.py: -------------------------------------------------------------------------------- 1 | from stateflow.client.fastapi.fastapi import ( 2 | FastAPIClient, 3 | Dataflow, 4 | PickleSerializer, 5 | SerDe, 6 | Event, 7 | T, 8 | StateflowFuture, 9 | StateflowFailure, 10 | FunctionAddress, 11 | FunctionType, 12 | EventType, 13 | ) 14 | import httpx 15 | import uuid 16 | import base64 17 | import time 18 | 19 | 20 | class AWSGatewayFastAPIClient(FastAPIClient): 21 | def __init__( 22 | self, 23 | flow: Dataflow, 24 | api_gateway_url: str, 25 | serializer: SerDe = PickleSerializer(), 26 | timeout: int = 5, 27 | root: str = "stateflow", 28 | ): 29 | super().__init__(flow, serializer, timeout, root) 30 | self.api_gateway_url: str = api_gateway_url 31 | self.http_client = httpx.AsyncClient() 32 | 33 | def setup_init(self): 34 | super().setup_init() 35 | 36 | async def send(self, event: Event, return_type: T = None): 37 | event_serialized: bytes = self.serializer.serialize_event(event) 38 | event_encoded = base64.b64encode(event_serialized).decode() 39 | 40 | result = await self.http_client.post( 41 | self.api_gateway_url, json={"event": event_encoded}, timeout=self.timeout 42 | ) 43 | result_json = result.json() 44 | 45 | result_event = base64.b64decode(result_json["event"]) 46 | 47 | future = StateflowFuture( 48 | event.event_id, time.time(), event.fun_address, return_type 49 | ) 50 | 51 | future.complete(self.serializer.deserialize_event(result_event)) 52 | 53 | try: 54 | result = future.get() 55 | except StateflowFailure as exc: 56 | return exc 57 | 58 | return result 59 | 60 | async def send_and_wait_with_future( 61 | self, 62 | event: Event, 63 | future: StateflowFuture, 64 | timeout_msg: str = "Event timed out.", 65 | ): 66 | event_serialized: bytes = self.serializer.serialize_event(event) 67 | event_encoded = base64.b64encode(event_serialized).decode() 68 | 69 | try: 70 | result = await self.http_client.post( 71 | self.api_gateway_url, 72 | json={"event": event_encoded}, 73 | timeout=self.timeout, 74 | ) 75 | except httpx.TimeoutException: 76 | future.complete_with_failure(timeout_msg) 77 | return 78 | 79 | result_json = result.json() 80 | result_event = base64.b64decode(result_json["event"]) 81 | future.complete(self.serializer.deserialize_event(result_event)) 82 | -------------------------------------------------------------------------------- /stateflow/client/fastapi/aws_lambda.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from stateflow.client.fastapi.fastapi import ( 4 | FastAPIClient, 5 | Dataflow, 6 | PickleSerializer, 7 | SerDe, 8 | Event, 9 | T, 10 | StateflowFuture, 11 | StateflowFailure, 12 | ) 13 | import base64 14 | import time 15 | import boto3 16 | 17 | 18 | class AWSLambdaFastAPIClient(FastAPIClient): 19 | def __init__( 20 | self, 21 | flow: Dataflow, 22 | function_name: str, 23 | serializer: SerDe = PickleSerializer(), 24 | timeout: int = 5, 25 | root: str = "stateflow", 26 | ): 27 | super().__init__(flow, serializer, timeout, root) 28 | self.client = boto3.client("lambda") 29 | self.function_name = function_name 30 | 31 | def setup_init(self): 32 | super().setup_init() 33 | 34 | async def send(self, event: Event, return_type: T = None): 35 | event_serialized: bytes = self.serializer.serialize_event(event) 36 | 37 | event_encoded = base64.b64encode(event_serialized) 38 | 39 | result = self.client.invoke( 40 | FunctionName=self.function_name, 41 | Payload=event_encoded, 42 | ) 43 | 44 | result = result["Payload"].read() 45 | result_json = json.loads(result)["body"] 46 | result_event = base64.b64decode(json.loads(result_json)["event"]) 47 | 48 | future = StateflowFuture( 49 | event.event_id, time.time(), event.fun_address, return_type 50 | ) 51 | 52 | future.complete(self.serializer.deserialize_event(result_event)) 53 | 54 | try: 55 | result = future.get() 56 | except StateflowFailure as exc: 57 | return exc 58 | 59 | return result 60 | 61 | async def send_and_wait_with_future( 62 | self, 63 | event: Event, 64 | future: StateflowFuture, 65 | timeout_msg: str = "Event timed out.", 66 | ): 67 | event_serialized: bytes = self.serializer.serialize_event(event) 68 | event_encoded = base64.b64encode(event_serialized).decode() 69 | 70 | result = self.client.invoke( 71 | FunctionName=self.function_name, 72 | Payload=json.dumps({"event": event_encoded}), 73 | ) 74 | result = result["Payload"].read() 75 | result_json = json.loads(result)["body"] 76 | result_event = json.loads(result_json)["event"] 77 | future.complete( 78 | self.serializer.deserialize_event(base64.b64decode(result_event)) 79 | ) 80 | -------------------------------------------------------------------------------- /stateflow/client/fastapi/kafka.py: -------------------------------------------------------------------------------- 1 | from stateflow.client.fastapi.fastapi import ( 2 | FastAPIClient, 3 | Dataflow, 4 | SerDe, 5 | PickleSerializer, 6 | Event, 7 | StateflowFailure, 8 | StateflowFuture, 9 | FunctionAddress, 10 | FunctionType, 11 | EventType, 12 | T, 13 | Dict, 14 | ) 15 | from stateflow.dataflow.dataflow import IngressRouter 16 | from aiokafka.helpers import create_ssl_context 17 | from aiokafka import AIOKafkaProducer, AIOKafkaConsumer 18 | import asyncio 19 | import uuid 20 | import time 21 | 22 | 23 | class KafkaFastAPIClient(FastAPIClient): 24 | def __init__( 25 | self, 26 | flow: Dataflow, 27 | serializer: SerDe = PickleSerializer(), 28 | timeout: int = 5, 29 | root: str = "stateflow", 30 | statefun_mode: bool = False, 31 | producer_config: Dict = {}, 32 | consumer_config: Dict = {}, 33 | ): 34 | super().__init__(flow, serializer, timeout, root) 35 | 36 | self.producer: AIOKafkaProducer = None 37 | self.consumer: AIOKafkaConsumer = None 38 | 39 | self.producer_config = producer_config 40 | self.consumer_config = consumer_config 41 | 42 | self.statefun_mode: bool = statefun_mode 43 | 44 | if self.statefun_mode: 45 | self.ingress_router = IngressRouter(self.serializer) 46 | 47 | def setup_init(self): 48 | super().setup_init() 49 | 50 | @self.app.on_event("startup") 51 | async def setup_kafka(): 52 | if "bootstrap_servers" not in self.producer_config: 53 | self.producer_config["bootstrap_servers"] = "localhost:9092" 54 | 55 | self.producer = AIOKafkaProducer( 56 | ssl_context=create_ssl_context(), **self.producer_config 57 | ) 58 | await self.producer.start() 59 | 60 | if "bootstrap_servers" not in self.consumer_config: 61 | self.producer_config["bootstrap_servers"] = "localhost:9092" 62 | 63 | if "group_id" not in self.consumer_config: 64 | self.producer_config["group_id"] = str(uuid.uuid4()) 65 | 66 | if "auto_offset_reset" not in self.consumer_config: 67 | self.producer_config["auto_offset_reset"] = "latest" 68 | 69 | self.consumer = AIOKafkaConsumer( 70 | "client_reply", 71 | ssl_context=create_ssl_context(), 72 | loop=asyncio.get_event_loop(), 73 | **self.consumer_config 74 | ) 75 | await self.consumer.start() 76 | asyncio.create_task(self.consume_forever()) 77 | 78 | @self.app.on_event("shutdown") 79 | async def stop_kafka(): 80 | await self.producer.close() 81 | 82 | return setup_kafka 83 | 84 | async def consume_forever(self): 85 | """Consumes from the Kafka topic. 86 | 87 | :return: 88 | """ 89 | async for msg in self.consumer: 90 | if msg.key and msg.key.decode("utf-8") not in self.request_map: 91 | continue 92 | 93 | return_event: Event = self.serializer.deserialize_event(msg.value) 94 | if return_event.event_id in self.request_map: 95 | self.request_map[return_event.event_id].set_result(return_event) 96 | del self.request_map[return_event.event_id] 97 | 98 | async def send_and_wait_with_future( 99 | self, 100 | event: Event, 101 | future: StateflowFuture, 102 | timeout_msg: str = "Event timed out.", 103 | ): 104 | 105 | if not self.statefun_mode: 106 | await self.producer.send_and_wait( 107 | "client_request", self.serializer.serialize_event(event) 108 | ) 109 | elif event.event_type == EventType.Request.Ping: 110 | await self.producer.send_and_wait( 111 | "globals_ping", 112 | self.serializer.serialize_event(event), 113 | key=bytes(event.event_id, "utf-8"), 114 | ) 115 | else: 116 | route = self.ingress_router.route(event) 117 | topic = route.route_name.replace("/", "_") 118 | key = route.key or event.event_id 119 | 120 | if not route.key: 121 | topic = topic + "_create" 122 | 123 | await self.producer.send_and_wait( 124 | topic, 125 | value=self.serializer.serialize_event(event), 126 | key=bytes(key, "utf-8"), 127 | ) 128 | 129 | loop = asyncio.get_running_loop() 130 | asyncio_future = loop.create_future() 131 | 132 | self.request_map[event.event_id] = asyncio_future 133 | 134 | try: 135 | result = await asyncio.wait_for(asyncio_future, timeout=self.timeout) 136 | except asyncio.TimeoutError: 137 | del self.request_map[event.event_id] 138 | future.complete_with_failure(timeout_msg) 139 | else: 140 | future.complete(result) 141 | 142 | async def send(self, event: Event, return_type: T = None): 143 | if not self.statefun_mode: 144 | await self.producer.send_and_wait( 145 | "client_request", self.serializer.serialize_event(event) 146 | ) 147 | else: 148 | route = self.ingress_router.route(event) 149 | topic = route.route_name.replace("/", "_") 150 | key = route.key or event.event_id 151 | 152 | if not route.key: 153 | topic = topic + "_create" 154 | 155 | await self.producer.send_and_wait( 156 | topic, 157 | value=self.serializer.serialize_event(event), 158 | key=bytes(key, "utf-8"), 159 | ) 160 | 161 | loop = asyncio.get_running_loop() 162 | 163 | fut = loop.create_future() 164 | future = StateflowFuture( 165 | event.event_id, time.time(), event.fun_address, return_type 166 | ) 167 | 168 | self.request_map[event.event_id] = fut 169 | 170 | try: 171 | result = await asyncio.wait_for(fut, timeout=self.timeout) 172 | except asyncio.TimeoutError: 173 | del self.request_map[event.event_id] 174 | raise StateflowFailure("Request timed out!") 175 | 176 | future.complete(result) 177 | 178 | try: 179 | result = future.get() 180 | except StateflowFailure as exc: 181 | return exc 182 | 183 | return result 184 | -------------------------------------------------------------------------------- /stateflow/client/future.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar, Optional 2 | from stateflow.dataflow.event import FunctionAddress, Event 3 | import time 4 | from stateflow.dataflow.event import EventType 5 | from asyncio import Future 6 | from collections.abc import Iterable 7 | 8 | # Type variable used to represent the return value of a StateflowFuture. 9 | T = TypeVar("T") 10 | 11 | 12 | class StateflowFailure(Exception): 13 | """Wrapper for an exception upon completion of a StateflowFuture.""" 14 | 15 | def __init__(self, error_msg: str): 16 | """Initializes a StateflowFailure. 17 | 18 | :param error_msg: the error message. 19 | """ 20 | self.error_msg = error_msg 21 | 22 | def __repr__(self): 23 | """Representation of this StateflowFailure.""" 24 | return f"StateflowFailure: {self.error_msg}" 25 | 26 | def __str__(self): 27 | """String representation of this StateflowFailure.""" 28 | return f"StateflowFailure: {self.error_msg}" 29 | 30 | 31 | class StateflowFuture(Generic[T]): 32 | def __init__( 33 | self, id: str, timestamp: float, function_addr: FunctionAddress, return_type: T 34 | ): 35 | """Initializes a Stateflow future which needs to be completed. 36 | 37 | :param id: the id of the request. The reply will have the same id. 38 | :param timestamp: the timestamp for this future. Could be used to compute a timeout. 39 | :param function_addr: the function address of this future. 40 | :param return_type: the type of the return value. 41 | """ 42 | self.id: str = id 43 | self.timestamp: float = timestamp 44 | self.function_addr = function_addr 45 | self.return_type = return_type 46 | 47 | # To be completed later on. 48 | self.is_completed: bool = False 49 | self.result: Optional[T] = None 50 | 51 | def complete(self, event: Event): 52 | """Completes the future given a 'reply' event. 53 | 54 | :param event: the reply event from the runtime. 55 | """ 56 | self.is_completed = event 57 | 58 | if event.event_type == EventType.Reply.FailedInvocation: 59 | self.result = StateflowFailure(event.payload["error_message"]) 60 | elif event.event_type == EventType.Reply.SuccessfulCreateClass: 61 | if self.return_type: 62 | self.result = self.return_type(__key=event.fun_address.key) 63 | else: 64 | self.result = ( 65 | f"Created {self.function_addr.function_type.get_full_name()} " 66 | f"instance with key = {event.fun_address.key}." 67 | ) 68 | elif event.event_type == EventType.Reply.SuccessfulInvocation: 69 | self.result = event.payload["return_results"] 70 | elif event.event_type == EventType.Reply.SuccessfulStateRequest: 71 | if "state" in event.payload: 72 | self.result = event.payload["state"] 73 | elif event.event_type == EventType.Reply.FoundClass: 74 | if self.return_type: 75 | self.result = self.return_type(__key=event.fun_address.key) 76 | else: 77 | self.result = ( 78 | f"Found {self.function_addr.function_type.get_full_name()} " 79 | f"instance with key = {event.fun_address.key}." 80 | ) 81 | elif event.event_type == EventType.Reply.Pong: 82 | self.result = None 83 | elif event.event_type == EventType.Reply.KeyNotFound: 84 | self.result = StateflowFailure(event.payload["error_message"]) 85 | else: 86 | raise AttributeError( 87 | f"Can't complete unknown even type: {event.event_type}" 88 | ) 89 | 90 | def complete_with_failure(self, msg: str): 91 | self.result = StateflowFailure(msg) 92 | 93 | def get(self, timeout=-1) -> T: 94 | """Gets the return value of this future. 95 | If not completed, it will wait until it is. 96 | 97 | NOTE: This might be blocking forever, if the future is never completed. 98 | 99 | :return: the return value. 100 | """ 101 | timeout_time = time.time() + timeout 102 | while not self.is_completed: 103 | if timeout != -1 and time.time() >= timeout_time: 104 | raise AttributeError( 105 | f"Timeout for the future {self} after {timeout} seconds." 106 | ) 107 | time.sleep(0.01) 108 | 109 | if isinstance(self.result, list): 110 | if ( 111 | len(self.result) == 1 112 | ): # If there is a list with only 1 element, we return that. 113 | return self.result[0] 114 | else: 115 | return tuple( 116 | self.result 117 | ) # We return lists as tuples, so it can be unpacked. 118 | 119 | if isinstance( 120 | self.result, StateflowFailure 121 | ): # If it is an error, we throw a failure. 122 | raise StateflowFailure(self.result.error_msg) 123 | 124 | return self.result 125 | -------------------------------------------------------------------------------- /stateflow/client/stateflow_client.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any, List 2 | from stateflow.client.future import StateflowFuture, T 3 | from stateflow.serialization.json_serde import SerDe, JsonSerializer 4 | from stateflow.dataflow.event import Event, EventType 5 | from stateflow.dataflow.address import FunctionAddress, FunctionType 6 | import uuid 7 | 8 | 9 | class StateflowClient: 10 | from stateflow.dataflow.dataflow import Dataflow 11 | 12 | def __init__(self, flow: Dataflow, serializer: SerDe = JsonSerializer()): 13 | self.flow = flow 14 | self.serializer: SerDe = serializer 15 | 16 | def send(self, event: Event) -> StateflowFuture[T]: 17 | pass 18 | 19 | def find(self, clasz, key: str) -> StateflowFuture[Optional[Any]]: 20 | event_id = str(uuid.uuid4()) 21 | event_type = EventType.Request.FindClass 22 | fun_address = FunctionAddress(FunctionType.create(clasz.descriptor), key) 23 | payload = {} 24 | 25 | return self.send(Event(event_id, fun_address, event_type, payload), clasz) 26 | 27 | def await_futures(self, future_list: List[StateflowFuture[T]]): 28 | waiting_for = [fut for fut in future_list if not fut.is_completed] 29 | while len(waiting_for): 30 | waiting_for = [fut for fut in future_list if not fut.is_completed] 31 | -------------------------------------------------------------------------------- /stateflow/core.py: -------------------------------------------------------------------------------- 1 | from inspect import isclass, getsource, getfile 2 | import libcst as cst 3 | from typing import List, Dict 4 | from stateflow.wrappers.class_wrapper import ClassWrapper 5 | from stateflow.wrappers.meta_wrapper import MetaWrapper 6 | from stateflow.dataflow.dataflow import Dataflow, Ingress, Egress 7 | from stateflow.dataflow.stateful_operator import StatefulOperator, Edge, Operator 8 | from stateflow.dataflow.event import EventType 9 | from stateflow.dataflow.address import FunctionType 10 | from stateflow.analysis.extract_class_descriptor import ( 11 | ExtractClassDescriptor, 12 | ClassDescriptor, 13 | ) 14 | from stateflow.split.split_analyze import Split 15 | import textwrap 16 | 17 | parse_cache: Dict[str, cst.Module] = {} 18 | 19 | registered_classes: List[ClassWrapper] = [] 20 | meta_classes: List = [] 21 | 22 | 23 | def stateflow(cls, parse_file=True): 24 | if not isclass(cls): 25 | raise AttributeError(f"Expected a class but got an {cls}.") 26 | 27 | # Parse source. 28 | if parse_file: 29 | class_file_name = getfile(cls) 30 | if class_file_name not in parse_cache: 31 | with open(getfile(cls), "r") as file: 32 | to_parse_file = file.read() 33 | 34 | parsed_cls = cst.parse_module(to_parse_file) 35 | parse_cache[class_file_name] = parsed_cls 36 | else: 37 | parsed_cls = parse_cache[class_file_name] 38 | else: 39 | class_source = getsource(cls) 40 | parsed_cls = cst.parse_module(textwrap.dedent(class_source)) 41 | 42 | wrapper = cst.metadata.MetadataWrapper(parsed_cls) 43 | expression_provider = wrapper.resolve(cst.metadata.ExpressionContextProvider) 44 | 45 | # Extract class description. 46 | extraction: ExtractClassDescriptor = ExtractClassDescriptor( 47 | parsed_cls, cls.__name__, expression_provider 48 | ) 49 | wrapper.visit(extraction) 50 | 51 | # Create ClassDescriptor 52 | class_desc: ClassDescriptor = ExtractClassDescriptor.create_class_descriptor( 53 | extraction 54 | ) 55 | 56 | # Register the class. 57 | registered_classes.append(ClassWrapper(cls, class_desc)) 58 | 59 | # Create a meta class.. 60 | meta_class = MetaWrapper( 61 | str(cls.__name__), 62 | tuple(cls.__bases__), 63 | dict(cls.__dict__), 64 | descriptor=class_desc, 65 | ) 66 | meta_classes.append(meta_class) 67 | 68 | return meta_class 69 | 70 | 71 | def _build_dataflow( 72 | registered_classes: List[ClassWrapper], meta_classes: List[MetaWrapper] 73 | ) -> Dataflow: 74 | operators: List[Operator] = [] 75 | edges: List[Edge] = [] 76 | 77 | for wrapper, meta_class in zip(registered_classes, meta_classes): 78 | name: str = wrapper.class_desc.class_name 79 | fun_type: FunctionType = FunctionType.create(wrapper.class_desc) 80 | 81 | # Create operator, we will add the edges later. 82 | operator: StatefulOperator = StatefulOperator( 83 | [], [], fun_type, wrapper, meta_class 84 | ) 85 | 86 | incoming_edges: List[Edge] = [] 87 | outgoing_edges: List[Edge] = [] 88 | 89 | # For all functions we have an incoming ingress and outgoing egress 90 | ingress: Ingress = Ingress(f"{name}-input", operator, EventType.Request) 91 | egress: Egress = Ingress(f"{name}-input", operator, EventType.Request) 92 | 93 | incoming_edges.append(ingress) 94 | outgoing_edges.append(egress) 95 | 96 | operator.incoming_edges = incoming_edges 97 | operator.outgoing_edges = outgoing_edges 98 | 99 | operators.append(operator) 100 | edges.extend(incoming_edges + outgoing_edges) 101 | 102 | return Dataflow(operators, edges) 103 | 104 | 105 | def init(): 106 | if len(registered_classes) == 0 or len(meta_classes) == 0: 107 | raise AttributeError( 108 | "Trying to initialize stateflow without any registered classes. " 109 | "Please register one using the @stateflow decorator." 110 | ) 111 | 112 | # We now link classes to each other. 113 | class_descs: List[ClassDescriptor] = [ 114 | wrapper.class_desc for wrapper in registered_classes 115 | ] 116 | 117 | for desc in class_descs: 118 | desc.link_to_other_classes(class_descs) 119 | 120 | # We execute the split phase 121 | split: Split = Split(class_descs, registered_classes) 122 | split.split_methods() 123 | 124 | flow: Dataflow = _build_dataflow(registered_classes, meta_classes) 125 | 126 | ### DEBUG 127 | operator_names: List[str] = [ 128 | op.class_wrapper.class_desc.class_name for op in flow.operators 129 | ] 130 | print( 131 | f"Registered {len(flow.operators)} operators with the names: {operator_names}." 132 | ) 133 | ### 134 | return flow 135 | 136 | 137 | def service_by_id(cls, service_id: str): 138 | from stateflow.dataflow.event_flow import InternalClassRef, FunctionAddress 139 | 140 | for clasz in registered_classes: 141 | if clasz.class_desc.class_name == cls.__name__: 142 | fun_ty = clasz.class_desc.to_function_type() 143 | return InternalClassRef(FunctionAddress(fun_ty, service_id)) 144 | 145 | 146 | def clear(): 147 | global parse_cache, registered_classes, meta_classes 148 | parse_cache.clear() 149 | registered_classes.clear() 150 | meta_classes.clear() 151 | -------------------------------------------------------------------------------- /stateflow/dataflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/dataflow/__init__.py -------------------------------------------------------------------------------- /stateflow/dataflow/address.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | 4 | class FunctionType: 5 | 6 | __slots__ = "namespace", "name", "stateful" 7 | 8 | def __init__(self, namespace: str, name: str, stateful: bool): 9 | self.namespace = namespace 10 | self.name = name 11 | self.stateful = stateful 12 | 13 | def is_stateless(self): 14 | return not self.stateful 15 | 16 | def get_full_name(self): 17 | return f"{self.namespace}/{self.name}" 18 | 19 | def get_safe_full_name(self): 20 | return f"{self.namespace}_{self.name}" 21 | 22 | def __eq__(self, other): 23 | if not isinstance(other, FunctionType): 24 | return False 25 | 26 | namespace_eq = self.namespace == other.namespace 27 | name_eq = self.name == other.name 28 | stateful_eq = self.stateful == other.stateful 29 | 30 | return namespace_eq and name_eq and stateful_eq 31 | 32 | def to_dict(self) -> Dict: 33 | return { 34 | "namespace": self.namespace, 35 | "name": self.name, 36 | "stateful": self.stateful, 37 | } 38 | 39 | def to_address(self) -> "FunctionAddress": 40 | return FunctionAddress(self, None) 41 | 42 | @staticmethod 43 | def create(desc) -> "FunctionType": 44 | name = desc.class_name 45 | namespace = "global" # for now we have a global namespace 46 | stateful = True # for now we only cover stateful functions 47 | 48 | return FunctionType(namespace, name, stateful) 49 | 50 | def __eq__(self, other): 51 | if not isinstance(other, FunctionType): 52 | return False 53 | 54 | return ( 55 | self.name == other.name 56 | and self.namespace == other.namespace 57 | and self.stateful == other.stateful 58 | ) 59 | 60 | 61 | class FunctionAddress: 62 | """The address of a stateful or stateless function. 63 | 64 | Consists of two parts: 65 | - a FunctionType: the namespace and name of the function, and a flag to specify it as stateful 66 | - a key: an optional key, in case we deal with a stateful function. 67 | 68 | This address can be used to route an event correctly through a dataflow. 69 | """ 70 | 71 | __slots__ = "function_type", "key" 72 | 73 | def __init__(self, function_type: FunctionType, key: Optional[str]): 74 | self.function_type = function_type 75 | self.key = key 76 | 77 | def is_stateless(self): 78 | return self.function_type.is_stateless() 79 | 80 | def to_dict(self): 81 | return {"function_type": self.function_type.to_dict(), "key": self.key} 82 | 83 | @staticmethod 84 | def from_dict(dictionary: Dict) -> "FunctionAddress": 85 | return FunctionAddress( 86 | FunctionType( 87 | dictionary["function_type"]["namespace"], 88 | dictionary["function_type"]["name"], 89 | dictionary["function_type"]["stateful"], 90 | ), 91 | dictionary["key"], 92 | ) 93 | 94 | def __eq__(self, other): 95 | if not isinstance(other, FunctionAddress): 96 | return False 97 | 98 | return self.key == other.key and self.function_type == other.function_type 99 | -------------------------------------------------------------------------------- /stateflow/dataflow/args.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List, Optional 2 | 3 | 4 | class Arguments: 5 | 6 | __slots__ = "_args" 7 | 8 | def __init__(self, args: Dict[str, Any]): 9 | self._args = args 10 | 11 | def __getitem__(self, item): 12 | return self._args[item] 13 | 14 | def __setitem__(self, key, value): 15 | self._args[key] = value 16 | 17 | def get(self) -> Dict[str, Any]: 18 | return self._args 19 | 20 | def get_keys(self) -> List[str]: 21 | return self._args.keys() 22 | 23 | def to_dict(self) -> Dict: 24 | return self._args 25 | 26 | @staticmethod 27 | def from_dict(dictionary: Dict): 28 | return Arguments(dictionary) 29 | 30 | @staticmethod 31 | def from_args_and_kwargs( 32 | desc: Dict[str, Any], *args, **kwargs 33 | ) -> Optional["Arguments"]: 34 | args_dict = {} 35 | for arg, name in zip(list(args), desc.keys()): 36 | args_dict[name] = arg 37 | 38 | for key, value in kwargs.items(): 39 | args_dict[key] = value 40 | 41 | arguments = Arguments(args_dict) 42 | 43 | if not desc.keys() == args_dict.keys(): 44 | raise AttributeError(f"Expected arguments: {desc} but got {args_dict}.") 45 | 46 | return arguments 47 | -------------------------------------------------------------------------------- /stateflow/dataflow/event.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | from enum import Enum, EnumMeta 3 | from stateflow.dataflow.address import FunctionAddress 4 | 5 | 6 | class MetaEnum(EnumMeta): 7 | def __contains__(cls, item): 8 | try: 9 | cls(item) 10 | except ValueError: 11 | return False 12 | return True 13 | 14 | 15 | class _Request(Enum, metaclass=MetaEnum): 16 | InvokeStateless = "InvokeStateless" 17 | InvokeStateful = "InvokeStateful" 18 | InitClass = "InitClass" 19 | 20 | FindClass = "FindClass" 21 | 22 | GetState = "GetState" 23 | SetState = "SetState" 24 | UpdateState = "UpdateState" 25 | DeleteState = "DeleteState" 26 | 27 | EventFlow = "EventFlow" 28 | 29 | Ping = "Ping" 30 | 31 | def __str__(self): 32 | return f"Request.{self.value}" 33 | 34 | 35 | class _Reply(Enum, metaclass=MetaEnum): 36 | SuccessfulInvocation = "SuccessfulInvocation" 37 | SuccessfulCreateClass = "SuccessfulCreateClass" 38 | 39 | FoundClass = "FoundClass" 40 | KeyNotFound = "KeyNotFound" 41 | 42 | SuccessfulStateRequest = "SuccessfulStateRequest" 43 | FailedInvocation = "FailedInvocation" 44 | 45 | Pong = "Pong" 46 | 47 | def __str__(self): 48 | return f"Reply.{self.value}" 49 | 50 | 51 | class EventType: 52 | Request = _Request 53 | Reply = _Reply 54 | 55 | @staticmethod 56 | def from_str(input_str: str) -> Optional["EventType"]: 57 | if input_str in EventType.Request: 58 | return EventType.Request[input_str] 59 | elif input_str in EventType.Reply: 60 | return EventType.Reply[input_str] 61 | else: 62 | return None 63 | 64 | 65 | class Event: 66 | from stateflow.dataflow.args import Arguments 67 | 68 | __slots__ = "event_id", "fun_address", "event_type", "payload" 69 | 70 | def __init__( 71 | self, 72 | event_id: str, 73 | fun_address: FunctionAddress, 74 | event_type: EventType, 75 | payload: Dict, 76 | ): 77 | self.event_id: str = event_id 78 | self.fun_address: FunctionAddress = fun_address 79 | self.event_type: EventType = event_type 80 | self.payload: Dict = payload 81 | 82 | def get_arguments(self) -> Optional[Arguments]: 83 | if "args" in self.payload: 84 | return self.payload["args"] 85 | else: 86 | return None 87 | 88 | def copy(self, **kwargs) -> "Event": 89 | new_args = {} 90 | for key, value in kwargs.items(): 91 | if key in self.__slots__: 92 | new_args[key] = value 93 | 94 | for key in self.__slots__: 95 | if key not in new_args: 96 | new_args[key] = getattr(self, key) 97 | 98 | return Event(**new_args) 99 | -------------------------------------------------------------------------------- /stateflow/dataflow/state.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | import ujson 3 | 4 | 5 | class State: 6 | __slots__ = "_data" 7 | 8 | def __init__(self, data: dict): 9 | self._data = data 10 | 11 | def __getitem__(self, item): 12 | return self._data[item] 13 | 14 | def __setitem__(self, key, value): 15 | self._data[key] = value 16 | 17 | def __str__(self): 18 | return str(self._data) 19 | 20 | def get_keys(self): 21 | return self._data.keys() 22 | 23 | def get(self): 24 | return self._data 25 | 26 | @staticmethod 27 | def serialize(state: "State") -> str: 28 | return ujson.encode(state._data) 29 | 30 | @staticmethod 31 | def deserialize(state_serialized: str) -> "State": 32 | return State(ujson.decode(state_serialized)) 33 | 34 | 35 | class StateDescriptor: 36 | def __init__(self, state_desc: Dict[str, Any]): 37 | self._state_desc = state_desc 38 | 39 | def get_keys(self): 40 | return self._state_desc.keys() 41 | 42 | def match(self, state: State) -> State: 43 | return self.get_keys() == state.get_keys() 44 | 45 | def __str__(self): 46 | return str(list(self._state_desc.keys())) 47 | 48 | def __contains__(self, item): 49 | return item in self._state_desc 50 | 51 | def __getitem__(self, item): 52 | return self._state_desc[item] 53 | 54 | def __setitem__(self, key, value): 55 | self._state_desc[key] = value 56 | -------------------------------------------------------------------------------- /stateflow/descriptors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/descriptors/__init__.py -------------------------------------------------------------------------------- /stateflow/descriptors/class_descriptor.py: -------------------------------------------------------------------------------- 1 | from stateflow.dataflow.state import StateDescriptor 2 | from typing import List, Optional 3 | from stateflow.descriptors.method_descriptor import MethodDescriptor 4 | from stateflow.dataflow.address import FunctionType 5 | import libcst as cst 6 | 7 | 8 | class ClassDescriptor: 9 | """A description of a class method.""" 10 | 11 | def __init__( 12 | self, 13 | class_name: str, 14 | module_node: cst.Module, 15 | class_node: cst.ClassDef, 16 | state_desc: StateDescriptor, 17 | methods_dec: List[MethodDescriptor], 18 | expression_provider, 19 | ): 20 | self.class_name: str = class_name 21 | self.module_node: cst.Module = module_node 22 | self.class_node: cst.ClassDef = class_node 23 | self.state_desc: StateDescriptor = state_desc 24 | self.methods_dec: List[MethodDescriptor] = methods_dec 25 | self.expression_provider = expression_provider 26 | 27 | def to_function_type(self) -> FunctionType: 28 | return FunctionType.create(self) 29 | 30 | def get_method_by_name(self, name: str) -> Optional[MethodDescriptor]: 31 | filter = [desc for desc in self.methods_dec if desc.method_name == name] 32 | 33 | if len(filter) == 0: 34 | return None 35 | 36 | return filter[0] 37 | 38 | def link_to_other_classes(self, descriptors: List["ClassDescriptor"]): 39 | for method in self.methods_dec: 40 | method.link_to_other_classes(descriptors) 41 | -------------------------------------------------------------------------------- /stateflow/descriptors/method_descriptor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List, Set, Tuple 2 | 3 | import libcst as cst 4 | 5 | from stateflow.dataflow.args import Arguments 6 | import re 7 | 8 | 9 | class MethodDescriptor: 10 | """A description of a class method.""" 11 | 12 | def __init__( 13 | self, 14 | method_name: str, 15 | read_only: bool, 16 | method_node: cst.FunctionDef, 17 | input_desc: "InputDescriptor", 18 | output_desc: "OutputDescriptor", 19 | external_attributes: Set[str], 20 | typed_declarations: Dict[str, str], 21 | write_to_self_attributes: Set[str], 22 | ): 23 | self.method_name: str = method_name 24 | self.read_only: bool = read_only 25 | self.method_node: cst.FunctionDef = method_node 26 | self.input_desc: "InputDescriptor" = input_desc 27 | self.output_desc: "OutputDescriptor" = output_desc 28 | 29 | self._external_attributes = external_attributes 30 | self._typed_declarations = typed_declarations 31 | 32 | self.write_to_self_attributes: Set[str] = write_to_self_attributes 33 | 34 | self.other_class_links: List = [] 35 | 36 | self.statement_blocks = [] 37 | self.flow_list = [] 38 | 39 | def is_splitted_function(self) -> bool: 40 | return len(self.statement_blocks) > 0 41 | 42 | def split_function(self, blocks, fun_addr): 43 | from stateflow.dataflow.event_flow import ( 44 | EventFlowNode, 45 | StartNode, 46 | ) 47 | 48 | self.statement_blocks = blocks.copy() 49 | self.flow_list: List[EventFlowNode] = [] 50 | 51 | # Build start of the flow. 52 | flow_start: EventFlowNode = StartNode(0, fun_addr) 53 | latest_node_id: int = flow_start.id 54 | self.flow_list.append(flow_start) 55 | 56 | # A mapping from Block to EventFlowNode. 57 | # Used to correctly build a EventFlowGraph 58 | flow_mapping = {block.block_id: None for block in self.statement_blocks} 59 | 60 | # print("Now splitting function.") 61 | for block in self.statement_blocks: 62 | flow_nodes: List[EventFlowNode] = block.build_event_flow_nodes( 63 | latest_node_id 64 | ) 65 | self.flow_list.extend(flow_nodes) 66 | if block.block_id == 0 or block.block_id == 1: 67 | pass 68 | # print(f"flow nodes {flow_nodes}") 69 | flow_mapping[block.block_id] = flow_nodes 70 | 71 | latest_node_id = self.flow_list[-1].id 72 | 73 | # print( 74 | # f"Now computed flow nodes for {block.block_id} with length {len(flow_nodes)}" 75 | # ) 76 | 77 | # print(f"Mapping {flow_mapping}") 78 | 79 | # Now that we got all flow nodes built, we can properly link them to each other. 80 | for block in self.statement_blocks: 81 | flow_nodes = flow_mapping[block.block_id] 82 | # print(f"Now looking at {block.block_id}") 83 | 84 | # Get next block of this current block. 85 | for next_block in block.next_block: 86 | # print(f"Block {block.block_id} is linked to {next_block.block_id}") 87 | next_flow_node_list = flow_mapping[next_block.block_id] 88 | if flow_nodes is None or next_flow_node_list is None: 89 | raise RuntimeError( 90 | f"Empty block flow nodes: {block.block_id}" 91 | f"{block.code()}" 92 | f"{flow_nodes}" 93 | f"Next block: {next_block.block_id}" 94 | f"{next_block.code()}" 95 | f"{next_flow_node_list}" 96 | f"{flow_mapping}" 97 | ) 98 | 99 | flow_nodes[-1].resolve_next(next_flow_node_list, next_block) 100 | 101 | # We won't set the previous, this is what we will do dynamically. 102 | # Based on the path that is traversed, we know what the previous was. 103 | 104 | flow_start.set_next(self.flow_list[1].id) 105 | 106 | def get_typed_params(self): 107 | # TODO Improve this. Very ambiguous name. 108 | params: List[str] = [] 109 | for name, typ in self.input_desc.get().items(): 110 | if typ in self.other_class_links: 111 | params.append(name) 112 | 113 | return params 114 | 115 | def _is_linked(self, name: str, types: Dict[str, str]) -> Tuple[bool, List[str]]: 116 | r = re.compile(f"^({name}|List\\[{name}\\])$") 117 | 118 | linked_vars = [] 119 | for name, typ in types.items(): 120 | if r.match(typ): 121 | linked_vars.append(name) 122 | 123 | return len(linked_vars) > 0, linked_vars 124 | 125 | def link_to_other_classes(self, descriptors: List): 126 | for d in descriptors: 127 | name = d.class_name 128 | 129 | is_linked, links = self._is_linked(name, self._typed_declarations) 130 | if is_linked: 131 | # We now check if this declaration is also attributed (i.e. get state, update state or invoke method). 132 | if len(set(links).intersection(self._external_attributes)) > 0: 133 | # Now we know this method is linked to another class or class method. 134 | self.other_class_links.append(d) 135 | elif set(links).intersection(set(self.input_desc.keys())): 136 | # The List is given as parameter. 137 | self.other_class_links.append(d) 138 | else: 139 | # TODO; we have a type decl to another class, but it is not used? Maybe throw a warning/error. 140 | pass 141 | 142 | def has_links(self) -> bool: 143 | return len(self.other_class_links) > 0 144 | 145 | 146 | class InputDescriptor: 147 | """A description of the input parameters of a function. 148 | Includes types if declared. This class works like a dictionary. 149 | """ 150 | 151 | def __init__(self, input_desc: Dict[str, Any]): 152 | self._input_desc: Dict[str, Any] = input_desc 153 | 154 | def __contains__(self, item): 155 | return item in self._input_desc 156 | 157 | def __delitem__(self, key): 158 | del self._input_desc[key] 159 | 160 | def __getitem__(self, item): 161 | return self._input_desc[item] 162 | 163 | def __setitem__(self, key, value): 164 | self._input_desc[key] = value 165 | 166 | def __str__(self): 167 | return self._input_desc.__str__() 168 | 169 | def __hash__(self): 170 | return self._input_desc.__hash__() 171 | 172 | def __eq__(self, other): 173 | return self._input_desc == other 174 | 175 | def keys(self): 176 | return list(self._input_desc.keys()) 177 | 178 | def get(self) -> Dict[str, Any]: 179 | return self._input_desc 180 | 181 | def match(self, args: Arguments) -> bool: 182 | return args.get_keys() == self._input_desc.keys() 183 | 184 | def __str__(self): 185 | return str(list(self._input_desc.keys())) 186 | 187 | 188 | class OutputDescriptor: 189 | """A description of the output of a function. 190 | Includes types if declared. Since a function can have multiple returns, 191 | we store each return in a list. 192 | 193 | A return is stored as a List of types. We don't store the return variable, 194 | because we do not care about it. We only care about the amount of return variables 195 | and potentially its type. 196 | """ 197 | 198 | def __init__(self, output_desc: List[List[Any]]): 199 | self.output_desc: List[List[Any]] = output_desc 200 | 201 | def num_returns(self): 202 | """The amount of (potential) outputs. 203 | 204 | If a method has multiple return paths, these are stored separately. 205 | This function returns the amount of these paths. 206 | 207 | :return: the amount of returns. 208 | """ 209 | return len(self.output_desc) 210 | -------------------------------------------------------------------------------- /stateflow/runtime/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/runtime/__init__.py -------------------------------------------------------------------------------- /stateflow/runtime/aws/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/runtime/aws/__init__.py -------------------------------------------------------------------------------- /stateflow/runtime/aws/abstract_lambda.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from stateflow.dataflow.dataflow import ( 3 | Dataflow, 4 | IngressRouter, 5 | EgressRouter, 6 | Event, 7 | Route, 8 | RouteDirection, 9 | EventFlowGraph, 10 | EventFlowNode, 11 | ) 12 | from stateflow.dataflow.stateful_operator import StatefulOperator 13 | from stateflow.dataflow.event import EventType 14 | from stateflow.serialization.pickle_serializer import SerDe, PickleSerializer 15 | from stateflow.runtime.runtime import Runtime 16 | from python_dynamodb_lock.python_dynamodb_lock import * 17 | import boto3 18 | from pynamodb.models import Model 19 | from pynamodb.attributes import UnicodeAttribute, BinaryAttribute 20 | from botocore.config import Config 21 | import datetime 22 | 23 | """Base class for implementing Lambda handlers as classes. 24 | Used across multiple Lambda functions (included in each zip file). 25 | Add additional features here common to all your Lambdas, like logging.""" 26 | 27 | 28 | class LambdaBase(object): 29 | @classmethod 30 | def get_handler(cls, *args, **kwargs): 31 | inst = cls(*args, **kwargs) 32 | 33 | def handler(event, context): 34 | return inst.handle(event, context) 35 | 36 | return inst, handler 37 | 38 | def handle(self, event, context): 39 | raise NotImplementedError 40 | 41 | 42 | class StateflowRecord(Model): 43 | """ 44 | A Stateflow Record 45 | """ 46 | 47 | class Meta: 48 | table_name = "stateflow" 49 | region = "eu-west-2" 50 | 51 | key = UnicodeAttribute(hash_key=True) 52 | state = BinaryAttribute(null=True) 53 | 54 | 55 | class AWSLambdaRuntime(LambdaBase, Runtime): 56 | def __init__( 57 | self, 58 | flow: Dataflow, 59 | table_name="stateflow", 60 | serializer: SerDe = PickleSerializer(), 61 | config: Config = Config(region_name="eu-west-2"), 62 | ): 63 | self.flow: Dataflow = flow 64 | self.serializer: SerDe = serializer 65 | 66 | self.ingress_router = IngressRouter(self.serializer) 67 | self.egress_router = EgressRouter(self.serializer, serialize_on_return=False) 68 | 69 | self.operators = { 70 | operator.function_type.get_full_name(): operator 71 | for operator in self.flow.operators 72 | } 73 | 74 | self.dynamodb = self._setup_dynamodb(config) 75 | self.lock_client: DynamoDBLockClient = self._setup_lock_client(3) 76 | 77 | def _setup_dynamodb(self, config: Config): 78 | return boto3.resource("dynamodb", config=config) 79 | 80 | def _setup_lock_client(self, expiry_period: int) -> DynamoDBLockClient: 81 | return DynamoDBLockClient( 82 | self.dynamodb, expiry_period=datetime.timedelta(seconds=expiry_period) 83 | ) 84 | 85 | def lock_key(self, key: str): 86 | return self.lock_client.acquire_lock(key) 87 | 88 | def get_state(self, key: str): 89 | try: 90 | record = StateflowRecord.get(key) 91 | return record.state 92 | except StateflowRecord.DoesNotExist: 93 | print(f"{key} does not exist yet") 94 | return None 95 | 96 | def save_state(self, key: str, state): 97 | record = StateflowRecord(key, state=state) 98 | record.save() 99 | 100 | def is_request_state(self, event: Event) -> bool: 101 | if event.event_type == EventType.Request.GetState: 102 | return True 103 | 104 | if event.event_type != EventType.Request.EventFlow: 105 | return False 106 | 107 | flow_graph: EventFlowGraph = event.payload["flow"] 108 | current_node = flow_graph.current_node 109 | 110 | if current_node.typ == EventFlowNode.REQUEST_STATE: 111 | return True 112 | 113 | return False 114 | 115 | def invoke_operator(self, route: Route) -> Event: 116 | event: Event = route.value 117 | 118 | operator_name: str = route.route_name 119 | operator: StatefulOperator = self.operators[operator_name] 120 | 121 | if event.event_type == EventType.Request.InitClass and route.key is None: 122 | new_event = operator.handle_create(event) 123 | return self.invoke_operator( 124 | Route( 125 | RouteDirection.INTERNAL, 126 | operator_name, 127 | new_event.fun_address.key, 128 | new_event, 129 | ) 130 | ) 131 | else: 132 | full_key: str = f"{operator_name}_{route.key}" 133 | 134 | # Lock the key in DynamoDB. 135 | if not self.is_request_state(event): 136 | lock = self.lock_key(full_key) 137 | lock = None 138 | else: 139 | lock = None 140 | 141 | operator_state = self.get_state(full_key) 142 | 143 | return_event, updated_state = operator.handle(event, operator_state) 144 | 145 | if updated_state is not operator_state: 146 | self.save_state(full_key, updated_state) 147 | 148 | if lock: 149 | lock.release() 150 | return return_event 151 | 152 | def handle_invocation(self, event: Event) -> Route: 153 | route: Route = self.ingress_router.route(event) 154 | 155 | if route.direction == RouteDirection.INTERNAL: 156 | return self.egress_router.route_and_serialize(self.invoke_operator(route)) 157 | elif route.direction == RouteDirection.EGRESS: 158 | return self.egress_router.route_and_serialize(route.value) 159 | else: 160 | return route 161 | 162 | def handle(self, event, context): 163 | raise NotImplementedError("Needs to be implemented by subclasses.") 164 | -------------------------------------------------------------------------------- /stateflow/runtime/aws/gateway_lambda.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from stateflow.runtime.aws.abstract_lambda import ( 4 | AWSLambdaRuntime, 5 | Dataflow, 6 | SerDe, 7 | Config, 8 | PickleSerializer, 9 | Event, 10 | RouteDirection, 11 | Route, 12 | ) 13 | import base64 14 | 15 | 16 | class AWSGatewayLambdaRuntime(AWSLambdaRuntime): 17 | def __init__( 18 | self, 19 | flow: Dataflow, 20 | table_name="stateflow", 21 | gateway: bool = True, 22 | serializer: SerDe = PickleSerializer(), 23 | config: Config = Config(region_name="eu-west-2"), 24 | ): 25 | super().__init__(flow, table_name, serializer, config) 26 | self.gateway = gateway 27 | 28 | def handle(self, event, context): 29 | if self.gateway: 30 | event_body = json.loads(event["body"]) 31 | event_encoded = event_body["event"] 32 | event_serialized = base64.b64decode(event_encoded) 33 | else: 34 | event_body = event["event"] 35 | event_serialized = base64.b64decode(event_body) 36 | 37 | parsed_event: Event = self.ingress_router.parse(event_serialized) 38 | return_route: Route = self.handle_invocation(parsed_event) 39 | 40 | while return_route.direction != RouteDirection.CLIENT: 41 | return_route = self.handle_invocation(return_route.value) 42 | 43 | return_event_serialized = self.egress_router.serialize(return_route.value) 44 | return_event_encoded = base64.b64encode(return_event_serialized) 45 | 46 | return { 47 | "statusCode": 200, 48 | "body": json.dumps({"event": return_event_encoded.decode()}), 49 | } 50 | -------------------------------------------------------------------------------- /stateflow/runtime/aws/kinesis_lambda.py: -------------------------------------------------------------------------------- 1 | from stateflow.runtime.aws.abstract_lambda import ( 2 | AWSLambdaRuntime, 3 | Route, 4 | RouteDirection, 5 | Event, 6 | SerDe, 7 | PickleSerializer, 8 | Dataflow, 9 | Config, 10 | ) 11 | import base64 12 | import boto3 13 | 14 | 15 | class AWSKinesisLambdaRuntime(AWSLambdaRuntime): 16 | def __init__( 17 | self, 18 | flow: Dataflow, 19 | table_name="stateflow", 20 | request_stream="stateflow-request", 21 | reply_stream="stateflow-reply", 22 | serializer: SerDe = PickleSerializer(), 23 | config: Config = Config(region_name="eu-west-1"), 24 | ): 25 | super().__init__(flow, table_name, serializer, config) 26 | 27 | self.kinesis = self._setup_kinesis(config) 28 | self.request_stream: str = request_stream 29 | self.reply_stream: str = reply_stream 30 | 31 | def _setup_kinesis(self, config: Config): 32 | return boto3.client("kinesis", config=config) 33 | 34 | def handle(self, event, context): 35 | for record in event["Records"]: 36 | event = base64.b64decode(record["kinesis"]["data"]) 37 | 38 | parsed_event: Event = self.ingress_router.parse(event) 39 | return_route: Route = self.handle_invocation(parsed_event) 40 | 41 | while return_route.direction != RouteDirection.CLIENT: 42 | return_route = self.handle_invocation(return_route.value) 43 | 44 | serialized_event = self.egress_router.serialize(return_route.value) 45 | self.kinesis.put_record( 46 | StreamName=self.reply_stream, 47 | Data=serialized_event, 48 | PartitionKey=return_route.value.event_id, 49 | ) 50 | -------------------------------------------------------------------------------- /stateflow/runtime/cloudburst/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/runtime/cloudburst/__init__.py -------------------------------------------------------------------------------- /stateflow/runtime/cloudburst/cloudburst.py: -------------------------------------------------------------------------------- 1 | from cloudburst.client.client import CloudburstConnection, CloudburstFuture 2 | from stateflow.runtime.runtime import Runtime 3 | from stateflow.dataflow.dataflow import ( 4 | Dataflow, 5 | Route, 6 | IngressRouter, 7 | EgressRouter, 8 | Event, 9 | RouteDirection, 10 | ) 11 | from stateflow.dataflow.stateful_operator import StatefulOperator 12 | from stateflow.serialization.pickle_serializer import SerDe, PickleSerializer 13 | 14 | 15 | class CloudBurstRuntime(Runtime): 16 | def __init__( 17 | self, 18 | flow: Dataflow, 19 | cloudburst_ip: str, 20 | client_ip: str, 21 | serializer: SerDe = PickleSerializer(), 22 | ): 23 | self.dataflow: Dataflow = flow 24 | self.operators = self.dataflow.operators 25 | 26 | self.cloudburst_ip: str = cloudburst_ip 27 | self.serializer: SerDe = serializer 28 | 29 | self.cloudb: CloudburstConnection = CloudburstConnection( 30 | cloudburst_ip, client_ip 31 | ) 32 | 33 | # All (incoming) events are send to a router, 34 | # which then calls all functions/dags and finally returns to the client. 35 | self._register_classes() 36 | self._register_routers() 37 | 38 | def _register_classes(self): 39 | for operator in self.operators: 40 | op: StatefulOperator = operator 41 | 42 | class CloudBurstCreateOperator: 43 | def __init__(self, op: StatefulOperator): 44 | self.operator = op 45 | 46 | def run(self, cloudburst_client, route: Route): 47 | return self.operator.handle_create(route.event) 48 | 49 | class CloudBurstOperator: 50 | def __init__(self, op: StatefulOperator): 51 | self.operator = op 52 | 53 | def run(self, cloudburst_client, route: Route): 54 | operator_state = cloudburst_client.get_object(route.key) 55 | return_event, updated_state = self.operator.handle( 56 | route.value, operator_state 57 | ) 58 | 59 | if updated_state is not operator_state: 60 | cloudburst_client.put_object(route.key, updated_state) 61 | 62 | return return_event 63 | 64 | self.cloudb.register( 65 | (CloudBurstCreateOperator, op), 66 | f"{op.function_type.get_full_name()}_create", 67 | ) 68 | self.cloudb.register( 69 | (CloudBurstOperator, op), f"{op.function_type.get_full_name()}" 70 | ) 71 | 72 | # We have a DAG for the create flow of an actor/function. 73 | # For an existing actor we just invoke the operator directly. 74 | self.cloudb.register_dag( 75 | f"{op.function_type.get_full_name()}_create_dag", 76 | [ 77 | f"{op.function_type.get_full_name()}_create", 78 | f"{op.function_type.get_full_name()}", 79 | ], 80 | [ 81 | ( 82 | f"{op.function_type.get_full_name()}_create", 83 | f"{op.function_type.get_full_name()}", 84 | ) 85 | ], 86 | ) 87 | 88 | def _register_routers(self): 89 | ingress = IngressRouter(self.serializer) 90 | egress = EgressRouter(self.serializer, False) 91 | 92 | class RouterOperator: 93 | def __init__(self, ingress: IngressRouter, egress: EgressRouter): 94 | self.ingress = ingress 95 | self.egress = egress 96 | 97 | def handle_invocation(self, cloudburst_client, route: Route) -> Route: 98 | # If this is a 'create' event, we have no key 99 | if not route.key: 100 | return_event: Event = cloudburst_client.call_dag( 101 | f"{route.route_name}_create", {"route": route} 102 | ).get() 103 | return self.egress.route_and_serialize(return_event) 104 | elif route.direction == RouteDirection.EGRESS: 105 | return self.egress_router.route_and_serialize(route.value) 106 | else: 107 | return_event: Event = CloudburstFuture( 108 | cloudburst_client.exec_func( 109 | f"{route.route_name}", {"route": route} 110 | ), 111 | cloudburst_client.kvs_client, 112 | cloudburst_client.serializer, 113 | ).get() 114 | return self.egress.route_and_serialize(return_event) 115 | 116 | def run(self, cloudburst_client, incoming_event: bytes) -> bytes: 117 | # All incoming events are handled by this stateless, preferably scaled operator. 118 | incoming_route: Route = self.ingress.parse_and_route(incoming_event) 119 | return_route: Route = self.handle_invocation(incoming_route) 120 | 121 | while return_route.direction != RouteDirection.CLIENT: 122 | return_route = self.handle_invocation(return_route) 123 | 124 | return_event_serialized = self.egress_router.serialize( 125 | return_route.value 126 | ) 127 | return return_event_serialized 128 | 129 | self.cloudb.register((RouterOperator, ingress, egress), f"stateflow") 130 | -------------------------------------------------------------------------------- /stateflow/runtime/dataflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/runtime/dataflow/__init__.py -------------------------------------------------------------------------------- /stateflow/runtime/dataflow/remote_lambda.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | from stateflow.runtime.aws.abstract_lambda import ( 4 | LambdaBase, 5 | Runtime, 6 | Dataflow, 7 | SerDe, 8 | Event, 9 | EventType, 10 | StatefulOperator, 11 | EventFlowGraph, 12 | EventFlowNode, 13 | ) 14 | import json 15 | from stateflow.serialization.proto.proto_serde import ProtoSerializer 16 | 17 | 18 | class RemoteLambda(LambdaBase, Runtime): 19 | def __init__(self, flow: Dataflow, serializer: ProtoSerializer = ProtoSerializer()): 20 | self.flow: Dataflow = flow 21 | self.serializer: ProtoSerializer = serializer 22 | 23 | self.operators = { 24 | operator.function_type.get_full_name(): operator 25 | for operator in self.flow.operators 26 | } 27 | 28 | def _handle_return(self, event: Event) -> Event: 29 | if event.event_type != EventType.Request.EventFlow: 30 | return event 31 | 32 | flow_graph: EventFlowGraph = event.payload["flow"] 33 | current_node = flow_graph.current_node 34 | 35 | if current_node.typ != EventFlowNode.RETURN: 36 | return event 37 | 38 | if not current_node.next: 39 | return event.copy( 40 | event_type=EventType.Reply.SuccessfulInvocation, 41 | payload={"return_results": current_node.get_results()}, 42 | ) 43 | 44 | else: 45 | for next_node_id in current_node.next: 46 | next_node = flow_graph.get_node_by_id(next_node_id) 47 | 48 | # Get next node and set proper input. 49 | next_node.input[current_node.return_name] = current_node.get_results() 50 | 51 | return event 52 | 53 | def handle(self, event, context): 54 | parsed_event, state, operator_name = self.serializer.deserialize_request( 55 | base64.b64decode(event["request"]) 56 | ) 57 | current_operator: StatefulOperator = self.operators[operator_name] 58 | 59 | if ( 60 | parsed_event.event_type == EventType.Request.InitClass 61 | and not parsed_event.fun_address.key 62 | ): 63 | return_event = current_operator.handle_create(parsed_event) 64 | return_state = state 65 | else: 66 | return_event, return_state = current_operator.handle(parsed_event, state) 67 | return_event = self._handle_return(return_event) 68 | 69 | return { 70 | "reply": base64.b64encode( 71 | self.serializer.serialize_request(return_event, return_state) 72 | ).decode() 73 | } 74 | -------------------------------------------------------------------------------- /stateflow/runtime/flink/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/runtime/flink/__init__.py -------------------------------------------------------------------------------- /stateflow/runtime/flink/statefun.py: -------------------------------------------------------------------------------- 1 | from stateflow.runtime.runtime import Runtime 2 | from stateflow.serialization.pickle_serializer import PickleSerializer, SerDe 3 | from stateflow.dataflow.dataflow import ( 4 | Dataflow, 5 | Event, 6 | EgressRouter, 7 | IngressRouter, 8 | RouteDirection, 9 | Route, 10 | EventType, 11 | ) 12 | from statefun import StatefulFunctions, simple_type, RequestReplyHandler, kafka_egress_message, \ 13 | Context, Message, message_builder, ValueSpec 14 | from aiohttp import web 15 | 16 | 17 | class StatefunRuntime(Runtime): 18 | def __init__( 19 | self, 20 | dataflow: Dataflow, 21 | serializer: SerDe = PickleSerializer(), 22 | reply_topic: str = "client_reply", 23 | ): 24 | self.dataflow = dataflow 25 | self.stateful_functions = StatefulFunctions() 26 | self.serializer = serializer 27 | self.reply_topic = reply_topic 28 | 29 | self.ingress_router = IngressRouter(self.serializer) 30 | self.egress_router = EgressRouter(self.serializer, serialize_on_return=False) 31 | 32 | self.byte_type = simple_type( 33 | "stateflow/byte_type", serialize_fn=lambda _: _, deserialize_fn=lambda _: _ 34 | ) 35 | 36 | self.operators_dict = {} 37 | 38 | self._add_ping_endpoint() 39 | self._add_operator_endpoints() 40 | self._add_create_endpoints() 41 | 42 | self.handle = self._build_handler() 43 | 44 | self.app = web.Application() 45 | self.app.add_routes([web.post("/statefun", self.handle)]) 46 | 47 | def _build_handler(self): 48 | self.handler = RequestReplyHandler(self.stateful_functions) 49 | 50 | async def handle(request): 51 | req = await request.read() 52 | res = await self.handler.handle_async(req) 53 | return web.Response(body=res, content_type="application/octet-stream") 54 | 55 | return handle 56 | 57 | def _add_ping_endpoint(self): 58 | async def ping_endpoint(ctx: Context, msg: Message): 59 | event_serialized = msg.raw_value() 60 | incoming_event = self.serializer.deserialize_event(event_serialized) 61 | 62 | if incoming_event.event_type != EventType.Request.Ping: 63 | raise AttributeError( 64 | f"Expected a Ping but got an {incoming_event.event_type}." 65 | ) 66 | 67 | outgoing_route: Route = self.ingress_router.route(incoming_event) 68 | 69 | ctx.send_egress( 70 | kafka_egress_message( 71 | typename="stateflow/kafka-egress", 72 | topic=self.reply_topic, 73 | key=outgoing_route.value.event_id, 74 | value=self.egress_router.serialize(outgoing_route.value), 75 | ) 76 | ) 77 | 78 | self.stateful_functions.register("globals/ping", ping_endpoint) 79 | 80 | def _route(self, ctx: Context, outgoing_event: Event): 81 | egress_route: Route = self.egress_router.route_and_serialize(outgoing_event) 82 | 83 | if egress_route.direction == RouteDirection.CLIENT: 84 | ctx.send_egress( 85 | kafka_egress_message( 86 | typename="stateflow/kafka-egress", 87 | topic=self.reply_topic, 88 | key=egress_route.value.event_id, 89 | value=self.egress_router.serialize(egress_route.value), 90 | ) 91 | ) 92 | return 93 | 94 | ingress_route: Route = self.ingress_router.route(egress_route.value) 95 | 96 | if ingress_route.direction == RouteDirection.INTERNAL: 97 | ctx.send( 98 | message_builder( 99 | target_typename=ingress_route.route_name, 100 | target_id=ingress_route.key, 101 | value=self.egress_router.serialize(ingress_route.value), 102 | value_type=self.byte_type, 103 | ) 104 | ) 105 | elif ingress_route.direction == RouteDirection.EGRESS: 106 | ctx.send_egress( 107 | kafka_egress_message( 108 | typename="stateflow/kafka-egress", 109 | topic=self.reply_topic, 110 | key=ingress_route.value.event_id, 111 | value=self.egress_router.serialize(ingress_route.value), 112 | ) 113 | ) 114 | 115 | def _add_operator_endpoints(self): 116 | for operator in self.dataflow.operators: 117 | self.operators_dict[operator.function_type.get_full_name()] = operator 118 | 119 | async def endpoint(ctx: Context, msg: Message): 120 | event_serialized = msg.raw_value() 121 | current_state = ctx.storage.state 122 | incoming_event = self.serializer.deserialize_event(event_serialized) 123 | 124 | outgoing_event, updated_state = self.operators_dict[ 125 | ctx.address.typename 126 | ].handle(incoming_event, current_state) 127 | 128 | if current_state != updated_state: 129 | ctx.storage.state = updated_state 130 | 131 | self._route(ctx, outgoing_event) 132 | 133 | self.stateful_functions.register( 134 | f"{operator.function_type.get_full_name()}", 135 | endpoint, 136 | specs=[ValueSpec(name="state", type=self.byte_type)], 137 | ) 138 | 139 | def _add_create_endpoints(self): 140 | for operator in self.dataflow.operators: 141 | self.operators_dict[ 142 | f"{operator.function_type.get_full_name()}_create" 143 | ] = operator 144 | 145 | async def endpoint(ctx: Context, msg: Message): 146 | event_serialized = msg.raw_value() 147 | incoming_event = self.serializer.deserialize_event(event_serialized) 148 | print( 149 | f"Now got a create request {incoming_event} for operator {operator.class_wrapper}" 150 | ) 151 | print(f"{ctx.address}") 152 | 153 | outgoing_event: Event = self.operators_dict[ 154 | ctx.address.typename 155 | ].handle_create(incoming_event) 156 | 157 | ctx.send( 158 | message_builder( 159 | target_typename=outgoing_event.fun_address.function_type.get_full_name(), 160 | target_id=outgoing_event.fun_address.key, 161 | value=self.egress_router.serialize(outgoing_event), 162 | value_type=self.byte_type, 163 | ) 164 | ) 165 | 166 | self.stateful_functions.register( 167 | f"{operator.function_type.get_full_name()}_create", endpoint 168 | ) 169 | 170 | def get_app(self): 171 | return self.app 172 | -------------------------------------------------------------------------------- /stateflow/runtime/runtime.py: -------------------------------------------------------------------------------- 1 | class Runtime: 2 | def __init__(self): 3 | pass 4 | 5 | def _setup_pipeline(self): 6 | raise NotImplementedError() 7 | 8 | def run(self, async_execution=False): 9 | raise NotImplementedError() 10 | -------------------------------------------------------------------------------- /stateflow/runtime/universalis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/runtime/universalis/__init__.py -------------------------------------------------------------------------------- /stateflow/serialization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/serialization/__init__.py -------------------------------------------------------------------------------- /stateflow/serialization/cloudpickle_serializer.py: -------------------------------------------------------------------------------- 1 | import struct 2 | 3 | from stateflow.serialization.serde import SerDe 4 | from stateflow.dataflow.event import Event 5 | from typing import Dict 6 | import cloudpickle 7 | 8 | 9 | class CloudpickleSerializer(SerDe): 10 | def serialize_event(self, event: Event) -> bytes: 11 | return struct.pack('>H', 0) + cloudpickle.dumps(event) 12 | 13 | def deserialize_event(self, event: bytes) -> Event: 14 | return cloudpickle.loads(event) 15 | 16 | def serialize_dict(self, dictionary: Dict) -> bytes: 17 | return struct.pack('>H', 0) + cloudpickle.dumps(dictionary) 18 | 19 | def deserialize_dict(self, dictionary: bytes) -> Dict: 20 | return cloudpickle.loads(dictionary) 21 | -------------------------------------------------------------------------------- /stateflow/serialization/json_serde.py: -------------------------------------------------------------------------------- 1 | from stateflow.serialization.serde import SerDe, Event, Dict 2 | from stateflow.dataflow.args import Arguments 3 | from stateflow.dataflow.event import EventType, FunctionAddress 4 | from stateflow.dataflow.event_flow import EventFlowGraph 5 | import ujson 6 | 7 | 8 | class JsonSerializer(SerDe): 9 | def serialize_event(self, event: Event) -> bytes: 10 | event_id: str = event.event_id 11 | event_type: str = event.event_type.value 12 | fun_address: dict = event.fun_address.to_dict() 13 | payload: dict = event.payload 14 | 15 | for item in payload: 16 | if hasattr(payload[item], "to_dict"): 17 | payload[item] = payload[item].to_dict() 18 | 19 | event = { 20 | "event_id": event_id, 21 | "event_type": event_type, 22 | "fun_address": fun_address, 23 | "payload": payload, 24 | } 25 | 26 | return bytes(self.serialize_dict(event), "utf-8") 27 | 28 | def deserialize_event(self, event: bytes) -> Event: 29 | json = self.deserialize_dict(event) 30 | 31 | event_id: str = json["event_id"] 32 | event_type: str = EventType.from_str(json["event_type"]) 33 | fun_address: dict = FunctionAddress.from_dict(json["fun_address"]) 34 | payload: dict = json["payload"] 35 | 36 | if "args" in payload: 37 | payload["args"] = Arguments.from_dict(json["payload"]["args"]) 38 | 39 | if "flow" in payload: 40 | payload["flow"] = EventFlowGraph.from_dict(payload["flow"]) 41 | 42 | return Event(event_id, fun_address, event_type, payload) 43 | 44 | def serialize_dict(self, dictionary: Dict) -> bytes: 45 | return ujson.encode(dictionary) 46 | 47 | def deserialize_dict(self, dictionary: bytes) -> Dict: 48 | return ujson.decode(dictionary) 49 | -------------------------------------------------------------------------------- /stateflow/serialization/pickle_serializer.py: -------------------------------------------------------------------------------- 1 | from stateflow.serialization.serde import SerDe 2 | from stateflow.dataflow.event import Event 3 | from typing import Dict 4 | import pickle 5 | 6 | 7 | class PickleSerializer(SerDe): 8 | def serialize_event(self, event: Event) -> bytes: 9 | return pickle.dumps(event) 10 | 11 | def deserialize_event(self, event: bytes) -> Event: 12 | return pickle.loads(event) 13 | 14 | def serialize_dict(self, dictionary: Dict) -> bytes: 15 | return pickle.dumps(dictionary) 16 | 17 | def deserialize_dict(self, dictionary: bytes) -> Dict: 18 | return pickle.loads(dictionary) 19 | -------------------------------------------------------------------------------- /stateflow/serialization/proto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/serialization/proto/__init__.py -------------------------------------------------------------------------------- /stateflow/serialization/proto/event.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message FunctionType { 4 | string namespace = 1; 5 | string name = 2; 6 | bool stateful = 3; 7 | } 8 | 9 | message FunctionAddress { 10 | FunctionType fun_type = 1; 11 | string key = 2; 12 | } 13 | 14 | enum Reply { 15 | SuccessfulInvocation = 0; 16 | SuccessfulCreateClass = 1; 17 | 18 | FoundClass = 2; 19 | KeyNotFound = 3; 20 | 21 | SuccessfulStateRequest = 4; 22 | FailedInvocation = 5; 23 | 24 | Pong = 6; 25 | } 26 | 27 | enum Request { 28 | InvokeStateless = 0; 29 | InvokeStateful = 1; 30 | InitClass = 2; 31 | 32 | FindClass = 3; 33 | 34 | GetState = 4; 35 | SetState = 5; 36 | UpdateState = 6; 37 | DeleteState = 7; 38 | 39 | EventFlow = 8; 40 | 41 | Ping = 9; 42 | } 43 | 44 | message EventFlowNode { 45 | FunctionAddress current_fun = 5; 46 | string current_node_type = 6; 47 | } 48 | 49 | message Event { 50 | string event_id = 1; 51 | FunctionAddress fun_address = 2; 52 | oneof event_type { 53 | Request request = 3; 54 | Reply reply = 4; 55 | } 56 | bytes payload = 5; 57 | EventFlowNode current = 6; 58 | } 59 | 60 | enum RouteDirection { 61 | EGRESS = 0; 62 | INTERNAL = 1; 63 | CLIENT = 2; 64 | } 65 | 66 | message Route { 67 | RouteDirection direction = 1; 68 | string route_name = 2; 69 | string key = 3; 70 | 71 | oneof value { 72 | Event event_value = 4; 73 | bytes bytes_value = 5; 74 | } 75 | } 76 | 77 | message EventRequestReply { 78 | Event event = 1; 79 | bytes state = 2; 80 | string operator_name = 3; 81 | } -------------------------------------------------------------------------------- /stateflow/serialization/proto/proto_serde.py: -------------------------------------------------------------------------------- 1 | from stateflow.serialization.serde import SerDe, Dict 2 | from stateflow.dataflow.address import FunctionAddress, FunctionType 3 | from stateflow.dataflow.event import Event, EventType 4 | from stateflow.dataflow.state import State 5 | from stateflow.dataflow.event_flow import EventFlowGraph 6 | from stateflow.serialization.proto import event_pb2 7 | from typing import Tuple 8 | import pickle 9 | 10 | 11 | class ProtoSerializer(SerDe): 12 | def event_to_proto(self, event: Event) -> event_pb2.Event: 13 | proto_event: event_pb2.Event = event_pb2.Event() 14 | 15 | # Set id. 16 | proto_event.event_id = event.event_id 17 | proto_event.fun_address.key = ( 18 | event.fun_address.key if event.fun_address.key else "" 19 | ) 20 | proto_event.fun_address.fun_type.namespace = ( 21 | event.fun_address.function_type.namespace 22 | ) 23 | proto_event.fun_address.fun_type.name = event.fun_address.function_type.name 24 | proto_event.fun_address.fun_type.stateful = ( 25 | event.fun_address.function_type.stateful 26 | ) 27 | 28 | # Set event type. 29 | if isinstance(event.event_type, EventType.Request): 30 | proto_event.request = event_pb2.Request.Value(event.event_type.value) 31 | elif isinstance(event.event_type, EventType.Reply): 32 | proto_event.reply = event_pb2.Reply.Value(event.event_type.value) 33 | else: 34 | raise AttributeError(f"Unknown event type! {event.event_type}") 35 | 36 | proto_event.payload = pickle.dumps(event.payload) 37 | 38 | # If we're dealing with event flow: 39 | if event.event_type == EventType.Request.EventFlow: 40 | flow_graph: EventFlowGraph = event.payload["flow"] 41 | current_node = flow_graph.current_node 42 | current_fun: FunctionAddress = current_node.fun_addr 43 | 44 | proto_event.current.current_fun.key = ( 45 | current_fun.key if current_fun.key else "" 46 | ) 47 | proto_event.current.current_fun.fun_type.namespace = ( 48 | current_fun.function_type.namespace 49 | ) 50 | proto_event.current.current_fun.fun_type.name = ( 51 | current_fun.function_type.name 52 | ) 53 | proto_event.current.current_fun.fun_type.stateful = ( 54 | current_fun.function_type.stateful 55 | ) 56 | proto_event.current.current_node_type = current_node.typ 57 | 58 | return proto_event 59 | 60 | def serialize_event(self, event: Event) -> str: 61 | proto_event: event_pb2.Event = self.event_to_proto(event) 62 | return proto_event.SerializeToString() 63 | 64 | def parse_event(self, event: event_pb2.Event) -> Event: 65 | event_id = event.event_id 66 | 67 | if event.fun_address.key: 68 | fun_addr = FunctionAddress( 69 | FunctionType( 70 | event.fun_address.fun_type.namespace, 71 | event.fun_address.fun_type.name, 72 | event.fun_address.fun_type.stateful, 73 | ), 74 | event.fun_address.key, 75 | ) 76 | else: 77 | fun_addr = FunctionAddress( 78 | FunctionType( 79 | event.fun_address.fun_type.namespace, 80 | event.fun_address.fun_type.name, 81 | event.fun_address.fun_type.stateful, 82 | ), 83 | "", 84 | ) 85 | 86 | # Set event type. 87 | if event.HasField("request"): 88 | event_type = EventType.Request[event_pb2.Request.Name(event.request)] 89 | elif event.HasField("reply"): 90 | event_type = EventType.Reply[event_pb2.Reply.Name(event.reply)] 91 | else: 92 | raise AttributeError(f"Unknown event type! {event.event_type}") 93 | 94 | payload = pickle.loads(event.payload) 95 | 96 | return Event(event_id, fun_addr, event_type, payload) 97 | 98 | def deserialize_event(self, raw_event: bytes) -> Event: 99 | event: event_pb2.Event = event_pb2.Event() 100 | event.ParseFromString(raw_event) 101 | 102 | return self.parse_event(event) 103 | 104 | def deserialize_request(self, raw_request: bytes) -> Tuple[Event, State]: 105 | request: event_pb2.EventRequestReply = event_pb2.EventRequestReply() 106 | request.ParseFromString(raw_request) 107 | 108 | return self.parse_event(request.event), request.state, request.operator_name 109 | 110 | def serialize_request(self, event: Event, state: bytes) -> bytes: 111 | request: event_pb2.EventRequestReply = event_pb2.EventRequestReply() 112 | request.event.CopyFrom(self.event_to_proto(event)) 113 | request.state = state 114 | request.operator_name = "" 115 | 116 | return request.SerializeToString() 117 | 118 | def serialize_dict(self, dict: Dict) -> bytes: 119 | return pickle.dumps(dict) 120 | 121 | def deserialize_dict(self, dict: bytes) -> Dict: 122 | return pickle.loads(dict) 123 | -------------------------------------------------------------------------------- /stateflow/serialization/serde.py: -------------------------------------------------------------------------------- 1 | from stateflow.dataflow.event import Event 2 | import abc 3 | from typing import Dict 4 | 5 | 6 | class SerDe(metaclass=abc.ABCMeta): 7 | @abc.abstractmethod 8 | def serialize_event(self, event: Event) -> bytes: 9 | pass 10 | 11 | @abc.abstractmethod 12 | def deserialize_event(self, event: bytes) -> Event: 13 | pass 14 | 15 | @abc.abstractmethod 16 | def serialize_dict(self, dict: Dict) -> bytes: 17 | pass 18 | 19 | @abc.abstractmethod 20 | def deserialize_dict(self, dict: bytes) -> Dict: 21 | pass 22 | -------------------------------------------------------------------------------- /stateflow/split/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /stateflow/split/conditional_block.py: -------------------------------------------------------------------------------- 1 | import libcst as cst 2 | from stateflow.split.split_block import ( 3 | SplitContext, 4 | Use, 5 | Block, 6 | InvocationContext, 7 | ReplaceCall, 8 | EventFlowNode, 9 | ) 10 | from stateflow.dataflow.event_flow import InvokeConditional 11 | from typing import List, Optional, Tuple 12 | from stateflow.descriptors.class_descriptor import ClassDescriptor 13 | from dataclasses import dataclass 14 | 15 | 16 | @dataclass 17 | class ConditionalBlockContext(SplitContext): 18 | """This is the context for a conditional block. 19 | 20 | This block may or may not have a previous_invocation. 21 | If that is the case, we will add it as parameter and replace the call. 22 | """ 23 | 24 | previous_invocation: Optional[InvocationContext] = None 25 | 26 | 27 | class ConditionalExpressionAnalyzer(cst.CSTVisitor): 28 | def __init__(self, expression_provider): 29 | self.expression_provider = expression_provider 30 | self.usages: List[Use] = [] 31 | 32 | def visit_Name(self, node: cst.Name): 33 | if node in self.expression_provider: 34 | expression_context = self.expression_provider[node] 35 | 36 | if ( 37 | expression_context == cst.metadata.ExpressionContext.LOAD 38 | and node.value != "self" 39 | and node.value != "True" 40 | and node.value != "False" 41 | and node.value != "print" 42 | ): 43 | self.usages.append(Use(node.value)) 44 | 45 | 46 | class ConditionalBlock(Block): 47 | def __init__( 48 | self, 49 | block_id: int, 50 | split_context: ConditionalBlockContext, 51 | test: cst.BaseExpression, 52 | previous_block: Optional[Block] = None, 53 | invocation_block: Optional[Block] = None, 54 | label: str = "", 55 | state_request: List[Tuple[str, ClassDescriptor]] = [], 56 | ): 57 | super().__init__(block_id, split_context, previous_block, label, state_request) 58 | self.test_expr: cst.BaseExpression = test 59 | self.invocation_block: Optional[Block] = invocation_block 60 | 61 | if self.invocation_block: 62 | self.invocation_block.set_next_block(self) 63 | 64 | # Get rid of the invocation and replace with the result. 65 | if self.split_context.previous_invocation: 66 | self.test_expr = self.test_expr.visit( 67 | ReplaceCall( 68 | self.split_context.previous_invocation.method_invoked, 69 | self._previous_call_result(), 70 | ) 71 | ) 72 | 73 | # Verify usages of this block. 74 | analyzer: ConditionalExpressionAnalyzer = ConditionalExpressionAnalyzer( 75 | split_context.expression_provider 76 | ) 77 | self.test_expr.visit(analyzer) 78 | 79 | self.true_block: Optional[Block] = None 80 | self.false_block: Optional[Block] = None 81 | 82 | self.dependencies: List[str] = [u.name for u in analyzer.usages] 83 | self.new_function: cst.FunctionDef = self.build_definition() 84 | 85 | def fun_name(self) -> str: 86 | """Get the name of this function given the block id. 87 | :return: the (unique) name of this block. 88 | """ 89 | return ( 90 | f"{self.split_context.original_method_node.name.value}_cond_{self.block_id}" 91 | ) 92 | 93 | def _get_true_block(self) -> Optional[Block]: 94 | return self.true_block 95 | 96 | def _get_false_block(self) -> Optional[Block]: 97 | return self.false_block 98 | 99 | def set_true_block(self, block: Block): 100 | self.true_block = block 101 | 102 | def set_false_block(self, block: Block): 103 | self.false_block = block 104 | 105 | def get_start_block(self) -> Block: 106 | return self if not self.invocation_block else self.invocation_block 107 | 108 | def _build_return(self) -> cst.SimpleStatementLine: 109 | return cst.SimpleStatementLine(body=[cst.Return(self.test_expr)]) 110 | 111 | def build_event_flow_nodes(self, node_id: int) -> List[EventFlowNode]: 112 | nodes_block = super().build_event_flow_nodes(node_id) 113 | 114 | # Initialize id. 115 | flow_node_id = node_id + len(nodes_block) + 1 # Offset the id. 116 | 117 | latest_node: Optional[EventFlowNode] = ( 118 | None if len(nodes_block) == 0 else nodes_block[-1] 119 | ) 120 | 121 | # For re-use purposes, we define the FunctionType of the class this StatementBlock belongs to. 122 | class_type = self.split_context.class_desc.to_function_type().to_address() 123 | 124 | """ For an conditional node, we assume the following scenario: 125 | 1. We assume that if the conditional relies on an external function, 126 | the previous node has that result. 127 | 2. The conditional is simply evaluated and based on the result, 128 | the path is decided. 129 | """ 130 | invoke_conditional: InvokeConditional = InvokeConditional( 131 | class_type, 132 | flow_node_id, 133 | self.fun_name(), 134 | self.dependencies, 135 | if_true_block_id=self.true_block.block_id, 136 | if_false_block_id=self.false_block.block_id, 137 | ) 138 | 139 | if latest_node: 140 | latest_node.set_next(invoke_conditional.id) 141 | invoke_conditional.set_previous(latest_node.id) 142 | 143 | # The 'true' and 'false' block are updated later on. 144 | # We don't know their id's yet. 145 | 146 | return nodes_block + [invoke_conditional] 147 | 148 | def build_definition(self) -> cst.FunctionDef: 149 | fun_name: cst.Name = cst.Name(self.fun_name()) 150 | param_node: cst.Paramaters = self._build_params() 151 | 152 | return_node: cst.SimpleStatementLine = self._build_return() 153 | return self.split_context.original_method_node.with_changes( 154 | name=fun_name, 155 | params=param_node, 156 | body=self.split_context.original_method_node.body.with_changes( 157 | body=[return_node] 158 | ), 159 | ) 160 | -------------------------------------------------------------------------------- /stateflow/split/for_block.py: -------------------------------------------------------------------------------- 1 | from stateflow.split.split_block import ( 2 | SplitContext, 3 | Block, 4 | EventFlowNode, 5 | ClassDescriptor, 6 | ) 7 | from stateflow.dataflow.event_flow import InvokeFor 8 | import libcst as cst 9 | import libcst.matchers as m 10 | from typing import Optional, List, Tuple 11 | 12 | 13 | class ForBlock(Block): 14 | def __init__( 15 | self, 16 | block_id: int, 17 | iter_name: str, 18 | target: cst.BaseAssignTargetExpression, 19 | split_context: SplitContext, 20 | previous_block: Optional[Block] = None, 21 | label: str = "", 22 | state_request: List[Tuple[str, ClassDescriptor]] = [], 23 | ): 24 | super().__init__(block_id, split_context, previous_block, label, state_request) 25 | self.iter_name: str = iter_name 26 | self.target: cst.BaseAssignTargetExpression = target 27 | self.else_block: Optional[Block] = None 28 | self.body_start_block: Optional[Block] = None 29 | 30 | self.dependencies.append(self.iter_name) 31 | self.new_function: cst.FunctionDef = self.build_definition() 32 | 33 | def _get_target_name(self): 34 | if isinstance(self.target, str): 35 | return self.target 36 | elif m.matches(self.target, m.Name()): 37 | return self.target.value 38 | else: 39 | raise AttributeError(f"Cannot convert target node to string {self.target}.") 40 | 41 | def set_else_block(self, block: Block): 42 | self.else_block = block 43 | 44 | def set_body_start_block(self, block: Block): 45 | self.body_start_block = block 46 | 47 | def fun_name(self) -> str: 48 | """Get the name of this function given the block id. 49 | :return: the (unique) name of this block. 50 | """ 51 | return ( 52 | f"{self.split_context.original_method_node.name.value}_iter_{self.block_id}" 53 | ) 54 | 55 | def _build_params(self) -> cst.Parameters: 56 | params: List[cst.Param] = [cst.Param(cst.Name(value="self"))] 57 | for usage in self.dependencies: 58 | params.append(cst.Param(cst.Name(value=usage))) 59 | 60 | param_node: cst.Parameters = cst.Parameters(tuple(params)) 61 | 62 | return param_node 63 | 64 | def _build_body(self): 65 | return cst.helpers.parse_template_module( 66 | """ 67 | try: 68 | {iter_target} = next({it}) 69 | except StopIteration: 70 | return {'_type': 'StopIteration'} 71 | """, 72 | iter_target=self.target, 73 | it=cst.Name(self.iter_name), 74 | ).body 75 | 76 | def _build_return(self) -> cst.SimpleStatementLine: 77 | return cst.SimpleStatementLine( 78 | body=[ 79 | cst.Return( 80 | cst.Tuple( 81 | [ 82 | cst.Element(self.target), 83 | cst.Element(cst.Name(self.iter_name)), 84 | ] 85 | ) 86 | ) 87 | ] 88 | ) 89 | 90 | def build_event_flow_nodes(self, node_id: int): 91 | # We can only have an else block, a for body block and a next statement block. 92 | assert len(self.next_block) <= 3 93 | 94 | nodes_block = super().build_event_flow_nodes(node_id) 95 | 96 | latest_node: Optional[EventFlowNode] = ( 97 | None if len(nodes_block) == 0 else nodes_block[-1] 98 | ) 99 | 100 | # Initialize id. 101 | flow_node_id = node_id + len(nodes_block) + 1 # Offset the id. 102 | 103 | # For re-use purposes, we define the FunctionType of the class this StatementBlock belongs to. 104 | class_type = self.split_context.class_desc.to_function_type().to_address() 105 | 106 | invoke_for: InvokeFor = InvokeFor( 107 | class_type, 108 | flow_node_id, 109 | self.fun_name(), 110 | self.iter_name, 111 | self._get_target_name(), 112 | for_body_block_id=self.body_start_block.block_id, 113 | else_block_id=self.else_block.block_id 114 | if self.else_block is not None 115 | else -1, 116 | ) 117 | 118 | if latest_node: 119 | latest_node.set_next(invoke_for.id) 120 | invoke_for.set_previous(latest_node.id) 121 | 122 | return nodes_block + [invoke_for] 123 | 124 | def build_definition(self) -> cst.FunctionDef: 125 | fun_name: cst.Name = cst.Name(self.fun_name()) 126 | param_node: cst.Paramaters = self._build_params() 127 | body = self._build_body() 128 | return_node: cst.SimpleStatementLine = self._build_return() 129 | 130 | return self.split_context.original_method_node.with_changes( 131 | name=fun_name, 132 | params=param_node, 133 | body=self.split_context.original_method_node.body.with_changes( 134 | body=list(body) + [return_node] 135 | ), 136 | ) 137 | -------------------------------------------------------------------------------- /stateflow/split/split_transform.py: -------------------------------------------------------------------------------- 1 | import libcst as cst 2 | from libcst import matchers as m 3 | from typing import Union, List, Dict 4 | from stateflow.split.split_block import StatementBlock 5 | 6 | 7 | class RemoveAfterClassDefinition(cst.CSTTransformer): 8 | def __init__(self, class_name: str): 9 | self.class_name: str = class_name 10 | self.is_defined = False 11 | 12 | def leave_ClassDef( 13 | self, original_node: cst.ClassDef, updated_node: cst.ClassDef 14 | ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: 15 | 16 | new_decorators = [] 17 | for decorator in updated_node.decorators: 18 | if "stateflow" not in cst.helpers.get_full_name_for_node(decorator): 19 | new_decorators.append(decorator) 20 | 21 | if original_node.name == self.class_name: 22 | self.is_defined = True 23 | return updated_node.with_changes(decorators=tuple(new_decorators)) 24 | 25 | return updated_node.with_changes(decorators=tuple(new_decorators)) 26 | 27 | # def on_leave( 28 | # self, original_node: cst.CSTNodeT, updated_node: cst.CSTNodeT 29 | # ) -> Union[cst.CSTNodeT, cst.RemovalSentinel]: 30 | # if self.is_defined: 31 | # return cst.RemovalSentinel() 32 | 33 | 34 | class SplitTransformer(cst.CSTTransformer): 35 | def __init__( 36 | self, class_name: str, updated_methods: Dict[str, List[StatementBlock]] 37 | ): 38 | self.class_name: str = class_name 39 | self.updated_methods = updated_methods 40 | 41 | def visit_ClassDef(self, node: cst.ClassDef): 42 | if node.name.value != self.class_name: 43 | return False 44 | return True 45 | 46 | def leave_ClassDef( 47 | self, original_node: cst.ClassDef, updated_node: cst.ClassDef 48 | ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: 49 | if updated_node.name.value != self.class_name: 50 | return updated_node 51 | 52 | class_body = updated_node.body.body 53 | 54 | for method in self.updated_methods.values(): 55 | for split in method: 56 | class_body = (*class_body, split.new_function) 57 | 58 | return updated_node.with_changes( 59 | body=updated_node.body.with_changes(body=class_body) 60 | ) 61 | 62 | def leave_FunctionDef( 63 | self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef 64 | ) -> cst.CSTNode: 65 | if ( 66 | m.matches(original_node.body, m.IndentedBlock()) 67 | and updated_node.name.value in self.updated_methods 68 | ): 69 | pass_node = cst.SimpleStatementLine(body=[cst.Pass()]) 70 | new_block = original_node.body.with_changes(body=[pass_node]) 71 | return updated_node.with_changes(body=new_block) 72 | 73 | return updated_node 74 | -------------------------------------------------------------------------------- /stateflow/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/util/__init__.py -------------------------------------------------------------------------------- /stateflow/util/dataflow_operator_generator.py: -------------------------------------------------------------------------------- 1 | from stateflow.dataflow.dataflow import Dataflow 2 | 3 | 4 | def generate_operators(flow: Dataflow): 5 | operators = ",".join( 6 | [operator.function_type.get_full_name() for operator in flow.operators] 7 | ) 8 | return operators 9 | -------------------------------------------------------------------------------- /stateflow/util/dataflow_visualizer.py: -------------------------------------------------------------------------------- 1 | from stateflow.split.split_block import Block 2 | from stateflow.split.conditional_block import ConditionalBlock 3 | from stateflow.dataflow.event_flow import ( 4 | EventFlowNode, 5 | InvokeConditional, 6 | InvokeExternal, 7 | ) 8 | from stateflow.client.class_ref import ClassRef 9 | from typing import List 10 | from graphviz import Digraph 11 | 12 | 13 | def visualize(blocks: List[Block], code=False): 14 | dot = Digraph(comment="Visualized dataflow") 15 | 16 | nodes = [] 17 | 18 | for b in blocks: 19 | if isinstance(b, ConditionalBlock): 20 | block_code = b.code().replace("\n", "\l").replace("\l", "", 1) 21 | dot_node = dot.node( 22 | str(b.block_id), 23 | label=f"{b.block_id} - {b.label}" if not code else block_code, 24 | _attributes={ 25 | "shape": "rectangle", 26 | "fillcolor": "lightskyblue", 27 | "style": "filled", 28 | }, 29 | ) 30 | else: 31 | block_code = b.code().replace("\n", "\l").replace("\l", "", 1) 32 | dot_node = dot.node( 33 | str(b.block_id), 34 | label=f"{b.block_id} - {b.label}" if not code else block_code, 35 | shape="rectangle", 36 | ) 37 | nodes.append(dot_node) 38 | 39 | for b in blocks: 40 | if isinstance(b, ConditionalBlock): 41 | dot.edge( 42 | str(b.block_id), 43 | str(b.true_block.block_id), 44 | label="T", 45 | color="darkgreen", 46 | style="dotted", 47 | ) 48 | if b.false_block: 49 | dot.edge( 50 | str(b.block_id), 51 | str(b.false_block.block_id), 52 | label="F", 53 | color="crimson", 54 | style="dotted", 55 | ) 56 | else: 57 | for next in b.next_block: 58 | dot.edge(str(b.block_id), str(next.block_id)) 59 | 60 | return dot.source 61 | 62 | 63 | def visualize_ref(ref: ClassRef, name: str): 64 | class_desc = ref._class_desc 65 | method = class_desc.get_method_by_name(name) 66 | 67 | return visualize_flow(method.flow_list) 68 | 69 | 70 | def visualize_flow(flow: List[EventFlowNode]): 71 | dot = Digraph(comment="Visualized dataflow") 72 | 73 | nodes = [] 74 | colors = [ 75 | "black", 76 | "purple3", 77 | "seagreen", 78 | "royalblue4", 79 | "orangered3", 80 | "yellow4", 81 | "webmaroon", 82 | ] 83 | 84 | for n in flow: 85 | if isinstance(n, InvokeConditional): 86 | nodes.append( 87 | dot.node( 88 | str(n.id), 89 | label=str(n.typ), 90 | _attributes={ 91 | "shape": "rectangle", 92 | "fillcolor": "lightskyblue", 93 | "style": "filled", 94 | }, 95 | fontcolor=colors[n.method_id], 96 | ) 97 | ) 98 | elif isinstance(n, InvokeExternal): 99 | nodes.append( 100 | dot.node( 101 | str(n.id), 102 | label=str(n.typ), 103 | style="filled", 104 | shape="box", 105 | fillcolor="darkseagreen3", 106 | fontcolor=colors[n.method_id], 107 | ) 108 | ) 109 | else: 110 | nodes.append( 111 | dot.node(str(n.id), label=str(n.typ), fontcolor=colors[n.method_id]) 112 | ) 113 | 114 | for n in flow: 115 | if isinstance(n, InvokeConditional): 116 | conditional: InvokeConditional = n 117 | 118 | dot.edge( 119 | str(conditional.id), 120 | str(conditional.if_true_node), 121 | label="T", 122 | color="darkgreen", 123 | style="dotted", 124 | ) 125 | dot.edge( 126 | str(conditional.id), 127 | str(conditional.if_false_node), 128 | label="F", 129 | color="crimson", 130 | style="dotted", 131 | ) 132 | else: 133 | for next in n.next: 134 | dot.edge(str(n.id), str(next)) 135 | 136 | return dot.source 137 | -------------------------------------------------------------------------------- /stateflow/util/local_runtime.py: -------------------------------------------------------------------------------- 1 | from stateflow.client.stateflow_client import StateflowClient, StateflowFuture, T 2 | from stateflow.dataflow.dataflow import Dataflow 3 | from stateflow.dataflow.stateful_operator import StatefulOperator 4 | from stateflow.serialization.pickle_serializer import SerDe, PickleSerializer 5 | from stateflow.dataflow.dataflow import ( 6 | IngressRouter, 7 | EgressRouter, 8 | Route, 9 | RouteDirection, 10 | EventType, 11 | ) 12 | from stateflow.dataflow.event import Event 13 | from typing import Dict, ByteString 14 | import time 15 | 16 | 17 | class LocalRuntime(StateflowClient): 18 | def __init__( 19 | self, 20 | flow: Dataflow, 21 | serializer: SerDe = PickleSerializer(), 22 | return_future: bool = False, 23 | ): 24 | super().__init__(flow, serializer) 25 | 26 | self.flow: Dataflow = flow 27 | self.serializer: SerDe = serializer 28 | 29 | self.ingress_router = IngressRouter(self.serializer) 30 | self.egress_router = EgressRouter(self.serializer, serialize_on_return=False) 31 | 32 | self.operators = { 33 | operator.function_type.get_full_name(): operator 34 | for operator in self.flow.operators 35 | } 36 | 37 | # Set the wrapper. 38 | [op.meta_wrapper.set_client(self) for op in flow.operators] 39 | 40 | self.state: Dict[str, ByteString] = {} 41 | self.return_future: bool = return_future 42 | 43 | def invoke_operator(self, route: Route) -> Event: 44 | event: Event = route.value 45 | 46 | operator_name: str = route.route_name 47 | operator: StatefulOperator = self.operators[operator_name] 48 | 49 | if event.event_type == EventType.Request.InitClass and route.key is None: 50 | new_event = operator.handle_create(event) 51 | return self.invoke_operator( 52 | Route( 53 | RouteDirection.INTERNAL, 54 | operator_name, 55 | new_event.fun_address.key, 56 | new_event, 57 | ) 58 | ) 59 | else: 60 | full_key: str = f"{operator_name}_{route.key}" 61 | operator_state = self.state.get(full_key) 62 | return_event, updated_state = operator.handle(event, operator_state) 63 | self.state[full_key] = updated_state 64 | 65 | return return_event 66 | 67 | def handle_invocation(self, event: Event) -> Route: 68 | route: Route = self.ingress_router.route(event) 69 | 70 | if route.direction == RouteDirection.INTERNAL: 71 | return self.egress_router.route_and_serialize(self.invoke_operator(route)) 72 | elif route.direction == RouteDirection.EGRESS: 73 | return self.egress_router.route_and_serialize(route.value) 74 | else: 75 | return route 76 | 77 | def execute_event(self, event: Event) -> Event: 78 | parsed_event: Event = self.ingress_router.parse(event) 79 | return_route: Route = self.handle_invocation(parsed_event) 80 | 81 | while return_route.direction != RouteDirection.CLIENT: 82 | return_route = self.handle_invocation(return_route.value) 83 | 84 | return return_route.value 85 | 86 | def send(self, event: Event, return_type: T = None) -> T: 87 | return_event = self.execute_event(self.serializer.serialize_event(event)) 88 | future = StateflowFuture( 89 | event.event_id, time.time(), event.fun_address, return_type 90 | ) 91 | 92 | future.complete(return_event) 93 | if self.return_future: 94 | return future 95 | else: 96 | return future.get() 97 | -------------------------------------------------------------------------------- /stateflow/util/stateflow_test.py: -------------------------------------------------------------------------------- 1 | from pytest import fixture 2 | from stateflow.util.local_runtime import LocalRuntime 3 | import stateflow.core as stateflow 4 | 5 | 6 | @fixture(autouse=True) 7 | def stateflow_test(): 8 | client = LocalRuntime(stateflow.init()) 9 | yield client 10 | -------------------------------------------------------------------------------- /stateflow/util/statefun_module_generator.py: -------------------------------------------------------------------------------- 1 | from stateflow.dataflow.dataflow import Dataflow 2 | from typing import List 3 | import yaml 4 | 5 | 6 | def generate( 7 | flow: Dataflow, statefun_cluster_url: str = "http://localhost:8000/statefun" 8 | ): 9 | spec = {"endpoints": [], "ingresses": {}, "egresses": {}} 10 | 11 | spec["endpoints"].append( 12 | {"endpoint": _generate_endpoint_dict("globals/ping", statefun_cluster_url)} 13 | ) 14 | 15 | all_functions: List[str] = [] 16 | for operator in flow.operators: 17 | fun_name: str = operator.function_type.get_full_name() 18 | spec["endpoints"].append( 19 | {"endpoint": _generate_endpoint_dict(fun_name, statefun_cluster_url)} 20 | ) 21 | 22 | spec["endpoints"].append( 23 | { 24 | "endpoint": _generate_endpoint_dict( 25 | f"{fun_name}_create", statefun_cluster_url 26 | ) 27 | } 28 | ) 29 | 30 | all_functions.append(fun_name) 31 | all_functions.append(f"{fun_name}_create") 32 | 33 | all_functions.append("globals/ping") 34 | 35 | spec["ingresses"] = [{"ingress": _generate_kafka_ingress(all_functions)}] 36 | spec["egresses"] = [ 37 | { 38 | "egress": { 39 | "meta": { 40 | "type": "io.statefun.kafka/egress", 41 | "id": "stateflow/kafka-egress", 42 | }, 43 | "spec": {"address": "localhost:9092"}, 44 | } 45 | } 46 | ] 47 | 48 | basic_setup = { 49 | "version": "3.0", 50 | "module": {"meta": {"type": "remote"}, "spec": spec}, 51 | } 52 | 53 | return yaml.dump(basic_setup) 54 | 55 | 56 | def _generate_endpoint_dict(function_name: str, statefun_cluster_url: str): 57 | return { 58 | "meta": {"kind": "http"}, 59 | "spec": {"functions": function_name, "urlPathTemplate": statefun_cluster_url}, 60 | } 61 | 62 | 63 | def _generate_kafka_ingress( 64 | all_functions: List[str], kafka_broker: str = "localhost:9092" 65 | ): 66 | topics = [] 67 | 68 | for topic in all_functions: 69 | topics.append( 70 | { 71 | "topic": topic.replace("/", "_"), 72 | "valueType": "stateflow/byte_type", 73 | "targets": [topic], 74 | } 75 | ) 76 | 77 | return { 78 | "meta": {"type": "io.statefun.kafka/ingress", "id": "stateflow/kafka-ingress"}, 79 | "spec": { 80 | "address": kafka_broker, 81 | "consumerGroupId": "stateflow-statefun-consumer", 82 | "topics": topics, 83 | }, 84 | } 85 | -------------------------------------------------------------------------------- /stateflow/wrappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delftdata/stateflow/2230d5a382914c660daf5e09e0c4ff37089e050e/stateflow/wrappers/__init__.py -------------------------------------------------------------------------------- /stateflow/wrappers/meta_wrapper.py: -------------------------------------------------------------------------------- 1 | from stateflow.client.class_ref import ( 2 | ClassRef, 3 | StateflowClient, 4 | StateflowFuture, 5 | AsyncClassRef, 6 | ) 7 | from stateflow.dataflow.address import FunctionType 8 | from stateflow.descriptors.class_descriptor import ClassDescriptor 9 | from stateflow.dataflow.event import Event, FunctionAddress, EventType 10 | from stateflow.dataflow.args import Arguments 11 | from typing import Union 12 | import uuid 13 | 14 | 15 | class MetaWrapper(type): 16 | """A meta-class around the client-side class definition. 17 | We use this meta-implementation to intercept interaction with a class. 18 | This interception is used to generate events to the back-end/runtime. 19 | 20 | For example, when the class is constructed an event is sent to the runtime to generate this instance there. 21 | This wrapper is responsible for two kind of behaviours: 22 | 1. Sending an event to the runtime to instantiate a class 23 | 2. Creating a ClientRef based on a created instance. 24 | """ 25 | 26 | def __new__(msc, name, bases, dct, descriptor: ClassDescriptor = None): 27 | """Constructs a meta-class for a certain class definition. 28 | 29 | :param name: name of the original class. 30 | :param bases: bases of the original class. 31 | :param dct: dct of the original class. 32 | :param descriptor: the class descriptor of this class. 33 | """ 34 | msc.client: StateflowClient = None 35 | msc.asynchronous: bool = False 36 | dct["descriptor"]: ClassDescriptor = descriptor 37 | return super(MetaWrapper, msc).__new__(msc, name, bases, dct) 38 | 39 | def to_asynchronous_wrapper(msc): 40 | msc.asynchronous = True 41 | 42 | def by_key(msc, key: str): 43 | if msc.asynchronous: 44 | return msc.__async_call__(**{"__key": key}) 45 | else: 46 | return msc(**{"__key": key}) 47 | 48 | async def __async_call__(msc, *args, **kwargs) -> Union[ClassRef, StateflowFuture]: 49 | if "__key" in kwargs: 50 | return AsyncClassRef( 51 | FunctionAddress(FunctionType.create(msc.descriptor), kwargs["__key"]), 52 | msc.descriptor, 53 | msc.client, 54 | ) 55 | 56 | fun_address = FunctionAddress(FunctionType.create(msc.descriptor), None) 57 | 58 | event_id: str = str(uuid.uuid4()) 59 | 60 | # Build arguments. 61 | # print(args) 62 | # print(kwargs) 63 | args = Arguments.from_args_and_kwargs( 64 | msc.descriptor.get_method_by_name("__init__").input_desc.get(), 65 | *args, 66 | **kwargs, 67 | ) 68 | 69 | payload = {"args": args} 70 | 71 | # Creates a class event. 72 | create_class_event = Event( 73 | event_id, fun_address, EventType.Request.InitClass, payload 74 | ) 75 | 76 | # print(f"Now sending events with payload {args.get()}") 77 | result = await msc.client.send(create_class_event, msc) 78 | return result 79 | 80 | def __call__(msc, *args, **kwargs) -> Union[ClassRef, StateflowFuture]: 81 | """Invoked on constructing an instance of class. 82 | We cover two scenarios here: 83 | 84 | 1. The instance is _not_ yet created on the server (we don't verify this here), and we send 85 | an event via the client to the runtime to create this object with the given args and kwargs. 86 | Therefore, we verify if the args + kwargs _matches_ the InputDescriptor of the __init__ method of the class. 87 | 2. The instance is created on the server and therefore we know the _key_ of the object. In that case, 88 | a ClassRef is returned. This a reference that the client-side can interact with (i.e. call methods, get and 89 | update attributes). 90 | 91 | We differentiate between both scenarios by looking for the "__key" attribute in kwargs. If this _is_ given 92 | we 'know' the instance has been created on the server and we can safely create and return a ClassRef. 93 | We might get conflicts if a user defines a "__key" argument in its __init__ method, but we assume this is 94 | very unlikely. We don't explicitly check if this is the case right now. 95 | 96 | :param args: invocation args. 97 | :param kwargs: invocation kwargs. 98 | :return: either a StateflowFuture or ClassRef. 99 | """ 100 | if "__call__" in vars(msc): 101 | return vars(msc)["__async_call__"](args, kwargs) 102 | 103 | if "__key" in kwargs: 104 | return ClassRef( 105 | FunctionAddress(FunctionType.create(msc.descriptor), kwargs["__key"]), 106 | msc.descriptor, 107 | msc.client, 108 | ) 109 | 110 | fun_address = FunctionAddress(FunctionType.create(msc.descriptor), None) 111 | 112 | event_id: str = str(uuid.uuid4()) 113 | 114 | # Build arguments. 115 | args = Arguments.from_args_and_kwargs( 116 | msc.descriptor.get_method_by_name("__init__").input_desc.get(), 117 | *args, 118 | **kwargs, 119 | ) 120 | 121 | payload = {"args": args} 122 | 123 | # Creates a class event. 124 | create_class_event = Event( 125 | event_id, fun_address, EventType.Request.InitClass, payload 126 | ) 127 | 128 | return msc.client.send(create_class_event, msc) 129 | 130 | def set_client(msc, client: StateflowClient): 131 | """Sets the client of this class. 132 | The reason we have to set it explicitly (and not add it as constructor argument) 133 | is because we initialize the meta class _before_ the client is initialized. 134 | 135 | I.e. 136 | stateflow.init() 137 | is called before 138 | StateFlowClient() 139 | 140 | :param client: the client to set. 141 | """ 142 | msc.client = client 143 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from tests.kafka import KafkaImage 3 | 4 | 5 | @pytest.fixture(scope="session") 6 | def kafka(): 7 | kafka_image = KafkaImage() 8 | yield kafka_image.run() 9 | kafka_image.stop() 10 | -------------------------------------------------------------------------------- /tests/analysis/ast_utils_test.py: -------------------------------------------------------------------------------- 1 | from tests.context import stateflow 2 | from stateflow.analysis.ast_utils import * 3 | 4 | 5 | def test_self_positive(): 6 | stmt = "self.x" 7 | parsed = cst.parse_statement(stmt) 8 | 9 | assert is_self(parsed.body[0].value) 10 | 11 | 12 | def test_self_negative(): 13 | stmt = "not_self.x" 14 | parsed = cst.parse_statement(stmt) 15 | 16 | assert not is_self(parsed.body[0].value) 17 | 18 | 19 | def test_self_not_attribute(): 20 | stmt = "x + 3" 21 | parsed = cst.parse_statement(stmt) 22 | 23 | assert not is_self(parsed.body[0].value) 24 | 25 | 26 | def test_type_positive_primitive(): 27 | stmt = "x: int" 28 | parsed = cst.parse_module(stmt) 29 | assert extract_types(parsed, parsed.body[0].body[0].annotation) == "int" 30 | 31 | stmt = "x: str" 32 | parsed = cst.parse_module(stmt) 33 | assert extract_types(parsed, parsed.body[0].body[0].annotation) == "str" 34 | 35 | stmt = "x: bool" 36 | parsed = cst.parse_module(stmt) 37 | assert extract_types(parsed, parsed.body[0].body[0].annotation) == "bool" 38 | 39 | stmt = "x: float" 40 | parsed = cst.parse_module(stmt) 41 | assert extract_types(parsed, parsed.body[0].body[0].annotation) == "float" 42 | 43 | stmt = "x: bytes" 44 | parsed = cst.parse_module(stmt) 45 | assert extract_types(parsed, parsed.body[0].body[0].annotation) == "bytes" 46 | 47 | 48 | def test_type_positive_complex_types(): 49 | stmt = """ 50 | from collections.abc import Sequence 51 | ConnectionOptions = dict[str, str] 52 | Address = tuple[str, int] 53 | Server = tuple[Address, ConnectionOptions] 54 | 55 | x: Sequence[Server] 56 | """ 57 | parsed = cst.parse_module(stmt) 58 | assert ( 59 | extract_types(parsed, parsed.body[-1].body[0].annotation) == "Sequence[Server]" 60 | ) 61 | 62 | stmt = """ 63 | class TestClass: 64 | pass 65 | 66 | x: TestClass 67 | """ 68 | parsed = cst.parse_module(stmt) 69 | assert extract_types(parsed, parsed.body[-1].body[0].annotation) == "TestClass" 70 | 71 | stmt = """ 72 | x: List[int] 73 | """ 74 | parsed = cst.parse_module(stmt) 75 | assert extract_types(parsed, parsed.body[-1].body[0].annotation) == "List[int]" 76 | 77 | stmt = """ 78 | x: List[OtherClass] 79 | """ 80 | parsed = cst.parse_module(stmt) 81 | assert ( 82 | extract_types(parsed, parsed.body[-1].body[0].annotation) == "List[OtherClass]" 83 | ) 84 | 85 | 86 | def test_type_unpacking(): 87 | stmt = "x: Tuple[str, int, bytes]" 88 | parsed = cst.parse_module(stmt) 89 | assert extract_types(parsed, parsed.body[0].body[0].annotation, unpack=True) == [ 90 | "str", 91 | "int", 92 | "bytes", 93 | ] 94 | 95 | 96 | def test_type_complex(): 97 | stmt = "x: Tuple[str, int, bytes, List[str]]" 98 | parsed = cst.parse_module(stmt) 99 | assert extract_types(parsed, parsed.body[0].body[0].annotation, unpack=True) == [ 100 | "str", 101 | "int", 102 | "bytes", 103 | "List[str]", 104 | ] 105 | -------------------------------------------------------------------------------- /tests/client/class_ref_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from tests.context import stateflow 3 | from tests.common.common_classes import User, stateflow 4 | from stateflow.client.class_ref import ( 5 | MethodRef, 6 | ClassRef, 7 | StateflowClient, 8 | Arguments, 9 | EventType, 10 | ) 11 | from stateflow.dataflow.address import FunctionType, FunctionAddress 12 | from unittest import mock 13 | 14 | 15 | class TestClassRef: 16 | def setup(self): 17 | flow = stateflow.init() 18 | self.item_desc = stateflow.core.registered_classes[0].class_desc 19 | self.user_desc = stateflow.core.registered_classes[1].class_desc 20 | 21 | def test_method_ref_simple_call(self): 22 | update_balance_method = self.user_desc.get_method_by_name("update_balance") 23 | 24 | class_ref_mock = mock.MagicMock(ClassRef) 25 | method_ref = MethodRef("update_balance", class_ref_mock, update_balance_method) 26 | 27 | method_ref(x=1) 28 | 29 | class_ref_mock._invoke_method.assert_called_once() 30 | name, args = class_ref_mock._invoke_method.call_args[0] 31 | 32 | assert name == "update_balance" 33 | assert args.get() == {"x": 1} 34 | 35 | def test_method_ref_call_flow(self): 36 | buy_item_method = self.user_desc.get_method_by_name("buy_item") 37 | 38 | class_ref_mock = mock.MagicMock(ClassRef) 39 | method_ref = MethodRef("buy_item", class_ref_mock, buy_item_method) 40 | 41 | method_ref(amount=1, item=None) 42 | 43 | class_ref_mock._invoke_flow.assert_called_once() 44 | flow, args = class_ref_mock._invoke_flow.call_args[0] 45 | 46 | assert isinstance(flow, list) 47 | assert list([(x.id, x.typ) for x in flow]) == list( 48 | [(x.id, x.typ) for x in buy_item_method.flow_list] 49 | ) 50 | assert args.get() == {"amount": 1, "item": None} 51 | 52 | def test_class_ref_simple_invoke(self): 53 | client_mock = mock.MagicMock(StateflowClient) 54 | class_ref = ClassRef( 55 | FunctionAddress(FunctionType("global", "User", True), "test-user"), 56 | self.user_desc, 57 | client_mock, 58 | ) 59 | 60 | class_ref._invoke_method("update_balance", Arguments({"x": 1})) 61 | 62 | client_mock.send.assert_called_once() 63 | event = client_mock.send.call_args[0][0] 64 | 65 | assert event.event_type == EventType.Request.InvokeStateful 66 | assert event.payload["args"].get() == {"x": 1} 67 | assert event.payload["method_name"] == "update_balance" 68 | assert event.fun_address == FunctionAddress( 69 | FunctionType("global", "User", True), "test-user" 70 | ) 71 | 72 | def test_class_ref_invoke_flow(self): 73 | client_mock = mock.MagicMock(StateflowClient) 74 | class_ref = ClassRef( 75 | FunctionAddress(FunctionType("global", "User", True), "test-user"), 76 | self.user_desc, 77 | client_mock, 78 | ) 79 | 80 | class_ref._invoke_flow( 81 | self.user_desc.get_method_by_name("buy_item").flow_list, 82 | Arguments({"amount": 1, "item": class_ref}), 83 | ) 84 | 85 | client_mock.send.assert_called_once() 86 | event = client_mock.send.call_args[0][0] 87 | 88 | assert event.event_type == EventType.Request.EventFlow 89 | assert event.fun_address == FunctionAddress( 90 | FunctionType("global", "User", True), "test-user" 91 | ) 92 | 93 | # TODO This test needs to be way more extensive, checking if parameters are properly matched. 94 | 95 | def test_class_ref_test_to_str(self): 96 | class_ref = ClassRef( 97 | FunctionAddress(FunctionType("global", "User", True), "test-user"), 98 | self.user_desc, 99 | None, 100 | ) 101 | 102 | assert str(class_ref) == "Class reference for User with key test-user." 103 | 104 | def test_class_ref_get_attribute(self): 105 | client_mock = mock.MagicMock(StateflowClient) 106 | class_ref = ClassRef( 107 | FunctionAddress(FunctionType("global", "User", True), "test-user"), 108 | self.user_desc, 109 | client_mock, 110 | ) 111 | 112 | class_ref.balance 113 | 114 | client_mock.send.assert_called_once() 115 | event = client_mock.send.call_args[0][0] 116 | 117 | assert event.event_type == EventType.Request.GetState 118 | assert event.payload["attribute"] == "balance" 119 | assert event.fun_address == FunctionAddress( 120 | FunctionType("global", "User", True), "test-user" 121 | ) 122 | 123 | def test_class_ref_set_attribute(self): 124 | client_mock = mock.MagicMock(StateflowClient) 125 | class_ref = ClassRef( 126 | FunctionAddress(FunctionType("global", "User", True), "test-user"), 127 | self.user_desc, 128 | client_mock, 129 | ) 130 | 131 | class_ref.balance = 10 132 | 133 | client_mock.send.assert_called_once() 134 | event = client_mock.send.call_args[0][0] 135 | 136 | assert event.event_type == EventType.Request.UpdateState 137 | assert event.payload["attribute"] == "balance" 138 | assert event.payload["attribute_value"] == 10 139 | assert event.fun_address == FunctionAddress( 140 | FunctionType("global", "User", True), "test-user" 141 | ) 142 | 143 | def test_class_ref_get_method_ref(self): 144 | client_mock = mock.MagicMock(StateflowClient) 145 | class_ref = ClassRef( 146 | FunctionAddress(FunctionType("global", "User", True), "test-user"), 147 | self.user_desc, 148 | client_mock, 149 | ) 150 | 151 | m_ref = class_ref.update_balance 152 | 153 | assert isinstance(m_ref, MethodRef) 154 | assert m_ref.method_name == "update_balance" 155 | assert m_ref.method_desc == self.user_desc.get_method_by_name("update_balance") 156 | assert m_ref._class_ref == class_ref 157 | 158 | def test_class_ref_get_self_attr(self): 159 | client_mock = mock.MagicMock(StateflowClient) 160 | class_ref = ClassRef( 161 | FunctionAddress(FunctionType("global", "User", True), "test-user"), 162 | self.user_desc, 163 | client_mock, 164 | ) 165 | 166 | client_ref = class_ref._client 167 | 168 | assert client_ref == client_mock 169 | 170 | def test_class_ref_get_self_attr_non_existing(self): 171 | client_mock = mock.MagicMock(StateflowClient) 172 | class_ref = ClassRef( 173 | FunctionAddress(FunctionType("global", "User", True), "test-user"), 174 | self.user_desc, 175 | client_mock, 176 | ) 177 | 178 | with pytest.raises(AttributeError): 179 | class_ref._doesnt_exist 180 | -------------------------------------------------------------------------------- /tests/client/future_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from tests.context import stateflow 3 | from tests.common.common_classes import User 4 | from stateflow.client.future import StateflowFuture, StateflowFailure 5 | from stateflow.dataflow.event import EventType, Event 6 | from stateflow.dataflow.address import FunctionAddress, FunctionType 7 | from stateflow.client.class_ref import ClassRef 8 | from typing import List 9 | 10 | 11 | def test_simple_future_complete(): 12 | flow_future = StateflowFuture( 13 | "123", 123, FunctionAddress(FunctionType("", "", True), "key"), bool 14 | ) 15 | 16 | flow_future.is_completed = True 17 | flow_future.result = True 18 | 19 | assert flow_future.get() is True 20 | assert flow_future.id == "123" 21 | assert flow_future.timestamp == 123 22 | 23 | 24 | def test_simple_future_complete_failure(): 25 | flow_future = StateflowFuture( 26 | "123", 123, FunctionAddress(FunctionType("", "", True), "key"), bool 27 | ) 28 | 29 | flow_future.is_completed = True 30 | flow_future.result = StateflowFailure("error!") 31 | 32 | with pytest.raises(StateflowFailure): 33 | flow_future.get() 34 | assert flow_future.id == "123" 35 | assert flow_future.timestamp == 123 36 | 37 | 38 | def test_future_complete_unknown_event(): 39 | flow_future = StateflowFuture( 40 | "123", 123, FunctionAddress(FunctionType("", "", True), "key"), bool 41 | ) 42 | event = Event( 43 | "id", FunctionAddress(FunctionType("", "", True), "key"), "UNKNOWN", {} 44 | ) 45 | with pytest.raises(AttributeError): 46 | flow_future.complete(event) 47 | 48 | 49 | def test_future_complete_failed_invocation_event(): 50 | flow_future = StateflowFuture( 51 | "123", 123, FunctionAddress(FunctionType("", "", True), "key"), bool 52 | ) 53 | event = Event( 54 | "id", 55 | FunctionAddress(FunctionType("", "", True), "key"), 56 | EventType.Reply.FailedInvocation, 57 | {"error_message": "this is an error!"}, 58 | ) 59 | 60 | with pytest.raises(StateflowFailure) as excep: 61 | flow_future.complete(event) 62 | flow_future.get() 63 | 64 | excep.error_msg = "this is an error!" 65 | 66 | 67 | def test_future_complete_found_class_event(): 68 | stateflow.init() 69 | flow_future = StateflowFuture( 70 | "123", 71 | 123, 72 | FunctionAddress(FunctionType("", "", True), "test-user"), 73 | stateflow.core.meta_classes[1], 74 | ) 75 | event = Event( 76 | "id", 77 | FunctionAddress(FunctionType("", "", True), "test-user"), 78 | EventType.Reply.FoundClass, 79 | {}, 80 | ) 81 | 82 | flow_future.complete(event) 83 | res = flow_future.get() 84 | 85 | assert isinstance(res, ClassRef) 86 | assert res._fun_addr.key == "test-user" 87 | 88 | 89 | def test_future_complete_successful_create_class_event(): 90 | stateflow.init() 91 | flow_future = StateflowFuture( 92 | "123", 93 | 123, 94 | FunctionAddress(FunctionType("", "", True), "test-user"), 95 | stateflow.core.meta_classes[1], 96 | ) 97 | event = Event( 98 | "id", 99 | FunctionAddress(FunctionType("", "", True), "test-user"), 100 | EventType.Reply.SuccessfulCreateClass, 101 | {}, 102 | ) 103 | 104 | flow_future.complete(event) 105 | res = flow_future.get() 106 | 107 | assert isinstance(res, ClassRef) 108 | assert res._fun_addr.key == "test-user" 109 | 110 | 111 | def test_future_complete_found_invocation_event(): 112 | flow_future = StateflowFuture( 113 | "123", 114 | 123, 115 | FunctionAddress(FunctionType("", "", True), "test-user"), 116 | int, 117 | ) 118 | event = Event( 119 | "id", 120 | FunctionAddress(FunctionType("", "", True), "test-user"), 121 | EventType.Reply.SuccessfulInvocation, 122 | {"return_results": 1}, 123 | ) 124 | 125 | flow_future.complete(event) 126 | res = flow_future.get() 127 | 128 | assert isinstance(res, int) 129 | assert res == 1 130 | 131 | 132 | def test_future_complete_found_invocation_event_list(): 133 | flow_future = StateflowFuture( 134 | "123", 135 | 123, 136 | FunctionAddress(FunctionType("", "", True), "test-user"), 137 | int, 138 | ) 139 | event = Event( 140 | "id", 141 | FunctionAddress(FunctionType("", "", True), "test-user"), 142 | EventType.Reply.SuccessfulInvocation, 143 | {"return_results": [1]}, 144 | ) 145 | 146 | flow_future.complete(event) 147 | res = flow_future.get() 148 | 149 | assert isinstance(res, int) 150 | assert res == 1 151 | 152 | 153 | def test_future_complete_found_invocation_event_list_multiple(): 154 | flow_future = StateflowFuture( 155 | "123", 156 | 123, 157 | FunctionAddress(FunctionType("", "", True), "test-user"), 158 | List[int], 159 | ) 160 | event = Event( 161 | "id", 162 | FunctionAddress(FunctionType("", "", True), "test-user"), 163 | EventType.Reply.SuccessfulInvocation, 164 | {"return_results": [1, 3, 4]}, 165 | ) 166 | 167 | flow_future.complete(event) 168 | res = flow_future.get() 169 | 170 | assert res == (1, 3, 4) 171 | 172 | 173 | def test_future_complete_state_request(): 174 | flow_future = StateflowFuture( 175 | "123", 176 | 123, 177 | FunctionAddress(FunctionType("", "", True), "test-user"), 178 | List[int], 179 | ) 180 | event = Event( 181 | "id", 182 | FunctionAddress(FunctionType("", "", True), "test-user"), 183 | EventType.Reply.SuccessfulStateRequest, 184 | {"state": [1, 3, 4]}, 185 | ) 186 | 187 | flow_future.complete(event) 188 | res = flow_future.get() 189 | 190 | assert res == (1, 3, 4) 191 | 192 | 193 | def test_failure_object(): 194 | failure = StateflowFailure("an error") 195 | 196 | assert str(failure) == "StateflowFailure: an error" 197 | assert repr(failure) == "StateflowFailure: an error" 198 | -------------------------------------------------------------------------------- /tests/client/kafka_client_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from tests.common.common_classes import User, stateflow 3 | from tests.context import stateflow 4 | from stateflow.client.kafka_client import ( 5 | StateflowKafkaClient, 6 | Producer, 7 | Consumer, 8 | Event, 9 | EventType, 10 | FunctionAddress, 11 | FunctionType, 12 | ) 13 | from stateflow.client.class_ref import ClassRef 14 | from unittest import mock 15 | 16 | 17 | class TestStateflowKafkaClient: 18 | def setup(self): 19 | flow = stateflow.init() 20 | self.client: StateflowKafkaClient = StateflowKafkaClient.__new__( 21 | StateflowKafkaClient 22 | ) 23 | 24 | self.producer_mock = mock.MagicMock(Producer) 25 | self.consumer_mock = mock.MagicMock(Consumer) 26 | 27 | self.client._set_producer = lambda _: self.producer_mock 28 | self.client._set_consumer = lambda _: self.consumer_mock 29 | 30 | self.consumer_mock.poll = lambda _: None 31 | 32 | self.client.__init__(flow, "") 33 | 34 | def test_simple_send(self): 35 | self.client.send( 36 | Event( 37 | "123", 38 | FunctionAddress(FunctionType("global", "User", True), "test-user"), 39 | EventType.Request.InvokeStateful, 40 | {}, 41 | ), 42 | ClassRef, 43 | ) 44 | 45 | self.client.running = False 46 | 47 | self.producer_mock.produce.assert_called_once() 48 | -------------------------------------------------------------------------------- /tests/common/common_classes.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import List 3 | from tests.context import stateflow 4 | 5 | 6 | @stateflow.stateflow 7 | class Item: 8 | def __init__(self, item_name: str, price: int): 9 | self.item_name: str = item_name 10 | self.stock: int = 0 11 | self.price: int = price 12 | 13 | def update_stock(self, amount: int) -> bool: 14 | if (self.stock + amount) < 0: # We can't get a stock < 0. 15 | return False 16 | 17 | self.stock += amount 18 | return True 19 | 20 | def __key__(self): 21 | return self.item_name 22 | 23 | 24 | @stateflow.stateflow 25 | class User: 26 | def __init__(self, username: str): 27 | self.username: str = username 28 | self.balance: int = 0 29 | self.items: List[Item] = [] 30 | 31 | def update_balance(self, x: int): 32 | self.balance += x 33 | 34 | def buy_item(self, amount: int, item: Item) -> bool: 35 | total_price = amount * item.price 36 | 37 | if self.balance < total_price: 38 | return False 39 | 40 | if not item.update_stock(-amount): 41 | return False # For some reason, stock couldn't be decreased. 42 | 43 | self.balance -= total_price 44 | return True 45 | 46 | def simple_for_loops(self, users: List["User"]): 47 | i = 0 48 | for user in users: 49 | if i > 0: 50 | user.update_balance(9) 51 | else: 52 | user.update_balance(4) 53 | i += 1 54 | 55 | return i 56 | 57 | def __key__(self): 58 | return self.username 59 | 60 | 61 | @stateflow.stateflow 62 | class ExperimentalB: 63 | def __init__(self, name: str): 64 | self.name = name 65 | self.balance = 0 66 | 67 | def add_balance(self, balance: int): 68 | self.balance += balance 69 | 70 | def set_balance(self, balance: int): 71 | self.balance = balance 72 | 73 | def balance_equal_to(self, equal_balance: int) -> bool: 74 | return self.balance == equal_balance 75 | 76 | def __key__(self): 77 | return self.name 78 | 79 | 80 | @stateflow.stateflow 81 | class ExperimentalA: 82 | def __init__(self, name: str): 83 | self.name = name 84 | self.balance = 0 85 | 86 | def complex_method(self, balance: int, other: ExperimentalB) -> bool: 87 | self.balance += balance * 2 88 | other.add_balance(balance * 2) 89 | self.balance -= balance 90 | other.add_balance(-balance) 91 | self.balance -= balance 92 | is_equal = other.balance_equal_to(balance) 93 | return is_equal 94 | 95 | def complex_if(self, balance: int, b_ins: ExperimentalB): 96 | self.balance = balance 97 | 98 | if self.balance > 10: 99 | b_ins.add_balance(balance) 100 | self.balance = 0 101 | elif b_ins.balance_equal_to(5): 102 | self.balance = 1 103 | else: 104 | self.balance = 2 105 | 106 | return self.balance 107 | 108 | def more_complex_if(self, balance: int, b_ins: ExperimentalB) -> int: 109 | self.balance = balance 110 | if balance >= 0: 111 | self.balance = balance 112 | if b_ins.balance_equal_to(balance * 2): 113 | self.balance = 1 114 | else: 115 | return -1 116 | 117 | return self.balance 118 | 119 | def test_no_return(self, balance: int, b_ins: ExperimentalB): 120 | if balance >= self.balance: 121 | self.balance = 0 122 | b_ins.add_balance(balance) 123 | else: 124 | self.balance = 1 125 | 126 | def work_with_list(self, x: int, others: List[ExperimentalB]): 127 | other_one: ExperimentalB = others[0] 128 | other_one.add_balance(10) 129 | 130 | if x > 0: 131 | others[-1].add_balance(10) 132 | else: 133 | other_one.add_balance(-10) 134 | 135 | def for_loops(self, x: int, others: List[ExperimentalB]): 136 | for y in others: 137 | y.add_balance(5) 138 | 139 | if x > 0: 140 | z = x 141 | else: 142 | z = -1 143 | 144 | return z 145 | 146 | def state_requests(self, items: List[ExperimentalB]): 147 | total: int = 0 148 | first_item: ExperimentalB = items[0] 149 | print(f"Total is now {total}.") 150 | total += first_item.balance # Total = 0 151 | first_item.set_balance(10) 152 | total += first_item.balance # total = 10 153 | first_item.set_balance(0) 154 | for x in items: 155 | total += x.balance # total = 10 156 | x.set_balance(5) 157 | total += x.balance # total = 10 + 5 + 5 = 20 158 | 159 | print(f"Total is now {total}.") 160 | total += first_item.balance # total = 25 161 | if total > 0: 162 | first_item.set_balance(1) 163 | 164 | print(f"Total is now {total}.") 165 | 166 | total += first_item.balance # total = 26 167 | return total 168 | 169 | def __key__(self): 170 | return self.name 171 | 172 | 173 | @stateflow.stateflow 174 | class OtherNestClass: 175 | def __init__(self, x: int): 176 | self.id = str(uuid.uuid4()) 177 | self.x = x 178 | 179 | print(f"Im {type(self)} with id {self.id}") 180 | 181 | def is_really_true(self): 182 | return True 183 | 184 | def is_true(self, other: "OtherNestClass"): 185 | is_really_true: bool = other.is_really_true() 186 | return is_really_true 187 | 188 | def nest_calll(self, other: "OtherNestClass") -> bool: 189 | z = 0 190 | is_true = other.is_true(other) 191 | return is_true 192 | 193 | def __key__(self): 194 | return self.id 195 | 196 | 197 | @stateflow.stateflow 198 | class NestClass: 199 | def __init__(self, x: int): 200 | self.id = str(uuid.uuid4()) 201 | self.x = x 202 | 203 | print(f"Im {type(self)} with id {self.id}") 204 | 205 | def nest_call(self, other: OtherNestClass): 206 | y = other.x 207 | z = 3 208 | 209 | if other.nest_calll(other): 210 | p = 3 211 | 212 | other.nest_calll(other) 213 | 214 | return y, z, p 215 | 216 | def __key__(self): 217 | return self.id 218 | -------------------------------------------------------------------------------- /tests/context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 5 | 6 | import stateflow 7 | -------------------------------------------------------------------------------- /tests/dataflow/arguments_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from tests.context import stateflow 3 | from stateflow.dataflow.args import Arguments 4 | 5 | 6 | def test_match_simple_args(): 7 | desc = {"x": int} 8 | args = Arguments.from_args_and_kwargs(desc, 1) 9 | 10 | assert args.get() == {"x": 1} 11 | 12 | 13 | def test_match_more_args(): 14 | desc = {"x": int, "y": str, "z": list} 15 | args = Arguments.from_args_and_kwargs(desc, 1, "hi", ["no"]) 16 | 17 | assert args.get() == {"x": 1, "y": "hi", "z": ["no"]} 18 | 19 | 20 | def test_match_simple_kwargs(): 21 | desc = {"x": int} 22 | args = Arguments.from_args_and_kwargs(desc, x=1) 23 | 24 | assert args.get() == {"x": 1} 25 | 26 | 27 | def test_match_more_kwargs(): 28 | desc = {"x": int, "y": str, "z": list} 29 | args = Arguments.from_args_and_kwargs(desc, x=1, y="hi", z=["no"]) 30 | 31 | assert args.get() == {"x": 1, "y": "hi", "z": ["no"]} 32 | 33 | 34 | def test_match_mix_args_kwargs(): 35 | desc = {"x": int, "y": str, "z": list} 36 | args = Arguments.from_args_and_kwargs(desc, 1, y="hi", z=["no"]) 37 | 38 | assert args.get() == {"x": 1, "y": "hi", "z": ["no"]} 39 | 40 | 41 | def test_match_mix_args_kwargs_more(): 42 | desc = {"x": int, "y": str, "z": list} 43 | args = Arguments.from_args_and_kwargs(desc, 1, z=["no"], y="hi") 44 | 45 | assert args.get() == {"x": 1, "y": "hi", "z": ["no"]} 46 | 47 | 48 | def test_match_wrong_desc(): 49 | desc = {"x": int, "y": str, "z": list} 50 | 51 | with pytest.raises(AttributeError): 52 | Arguments.from_args_and_kwargs(desc, x=1, y="hi", z=["no"], not_exist=1) 53 | -------------------------------------------------------------------------------- /tests/kafka/KafkaImage.py: -------------------------------------------------------------------------------- 1 | import docker 2 | from time import sleep 3 | import os 4 | from confluent_kafka.admin import AdminClient 5 | 6 | 7 | class KafkaImage: 8 | def __init__(self): 9 | self.image = "spotify/kafka" 10 | self.name = "kafka" 11 | self.host = "localhost" 12 | self.port = 9092 13 | 14 | def get_image_options(self): 15 | image_options = { 16 | "version": "latest", 17 | "environment": { 18 | "ADVERTISED_PORT": "9092", 19 | "ADVERTISED_HOST": "0.0.0.0", 20 | }, 21 | "ports": {"9092": "9092", "2181": "2181"}, 22 | } 23 | 24 | return image_options 25 | 26 | def check(self): 27 | client = AdminClient({"bootstrap.servers": "localhost:9092"}) 28 | try: 29 | client.list_topics(timeout=10) 30 | del client 31 | return True 32 | except Exception as excep: 33 | print(excep) 34 | return False 35 | 36 | # This is based on the 'pytest_docker_fixtures' library 37 | def run(self): 38 | docker_client = docker.from_env() 39 | image_options = self.get_image_options() 40 | 41 | max_wait_s = image_options.pop("max_wait_s", None) or 30 42 | 43 | # Create a new one 44 | container = docker_client.containers.run( 45 | image=self.image, **image_options, detach=True 46 | ) 47 | ident = container.id 48 | count = 1 49 | 50 | self.container_obj = docker_client.containers.get(ident) 51 | 52 | opened = False 53 | 54 | print(f"starting {self.name}") 55 | while count < max_wait_s and not opened: 56 | if count > 0: 57 | sleep(1) 58 | count += 1 59 | try: 60 | self.container_obj = docker_client.containers.get(ident) 61 | except docker.errors.NotFound: 62 | print(f"Container not found for {self.name}") 63 | continue 64 | if self.container_obj.status == "exited": 65 | logs = self.container_obj.logs() 66 | self.stop() 67 | raise Exception(f"Container failed to start {logs}") 68 | 69 | if self.container_obj.attrs["NetworkSettings"]["IPAddress"] != "": 70 | if os.environ.get("TESTING", "") == "jenkins": 71 | network = self.container_obj.attrs["NetworkSettings"] 72 | self.host = network["IPAddress"] 73 | else: 74 | self.host = "localhost" 75 | 76 | if self.host != "": 77 | opened = self.check() 78 | if not opened: 79 | logs = self.container_obj.logs().decode("utf-8") 80 | self.stop() 81 | raise Exception( 82 | f"Could not start {self.name}: {logs}\n" 83 | f"Image: {self.image}\n" 84 | f"Options:\n{(image_options)}" 85 | ) 86 | print(f"{self.name} started") 87 | return self.host, self.port 88 | 89 | def stop(self): 90 | if self.container_obj is not None: 91 | try: 92 | self.container_obj.kill() 93 | print(f"{self.name} stopped") 94 | except docker.errors.APIError: 95 | pass 96 | try: 97 | self.container_obj.remove(v=True, force=True) 98 | except docker.errors.APIError: 99 | pass 100 | -------------------------------------------------------------------------------- /tests/local_runtime_test.py: -------------------------------------------------------------------------------- 1 | from .context import stateflow 2 | 3 | 4 | from stateflow import stateflow_test 5 | from tests.common.common_classes import User, Item 6 | 7 | 8 | def test_user(): 9 | user = User("kyriakos") 10 | user.update_balance(10) 11 | 12 | assert user.balance == 10 13 | assert user.username == "kyriakos" 14 | -------------------------------------------------------------------------------- /tests/runtime/aws_runtime_test.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from tests.context import stateflow 3 | from tests.common.common_classes import stateflow 4 | from stateflow.runtime.aws.abstract_lambda import AWSLambdaRuntime 5 | from stateflow.runtime.aws.kinesis_lambda import AWSKinesisLambdaRuntime 6 | from stateflow.runtime.aws.gateway_lambda import AWSGatewayLambdaRuntime 7 | from stateflow.dataflow.event import Event, EventType 8 | from stateflow.dataflow.event_flow import InternalClassRef 9 | from stateflow.dataflow.state import State 10 | from stateflow.serialization.pickle_serializer import PickleSerializer 11 | from stateflow.serialization.json_serde import JsonSerializer 12 | from stateflow.dataflow.args import Arguments 13 | from stateflow.dataflow.address import FunctionType, FunctionAddress 14 | from python_dynamodb_lock.python_dynamodb_lock import * 15 | import uuid 16 | from unittest import mock 17 | import json 18 | 19 | 20 | class TestAWSRuntime: 21 | def setup_handle(self): 22 | return AWSKinesisLambdaRuntime.get_handler(stateflow.init()) 23 | 24 | def setup_gateway_handle(self): 25 | return AWSGatewayLambdaRuntime.get_handler(stateflow.init()) 26 | 27 | def test_simple_event(self): 28 | kinesis_mock = mock.MagicMock() 29 | lock_mock = mock.MagicMock(DynamoDBLockClient) 30 | 31 | AWSKinesisLambdaRuntime._setup_dynamodb = lambda x, y: None 32 | AWSKinesisLambdaRuntime._setup_kinesis = lambda x, y: kinesis_mock 33 | AWSKinesisLambdaRuntime._setup_lock_client = lambda x, y: lock_mock 34 | 35 | event_id: str = str(uuid.uuid4()) 36 | event: Event = Event( 37 | event_id, 38 | FunctionAddress(FunctionType("global", "User", True), None), 39 | EventType.Request.InitClass, 40 | {"args": Arguments({"username": "wouter"})}, 41 | ) 42 | 43 | serialized_event = PickleSerializer().serialize_event(event) 44 | json_event = { 45 | "Records": [{"kinesis": {"data": base64.b64encode(serialized_event)}}] 46 | } 47 | 48 | inst, handler = self.setup_handle() 49 | 50 | inst.get_state = lambda x: None 51 | inst.save_state = lambda x, y: None 52 | 53 | handler(json_event, None) 54 | 55 | lock_mock.acquire_lock.assert_called_once() 56 | kinesis_mock.put_record.assert_called_once() 57 | 58 | lock_mock.reset_mock() 59 | kinesis_mock.reset_mock() 60 | 61 | event: Event = Event( 62 | event_id, 63 | FunctionAddress(FunctionType("global", "User", True), "wouter"), 64 | EventType.Request.InvokeStateful, 65 | { 66 | "args": Arguments( 67 | { 68 | "items": [ 69 | InternalClassRef( 70 | FunctionAddress( 71 | FunctionType("globals", "Item", True), "coke" 72 | ) 73 | ), 74 | InternalClassRef( 75 | FunctionAddress( 76 | FunctionType("globals", "Item", True), "pepsi" 77 | ) 78 | ), 79 | ] 80 | } 81 | ), 82 | "method_name": "state_requests", 83 | }, 84 | ) 85 | 86 | serialized_event = PickleSerializer().serialize_event(event) 87 | json_event = { 88 | "Records": [{"kinesis": {"data": base64.b64encode(serialized_event)}}] 89 | } 90 | 91 | inst.get_state = lambda x: PickleSerializer().serialize_dict( 92 | State({"username": "wouter", "x": 5}).get() 93 | ) 94 | 95 | handler(json_event, None) 96 | 97 | lock_mock.acquire_lock.assert_called_once() 98 | kinesis_mock.put_record.assert_called_once() 99 | 100 | def test_simple_event_gateway(self): 101 | lock_mock = mock.MagicMock(DynamoDBLockClient) 102 | 103 | AWSGatewayLambdaRuntime._setup_dynamodb = lambda x, y: None 104 | AWSGatewayLambdaRuntime._setup_lock_client = lambda x, y: lock_mock 105 | 106 | event_id: str = str(uuid.uuid4()) 107 | event: Event = Event( 108 | event_id, 109 | FunctionAddress(FunctionType("global", "User", True), None), 110 | EventType.Request.InitClass, 111 | {"args": Arguments({"username": "wouter"})}, 112 | ) 113 | 114 | serialized_event = PickleSerializer().serialize_event(event) 115 | json_event = { 116 | "body": json.dumps({"event": base64.b64encode(serialized_event).decode()}) 117 | } 118 | 119 | inst, handler = self.setup_gateway_handle() 120 | 121 | inst.get_state = lambda x: None 122 | inst.save_state = lambda x, y: None 123 | 124 | handler(json_event, None) 125 | 126 | lock_mock.acquire_lock.assert_called_once() 127 | 128 | lock_mock.reset_mock() 129 | 130 | event: Event = Event( 131 | event_id, 132 | FunctionAddress(FunctionType("global", "User", True), "wouter"), 133 | EventType.Request.InvokeStateful, 134 | { 135 | "args": Arguments( 136 | { 137 | "items": [ 138 | InternalClassRef( 139 | FunctionAddress( 140 | FunctionType("globals", "Item", True), "coke" 141 | ) 142 | ), 143 | InternalClassRef( 144 | FunctionAddress( 145 | FunctionType("globals", "Item", True), "pepsi" 146 | ) 147 | ), 148 | ] 149 | } 150 | ), 151 | "method_name": "state_requests", 152 | }, 153 | ) 154 | 155 | serialized_event = PickleSerializer().serialize_event(event) 156 | json_event = { 157 | "body": json.dumps({"event": base64.b64encode(serialized_event).decode()}) 158 | } 159 | 160 | inst.get_state = lambda x: PickleSerializer().serialize_dict( 161 | State({"username": "wouter", "x": 5}).get() 162 | ) 163 | 164 | handler(json_event, None) 165 | 166 | lock_mock.acquire_lock.assert_called_once() 167 | -------------------------------------------------------------------------------- /tests/serialization/proto_serializer_test.py: -------------------------------------------------------------------------------- 1 | from stateflow.serialization.proto.proto_serde import ( 2 | ProtoSerializer, 3 | Event, 4 | FunctionAddress, 5 | FunctionType, 6 | EventType, 7 | ) 8 | 9 | 10 | def test_simple_event(): 11 | event = Event( 12 | "123", 13 | FunctionAddress(FunctionType("global", "User", True), "wouter"), 14 | EventType.Request.InitClass, 15 | {"test": "test"}, 16 | ) 17 | serde = ProtoSerializer() 18 | 19 | event_ser = serde.serialize_event(event) 20 | event_deser = serde.deserialize_event(event_ser) 21 | 22 | print(event_ser) 23 | assert event_deser.event_id == event.event_id 24 | assert event_deser.fun_address == event.fun_address 25 | assert event_deser.event_type == event.event_type 26 | assert event_deser.payload == event.payload 27 | -------------------------------------------------------------------------------- /tests/wrapper/meta_wrapper_test.py: -------------------------------------------------------------------------------- 1 | from tests.context import stateflow 2 | from stateflow.client.stateflow_client import StateflowClient 3 | from stateflow.wrappers.meta_wrapper import MetaWrapper 4 | import inspect 5 | import libcst as cst 6 | from stateflow.analysis.extract_class_descriptor import ExtractClassDescriptor 7 | from stateflow.descriptors.class_descriptor import ClassDescriptor 8 | from stateflow.dataflow.event import EventType 9 | from unittest import mock 10 | from stateflow.client.class_ref import ClassRef 11 | 12 | 13 | class SimpleClass: 14 | def __init__(self, name: str): 15 | self.name = name 16 | self.x = 10 17 | 18 | def update(self, x: int) -> int: 19 | self.x -= x 20 | return self.x 21 | 22 | def __key__(self): 23 | return self.name 24 | 25 | 26 | class TestMetaWrapper: 27 | def get_meta_wrapper(self) -> MetaWrapper: 28 | # Parse 29 | code = inspect.getsource(SimpleClass) 30 | parsed_class = cst.parse_module(code) 31 | 32 | wrapper = cst.metadata.MetadataWrapper(parsed_class) 33 | expression_provider = wrapper.resolve(cst.metadata.ExpressionContextProvider) 34 | 35 | # Extract 36 | extraction: ExtractClassDescriptor = ExtractClassDescriptor( 37 | parsed_class, "SimpleClass", expression_provider 38 | ) 39 | parsed_class.visit(extraction) 40 | 41 | # Create ClassDescriptor 42 | class_desc: ClassDescriptor = ExtractClassDescriptor.create_class_descriptor( 43 | extraction 44 | ) 45 | 46 | # Create a meta class.. 47 | meta_class = MetaWrapper( 48 | str(SimpleClass.__name__), 49 | tuple(SimpleClass.__bases__), 50 | dict(SimpleClass.__dict__), 51 | descriptor=class_desc, 52 | ) 53 | 54 | return meta_class 55 | 56 | def test_initialize(self): 57 | mock_client = mock.MagicMock(StateflowClient) 58 | 59 | wrapper = self.get_meta_wrapper() 60 | wrapper.set_client(mock_client) 61 | 62 | # We expect to create a new instance here, by sending an event. 63 | created_class = wrapper("wouter") 64 | 65 | # Verify the send is called. 66 | mock_client.send.assert_called_once() 67 | 68 | event, wrapper_ret = mock_client.send.call_args[0] 69 | 70 | assert wrapper == wrapper_ret 71 | assert event.event_type == EventType.Request.InitClass 72 | assert not event.fun_address.key 73 | assert event.fun_address.function_type.name == "SimpleClass" 74 | assert "args" in event.payload 75 | assert event.payload["args"]["name"] == "wouter" 76 | 77 | def test_create_class_ref(self): 78 | mock_client = mock.MagicMock(StateflowClient) 79 | 80 | wrapper = self.get_meta_wrapper() 81 | wrapper.set_client(mock_client) 82 | 83 | # We expect to create a new instance here, by sending an event. 84 | created_class = wrapper("wouter", __key="wouter") 85 | 86 | assert isinstance(created_class, ClassRef) 87 | assert created_class._client == mock_client 88 | assert created_class._fun_addr.key == "wouter" 89 | -------------------------------------------------------------------------------- /zipfian_generator.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | ZIPF_CONSTANT = 0.99 4 | 5 | 6 | class ZipfGenerator: 7 | """ 8 | Adapted from YCSB's ZipfGenerator here: 9 | https://github.com/brianfrankcooper/YCSB/blob/master/core/src/main/java/site/ycsb/generator/ZipfianGenerator.java 10 | That's an implementation from "Quickly Generating Billion-Record Synthetic Databases", Jim Gray et al, SIGMOD 1994. 11 | """ 12 | 13 | def __init__(self, items: int = None, mn: int = None, mx: int = None, zipf_const: float = None): 14 | 15 | if items is not None: 16 | self.__max = items - 1 17 | self.__min = 0 18 | self.__items = items 19 | else: 20 | self.__max = mx 21 | self.__min = mn 22 | self.__items = self.__max - self.__min + 1 23 | 24 | if zipf_const is not None: 25 | self.__zipf_constant: float = zipf_const 26 | else: 27 | self.__zipf_constant: float = ZIPF_CONSTANT 28 | 29 | self.__zeta = self.zeta_static(self.__max - self.__min + 1, self.__zipf_constant) 30 | self.__base = self.__min 31 | self.__theta: float = self.__zipf_constant 32 | zeta2theta = self.zeta(2, self.__theta) 33 | self.__alpha: float = 1.0 / (1.0 - self.__theta) 34 | self.__count_for_zeta: int = items 35 | self.__eta: float = (1 - pow(2.0 / items, 1 - self.__theta)) / (1 - zeta2theta / self.__zeta) 36 | self.__allow_item_count_decrease: bool = False 37 | 38 | def __next__(self): 39 | u: float = random.random() 40 | uz: float = u * self.__zeta 41 | if uz < 1.0: 42 | return self.__base 43 | if uz < 1.0 + pow(0.5, self.__theta): 44 | return self.__base + 1 45 | return self.__base + int(self.__items * pow(self.__eta * u - self.__eta + 1, self.__alpha)) 46 | 47 | def __iter__(self): 48 | return self 49 | 50 | def zeta(self, *params): 51 | if len(params) == 2: 52 | n, theta_val = params 53 | self.__count_for_zeta = n 54 | return self.zeta_static(n, theta_val) 55 | elif len(params) == 4: 56 | st, n, theta_val, initial_sum = params 57 | self.__count_for_zeta = n 58 | return self.zeta_static(n, theta_val, theta_val, initial_sum) 59 | 60 | def zeta_static(self, *params): 61 | if len(params) == 2: 62 | n, theta = params 63 | st = 0 64 | initial_sum = 0 65 | return self.zeta_sum(st, n, theta, initial_sum) 66 | elif len(params) == 4: 67 | st, n, theta, initial_sum = params 68 | return self.zeta_sum(st, n, theta, initial_sum) 69 | 70 | @staticmethod 71 | def zeta_sum(st, n, theta, initial_sum): 72 | s = initial_sum 73 | for i in range(st, n): 74 | s += 1 / (pow(i + 1, theta)) 75 | return s 76 | 77 | 78 | if __name__ == "__main__": 79 | counts = {} 80 | g = ZipfGenerator(items=10) 81 | for _ in range(200): 82 | num = next(g) 83 | # print(num) 84 | if num in counts: 85 | counts[num] += 1 86 | else: 87 | counts[num] = 1 88 | print(dict(sorted(counts.items(), key=lambda item: -item[1]))) 89 | --------------------------------------------------------------------------------